mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[BE][MPS] Build metal kernels of MacOS-14+ (#159733)
Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary. Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic. Part of https://github.com/pytorch/pytorch/issues/159275 Pull Request resolved: https://github.com/pytorch/pytorch/pull/159733 Approved by: https://github.com/dcci ghstack dependencies: #159731, #159732
This commit is contained in:
committed by
PyTorch MergeBot
parent
5116c49b52
commit
e2a5c42e7e
@ -704,21 +704,17 @@ if(USE_MPS)
|
||||
if(CAN_COMPILE_METAL)
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
cmake_path(GET SHADER STEM TGT_STEM)
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
|
||||
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
|
||||
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
|
||||
list(APPEND AIR_BASIC ${TGT_BASIC})
|
||||
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
|
||||
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
|
||||
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
|
||||
endforeach()
|
||||
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
|
||||
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
|
||||
add_custom_command(
|
||||
COMMAND echo "// $$(date)" > metallib_dummy.cpp
|
||||
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
|
||||
DEPENDS kernels_basic.metallib
|
||||
OUTPUT metallib_dummy.cpp
|
||||
COMMENT "Updating metallibs timestamp")
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
|
||||
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
|
||||
else()
|
||||
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
|
||||
foreach(SHADER ${native_mps_metal})
|
||||
|
@ -953,8 +953,7 @@ class BundledShaderLibary : public MetalShaderLibrary {
|
||||
if (C10_UNLIKELY(!library)) {
|
||||
auto device = MPSDevice::getInstance()->device();
|
||||
NSError* error = nil;
|
||||
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
|
||||
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
|
||||
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
|
||||
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
|
||||
}
|
||||
return library;
|
||||
|
@ -33,21 +33,15 @@ struct shrink_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardsigmoid_functor {
|
||||
template <typename T>
|
||||
@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardsigmoid, float, float);
|
||||
REGISTER_UNARY_OP(hardsigmoid, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct hardswish_functor {
|
||||
template <typename T>
|
||||
@ -103,15 +93,11 @@ struct hardswish_backward_functor {
|
||||
|
||||
REGISTER_UNARY_OP(hardswish, float, float);
|
||||
REGISTER_UNARY_OP(hardswish, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_OP(hardswish_backward, float, float);
|
||||
REGISTER_BINARY_OP(hardswish_backward, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
struct leaky_relu_functor {
|
||||
template <typename T>
|
||||
@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
@ -113,18 +113,12 @@ kernel void ampUpdateScale(
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(float);
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
|
||||
#endif
|
||||
|
@ -590,9 +590,7 @@ kernel void attention(
|
||||
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(float);
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
|
||||
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
|
||||
@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
|
||||
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(float);
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
|
||||
#endif
|
||||
|
@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
|
||||
};
|
||||
|
||||
struct nextafter_functor {
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename U>
|
||||
struct bit_type {};
|
||||
template <>
|
||||
struct bit_type<float> {
|
||||
using type = int;
|
||||
};
|
||||
template <>
|
||||
struct bit_type<half> {
|
||||
using type = short;
|
||||
};
|
||||
#endif
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
#if __METAL_VERSION__ >= 310
|
||||
return static_cast<T>(::metal::nextafter(a, b));
|
||||
#else
|
||||
using U = typename bit_type<T>::type;
|
||||
if (a == b) {
|
||||
return a;
|
||||
}
|
||||
if (::metal::isunordered(a, b)) {
|
||||
return NAN;
|
||||
}
|
||||
if (a == 0) {
|
||||
constexpr auto eps = as_type<T>(static_cast<U>(1));
|
||||
return b > 0 ? eps : -eps;
|
||||
}
|
||||
auto bits = as_type<U>(a);
|
||||
(a > 0) ^ (a > b) ? bits++ : bits--;
|
||||
return as_type<T>(bits);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@ -344,13 +315,6 @@ struct fmod_functor {
|
||||
}
|
||||
};
|
||||
|
||||
// Some helper defines
|
||||
#if __METAL_VERSION__ >= 310
|
||||
#define _METAL_310_PLUS(x) x
|
||||
#else
|
||||
#define _METAL_310_PLUS(x)
|
||||
#endif
|
||||
|
||||
#define REGISTER_INTEGER_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, long, long); \
|
||||
REGISTER_BINARY_OP(NAME, int, int); \
|
||||
@ -370,12 +334,12 @@ struct fmod_functor {
|
||||
#define REGISTER_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
|
||||
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
|
||||
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
|
||||
|
||||
REGISTER_FLOAT_BINARY_OP(copysign);
|
||||
REGISTER_INT2FLOAT_BINARY_OP(copysign);
|
||||
@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
// Complex binary functions
|
||||
REGISTER_BINARY_OP(polar, float, float2);
|
||||
|
@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
|
||||
REGISTER_SEARCHSORTED_OP(float, long);
|
||||
REGISTER_SEARCHSORTED_OP(half, int);
|
||||
REGISTER_SEARCHSORTED_OP(half, long);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, int);
|
||||
REGISTER_SEARCHSORTED_OP(bfloat, long);
|
||||
#endif
|
||||
REGISTER_SEARCHSORTED_OP(char, int);
|
||||
REGISTER_SEARCHSORTED_OP(char, long);
|
||||
REGISTER_SEARCHSORTED_OP(uchar, int);
|
||||
|
@ -96,6 +96,4 @@ kernel void col2im_kernel(
|
||||
INSTANTIATE_COL2IM(bool);
|
||||
INSTANTIATE_COL2IM(float);
|
||||
INSTANTIATE_COL2IM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_COL2IM(bfloat);
|
||||
#endif
|
||||
|
@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
|
||||
REGISTER_CROSS_FUNC(char);
|
||||
REGISTER_CROSS_FUNC(uchar);
|
||||
REGISTER_CROSS_FUNC(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_FUNC(bfloat);
|
||||
#endif
|
||||
|
||||
template <typename T, typename U>
|
||||
kernel void cross(
|
||||
@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
|
||||
REGISTER_CROSS_OP(char);
|
||||
REGISTER_CROSS_OP(uchar);
|
||||
REGISTER_CROSS_OP(bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_CROSS_OP(bfloat);
|
||||
#endif
|
||||
|
@ -1,11 +1,9 @@
|
||||
#include <metal_stdlib>
|
||||
|
||||
using metal::max;
|
||||
#if __METAL_VERSION__ >= 310
|
||||
bfloat max(bfloat a, bfloat b) {
|
||||
return a > b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
#define kmaxThreadGroups 32
|
||||
#define kmaxTensors 32
|
||||
@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
|
||||
REGISTER_ADAM_OPS_QUART(float, half);
|
||||
REGISTER_ADAM_OPS_QUART(half, float);
|
||||
REGISTER_ADAM_OPS_QUART(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_ADAM_OPS_QUART(float, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
|
||||
REGISTER_ADAM_OPS_QUART(bfloat, float);
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
inline void sgd_momentum_math(
|
||||
@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
|
||||
REGISTER_FUSED_SGD_OP(half);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_FUSED_SGD_OP(bfloat);
|
||||
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
|
||||
#endif
|
||||
|
@ -106,9 +106,7 @@ kernel void polygamma(
|
||||
constant int64_t& order [[buffer(2)]], \
|
||||
uint id [[thread_position_in_grid]]);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_GAMMA_KERNELS(half, half);
|
||||
INSTANTIATE_GAMMA_KERNELS(float, float);
|
||||
INSTANTIATE_GAMMA_KERNELS(bool, float);
|
||||
|
@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float);
|
||||
INSTANTIATE_IM2COL(float2);
|
||||
INSTANTIATE_IM2COL(half);
|
||||
INSTANTIATE_IM2COL(half2);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_IM2COL(bfloat);
|
||||
#endif
|
||||
|
@ -240,9 +240,7 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
|
||||
REGISTER_INDEX_OP(put_accumulate, char, char);
|
||||
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
|
||||
REGISTER_INDEX_OP(put_accumulate, bool, bool);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
template <typename StridesT, typename DataT>
|
||||
kernel void kernel_index_offsets(
|
||||
@ -477,10 +475,8 @@ INSTANTIATE_INDEX_COPY(char, long);
|
||||
INSTANTIATE_INDEX_COPY(uchar, int);
|
||||
INSTANTIATE_INDEX_COPY(uchar, long);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INDEX_COPY(bfloat, int);
|
||||
INSTANTIATE_INDEX_COPY(bfloat, long);
|
||||
#endif
|
||||
INSTANTIATE_INDEX_COPY(float2, int);
|
||||
INSTANTIATE_INDEX_COPY(float2, long);
|
||||
INSTANTIATE_INDEX_COPY(half2, int);
|
||||
|
@ -288,7 +288,6 @@ kernel void layer_norm_looped(
|
||||
#define instantiate_layer_norm(DTYPE) \
|
||||
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
|
||||
|
||||
instantiate_layer_norm(float) instantiate_layer_norm(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_layer_norm(bfloat)
|
||||
#endif
|
||||
instantiate_layer_norm(float);
|
||||
instantiate_layer_norm(half);
|
||||
instantiate_layer_norm(bfloat);
|
||||
|
@ -635,9 +635,7 @@ kernel void applyPivots(
|
||||
|
||||
INSTANTIATE_NAIVE_MM(float);
|
||||
INSTANTIATE_NAIVE_MM(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_NAIVE_MM(bfloat);
|
||||
#endif
|
||||
|
||||
// Integral MM
|
||||
INSTANTIATE_NAIVE_MM(short);
|
||||
|
@ -453,6 +453,7 @@ kernel void avg_pool(
|
||||
|
||||
REGISTER_POOL_OP(float);
|
||||
REGISTER_POOL_OP(half);
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_POOL_OP(int);
|
||||
REGISTER_POOL_OP(long);
|
||||
REGISTER_POOL_OP(short);
|
||||
@ -462,8 +463,4 @@ REGISTER_POOL_OP(bool);
|
||||
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(float);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(half);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_POOL_OP(bfloat);
|
||||
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
|
||||
#endif
|
||||
|
@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128);
|
||||
INSTANTIATE_INT4MV(half, 128);
|
||||
INSTANTIATE_INT4MV(float, 256);
|
||||
INSTANTIATE_INT4MV(half, 256);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_INT4MV(bfloat, 32);
|
||||
INSTANTIATE_INT4MV(bfloat, 64);
|
||||
INSTANTIATE_INT4MV(bfloat, 128);
|
||||
INSTANTIATE_INT4MV(bfloat, 256);
|
||||
#endif
|
||||
|
||||
// ------------------------------ int8 MM For M >= 12 ------------------------------------
|
||||
/**
|
||||
@ -234,12 +232,10 @@ template <> struct BlockType<half> {
|
||||
using simdgroup_type8x8 = simdgroup_half8x8;
|
||||
using type4 = half4;
|
||||
};
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <> struct BlockType<bfloat> {
|
||||
using simdgroup_type8x8 = simdgroup_bfloat8x8;
|
||||
using type4 = bfloat4;
|
||||
};
|
||||
#endif
|
||||
|
||||
template<typename T>
|
||||
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
|
||||
@ -490,9 +486,7 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
|
||||
|
||||
INSTANTIATE_MM(float, char, get_scale_zero_q8);
|
||||
INSTANTIATE_MM(half, char, get_scale_zero_q8);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
|
||||
#endif
|
||||
// ------------------------------ int8 MM For M < 12 ------------------------------------
|
||||
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
|
||||
|
||||
@ -646,6 +640,4 @@ kernel void kernel_mul_mv<DTYPE>(
|
||||
|
||||
INSTANTIATE_MV(float);
|
||||
INSTANTIATE_MV(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_MV(bfloat);
|
||||
#endif
|
||||
|
@ -192,6 +192,4 @@ template <typename T>
|
||||
|
||||
instantiate_rms(float)
|
||||
instantiate_rms(half)
|
||||
#if __METAL_VERSION__ >= 310
|
||||
instantiate_rms(bfloat)
|
||||
#endif // clang-format on
|
||||
|
@ -23,6 +23,4 @@ kernel void renorm(
|
||||
|
||||
REGISTER_RENORM_OP(float);
|
||||
REGISTER_RENORM_OP(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_RENORM_OP(bfloat);
|
||||
#endif
|
||||
|
@ -25,379 +25,6 @@ struct LogAddExp {
|
||||
};
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ < 310
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMinOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::min(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::max());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct CumMaxOp {
|
||||
static acc_t apply(acc_t a, acc_t b) {
|
||||
return metal::max(a, b);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return static_cast<acc_t>(
|
||||
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
|
||||
: metal::numeric_limits<T>::lowest());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, typename acc_t = accum_t<T>>
|
||||
struct LogCumSumExpOp {
|
||||
static acc_t apply(acc_t x, acc_t y) {
|
||||
return LogAddExp{}(x, y);
|
||||
}
|
||||
static acc_t identity() {
|
||||
return -metal::numeric_limits<acc_t>::infinity();
|
||||
}
|
||||
};
|
||||
|
||||
// Inclusive scan along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_rows [[buffer(2)]],
|
||||
constant uint& row_size [[buffer(3)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[offset + col] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant uint& num_orows [[buffer(2)]],
|
||||
constant uint& num_irows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[idx] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along innermost dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_innermost_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_rows [[buffer(3)]],
|
||||
constant uint& row_size [[buffer(4)]],
|
||||
uint row [[thread_position_in_grid]]) {
|
||||
if (row >= num_rows)
|
||||
return;
|
||||
|
||||
const uint offset = row * row_size;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
for (uint col = 0; col < row_size; col++) {
|
||||
T val = input[offset + col];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[offset + col] = static_cast<T>(accumulator);
|
||||
indices[offset + col] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Inclusive scan with indices along outer dimension for contiguous tensors
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_contiguous_outer_dim(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant uint& num_orows [[buffer(3)]],
|
||||
constant uint& num_irows [[buffer(4)]],
|
||||
constant uint& row_size [[buffer(5)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const uint orow = thread_index / num_irows;
|
||||
const uint irow = thread_index % num_irows;
|
||||
|
||||
if (orow >= num_orows)
|
||||
return;
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
|
||||
const uint idx_base = orow * row_size * num_irows + irow;
|
||||
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
|
||||
T val = input[idx];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = col;
|
||||
}
|
||||
values[idx] = static_cast<T>(accumulator);
|
||||
indices[idx] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
// Shared utility functions for strided kernels
|
||||
inline long calculate_non_scan_elements(
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long total = 1;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
total *= sizes[i];
|
||||
}
|
||||
}
|
||||
return total;
|
||||
}
|
||||
|
||||
inline void thread_index_to_coordinates(
|
||||
uint index,
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* sizes,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long remaining_index = index;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
pos[i] = remaining_index % sizes[i];
|
||||
remaining_index /= sizes[i];
|
||||
} else {
|
||||
pos[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
inline long calculate_base_offset(
|
||||
int pos[c10::metal::max_ndim],
|
||||
constant long* strides,
|
||||
uint ndim,
|
||||
uint scan_dim) {
|
||||
long offset = 0;
|
||||
for (uint i = 0; i < ndim; ++i) {
|
||||
if (i != scan_dim) {
|
||||
offset += pos[i] * strides[i];
|
||||
}
|
||||
}
|
||||
return offset;
|
||||
}
|
||||
|
||||
// Generic strided scan kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* output [[buffer(1)]],
|
||||
constant long* sizes [[buffer(2)]],
|
||||
constant long* input_strides [[buffer(3)]],
|
||||
constant long* output_strides [[buffer(4)]],
|
||||
constant uint& ndim [[buffer(5)]],
|
||||
constant uint& scan_dim [[buffer(6)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long output_base_offset =
|
||||
calculate_base_offset(pos, output_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long output_scan_stride = output_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long output_offset =
|
||||
output_base_offset + scan_idx * output_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
accumulator = Op::apply(accumulator, accum_val);
|
||||
output[output_offset] = static_cast<T>(accumulator);
|
||||
}
|
||||
}
|
||||
|
||||
// Generic strided scan with indices kernel
|
||||
template <typename T, typename Op, typename acc_t = accum_t<T>>
|
||||
kernel void scan_with_indices_strided(
|
||||
constant T* input [[buffer(0)]],
|
||||
device T* values [[buffer(1)]],
|
||||
device int64_t* indices [[buffer(2)]],
|
||||
constant long* sizes [[buffer(3)]],
|
||||
constant long* input_strides [[buffer(4)]],
|
||||
constant long* values_strides [[buffer(5)]],
|
||||
constant long* indices_strides [[buffer(6)]],
|
||||
constant uint& ndim [[buffer(7)]],
|
||||
constant uint& scan_dim [[buffer(8)]],
|
||||
uint thread_index [[thread_position_in_grid]]) {
|
||||
const long total_non_scan_elements =
|
||||
calculate_non_scan_elements(sizes, ndim, scan_dim);
|
||||
if (thread_index >= total_non_scan_elements) {
|
||||
return;
|
||||
}
|
||||
|
||||
int pos[c10::metal::max_ndim];
|
||||
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
|
||||
|
||||
const long input_base_offset =
|
||||
calculate_base_offset(pos, input_strides, ndim, scan_dim);
|
||||
const long values_base_offset =
|
||||
calculate_base_offset(pos, values_strides, ndim, scan_dim);
|
||||
const long indices_base_offset =
|
||||
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
|
||||
|
||||
acc_t accumulator = Op::identity();
|
||||
int64_t best_idx = 0;
|
||||
const long scan_size = sizes[scan_dim];
|
||||
const long input_scan_stride = input_strides[scan_dim];
|
||||
const long values_scan_stride = values_strides[scan_dim];
|
||||
const long indices_scan_stride = indices_strides[scan_dim];
|
||||
|
||||
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
|
||||
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
|
||||
const long values_offset =
|
||||
values_base_offset + scan_idx * values_scan_stride;
|
||||
const long indices_offset =
|
||||
indices_base_offset + scan_idx * indices_scan_stride;
|
||||
|
||||
T val = input[input_offset];
|
||||
acc_t accum_val = static_cast<acc_t>(val);
|
||||
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
|
||||
accumulator = accum_val;
|
||||
best_idx = scan_idx;
|
||||
}
|
||||
values[values_offset] = static_cast<T>(accumulator);
|
||||
indices[indices_offset] = best_idx;
|
||||
}
|
||||
}
|
||||
|
||||
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_rows [[buffer(2)]], \
|
||||
constant uint & row_size [[buffer(3)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant uint & num_orows [[buffer(2)]], \
|
||||
constant uint & num_irows [[buffer(3)]], \
|
||||
constant uint & row_size [[buffer(4)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * output [[buffer(1)]], \
|
||||
constant long* sizes [[buffer(2)]], \
|
||||
constant long* input_strides [[buffer(3)]], \
|
||||
constant long* output_strides [[buffer(4)]], \
|
||||
constant uint& ndim [[buffer(5)]], \
|
||||
constant uint& scan_dim [[buffer(6)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
|
||||
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_rows [[buffer(3)]], \
|
||||
constant uint& row_size [[buffer(4)]], \
|
||||
uint row [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant uint& num_orows [[buffer(3)]], \
|
||||
constant uint& num_irows [[buffer(4)]], \
|
||||
constant uint& row_size [[buffer(5)]], \
|
||||
uint thread_index [[thread_position_in_grid]]); \
|
||||
\
|
||||
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
|
||||
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
|
||||
constant DTYPE * input [[buffer(0)]], \
|
||||
device DTYPE * values [[buffer(1)]], \
|
||||
device int64_t* indices [[buffer(2)]], \
|
||||
constant long* sizes [[buffer(3)]], \
|
||||
constant long* input_strides [[buffer(4)]], \
|
||||
constant long* values_strides [[buffer(5)]], \
|
||||
constant long* indices_strides [[buffer(6)]], \
|
||||
constant uint& ndim [[buffer(7)]], \
|
||||
constant uint& scan_dim [[buffer(8)]], \
|
||||
uint thread_index [[thread_position_in_grid]]);
|
||||
|
||||
// Simple scan operations
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
|
||||
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
|
||||
|
||||
// Scan operations with indices
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
|
||||
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
|
||||
|
||||
#else // __METAL_VERSION__ >= 310
|
||||
|
||||
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
|
||||
|
||||
// The reminder of this file contains cummin and cummax implementations adapted
|
||||
@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
|
||||
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
|
||||
|
||||
#endif
|
||||
|
@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float);
|
||||
REGISTER_SPECIAL(int, float);
|
||||
REGISTER_SPECIAL(long, float);
|
||||
REGISTER_SPECIAL(half, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_SPECIAL(bfloat, bfloat);
|
||||
#endif
|
||||
|
@ -100,9 +100,7 @@ kernel void triul(
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half, int);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_TRIUL_KERNELS(float2, int);
|
||||
INSTANTIATE_TRIUL_KERNELS(half2, int);
|
||||
|
@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half);
|
||||
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
|
||||
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(neg, bfloat, bfloat);
|
||||
REGISTER_UNARY_OP(abs, bfloat, bfloat);
|
||||
#endif
|
||||
INSTANTIATE_UNARY_KERNELS2(half, half);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, float);
|
||||
INSTANTIATE_UNARY_KERNELS2(float, bool);
|
||||
@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
|
||||
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
|
||||
#endif
|
||||
|
@ -70,6 +70,4 @@ kernel void unfold_backward(
|
||||
|
||||
INSTANTIATE_UNFOLD_BACKWARD(float);
|
||||
INSTANTIATE_UNFOLD_BACKWARD(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
|
||||
#endif
|
||||
|
@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
|
||||
INSTANTIATE_UPSAMPLE_3D(uchar);
|
||||
INSTANTIATE_UPSAMPLE_ALL(float);
|
||||
INSTANTIATE_UPSAMPLE_ALL(half);
|
||||
#if __METAL_VERSION__ >= 310
|
||||
INSTANTIATE_UPSAMPLE_ALL(bfloat);
|
||||
#endif
|
||||
|
@ -85,7 +85,6 @@ struct AtomicType<uchar> {
|
||||
}
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct AtomicType<bfloat> {
|
||||
using type = ::metal::atomic<uint>;
|
||||
@ -93,7 +92,6 @@ struct AtomicType<bfloat> {
|
||||
atomic_add_helper<bfloat>(data, offset, value);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
// Metal supports atomic_store_explicit for bools, but
|
||||
// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to
|
||||
|
@ -9,7 +9,6 @@
|
||||
#define C10_METAL_CONSTEXPR constexpr
|
||||
#endif
|
||||
|
||||
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
|
||||
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
||||
_(Byte, 0) \
|
||||
_(Char, 1) \
|
||||
@ -22,19 +21,6 @@
|
||||
_(ComplexFloat, 9) \
|
||||
_(Bool, 11) \
|
||||
_(BFloat16, 15)
|
||||
#else
|
||||
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
|
||||
_(Byte, 0) \
|
||||
_(Char, 1) \
|
||||
_(Short, 2) \
|
||||
_(Int, 3) \
|
||||
_(Long, 4) \
|
||||
_(Half, 5) \
|
||||
_(Float, 6) \
|
||||
_(ComplexHalf, 8) \
|
||||
_(ComplexFloat, 9) \
|
||||
_(Bool, 11)
|
||||
#endif
|
||||
|
||||
namespace c10 {
|
||||
namespace metal {
|
||||
|
@ -186,10 +186,8 @@ inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
|
||||
return cast_to<T>(val_at_offs<float>(ptr, offs));
|
||||
case ScalarType::Half:
|
||||
return cast_to<T>(val_at_offs<half>(ptr, offs));
|
||||
#if __METAL_VERSION__ >= 310
|
||||
case ScalarType::BFloat16:
|
||||
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
|
||||
#endif
|
||||
// Complex
|
||||
case ScalarType::ComplexHalf:
|
||||
return cast_to<T>(val_at_offs<half2>(ptr, offs));
|
||||
|
@ -15,12 +15,10 @@ struct simd_type {
|
||||
template <typename T>
|
||||
using simd_type_t = typename simd_type<T>::t;
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct simd_type<bfloat> {
|
||||
using t = float;
|
||||
};
|
||||
#endif
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
|
@ -24,14 +24,12 @@ struct vectypes<half> {
|
||||
using type2 = half2;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct vectypes<bfloat> {
|
||||
using type4 = bfloat4;
|
||||
using type3 = bfloat3;
|
||||
using type2 = bfloat2;
|
||||
};
|
||||
#endif
|
||||
|
||||
template <>
|
||||
struct vectypes<short> {
|
||||
@ -79,12 +77,10 @@ struct OpMathType<uchar> {
|
||||
using type = int;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
struct OpMathType<bfloat> {
|
||||
using type = float;
|
||||
};
|
||||
#endif
|
||||
|
||||
// Type promotion structure for higher precision accumulation
|
||||
template <typename T>
|
||||
@ -98,13 +94,11 @@ struct AccumulationType<half> {
|
||||
using type = float;
|
||||
};
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
// Specialization for bfloat - promote to float for accumulation
|
||||
template <>
|
||||
struct AccumulationType<bfloat> {
|
||||
using type = float;
|
||||
};
|
||||
#endif
|
||||
|
||||
} // namespace detail
|
||||
|
||||
@ -130,7 +124,6 @@ min(T a, U b) {
|
||||
return ::metal::min(a, static_cast<T>(b));
|
||||
}
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
template <>
|
||||
inline bfloat min(bfloat a, bfloat b) {
|
||||
return bfloat(
|
||||
@ -142,7 +135,6 @@ inline bfloat max(bfloat a, bfloat b) {
|
||||
return bfloat(
|
||||
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using vec2type_t = typename detail::vectypes<T>::type2;
|
||||
|
@ -825,7 +825,6 @@ if(USE_MPS)
|
||||
if(CAN_COMPILE_METAL)
|
||||
add_dependencies(torch_cpu metallibs)
|
||||
target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_basic,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_basic.metallib)
|
||||
target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_bfloat,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_bfloat.metallib)
|
||||
else()
|
||||
target_compile_definitions(torch_cpu PRIVATE PYTORCH_JIT_COMPILE_SHADERS)
|
||||
endif()
|
||||
|
Reference in New Issue
Block a user