diff --git a/aten/src/ATen/native/mps/MetalShaderLibrary.h b/aten/src/ATen/native/mps/MetalShaderLibrary.h index acd2bf66101f..5bdf2f9fa771 100644 --- a/aten/src/ATen/native/mps/MetalShaderLibrary.h +++ b/aten/src/ATen/native/mps/MetalShaderLibrary.h @@ -29,6 +29,9 @@ struct TensorIteratorBase; namespace at::native::mps { +// Forward declaration of MPSScalar - for exec_binary_alpha_kernel() +struct MPSScalar; + namespace detail { template class has_size_type { @@ -138,6 +141,10 @@ class MetalShaderLibrary { const std::string& name, std::optional extra = std::nullopt); void exec_binary_kernel(TensorIteratorBase& iter, const std::string& name); + void exec_binary_alpha_kernel( + TensorIteratorBase& iter, + const std::string& name, + const MPSScalar& alpha); protected: virtual MTLLibrary_t getLibrary(); diff --git a/aten/src/ATen/native/mps/OperationUtils.mm b/aten/src/ATen/native/mps/OperationUtils.mm index 0a3a91ebe2f7..316ce8a6ea39 100644 --- a/aten/src/ATen/native/mps/OperationUtils.mm +++ b/aten/src/ATen/native/mps/OperationUtils.mm @@ -1078,6 +1078,68 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std: }); } +void MetalShaderLibrary::exec_binary_alpha_kernel(TensorIteratorBase& iter, + const std::string& name, + const MPSScalar& alpha) { + TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS"); + TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator"); + + Tensor input = iter.input(0); + Tensor other = iter.input(1); + Tensor out = iter.output(); + + id device = MPSDevice::getInstance()->device(); + MPSStream* mpsStream = getCurrentMPSStream(); + const uint32_t nDim = iter.ndim(); + constexpr uint32_t nOffsets = 3; + const uint32_t numThreads = iter.numel(); + const auto cast_needed = input.scalar_type() != other.scalar_type(); + const auto suffix = iter.is_contiguous() ? "dense" : "strided"; + // TODO: Implicitly pass both input and output types to non-cast kernels + const auto kernel_name = cast_needed + ? fmt::format("{}_alpha_{}_cast_{}", name, suffix, scalarToMetalTypeString(out)) + : fmt::format("{}_alpha_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input)); + dispatch_sync_with_rethrow(mpsStream->queue(), ^() { + @autoreleasepool { + auto computeEncoder = mpsStream->commandEncoder(); + auto binaryPSO = getPipelineStateForFunc(kernel_name); + // this function call is a no-op if MPS Profiler is not enabled + getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other}); + [computeEncoder setComputePipelineState:binaryPSO]; + // Iterator is contiguous if all of its elements are dense in storage, + // i.e. it's true for both row-first and column-first tensors + if (iter.is_contiguous()) { + mtl_setArgs(computeEncoder, out, input, other, alpha); + if (cast_needed) { + std::array size_and_types = {static_cast(c10::elementSize(input.scalar_type())), + static_cast(c10::elementSize(other.scalar_type())), + static_cast(input.scalar_type()), + static_cast(other.scalar_type())}; + mtl_setBytes(computeEncoder, size_and_types, 4); + } + } else { + // Please note that shapes and strides of the iterator might be + // different than that of its operands, for example binary op + // between 4x4 tensor and scalar will result in 1D 16 element iterator + std::array ndim_and_types = { + iter.ndim(), static_cast(input.scalar_type()), static_cast(other.scalar_type())}; + mtl_setArgs(computeEncoder, + out, + input, + other, + alpha, + iter.shape(), + iter.strides(0), + iter.strides(1), + iter.strides(2), + ndim_and_types); + } + mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads); + getMPSProfiler().endProfileKernel(binaryPSO); + } + }); +} + MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() { static BundledShaderLibary l; return l; diff --git a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal index 858f62c258ac..e2a6d5e05a31 100644 --- a/aten/src/ATen/native/mps/kernels/BinaryKernel.metal +++ b/aten/src/ATen/native/mps/kernels/BinaryKernel.metal @@ -4,6 +4,48 @@ #include using namespace metal; +struct add_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(a + b); + } +}; + +struct sub_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(a - b); + } +}; + +struct lerp_functor { + template + inline T operator()(const T a, const T b) { + return static_cast(b); + } +}; + +struct add_alpha_functor { + template + inline T operator()(const T a, const T b, const T alpha) { + return static_cast(a + (alpha * b)); + } +}; + +struct sub_alpha_functor { + template + inline T operator()(const T a, const T b, const T alpha) { + return static_cast(a - (alpha * b)); + } +}; + +struct lerp_alpha_functor { + template + inline T operator()(const T a, const T b, const T alpha) { + return static_cast(a + (alpha * (b - a))); + } +}; + struct fmax_functor { template inline T operator()(const T a, const T b) { @@ -152,6 +194,48 @@ struct complex_mul_functor { } }; +struct complex_add_alpha_functor { + template + inline T operator()(const T a, const T b, const T alpha) { + return T( + a.x + (alpha.x * b.x - alpha.y * b.y), + a.y + (alpha.x * b.y + alpha.y * b.x)); + } +}; + +struct complex_add_functor { + template + inline T operator()(const T a, const T b) { + return T(a.x + b.x, a.y + b.y); + } +}; + +struct complex_sub_alpha_functor { + template + inline T operator()(const T a, const T b, const T alpha) { + return T( + a.x - (alpha.x * b.x - alpha.y * b.y), + a.y - (alpha.x * b.y + alpha.y * b.x)); + } +}; + +struct complex_lerp_alpha_functor { + template + inline T operator()(const T a, const T b, const T alpha) { + auto intr = T(b.x - a.x, b.y - a.y); + return T( + a.x + (alpha.x * intr.x - intr.y * intr.y), + a.y + (alpha.x * intr.y + alpha.y * intr.x)); + } +}; + +struct complex_lerp_functor { + template + inline T operator()(const T a, const T b) { + return T(b.x, b.y); + } +}; + REGISTER_BINARY_OP(copysign, long, float); REGISTER_BINARY_OP(copysign, int, float); REGISTER_BINARY_OP(copysign, float, float); @@ -182,6 +266,54 @@ REGISTER_BINARY_OP(hermite_polynomial_h, float, float); REGISTER_BINARY_OP(hermite_polynomial_h, half, half); REGISTER_BINARY_OP(hermite_polynomial_he, float, float); REGISTER_BINARY_OP(hermite_polynomial_he, half, half); +REGISTER_BINARY_OP(add, long, long); +REGISTER_BINARY_OP(add, int, int); +REGISTER_BINARY_OP(add, float, float); +REGISTER_BINARY_OP(add, half, half); +REGISTER_BINARY_OP(add, short, short); +REGISTER_BINARY_OP(add, uchar, uchar); +REGISTER_BINARY_OP(add, char, char); +REGISTER_BINARY_OP(add, bool, bool); +REGISTER_BINARY_OP(sub, long, long); +REGISTER_BINARY_OP(sub, int, int); +REGISTER_BINARY_OP(sub, float, float); +REGISTER_BINARY_OP(sub, half, half); +REGISTER_BINARY_OP(sub, short, short); +REGISTER_BINARY_OP(sub, uchar, uchar); +REGISTER_BINARY_OP(sub, char, char); +REGISTER_BINARY_OP(sub, bool, bool); +REGISTER_BINARY_OP(lerp, long, long); +REGISTER_BINARY_OP(lerp, int, int); +REGISTER_BINARY_OP(lerp, float, float); +REGISTER_BINARY_OP(lerp, half, half); +REGISTER_BINARY_OP(lerp, short, short); +REGISTER_BINARY_OP(lerp, uchar, uchar); +REGISTER_BINARY_OP(lerp, char, char); +REGISTER_BINARY_OP(lerp, bool, bool); +REGISTER_BINARY_ALPHA_OP(add_alpha, long, long); +REGISTER_BINARY_ALPHA_OP(add_alpha, int, int); +REGISTER_BINARY_ALPHA_OP(add_alpha, float, float); +REGISTER_BINARY_ALPHA_OP(add_alpha, half, half); +REGISTER_BINARY_ALPHA_OP(add_alpha, short, short); +REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(add_alpha, char, char); +REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool); +REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long); +REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int); +REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float); +REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half); +REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short); +REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool); #if __METAL_VERSION__ >= 310 REGISTER_BINARY_OP(copysign, bfloat, bfloat); @@ -196,6 +328,12 @@ REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat); REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat); REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat); REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat); +REGISTER_BINARY_OP(add, bfloat, bfloat); +REGISTER_BINARY_OP(sub, bfloat, bfloat); +REGISTER_BINARY_OP(lerp, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat); +REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat); #endif // Complex binary functions @@ -205,3 +343,15 @@ REGISTER_BINARY_OP(make_complex, float, float2); REGISTER_BINARY_OP(make_complex, half, half2); REGISTER_BINARY_OP(complex_mul, float2, float2); REGISTER_BINARY_OP(complex_mul, half2, half2); +REGISTER_BINARY_OP(add, float2, float2); +REGISTER_BINARY_OP(add, half2, half2); +REGISTER_BINARY_OP(sub, float2, float2); +REGISTER_BINARY_OP(sub, half2, half2); +REGISTER_BINARY_OP(lerp, float2, float2); +REGISTER_BINARY_OP(lerp, half2, half2); +REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2); +REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2); +REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, half2, half2); +REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, float2, float2); +REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, half2, half2); diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.h b/aten/src/ATen/native/mps/operations/BinaryKernel.h index 6ee63360cc41..c4b462c5dd8f 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.h +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.h @@ -1,8 +1,14 @@ #pragma once namespace at::native::mps { +void binary_op_kernel( + const std::string func_name, + const Tensor& input, + const Tensor& other, + const Tensor& output, + const std::optional& alpha = std::nullopt); void complex_mul_out( const Tensor& input, const Tensor& other, const Tensor& output); -} +} // namespace at::native::mps diff --git a/aten/src/ATen/native/mps/operations/BinaryKernel.mm b/aten/src/ATen/native/mps/operations/BinaryKernel.mm index cb6c5c073b9a..9f5631b5813a 100644 --- a/aten/src/ATen/native/mps/operations/BinaryKernel.mm +++ b/aten/src/ATen/native/mps/operations/BinaryKernel.mm @@ -50,6 +50,30 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out lib.exec_binary_kernel(iter, "complex_mul"); } +void binary_op_kernel(const std::string func_name, + const Tensor& input, + const Tensor& other, + const Tensor& output, + const std::optional& alpha) { + auto new_size = at::infer_size(input.sizes(), other.sizes()); + if (!output.sizes().equals(new_size)) { + output.resize_(new_size); + } + uint32_t length = output.numel(); + if (length == 0) { + return; + } + + auto iter = + TensorIteratorConfig().add_output(output).add_input(input).add_input(other).check_all_same_dtype(false).build(); + + if (alpha) { + lib.exec_binary_alpha_kernel(iter, func_name, *alpha); + } else { + lib.exec_binary_kernel(iter, func_name); + } +} + } // namespace mps static void fmax_mps_kernel(TensorIteratorBase& iter) { diff --git a/aten/src/ATen/native/mps/operations/BinaryOps.mm b/aten/src/ATen/native/mps/operations/BinaryOps.mm index 864d8c2ca7cf..690988dfc74d 100644 --- a/aten/src/ATen/native/mps/operations/BinaryOps.mm +++ b/aten/src/ATen/native/mps/operations/BinaryOps.mm @@ -265,9 +265,21 @@ static void add_sub_lerp_template(const Tensor& self, } const bool alpha_has_value = alpha.toDouble() != 1.0; - if (alpha_has_value) { - auto commonDtype = at::result_type(self, other); - at::native::alpha_check(commonDtype, alpha); + auto self_complex = c10::isComplexType(self.scalar_type()); + auto other_complex = c10::isComplexType(other.scalar_type()); + auto commonDtype = at::result_type(self, other); + if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) { + if (alpha_has_value) { + at::native::alpha_check(commonDtype, alpha); + mps::binary_op_kernel((self_complex || other_complex) ? "complex_" + op_name : op_name, + self, + other, + output, + getMPSScalar(alpha, commonDtype)); + } else { + mps::binary_op_kernel(op_name, self, other, output); + } + return; } if (!alpha_has_value && op_name == "lerp") { diff --git a/c10/metal/indexing.h b/c10/metal/indexing.h index 2fe7b96bf430..367a34596f7c 100644 --- a/c10/metal/indexing.h +++ b/c10/metal/indexing.h @@ -165,6 +165,29 @@ kernel void binary_strided( ref_at_offs>(output, output_offs) = f(a, b); } +template +kernel void alpha_binary_strided( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant T* alpha [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other_strides [[buffer(7)]], + constant uint3& ndim [[buffer(8)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim.x); + const auto input_offs = offset_from_coord(pos, input_strides, ndim.x); + const auto other_offs = offset_from_coord(pos, other_strides, ndim.x); + const auto output_offs = offset_from_coord(pos, output_strides, ndim.x); + const auto a = val_at_offs(input, input_offs); + const auto b = val_at_offs(other, other_offs); + ref_at_offs>(output, output_offs) = f(a, b, *alpha); +} + template kernel void binary_strided_cast( device void* output [[buffer(0)]], @@ -189,6 +212,31 @@ kernel void binary_strided_cast( ref_at_offs>(output, output_offs) = f(a, b); } +template +kernel void alpha_binary_strided_cast( + device void* output [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant T* alpha [[buffer(3)]], + constant long* sizes [[buffer(4)]], + constant long* output_strides [[buffer(5)]], + constant long* input_strides [[buffer(6)]], + constant long* other_strides [[buffer(7)]], + constant uint3& ndim_types [[buffer(8)]], + uint index [[thread_position_in_grid]]) { + F f; + int pos[max_ndim]; + pos_from_thread_index(int(index), pos, sizes, ndim_types.x); + const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x); + const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x); + const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x); + const auto a = + val_at_offs(input, input_offs, static_cast(ndim_types.y)); + const auto b = + val_at_offs(other, other_offs, static_cast(ndim_types.z)); + ref_at_offs>(output, output_offs) = f(a, b, *alpha); +} + template kernel void binary_dense( device result_of* out [[buffer(0)]], @@ -199,6 +247,17 @@ kernel void binary_dense( out[tid] = f(input[tid], other[tid]); } +template +kernel void alpha_binary_dense( + device result_of* out [[buffer(0)]], + constant T* input [[buffer(1)]], + constant T* other [[buffer(2)]], + constant T* alpha [[buffer(3)]], + uint tid [[thread_position_in_grid]]) { + F f; + out[tid] = f(input[tid], other[tid], *alpha); +} + template kernel void binary_dense_cast( device result_of* out [[buffer(0)]], @@ -214,6 +273,22 @@ kernel void binary_dense_cast( out[tid] = f(a, b); } +template +kernel void alpha_binary_dense_cast( + device result_of* out [[buffer(0)]], + constant void* input [[buffer(1)]], + constant void* other [[buffer(2)]], + constant T* alpha [[buffer(3)]], + constant uint4& sizes_types [[buffer(4)]], + uint tid [[thread_position_in_grid]]) { + F f; + const auto a = val_at_offs( + input, tid * sizes_types.x, static_cast(sizes_types.z)); + const auto b = val_at_offs( + other, tid * sizes_types.y, static_cast(sizes_types.w)); + out[tid] = f(a, b, *alpha); +} + #define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \ static_assert( \ ::metal::is_same_v< \ @@ -257,5 +332,55 @@ kernel void binary_dense_cast( constant void* other, \ constant uint4& sizes_types, \ uint tid) + +#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEO) \ + static_assert( \ + ::metal::is_same_v< \ + DTYPEO, \ + ::c10::metal::result_of>, \ + "Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \ + template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::alpha_binary_strided( \ + device void* out, \ + constant void* input, \ + constant void* other, \ + constant DTYPEI* alpha, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other_strides, \ + constant uint3& ndim, \ + uint tid); \ + template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::alpha_binary_strided_cast( \ + device void* out, \ + constant void* input, \ + constant void* other, \ + constant DTYPEI* alpha, \ + constant long* sizes, \ + constant long* output_strides, \ + constant long* input_strides, \ + constant long* other_strides, \ + constant uint3& ndim_types, \ + uint tid); \ + template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \ + c10::metal::alpha_binary_dense( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant DTYPEI * input_, \ + constant DTYPEI * other_, \ + constant DTYPEI * alpha, \ + uint tid); \ + template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \ + metal::alpha_binary_dense_cast( \ + device ::c10::metal:: \ + result_of * \ + out_, \ + constant void* input, \ + constant void* other, \ + constant DTYPEI* alpha, \ + constant uint4& sizes_types, \ + uint tid) } // namespace metal } // namespace c10