// Metal helper functions #pragma once #include #include namespace c10 { namespace metal { namespace detail { template struct vectypes {}; template <> struct vectypes { using type4 = float4; using type3 = float3; using type2 = float2; }; template <> struct vectypes { using type4 = half4; using type3 = half3; using type2 = half2; }; template <> struct vectypes { using type4 = bfloat4; using type3 = bfloat3; using type2 = bfloat2; }; template <> struct vectypes { using type4 = short4; using type3 = short3; using type2 = short2; }; template <> struct vectypes { using type4 = int4; using type3 = int3; using type2 = int2; }; template <> struct vectypes { using type4 = short4; using type3 = short3; using type2 = short2; }; template struct OpMathType { using type = T; }; template <> struct OpMathType { using type = float; }; template <> struct OpMathType { using type = int; }; template <> struct OpMathType { using type = int; }; template <> struct OpMathType { using type = int; }; template <> struct OpMathType { using type = float; }; // Type promotion structure for higher precision accumulation template struct AccumulationType { using type = T; }; // Specialization for half - promote to float for accumulation template <> struct AccumulationType { using type = float; }; // Specialization for bfloat - promote to float for accumulation template <> struct AccumulationType { using type = float; }; } // namespace detail template ::metal::enable_if_t<::metal::is_floating_point_v, T> max(T a, T b) { return ::metal::isunordered(a, b) ? NAN : ::metal::max(a, b); } template ::metal::enable_if_t<::metal::is_integral_v&& ::metal::is_integral_v, T> max(T a, U b) { return ::metal::max(a, static_cast(b)); } template ::metal::enable_if_t<::metal::is_floating_point_v, T> min(T a, T b) { return ::metal::isunordered(a, b) ? NAN : ::metal::min(a, b); } template ::metal::enable_if_t<::metal::is_integral_v&& ::metal::is_integral_v, T> min(T a, U b) { return ::metal::min(a, static_cast(b)); } template <> inline bfloat min(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::min(float(a), float(b))); } template <> inline bfloat max(bfloat a, bfloat b) { return bfloat( ::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b))); } template using vec2type_t = typename detail::vectypes::type2; template using vec4type_t = typename detail::vectypes::type4; template using opmath_t = typename detail::OpMathType::type; template using accum_t = typename detail::AccumulationType::type; // TODO: Move it to type_traits header may be template using result_of = decltype(::metal::declval()(::metal::declval()...)); template constexpr constant bool is_complex_v = ::metal::is_same_v || ::metal::is_same_v; template constexpr constant bool is_scalar_floating_point_v = ::metal::is_floating_point_v && ::metal::is_scalar_v; template constexpr constant bool is_scalar_integral_v = ::metal::is_integral_v && ::metal::is_scalar_v; template using common_dtype = decltype(U(0) + V(0)); // floor_divide template < typename T, typename U, ::metal::enable_if_t< is_scalar_integral_v && is_scalar_integral_v, bool> = true> inline common_dtype floor_divide(T x, U y) { const auto quot = x / y; return (x < 0) == (y < 0) ? quot : (x % y != 0) ? quot - 1 : quot; } template < typename T, typename U, ::metal::enable_if_t< is_scalar_floating_point_v && is_scalar_floating_point_v, bool> = true> inline common_dtype floor_divide(T x, U y) { return ::metal::floor(x / y); } // fmod template < typename T, typename U, ::metal::enable_if_t< is_scalar_integral_v && is_scalar_integral_v, bool> = true> inline common_dtype fmod(T x, U y) { return x % y; } template < typename T, typename U, ::metal::enable_if_t< is_scalar_floating_point_v && is_scalar_floating_point_v, bool> = true> inline common_dtype fmod(T x, U y) { return ::metal::fmod(x, y); } // cast_to primitives // - No-op if types as the same template < typename T, typename U, ::metal::enable_if_t<::metal::is_same_v, bool> = true> inline T cast_to(const U from) { return from; } // - Simple cast between scalar and complex dtypes template < typename T, typename U, ::metal::enable_if_t< !::metal::is_same_v && (is_complex_v == is_complex_v), bool> = true> inline T cast_to(const U from) { return static_cast(from); } // - Scalar to complex template < typename T, typename U, ::metal::enable_if_t && !is_complex_v, bool> = true> inline T cast_to(const U from) { return T(float(from), 0.0); } // - Complex to scalar (should not really be used, but exists for compliteness) template < typename T, typename U, ::metal::enable_if_t && is_complex_v, bool> = true> inline T cast_to(const U from) { return static_cast(from.x); } // Generalizable math operators (used for both scalar and complex) template < typename T, typename U, ::metal::enable_if_t, bool> = true> inline common_dtype mul(const T x, const U y) { return x * y; } template < typename T, typename U, ::metal::enable_if_t && is_complex_v, bool> = true> inline common_dtype mul(const T x, const U y) { return T(x.x * y.x - x.y * y.y, x.x * y.y + x.y * y.x); } template < typename T, typename U, ::metal::enable_if_t, bool> = true> inline common_dtype div(const T x, const U y) { return x / y; } template < typename T, typename U, ::metal::enable_if_t && is_complex_v, bool> = true> inline common_dtype div(const T x, const U y) { return T(::metal::dot(x, y), x.y * y.x - x.x * y.y) / ::metal::dot(y, y); } // Remainder operator template < typename T, typename U, ::metal::enable_if_t< is_scalar_floating_point_v || is_scalar_floating_point_v, bool> = true> inline float remainder(const T x, const U y) { const auto x_f = static_cast(x); const auto y_f = static_cast(y); return x_f - y_f * floor_divide(x_f, y_f); } template < typename T, typename U, ::metal::enable_if_t< is_scalar_integral_v && is_scalar_integral_v, bool> = true> inline common_dtype remainder(const T x, const U y) { auto rc = x % y; return rc == 0 || (x ^ y) > 0 ? rc : rc + y; } // Based on algorithm described in // https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202 inline float log1p(float x) { const auto xp1 = 1.0f + x; // First two elements of Taylor series for log(1+x) in Horner's form are: // log(1+x) = x * (1 - x * (.5 ...)), but if 1 + x == x, then it's just x if (xp1 == 1.0f) { return x; } auto rc = ::metal::precise::log(xp1); if (x > -.5 && x < .5) { // Order of operations is important here for higher precision rc *= x / (xp1 - 1.0f); } return rc; } template struct pair { T1 first; T2 second; }; #define INSTANTIATE_FOR_ALL_TYPES(MACRO) \ MACRO(float); \ MACRO(half); \ MACRO(bfloat); \ MACRO(float2); \ MACRO(long); \ MACRO(char); \ MACRO(uchar); \ MACRO(short); \ MACRO(int); #define INSTANTIATE_FOR_FLOAT_TYPES(MACRO) \ MACRO(float); \ MACRO(half); \ MACRO(bfloat); } // namespace metal } // namespace c10