[BE][MPS] Build metal kernels of MacOS-14+ (#159733)

Which makes `#if __METAL_VERSION__ >= 310` guards for `bfloat` use support unnecessary.
Rename `kernels_bfloat.metallib` into `kernels_basic` and remove custom build/selection logic.

Part of https://github.com/pytorch/pytorch/issues/159275
Pull Request resolved: https://github.com/pytorch/pytorch/pull/159733
Approved by: https://github.com/dcci
ghstack dependencies: #159731, #159732
This commit is contained in:
Nikita Shulga
2025-08-03 12:23:22 -07:00
committed by PyTorch MergeBot
parent 5116c49b52
commit e2a5c42e7e
31 changed files with 11 additions and 538 deletions

View File

@ -704,21 +704,17 @@ if(USE_MPS)
if(CAN_COMPILE_METAL)
foreach(SHADER ${native_mps_metal})
cmake_path(GET SHADER STEM TGT_STEM)
string(CONCAT TGT_BASIC ${TGT_STEM} "_30.air")
string(CONCAT TGT_BFLOAT ${TGT_STEM} "_31.air")
string(CONCAT TGT_BASIC ${TGT_STEM} "_31.air")
list(APPEND AIR_BASIC ${TGT_BASIC})
list(APPEND AIR_BFLOAT ${TGT_BFLOAT})
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.0")
metal_to_air(${SHADER} ${TGT_BFLOAT} "-std=metal3.1")
metal_to_air(${SHADER} ${TGT_BASIC} "-std=metal3.1")
endforeach()
air_to_metallib(kernels_basic.metallib ${AIR_BASIC})
air_to_metallib(kernels_bfloat.metallib ${AIR_BFLOAT})
add_custom_command(
COMMAND echo "// $$(date)" > metallib_dummy.cpp
DEPENDS kernels_basic.metallib kernels_bfloat.metallib
DEPENDS kernels_basic.metallib
OUTPUT metallib_dummy.cpp
COMMENT "Updating metallibs timestamp")
add_custom_target(metallibs DEPENDS kernels_basic.metallib kernels_bfloat.metallib metallib_dummy.cpp)
add_custom_target(metallibs DEPENDS kernels_basic.metallib metallib_dummy.cpp)
else()
file(MAKE_DIRECTORY "${CMAKE_CURRENT_BINARY_DIR}/native/mps")
foreach(SHADER ${native_mps_metal})

View File

@ -953,8 +953,7 @@ class BundledShaderLibary : public MetalShaderLibrary {
if (C10_UNLIKELY(!library)) {
auto device = MPSDevice::getInstance()->device();
NSError* error = nil;
auto section_name = is_macos_13_or_newer(MacOSVersion::MACOS_VER_14_0_PLUS) ? "metal_bfloat" : "metal_basic";
library = [device newLibraryWithData:getSectionData(section_name) error:&error];
library = [device newLibraryWithData:getSectionData("metal_basic") error:&error];
TORCH_CHECK(library, "Failed to create metal library, error: ", [[error description] UTF8String]);
}
return library;

View File

@ -33,21 +33,15 @@ struct shrink_backward_functor {
REGISTER_UNARY_ALPHA_OP(hardshrink, float, float, float);
REGISTER_UNARY_ALPHA_OP(hardshrink, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(hardshrink, bfloat, bfloat, bfloat);
#endif
REGISTER_UNARY_ALPHA_OP(softshrink, float, float, float);
REGISTER_UNARY_ALPHA_OP(softshrink, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(softshrink, bfloat, bfloat, bfloat);
#endif
REGISTER_BINARY_ALPHA_OP(shrink_backward, float, float, float);
REGISTER_BINARY_ALPHA_OP(shrink_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(shrink_backward, bfloat, bfloat, bfloat);
#endif
struct hardsigmoid_functor {
template <typename T>
@ -67,15 +61,11 @@ struct hardsigmoid_backward_functor {
REGISTER_UNARY_OP(hardsigmoid, float, float);
REGISTER_UNARY_OP(hardsigmoid, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardsigmoid, bfloat, bfloat);
#endif
REGISTER_BINARY_OP(hardsigmoid_backward, float, float);
REGISTER_BINARY_OP(hardsigmoid_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardsigmoid_backward, bfloat, bfloat);
#endif
struct hardswish_functor {
template <typename T>
@ -103,15 +93,11 @@ struct hardswish_backward_functor {
REGISTER_UNARY_OP(hardswish, float, float);
REGISTER_UNARY_OP(hardswish, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_OP(hardswish, bfloat, bfloat);
#endif
REGISTER_BINARY_OP(hardswish_backward, float, float);
REGISTER_BINARY_OP(hardswish_backward, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(hardswish_backward, bfloat, bfloat);
#endif
struct leaky_relu_functor {
template <typename T>
@ -135,12 +121,8 @@ struct leaky_relu_backward_functor {
REGISTER_UNARY_ALPHA_OP(leaky_relu, float, float, float);
REGISTER_UNARY_ALPHA_OP(leaky_relu, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(leaky_relu, bfloat, bfloat, bfloat);
#endif
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, float, float, float);
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, half, half, half);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(leaky_relu_backward, bfloat, bfloat, bfloat);
#endif

View File

@ -113,18 +113,12 @@ kernel void ampUpdateScale(
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(float);
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE(bfloat);
#endif
INSTANTIATE_AMP_UPDATE_SCALE(float);
INSTANTIATE_AMP_UPDATE_SCALE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_UPDATE_SCALE(bfloat);
#endif
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(float);
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_AMP_NONFINITE_CHECK_AND_UNSCALE_SINGLE(bfloat);
#endif

View File

@ -590,9 +590,7 @@ kernel void attention(
INSTANTIATE_SDPA_VECTOR_HEADS(float);
INSTANTIATE_SDPA_VECTOR_HEADS(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
#endif
#define INSTANTIATE_ATTN(DTYPE, bq, bk, bd, wm, wn) \
template [[host_name("attention_" #DTYPE "_bq" #bq "_bk" #bk "_bd" #bd \
@ -621,6 +619,4 @@ INSTANTIATE_SDPA_VECTOR_HEADS(bfloat);
INSTANTIATE_ATTN_SHAPES_HELPER(float);
INSTANTIATE_ATTN_SHAPES_HELPER(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_ATTN_SHAPES_HELPER(bfloat);
#endif

View File

@ -209,38 +209,9 @@ struct hermite_polynomial_he_functor {
};
struct nextafter_functor {
#if __METAL_VERSION__ < 310
template <typename U>
struct bit_type {};
template <>
struct bit_type<float> {
using type = int;
};
template <>
struct bit_type<half> {
using type = short;
};
#endif
template <typename T>
inline T operator()(const T a, const T b) {
#if __METAL_VERSION__ >= 310
return static_cast<T>(::metal::nextafter(a, b));
#else
using U = typename bit_type<T>::type;
if (a == b) {
return a;
}
if (::metal::isunordered(a, b)) {
return NAN;
}
if (a == 0) {
constexpr auto eps = as_type<T>(static_cast<U>(1));
return b > 0 ? eps : -eps;
}
auto bits = as_type<U>(a);
(a > 0) ^ (a > b) ? bits++ : bits--;
return as_type<T>(bits);
#endif
}
};
@ -344,13 +315,6 @@ struct fmod_functor {
}
};
// Some helper defines
#if __METAL_VERSION__ >= 310
#define _METAL_310_PLUS(x) x
#else
#define _METAL_310_PLUS(x)
#endif
#define REGISTER_INTEGER_BINARY_OP(NAME) \
REGISTER_BINARY_OP(NAME, long, long); \
REGISTER_BINARY_OP(NAME, int, int); \
@ -370,12 +334,12 @@ struct fmod_functor {
#define REGISTER_FLOAT_BINARY_OP(NAME) \
REGISTER_BINARY_OP(NAME, float, float); \
REGISTER_BINARY_OP(NAME, half, half); \
_METAL_310_PLUS(REGISTER_BINARY_OP(NAME, bfloat, bfloat))
REGISTER_BINARY_OP(NAME, bfloat, bfloat)
#define REGISTER_OPMATH_FLOAT_BINARY_OP(NAME) \
REGISTER_OPMATH_BINARY_OP(NAME, float, float); \
REGISTER_OPMATH_BINARY_OP(NAME, half, half); \
_METAL_310_PLUS(REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat))
REGISTER_OPMATH_BINARY_OP(NAME, bfloat, bfloat)
REGISTER_FLOAT_BINARY_OP(copysign);
REGISTER_INT2FLOAT_BINARY_OP(copysign);
@ -447,11 +411,9 @@ REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char, char);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat, bfloat);
#endif
// Complex binary functions
REGISTER_BINARY_OP(polar, float, float2);

View File

@ -180,10 +180,8 @@ REGISTER_SEARCHSORTED_OP(float, int);
REGISTER_SEARCHSORTED_OP(float, long);
REGISTER_SEARCHSORTED_OP(half, int);
REGISTER_SEARCHSORTED_OP(half, long);
#if __METAL_VERSION__ >= 310
REGISTER_SEARCHSORTED_OP(bfloat, int);
REGISTER_SEARCHSORTED_OP(bfloat, long);
#endif
REGISTER_SEARCHSORTED_OP(char, int);
REGISTER_SEARCHSORTED_OP(char, long);
REGISTER_SEARCHSORTED_OP(uchar, int);

View File

@ -96,6 +96,4 @@ kernel void col2im_kernel(
INSTANTIATE_COL2IM(bool);
INSTANTIATE_COL2IM(float);
INSTANTIATE_COL2IM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_COL2IM(bfloat);
#endif

View File

@ -20,9 +20,7 @@ REGISTER_CROSS_FUNC(short);
REGISTER_CROSS_FUNC(char);
REGISTER_CROSS_FUNC(uchar);
REGISTER_CROSS_FUNC(bool);
#if __METAL_VERSION__ >= 310
REGISTER_CROSS_FUNC(bfloat);
#endif
template <typename T, typename U>
kernel void cross(
@ -68,6 +66,4 @@ REGISTER_CROSS_OP(short);
REGISTER_CROSS_OP(char);
REGISTER_CROSS_OP(uchar);
REGISTER_CROSS_OP(bool);
#if __METAL_VERSION__ >= 310
REGISTER_CROSS_OP(bfloat);
#endif

View File

@ -1,11 +1,9 @@
#include <metal_stdlib>
using metal::max;
#if __METAL_VERSION__ >= 310
bfloat max(bfloat a, bfloat b) {
return a > b ? a : b;
}
#endif
#define kmaxThreadGroups 32
#define kmaxTensors 32
@ -306,11 +304,9 @@ REGISTER_ADAM_OPS_QUART(float, float);
REGISTER_ADAM_OPS_QUART(float, half);
REGISTER_ADAM_OPS_QUART(half, float);
REGISTER_ADAM_OPS_QUART(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_ADAM_OPS_QUART(float, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, bfloat);
REGISTER_ADAM_OPS_QUART(bfloat, float);
#endif
template <typename T>
inline void sgd_momentum_math(
@ -460,7 +456,5 @@ REGISTER_FUSED_SGD_OP(float);
REGISTER_FUSED_SGD_OP(half);
REGISTER_FUSED_SGD_MOMENTUM_OP(float);
REGISTER_FUSED_SGD_MOMENTUM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_FUSED_SGD_OP(bfloat);
REGISTER_FUSED_SGD_MOMENTUM_OP(bfloat);
#endif

View File

@ -106,9 +106,7 @@ kernel void polygamma(
constant int64_t& order [[buffer(2)]], \
uint id [[thread_position_in_grid]]);
#if __METAL_VERSION__ >= 310
INSTANTIATE_GAMMA_KERNELS(bfloat, bfloat);
#endif
INSTANTIATE_GAMMA_KERNELS(half, half);
INSTANTIATE_GAMMA_KERNELS(float, float);
INSTANTIATE_GAMMA_KERNELS(bool, float);

View File

@ -76,6 +76,4 @@ INSTANTIATE_IM2COL(float);
INSTANTIATE_IM2COL(float2);
INSTANTIATE_IM2COL(half);
INSTANTIATE_IM2COL(half2);
#if __METAL_VERSION__ >= 310
INSTANTIATE_IM2COL(bfloat);
#endif

View File

@ -240,9 +240,7 @@ REGISTER_INDEX_OP(put_accumulate, short, short);
REGISTER_INDEX_OP(put_accumulate, char, char);
REGISTER_INDEX_OP(put_accumulate, uchar, uchar);
REGISTER_INDEX_OP(put_accumulate, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_INDEX_OP(put_accumulate, bfloat, bfloat);
#endif
template <typename StridesT, typename DataT>
kernel void kernel_index_offsets(
@ -477,10 +475,8 @@ INSTANTIATE_INDEX_COPY(char, long);
INSTANTIATE_INDEX_COPY(uchar, int);
INSTANTIATE_INDEX_COPY(uchar, long);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INDEX_COPY(bfloat, int);
INSTANTIATE_INDEX_COPY(bfloat, long);
#endif
INSTANTIATE_INDEX_COPY(float2, int);
INSTANTIATE_INDEX_COPY(float2, long);
INSTANTIATE_INDEX_COPY(half2, int);

View File

@ -288,7 +288,6 @@ kernel void layer_norm_looped(
#define instantiate_layer_norm(DTYPE) \
instantiate_layer_norm_single_row(DTYPE) instantiate_layer_norm_looped(DTYPE)
instantiate_layer_norm(float) instantiate_layer_norm(half)
#if __METAL_VERSION__ >= 310
instantiate_layer_norm(bfloat)
#endif
instantiate_layer_norm(float);
instantiate_layer_norm(half);
instantiate_layer_norm(bfloat);

View File

@ -635,9 +635,7 @@ kernel void applyPivots(
INSTANTIATE_NAIVE_MM(float);
INSTANTIATE_NAIVE_MM(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_NAIVE_MM(bfloat);
#endif
// Integral MM
INSTANTIATE_NAIVE_MM(short);

View File

@ -453,6 +453,7 @@ kernel void avg_pool(
REGISTER_POOL_OP(float);
REGISTER_POOL_OP(half);
REGISTER_POOL_OP(bfloat);
REGISTER_POOL_OP(int);
REGISTER_POOL_OP(long);
REGISTER_POOL_OP(short);
@ -462,8 +463,4 @@ REGISTER_POOL_OP(bool);
REGISTER_MAX_POOL_BACKWARD_OP(float);
REGISTER_MAX_POOL_BACKWARD_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_POOL_OP(bfloat);
REGISTER_MAX_POOL_BACKWARD_OP(bfloat);
#endif

View File

@ -197,12 +197,10 @@ INSTANTIATE_INT4MV(float, 128);
INSTANTIATE_INT4MV(half, 128);
INSTANTIATE_INT4MV(float, 256);
INSTANTIATE_INT4MV(half, 256);
#if __METAL_VERSION__ >= 310
INSTANTIATE_INT4MV(bfloat, 32);
INSTANTIATE_INT4MV(bfloat, 64);
INSTANTIATE_INT4MV(bfloat, 128);
INSTANTIATE_INT4MV(bfloat, 256);
#endif
// ------------------------------ int8 MM For M >= 12 ------------------------------------
/**
@ -234,12 +232,10 @@ template <> struct BlockType<half> {
using simdgroup_type8x8 = simdgroup_half8x8;
using type4 = half4;
};
#if __METAL_VERSION__ >= 310
template <> struct BlockType<bfloat> {
using simdgroup_type8x8 = simdgroup_bfloat8x8;
using type4 = bfloat4;
};
#endif
template<typename T>
float2 get_scale_zero_q8(constant T * scalesAndZeros, uint2 index) {
@ -490,9 +486,7 @@ kernel void kernel_mul_mm<DTYPE, WDTYPE, DEQUANT_FUNC>( \
INSTANTIATE_MM(float, char, get_scale_zero_q8);
INSTANTIATE_MM(half, char, get_scale_zero_q8);
#if __METAL_VERSION__ >= 310
INSTANTIATE_MM(bfloat, char, get_scale_zero_q8);
#endif
// ------------------------------ int8 MM For M < 12 ------------------------------------
/* Matrix vector multiplication, used for small M size for matrix multiplication as well.
@ -646,6 +640,4 @@ kernel void kernel_mul_mv<DTYPE>(
INSTANTIATE_MV(float);
INSTANTIATE_MV(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_MV(bfloat);
#endif

View File

@ -192,6 +192,4 @@ template <typename T>
instantiate_rms(float)
instantiate_rms(half)
#if __METAL_VERSION__ >= 310
instantiate_rms(bfloat)
#endif // clang-format on

View File

@ -23,6 +23,4 @@ kernel void renorm(
REGISTER_RENORM_OP(float);
REGISTER_RENORM_OP(half);
#if __METAL_VERSION__ >= 310
REGISTER_RENORM_OP(bfloat);
#endif

View File

@ -25,379 +25,6 @@ struct LogAddExp {
};
};
#if __METAL_VERSION__ < 310
template <typename T, typename acc_t = accum_t<T>>
struct CumMinOp {
static acc_t apply(acc_t a, acc_t b) {
return metal::min(a, b);
}
static acc_t identity() {
return static_cast<acc_t>(
metal::is_floating_point_v<T> ? metal::numeric_limits<T>::infinity()
: metal::numeric_limits<T>::max());
}
};
template <typename T, typename acc_t = accum_t<T>>
struct CumMaxOp {
static acc_t apply(acc_t a, acc_t b) {
return metal::max(a, b);
}
static acc_t identity() {
return static_cast<acc_t>(
metal::is_floating_point_v<T> ? -metal::numeric_limits<T>::infinity()
: metal::numeric_limits<T>::lowest());
}
};
template <typename T, typename acc_t = accum_t<T>>
struct LogCumSumExpOp {
static acc_t apply(acc_t x, acc_t y) {
return LogAddExp{}(x, y);
}
static acc_t identity() {
return -metal::numeric_limits<acc_t>::infinity();
}
};
// Inclusive scan along innermost dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_contiguous_innermost_dim(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant uint& num_rows [[buffer(2)]],
constant uint& row_size [[buffer(3)]],
uint row [[thread_position_in_grid]]) {
if (row >= num_rows)
return;
const uint offset = row * row_size;
acc_t accumulator = Op::identity();
for (uint col = 0; col < row_size; col++) {
T val = input[offset + col];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[offset + col] = static_cast<T>(accumulator);
}
}
// Inclusive scan along outer dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_contiguous_outer_dim(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant uint& num_orows [[buffer(2)]],
constant uint& num_irows [[buffer(3)]],
constant uint& row_size [[buffer(4)]],
uint thread_index [[thread_position_in_grid]]) {
const uint orow = thread_index / num_irows;
const uint irow = thread_index % num_irows;
if (orow >= num_orows)
return;
acc_t accumulator = Op::identity();
const uint idx_base = orow * row_size * num_irows + irow;
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
T val = input[idx];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[idx] = static_cast<T>(accumulator);
}
}
// Inclusive scan with indices along innermost dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_contiguous_innermost_dim(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant uint& num_rows [[buffer(3)]],
constant uint& row_size [[buffer(4)]],
uint row [[thread_position_in_grid]]) {
if (row >= num_rows)
return;
const uint offset = row * row_size;
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
for (uint col = 0; col < row_size; col++) {
T val = input[offset + col];
acc_t accum_val = static_cast<acc_t>(val);
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = col;
}
values[offset + col] = static_cast<T>(accumulator);
indices[offset + col] = best_idx;
}
}
// Inclusive scan with indices along outer dimension for contiguous tensors
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_contiguous_outer_dim(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant uint& num_orows [[buffer(3)]],
constant uint& num_irows [[buffer(4)]],
constant uint& row_size [[buffer(5)]],
uint thread_index [[thread_position_in_grid]]) {
const uint orow = thread_index / num_irows;
const uint irow = thread_index % num_irows;
if (orow >= num_orows)
return;
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
const uint idx_base = orow * row_size * num_irows + irow;
for (uint col = 0, idx = idx_base; col < row_size; col++, idx += num_irows) {
T val = input[idx];
acc_t accum_val = static_cast<acc_t>(val);
if (col == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = col;
}
values[idx] = static_cast<T>(accumulator);
indices[idx] = best_idx;
}
}
// Shared utility functions for strided kernels
inline long calculate_non_scan_elements(
constant long* sizes,
uint ndim,
uint scan_dim) {
long total = 1;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
total *= sizes[i];
}
}
return total;
}
inline void thread_index_to_coordinates(
uint index,
int pos[c10::metal::max_ndim],
constant long* sizes,
uint ndim,
uint scan_dim) {
long remaining_index = index;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
pos[i] = remaining_index % sizes[i];
remaining_index /= sizes[i];
} else {
pos[i] = 0;
}
}
}
inline long calculate_base_offset(
int pos[c10::metal::max_ndim],
constant long* strides,
uint ndim,
uint scan_dim) {
long offset = 0;
for (uint i = 0; i < ndim; ++i) {
if (i != scan_dim) {
offset += pos[i] * strides[i];
}
}
return offset;
}
// Generic strided scan kernel
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_strided(
constant T* input [[buffer(0)]],
device T* output [[buffer(1)]],
constant long* sizes [[buffer(2)]],
constant long* input_strides [[buffer(3)]],
constant long* output_strides [[buffer(4)]],
constant uint& ndim [[buffer(5)]],
constant uint& scan_dim [[buffer(6)]],
uint thread_index [[thread_position_in_grid]]) {
const long total_non_scan_elements =
calculate_non_scan_elements(sizes, ndim, scan_dim);
if (thread_index >= total_non_scan_elements) {
return;
}
int pos[c10::metal::max_ndim];
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
const long input_base_offset =
calculate_base_offset(pos, input_strides, ndim, scan_dim);
const long output_base_offset =
calculate_base_offset(pos, output_strides, ndim, scan_dim);
acc_t accumulator = Op::identity();
const long scan_size = sizes[scan_dim];
const long input_scan_stride = input_strides[scan_dim];
const long output_scan_stride = output_strides[scan_dim];
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
const long output_offset =
output_base_offset + scan_idx * output_scan_stride;
T val = input[input_offset];
acc_t accum_val = static_cast<acc_t>(val);
accumulator = Op::apply(accumulator, accum_val);
output[output_offset] = static_cast<T>(accumulator);
}
}
// Generic strided scan with indices kernel
template <typename T, typename Op, typename acc_t = accum_t<T>>
kernel void scan_with_indices_strided(
constant T* input [[buffer(0)]],
device T* values [[buffer(1)]],
device int64_t* indices [[buffer(2)]],
constant long* sizes [[buffer(3)]],
constant long* input_strides [[buffer(4)]],
constant long* values_strides [[buffer(5)]],
constant long* indices_strides [[buffer(6)]],
constant uint& ndim [[buffer(7)]],
constant uint& scan_dim [[buffer(8)]],
uint thread_index [[thread_position_in_grid]]) {
const long total_non_scan_elements =
calculate_non_scan_elements(sizes, ndim, scan_dim);
if (thread_index >= total_non_scan_elements) {
return;
}
int pos[c10::metal::max_ndim];
thread_index_to_coordinates(thread_index, pos, sizes, ndim, scan_dim);
const long input_base_offset =
calculate_base_offset(pos, input_strides, ndim, scan_dim);
const long values_base_offset =
calculate_base_offset(pos, values_strides, ndim, scan_dim);
const long indices_base_offset =
calculate_base_offset(pos, indices_strides, ndim, scan_dim);
acc_t accumulator = Op::identity();
int64_t best_idx = 0;
const long scan_size = sizes[scan_dim];
const long input_scan_stride = input_strides[scan_dim];
const long values_scan_stride = values_strides[scan_dim];
const long indices_scan_stride = indices_strides[scan_dim];
for (long scan_idx = 0; scan_idx < scan_size; scan_idx++) {
const long input_offset = input_base_offset + scan_idx * input_scan_stride;
const long values_offset =
values_base_offset + scan_idx * values_scan_stride;
const long indices_offset =
indices_base_offset + scan_idx * indices_scan_stride;
T val = input[input_offset];
acc_t accum_val = static_cast<acc_t>(val);
if (scan_idx == 0 || Op::apply(accum_val, accumulator) == accum_val) {
accumulator = accum_val;
best_idx = scan_idx;
}
values[values_offset] = static_cast<T>(accumulator);
indices[indices_offset] = best_idx;
}
}
#define REGISTER_SCAN_OP(OP_NAME, OP_CLASS, DTYPE) \
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
scan_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant uint & num_rows [[buffer(2)]], \
constant uint & row_size [[buffer(3)]], \
uint row [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
scan_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant uint & num_orows [[buffer(2)]], \
constant uint & num_irows [[buffer(3)]], \
constant uint & row_size [[buffer(4)]], \
uint thread_index [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
scan_strided<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * output [[buffer(1)]], \
constant long* sizes [[buffer(2)]], \
constant long* input_strides [[buffer(3)]], \
constant long* output_strides [[buffer(4)]], \
constant uint& ndim [[buffer(5)]], \
constant uint& scan_dim [[buffer(6)]], \
uint thread_index [[thread_position_in_grid]]);
#define REGISTER_SCAN_WITH_INDICES_OP(OP_NAME, OP_CLASS, DTYPE) \
template [[host_name(#OP_NAME "_contiguous_innermost_" #DTYPE)]] kernel void \
scan_with_indices_contiguous_innermost_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant uint& num_rows [[buffer(3)]], \
constant uint& row_size [[buffer(4)]], \
uint row [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_contiguous_outer_" #DTYPE)]] kernel void \
scan_with_indices_contiguous_outer_dim<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant uint& num_orows [[buffer(3)]], \
constant uint& num_irows [[buffer(4)]], \
constant uint& row_size [[buffer(5)]], \
uint thread_index [[thread_position_in_grid]]); \
\
template [[host_name(#OP_NAME "_strided_" #DTYPE)]] kernel void \
scan_with_indices_strided<DTYPE, OP_CLASS<DTYPE>>( \
constant DTYPE * input [[buffer(0)]], \
device DTYPE * values [[buffer(1)]], \
device int64_t* indices [[buffer(2)]], \
constant long* sizes [[buffer(3)]], \
constant long* input_strides [[buffer(4)]], \
constant long* values_strides [[buffer(5)]], \
constant long* indices_strides [[buffer(6)]], \
constant uint& ndim [[buffer(7)]], \
constant uint& scan_dim [[buffer(8)]], \
uint thread_index [[thread_position_in_grid]]);
// Simple scan operations
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, float);
REGISTER_SCAN_OP(logcumsumexp, LogCumSumExpOp, half);
// Scan operations with indices
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, float);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, half);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, long);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, int);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, short);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, char);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, uchar);
REGISTER_SCAN_WITH_INDICES_OP(cummin, CumMinOp, bool);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, float);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, half);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, long);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, int);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool);
#else // __METAL_VERSION__ >= 310
C10_METAL_CONSTEXPR auto simd_size = c10::metal::simdgroup_size;
// The reminder of this file contains cummin and cummax implementations adapted
@ -1159,5 +786,3 @@ REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, short, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, char, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, uchar, 4);
REGISTER_SCAN_WITH_INDICES_OP(cummax, CumMaxOp, bool, 4);
#endif

View File

@ -89,6 +89,4 @@ REGISTER_SPECIAL(short, float);
REGISTER_SPECIAL(int, float);
REGISTER_SPECIAL(long, float);
REGISTER_SPECIAL(half, half);
#if __METAL_VERSION__ >= 310
REGISTER_SPECIAL(bfloat, bfloat);
#endif

View File

@ -100,9 +100,7 @@ kernel void triul(
INSTANTIATE_TRIUL_KERNELS(float, int);
INSTANTIATE_TRIUL_KERNELS(half, int);
#if __METAL_VERSION__ >= 310
INSTANTIATE_TRIUL_KERNELS(bfloat, int);
#endif
INSTANTIATE_TRIUL_KERNELS(float2, int);
INSTANTIATE_TRIUL_KERNELS(half2, int);

View File

@ -556,11 +556,9 @@ REGISTER_UNARY_OP(abs, half, half);
REGISTER_UNARY_OP(acos, DTYPE1, DTYPE0); \
REGISTER_UNARY_OP(atan, DTYPE1, DTYPE0)
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNARY_KERNELS2(bfloat, bfloat);
REGISTER_UNARY_OP(neg, bfloat, bfloat);
REGISTER_UNARY_OP(abs, bfloat, bfloat);
#endif
INSTANTIATE_UNARY_KERNELS2(half, half);
INSTANTIATE_UNARY_KERNELS2(float, float);
INSTANTIATE_UNARY_KERNELS2(float, bool);
@ -600,6 +598,4 @@ INSTANTIATE_UNARY_KERNELS_VEC2(float);
REGISTER_UNARY_ALPHA_OP(round_decimals, float, long, float);
REGISTER_UNARY_ALPHA_OP(round_decimals, half, long, half);
#if __METAL_VERSION__ >= 310
REGISTER_UNARY_ALPHA_OP(round_decimals, bfloat, long, bfloat);
#endif

View File

@ -70,6 +70,4 @@ kernel void unfold_backward(
INSTANTIATE_UNFOLD_BACKWARD(float);
INSTANTIATE_UNFOLD_BACKWARD(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UNFOLD_BACKWARD(bfloat);
#endif

View File

@ -852,6 +852,4 @@ INSTANTIATE_UPSAMPLE_2D(bilinear2d, uchar);
INSTANTIATE_UPSAMPLE_3D(uchar);
INSTANTIATE_UPSAMPLE_ALL(float);
INSTANTIATE_UPSAMPLE_ALL(half);
#if __METAL_VERSION__ >= 310
INSTANTIATE_UPSAMPLE_ALL(bfloat);
#endif

View File

@ -85,7 +85,6 @@ struct AtomicType<uchar> {
}
};
#if __METAL_VERSION__ >= 310
template <>
struct AtomicType<bfloat> {
using type = ::metal::atomic<uint>;
@ -93,7 +92,6 @@ struct AtomicType<bfloat> {
atomic_add_helper<bfloat>(data, offset, value);
}
};
#endif
// Metal supports atomic_store_explicit for bools, but
// sizeof(::metal::atomic_bool) is 4 Therefore it could not be used to

View File

@ -9,7 +9,6 @@
#define C10_METAL_CONSTEXPR constexpr
#endif
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
@ -22,19 +21,6 @@
_(ComplexFloat, 9) \
_(Bool, 11) \
_(BFloat16, 15)
#else
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
_(Short, 2) \
_(Int, 3) \
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(ComplexHalf, 8) \
_(ComplexFloat, 9) \
_(Bool, 11)
#endif
namespace c10 {
namespace metal {

View File

@ -186,10 +186,8 @@ inline T val_at_offs(constant void* ptr, long offs, ScalarType type) {
return cast_to<T>(val_at_offs<float>(ptr, offs));
case ScalarType::Half:
return cast_to<T>(val_at_offs<half>(ptr, offs));
#if __METAL_VERSION__ >= 310
case ScalarType::BFloat16:
return cast_to<T>(val_at_offs<bfloat>(ptr, offs));
#endif
// Complex
case ScalarType::ComplexHalf:
return cast_to<T>(val_at_offs<half2>(ptr, offs));

View File

@ -15,12 +15,10 @@ struct simd_type {
template <typename T>
using simd_type_t = typename simd_type<T>::t;
#if __METAL_VERSION__ >= 310
template <>
struct simd_type<bfloat> {
using t = float;
};
#endif
} // namespace detail
template <typename T>

View File

@ -24,14 +24,12 @@ struct vectypes<half> {
using type2 = half2;
};
#if __METAL_VERSION__ >= 310
template <>
struct vectypes<bfloat> {
using type4 = bfloat4;
using type3 = bfloat3;
using type2 = bfloat2;
};
#endif
template <>
struct vectypes<short> {
@ -79,12 +77,10 @@ struct OpMathType<uchar> {
using type = int;
};
#if __METAL_VERSION__ >= 310
template <>
struct OpMathType<bfloat> {
using type = float;
};
#endif
// Type promotion structure for higher precision accumulation
template <typename T>
@ -98,13 +94,11 @@ struct AccumulationType<half> {
using type = float;
};
#if __METAL_VERSION__ >= 310
// Specialization for bfloat - promote to float for accumulation
template <>
struct AccumulationType<bfloat> {
using type = float;
};
#endif
} // namespace detail
@ -130,7 +124,6 @@ min(T a, U b) {
return ::metal::min(a, static_cast<T>(b));
}
#if __METAL_VERSION__ >= 310
template <>
inline bfloat min(bfloat a, bfloat b) {
return bfloat(
@ -142,7 +135,6 @@ inline bfloat max(bfloat a, bfloat b) {
return bfloat(
::metal::isunordered(a, b) ? NAN : ::metal::max(float(a), float(b)));
}
#endif
template <typename T>
using vec2type_t = typename detail::vectypes<T>::type2;

View File

@ -825,7 +825,6 @@ if(USE_MPS)
if(CAN_COMPILE_METAL)
add_dependencies(torch_cpu metallibs)
target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_basic,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_basic.metallib)
target_link_options(torch_cpu PRIVATE -Wl,-sectcreate,__TEXT,metal_bfloat,${CMAKE_CURRENT_BINARY_DIR}/aten/src/ATen/kernels_bfloat.metallib)
else()
target_compile_definitions(torch_cpu PRIVATE PYTORCH_JIT_COMPILE_SHADERS)
endif()