Implement metal kernel for basic MPS arithmetic ops using TensorIterator (#147644)

Add metal kernels for add, subtract, & lerp ops using TensorIterator. Should help resolve: https://github.com/pytorch/pytorch/issues/143874
Pull Request resolved: https://github.com/pytorch/pytorch/pull/147644
Approved by: https://github.com/malfet
This commit is contained in:
Siddharth Kotapati
2025-04-29 14:24:49 +00:00
committed by PyTorch MergeBot
parent 2fb62f8288
commit 663bcb68ba
7 changed files with 390 additions and 4 deletions

View File

@ -29,6 +29,9 @@ struct TensorIteratorBase;
namespace at::native::mps {
// Forward declaration of MPSScalar - for exec_binary_alpha_kernel()
struct MPSScalar;
namespace detail {
template <typename T>
class has_size_type {
@ -138,6 +141,10 @@ class MetalShaderLibrary {
const std::string& name,
std::optional<int64_t> extra = std::nullopt);
void exec_binary_kernel(TensorIteratorBase& iter, const std::string& name);
void exec_binary_alpha_kernel(
TensorIteratorBase& iter,
const std::string& name,
const MPSScalar& alpha);
protected:
virtual MTLLibrary_t getLibrary();

View File

@ -1078,6 +1078,68 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
});
}
void MetalShaderLibrary::exec_binary_alpha_kernel(TensorIteratorBase& iter,
const std::string& name,
const MPSScalar& alpha) {
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
Tensor input = iter.input(0);
Tensor other = iter.input(1);
Tensor out = iter.output();
id<MTLDevice> device = MPSDevice::getInstance()->device();
MPSStream* mpsStream = getCurrentMPSStream();
const uint32_t nDim = iter.ndim();
constexpr uint32_t nOffsets = 3;
const uint32_t numThreads = iter.numel();
const auto cast_needed = input.scalar_type() != other.scalar_type();
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
// TODO: Implicitly pass both input and output types to non-cast kernels
const auto kernel_name = cast_needed
? fmt::format("{}_alpha_{}_cast_{}", name, suffix, scalarToMetalTypeString(out))
: fmt::format("{}_alpha_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input));
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
@autoreleasepool {
auto computeEncoder = mpsStream->commandEncoder();
auto binaryPSO = getPipelineStateForFunc(kernel_name);
// this function call is a no-op if MPS Profiler is not enabled
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
[computeEncoder setComputePipelineState:binaryPSO];
// Iterator is contiguous if all of its elements are dense in storage,
// i.e. it's true for both row-first and column-first tensors
if (iter.is_contiguous()) {
mtl_setArgs(computeEncoder, out, input, other, alpha);
if (cast_needed) {
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
static_cast<int>(c10::elementSize(other.scalar_type())),
static_cast<int>(input.scalar_type()),
static_cast<int>(other.scalar_type())};
mtl_setBytes(computeEncoder, size_and_types, 4);
}
} else {
// Please note that shapes and strides of the iterator might be
// different than that of its operands, for example binary op
// between 4x4 tensor and scalar will result in 1D 16 element iterator
std::array<int, 3> ndim_and_types = {
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
mtl_setArgs(computeEncoder,
out,
input,
other,
alpha,
iter.shape(),
iter.strides(0),
iter.strides(1),
iter.strides(2),
ndim_and_types);
}
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
getMPSProfiler().endProfileKernel(binaryPSO);
}
});
}
MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() {
static BundledShaderLibary l;
return l;

View File

@ -4,6 +4,48 @@
#include <metal_stdlib>
using namespace metal;
struct add_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(a + b);
}
};
struct sub_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(a - b);
}
};
struct lerp_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(b);
}
};
struct add_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return static_cast<T>(a + (alpha * b));
}
};
struct sub_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return static_cast<T>(a - (alpha * b));
}
};
struct lerp_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return static_cast<T>(a + (alpha * (b - a)));
}
};
struct fmax_functor {
template <typename T>
inline T operator()(const T a, const T b) {
@ -152,6 +194,48 @@ struct complex_mul_functor {
}
};
struct complex_add_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return T(
a.x + (alpha.x * b.x - alpha.y * b.y),
a.y + (alpha.x * b.y + alpha.y * b.x));
}
};
struct complex_add_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(a.x + b.x, a.y + b.y);
}
};
struct complex_sub_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
return T(
a.x - (alpha.x * b.x - alpha.y * b.y),
a.y - (alpha.x * b.y + alpha.y * b.x));
}
};
struct complex_lerp_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
auto intr = T(b.x - a.x, b.y - a.y);
return T(
a.x + (alpha.x * intr.x - intr.y * intr.y),
a.y + (alpha.x * intr.y + alpha.y * intr.x));
}
};
struct complex_lerp_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(b.x, b.y);
}
};
REGISTER_BINARY_OP(copysign, long, float);
REGISTER_BINARY_OP(copysign, int, float);
REGISTER_BINARY_OP(copysign, float, float);
@ -182,6 +266,54 @@ REGISTER_BINARY_OP(hermite_polynomial_h, float, float);
REGISTER_BINARY_OP(hermite_polynomial_h, half, half);
REGISTER_BINARY_OP(hermite_polynomial_he, float, float);
REGISTER_BINARY_OP(hermite_polynomial_he, half, half);
REGISTER_BINARY_OP(add, long, long);
REGISTER_BINARY_OP(add, int, int);
REGISTER_BINARY_OP(add, float, float);
REGISTER_BINARY_OP(add, half, half);
REGISTER_BINARY_OP(add, short, short);
REGISTER_BINARY_OP(add, uchar, uchar);
REGISTER_BINARY_OP(add, char, char);
REGISTER_BINARY_OP(add, bool, bool);
REGISTER_BINARY_OP(sub, long, long);
REGISTER_BINARY_OP(sub, int, int);
REGISTER_BINARY_OP(sub, float, float);
REGISTER_BINARY_OP(sub, half, half);
REGISTER_BINARY_OP(sub, short, short);
REGISTER_BINARY_OP(sub, uchar, uchar);
REGISTER_BINARY_OP(sub, char, char);
REGISTER_BINARY_OP(sub, bool, bool);
REGISTER_BINARY_OP(lerp, long, long);
REGISTER_BINARY_OP(lerp, int, int);
REGISTER_BINARY_OP(lerp, float, float);
REGISTER_BINARY_OP(lerp, half, half);
REGISTER_BINARY_OP(lerp, short, short);
REGISTER_BINARY_OP(lerp, uchar, uchar);
REGISTER_BINARY_OP(lerp, char, char);
REGISTER_BINARY_OP(lerp, bool, bool);
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
REGISTER_BINARY_ALPHA_OP(add_alpha, half, half);
REGISTER_BINARY_ALPHA_OP(add_alpha, short, short);
REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(add_alpha, char, char);
REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool);
REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long);
REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int);
REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float);
REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half);
REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short);
REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool);
#if __METAL_VERSION__ >= 310
REGISTER_BINARY_OP(copysign, bfloat, bfloat);
@ -196,6 +328,12 @@ REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat);
REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat);
REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
REGISTER_BINARY_OP(add, bfloat, bfloat);
REGISTER_BINARY_OP(sub, bfloat, bfloat);
REGISTER_BINARY_OP(lerp, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
#endif
// Complex binary functions
@ -205,3 +343,15 @@ REGISTER_BINARY_OP(make_complex, float, float2);
REGISTER_BINARY_OP(make_complex, half, half2);
REGISTER_BINARY_OP(complex_mul, float2, float2);
REGISTER_BINARY_OP(complex_mul, half2, half2);
REGISTER_BINARY_OP(add, float2, float2);
REGISTER_BINARY_OP(add, half2, half2);
REGISTER_BINARY_OP(sub, float2, float2);
REGISTER_BINARY_OP(sub, half2, half2);
REGISTER_BINARY_OP(lerp, float2, float2);
REGISTER_BINARY_OP(lerp, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, half2, half2);

View File

@ -1,8 +1,14 @@
#pragma once
namespace at::native::mps {
void binary_op_kernel(
const std::string func_name,
const Tensor& input,
const Tensor& other,
const Tensor& output,
const std::optional<MPSScalar>& alpha = std::nullopt);
void complex_mul_out(
const Tensor& input,
const Tensor& other,
const Tensor& output);
}
} // namespace at::native::mps

View File

@ -50,6 +50,30 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
lib.exec_binary_kernel(iter, "complex_mul");
}
void binary_op_kernel(const std::string func_name,
const Tensor& input,
const Tensor& other,
const Tensor& output,
const std::optional<MPSScalar>& alpha) {
auto new_size = at::infer_size(input.sizes(), other.sizes());
if (!output.sizes().equals(new_size)) {
output.resize_(new_size);
}
uint32_t length = output.numel();
if (length == 0) {
return;
}
auto iter =
TensorIteratorConfig().add_output(output).add_input(input).add_input(other).check_all_same_dtype(false).build();
if (alpha) {
lib.exec_binary_alpha_kernel(iter, func_name, *alpha);
} else {
lib.exec_binary_kernel(iter, func_name);
}
}
} // namespace mps
static void fmax_mps_kernel(TensorIteratorBase& iter) {

View File

@ -265,9 +265,21 @@ static void add_sub_lerp_template(const Tensor& self,
}
const bool alpha_has_value = alpha.toDouble() != 1.0;
if (alpha_has_value) {
auto commonDtype = at::result_type(self, other);
at::native::alpha_check(commonDtype, alpha);
auto self_complex = c10::isComplexType(self.scalar_type());
auto other_complex = c10::isComplexType(other.scalar_type());
auto commonDtype = at::result_type(self, other);
if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) {
if (alpha_has_value) {
at::native::alpha_check(commonDtype, alpha);
mps::binary_op_kernel((self_complex || other_complex) ? "complex_" + op_name : op_name,
self,
other,
output,
getMPSScalar(alpha, commonDtype));
} else {
mps::binary_op_kernel(op_name, self, other, output);
}
return;
}
if (!alpha_has_value && op_name == "lerp") {

View File

@ -165,6 +165,29 @@ kernel void binary_strided(
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
}
template <typename T, typename F>
kernel void alpha_binary_strided(
device void* output [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other [[buffer(2)]],
constant T* alpha [[buffer(3)]],
constant long* sizes [[buffer(4)]],
constant long* output_strides [[buffer(5)]],
constant long* input_strides [[buffer(6)]],
constant long* other_strides [[buffer(7)]],
constant uint3& ndim [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
F f;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim.x);
const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
const auto a = val_at_offs<T>(input, input_offs);
const auto b = val_at_offs<T>(other, other_offs);
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, *alpha);
}
template <typename T, typename F>
kernel void binary_strided_cast(
device void* output [[buffer(0)]],
@ -189,6 +212,31 @@ kernel void binary_strided_cast(
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
}
template <typename T, typename F>
kernel void alpha_binary_strided_cast(
device void* output [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other [[buffer(2)]],
constant T* alpha [[buffer(3)]],
constant long* sizes [[buffer(4)]],
constant long* output_strides [[buffer(5)]],
constant long* input_strides [[buffer(6)]],
constant long* other_strides [[buffer(7)]],
constant uint3& ndim_types [[buffer(8)]],
uint index [[thread_position_in_grid]]) {
F f;
int pos[max_ndim];
pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
const auto a =
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
const auto b =
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, *alpha);
}
template <typename T, typename F>
kernel void binary_dense(
device result_of<F, T, T>* out [[buffer(0)]],
@ -199,6 +247,17 @@ kernel void binary_dense(
out[tid] = f(input[tid], other[tid]);
}
template <typename T, typename F>
kernel void alpha_binary_dense(
device result_of<F, T, T, T>* out [[buffer(0)]],
constant T* input [[buffer(1)]],
constant T* other [[buffer(2)]],
constant T* alpha [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
F f;
out[tid] = f(input[tid], other[tid], *alpha);
}
template <typename T, typename F>
kernel void binary_dense_cast(
device result_of<F, T, T>* out [[buffer(0)]],
@ -214,6 +273,22 @@ kernel void binary_dense_cast(
out[tid] = f(a, b);
}
template <typename T, typename F>
kernel void alpha_binary_dense_cast(
device result_of<F, T, T, T>* out [[buffer(0)]],
constant void* input [[buffer(1)]],
constant void* other [[buffer(2)]],
constant T* alpha [[buffer(3)]],
constant uint4& sizes_types [[buffer(4)]],
uint tid [[thread_position_in_grid]]) {
F f;
const auto a = val_at_offs<T>(
input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
const auto b = val_at_offs<T>(
other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
out[tid] = f(a, b, *alpha);
}
#define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \
static_assert( \
::metal::is_same_v< \
@ -257,5 +332,55 @@ kernel void binary_dense_cast(
constant void* other, \
constant uint4& sizes_types, \
uint tid)
#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEO) \
static_assert( \
::metal::is_same_v< \
DTYPEO, \
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI>>, \
"Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
c10::metal::alpha_binary_strided<DTYPEI, NAME##_functor>( \
device void* out, \
constant void* input, \
constant void* other, \
constant DTYPEI* alpha, \
constant long* sizes, \
constant long* output_strides, \
constant long* input_strides, \
constant long* other_strides, \
constant uint3& ndim, \
uint tid); \
template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
metal::alpha_binary_strided_cast<DTYPEI, NAME##_functor>( \
device void* out, \
constant void* input, \
constant void* other, \
constant DTYPEI* alpha, \
constant long* sizes, \
constant long* output_strides, \
constant long* input_strides, \
constant long* other_strides, \
constant uint3& ndim_types, \
uint tid); \
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
c10::metal::alpha_binary_dense<DTYPEI, NAME##_functor>( \
device ::c10::metal:: \
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
out_, \
constant DTYPEI * input_, \
constant DTYPEI * other_, \
constant DTYPEI * alpha, \
uint tid); \
template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
metal::alpha_binary_dense_cast<DTYPEI, NAME##_functor>( \
device ::c10::metal:: \
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
out_, \
constant void* input, \
constant void* other, \
constant DTYPEI* alpha, \
constant uint4& sizes_types, \
uint tid)
} // namespace metal
} // namespace c10