mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Implement metal kernel for basic MPS arithmetic ops using TensorIterator (#147644)
Add metal kernels for add, subtract, & lerp ops using TensorIterator. Should help resolve: https://github.com/pytorch/pytorch/issues/143874 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147644 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
2fb62f8288
commit
663bcb68ba
@ -29,6 +29,9 @@ struct TensorIteratorBase;
|
||||
|
||||
namespace at::native::mps {
|
||||
|
||||
// Forward declaration of MPSScalar - for exec_binary_alpha_kernel()
|
||||
struct MPSScalar;
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
class has_size_type {
|
||||
@ -138,6 +141,10 @@ class MetalShaderLibrary {
|
||||
const std::string& name,
|
||||
std::optional<int64_t> extra = std::nullopt);
|
||||
void exec_binary_kernel(TensorIteratorBase& iter, const std::string& name);
|
||||
void exec_binary_alpha_kernel(
|
||||
TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
const MPSScalar& alpha);
|
||||
|
||||
protected:
|
||||
virtual MTLLibrary_t getLibrary();
|
||||
|
@ -1078,6 +1078,68 @@ void MetalShaderLibrary::exec_binary_kernel(TensorIteratorBase& iter, const std:
|
||||
});
|
||||
}
|
||||
|
||||
void MetalShaderLibrary::exec_binary_alpha_kernel(TensorIteratorBase& iter,
|
||||
const std::string& name,
|
||||
const MPSScalar& alpha) {
|
||||
TORCH_CHECK(iter.common_dtype() != at::kDouble, "float64 is not supported on MPS");
|
||||
TORCH_CHECK(iter.can_use_32bit_indexing(), "Can't be indexed using 32-bit iterator");
|
||||
|
||||
Tensor input = iter.input(0);
|
||||
Tensor other = iter.input(1);
|
||||
Tensor out = iter.output();
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
MPSStream* mpsStream = getCurrentMPSStream();
|
||||
const uint32_t nDim = iter.ndim();
|
||||
constexpr uint32_t nOffsets = 3;
|
||||
const uint32_t numThreads = iter.numel();
|
||||
const auto cast_needed = input.scalar_type() != other.scalar_type();
|
||||
const auto suffix = iter.is_contiguous() ? "dense" : "strided";
|
||||
// TODO: Implicitly pass both input and output types to non-cast kernels
|
||||
const auto kernel_name = cast_needed
|
||||
? fmt::format("{}_alpha_{}_cast_{}", name, suffix, scalarToMetalTypeString(out))
|
||||
: fmt::format("{}_alpha_{}_{}_{}", name, suffix, scalarToMetalTypeString(out), scalarToMetalTypeString(input));
|
||||
dispatch_sync_with_rethrow(mpsStream->queue(), ^() {
|
||||
@autoreleasepool {
|
||||
auto computeEncoder = mpsStream->commandEncoder();
|
||||
auto binaryPSO = getPipelineStateForFunc(kernel_name);
|
||||
// this function call is a no-op if MPS Profiler is not enabled
|
||||
getMPSProfiler().beginProfileKernel(binaryPSO, kernel_name, {input, other});
|
||||
[computeEncoder setComputePipelineState:binaryPSO];
|
||||
// Iterator is contiguous if all of its elements are dense in storage,
|
||||
// i.e. it's true for both row-first and column-first tensors
|
||||
if (iter.is_contiguous()) {
|
||||
mtl_setArgs(computeEncoder, out, input, other, alpha);
|
||||
if (cast_needed) {
|
||||
std::array<int, 4> size_and_types = {static_cast<int>(c10::elementSize(input.scalar_type())),
|
||||
static_cast<int>(c10::elementSize(other.scalar_type())),
|
||||
static_cast<int>(input.scalar_type()),
|
||||
static_cast<int>(other.scalar_type())};
|
||||
mtl_setBytes(computeEncoder, size_and_types, 4);
|
||||
}
|
||||
} else {
|
||||
// Please note that shapes and strides of the iterator might be
|
||||
// different than that of its operands, for example binary op
|
||||
// between 4x4 tensor and scalar will result in 1D 16 element iterator
|
||||
std::array<int, 3> ndim_and_types = {
|
||||
iter.ndim(), static_cast<int>(input.scalar_type()), static_cast<int>(other.scalar_type())};
|
||||
mtl_setArgs(computeEncoder,
|
||||
out,
|
||||
input,
|
||||
other,
|
||||
alpha,
|
||||
iter.shape(),
|
||||
iter.strides(0),
|
||||
iter.strides(1),
|
||||
iter.strides(2),
|
||||
ndim_and_types);
|
||||
}
|
||||
mtl_dispatch1DJob(computeEncoder, binaryPSO, numThreads);
|
||||
getMPSProfiler().endProfileKernel(binaryPSO);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
MetalShaderLibrary& MetalShaderLibrary::getBundledLibrary() {
|
||||
static BundledShaderLibary l;
|
||||
return l;
|
||||
|
@ -4,6 +4,48 @@
|
||||
#include <metal_stdlib>
|
||||
using namespace metal;
|
||||
|
||||
struct add_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(a + b);
|
||||
}
|
||||
};
|
||||
|
||||
struct sub_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(a - b);
|
||||
}
|
||||
};
|
||||
|
||||
struct lerp_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
};
|
||||
|
||||
struct add_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return static_cast<T>(a + (alpha * b));
|
||||
}
|
||||
};
|
||||
|
||||
struct sub_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return static_cast<T>(a - (alpha * b));
|
||||
}
|
||||
};
|
||||
|
||||
struct lerp_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return static_cast<T>(a + (alpha * (b - a)));
|
||||
}
|
||||
};
|
||||
|
||||
struct fmax_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
@ -152,6 +194,48 @@ struct complex_mul_functor {
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_add_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return T(
|
||||
a.x + (alpha.x * b.x - alpha.y * b.y),
|
||||
a.y + (alpha.x * b.y + alpha.y * b.x));
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_add_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(a.x + b.x, a.y + b.y);
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_sub_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
return T(
|
||||
a.x - (alpha.x * b.x - alpha.y * b.y),
|
||||
a.y - (alpha.x * b.y + alpha.y * b.x));
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_lerp_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
auto intr = T(b.x - a.x, b.y - a.y);
|
||||
return T(
|
||||
a.x + (alpha.x * intr.x - intr.y * intr.y),
|
||||
a.y + (alpha.x * intr.y + alpha.y * intr.x));
|
||||
}
|
||||
};
|
||||
|
||||
struct complex_lerp_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(b.x, b.y);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_BINARY_OP(copysign, long, float);
|
||||
REGISTER_BINARY_OP(copysign, int, float);
|
||||
REGISTER_BINARY_OP(copysign, float, float);
|
||||
@ -182,6 +266,54 @@ REGISTER_BINARY_OP(hermite_polynomial_h, float, float);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_h, half, half);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_he, float, float);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_he, half, half);
|
||||
REGISTER_BINARY_OP(add, long, long);
|
||||
REGISTER_BINARY_OP(add, int, int);
|
||||
REGISTER_BINARY_OP(add, float, float);
|
||||
REGISTER_BINARY_OP(add, half, half);
|
||||
REGISTER_BINARY_OP(add, short, short);
|
||||
REGISTER_BINARY_OP(add, uchar, uchar);
|
||||
REGISTER_BINARY_OP(add, char, char);
|
||||
REGISTER_BINARY_OP(add, bool, bool);
|
||||
REGISTER_BINARY_OP(sub, long, long);
|
||||
REGISTER_BINARY_OP(sub, int, int);
|
||||
REGISTER_BINARY_OP(sub, float, float);
|
||||
REGISTER_BINARY_OP(sub, half, half);
|
||||
REGISTER_BINARY_OP(sub, short, short);
|
||||
REGISTER_BINARY_OP(sub, uchar, uchar);
|
||||
REGISTER_BINARY_OP(sub, char, char);
|
||||
REGISTER_BINARY_OP(sub, bool, bool);
|
||||
REGISTER_BINARY_OP(lerp, long, long);
|
||||
REGISTER_BINARY_OP(lerp, int, int);
|
||||
REGISTER_BINARY_OP(lerp, float, float);
|
||||
REGISTER_BINARY_OP(lerp, half, half);
|
||||
REGISTER_BINARY_OP(lerp, short, short);
|
||||
REGISTER_BINARY_OP(lerp, uchar, uchar);
|
||||
REGISTER_BINARY_OP(lerp, char, char);
|
||||
REGISTER_BINARY_OP(lerp, bool, bool);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, half, half);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, short, short);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bool, bool);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, long, long);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, int, int);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, half, half);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, short, short);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bool, bool);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, long, long);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, int, int);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, float, float);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, half, half);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, short, short);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, uchar, uchar);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, char, char);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bool, bool);
|
||||
|
||||
#if __METAL_VERSION__ >= 310
|
||||
REGISTER_BINARY_OP(copysign, bfloat, bfloat);
|
||||
@ -196,6 +328,12 @@ REGISTER_BINARY_OP(chebyshev_polynomial_v, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(chebyshev_polynomial_w, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(add, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(sub, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(lerp, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
|
||||
#endif
|
||||
|
||||
// Complex binary functions
|
||||
@ -205,3 +343,15 @@ REGISTER_BINARY_OP(make_complex, float, float2);
|
||||
REGISTER_BINARY_OP(make_complex, half, half2);
|
||||
REGISTER_BINARY_OP(complex_mul, float2, float2);
|
||||
REGISTER_BINARY_OP(complex_mul, half2, half2);
|
||||
REGISTER_BINARY_OP(add, float2, float2);
|
||||
REGISTER_BINARY_OP(add, half2, half2);
|
||||
REGISTER_BINARY_OP(sub, float2, float2);
|
||||
REGISTER_BINARY_OP(sub, half2, half2);
|
||||
REGISTER_BINARY_OP(lerp, float2, float2);
|
||||
REGISTER_BINARY_OP(lerp, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_lerp_alpha, half2, half2);
|
||||
|
@ -1,8 +1,14 @@
|
||||
#pragma once
|
||||
|
||||
namespace at::native::mps {
|
||||
void binary_op_kernel(
|
||||
const std::string func_name,
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output,
|
||||
const std::optional<MPSScalar>& alpha = std::nullopt);
|
||||
void complex_mul_out(
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output);
|
||||
}
|
||||
} // namespace at::native::mps
|
||||
|
@ -50,6 +50,30 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
|
||||
lib.exec_binary_kernel(iter, "complex_mul");
|
||||
}
|
||||
|
||||
void binary_op_kernel(const std::string func_name,
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output,
|
||||
const std::optional<MPSScalar>& alpha) {
|
||||
auto new_size = at::infer_size(input.sizes(), other.sizes());
|
||||
if (!output.sizes().equals(new_size)) {
|
||||
output.resize_(new_size);
|
||||
}
|
||||
uint32_t length = output.numel();
|
||||
if (length == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto iter =
|
||||
TensorIteratorConfig().add_output(output).add_input(input).add_input(other).check_all_same_dtype(false).build();
|
||||
|
||||
if (alpha) {
|
||||
lib.exec_binary_alpha_kernel(iter, func_name, *alpha);
|
||||
} else {
|
||||
lib.exec_binary_kernel(iter, func_name);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
||||
static void fmax_mps_kernel(TensorIteratorBase& iter) {
|
||||
|
@ -265,9 +265,21 @@ static void add_sub_lerp_template(const Tensor& self,
|
||||
}
|
||||
|
||||
const bool alpha_has_value = alpha.toDouble() != 1.0;
|
||||
if (alpha_has_value) {
|
||||
auto commonDtype = at::result_type(self, other);
|
||||
at::native::alpha_check(commonDtype, alpha);
|
||||
auto self_complex = c10::isComplexType(self.scalar_type());
|
||||
auto other_complex = c10::isComplexType(other.scalar_type());
|
||||
auto commonDtype = at::result_type(self, other);
|
||||
if (self.is_mps() && other.is_mps() && (output.scalar_type() == commonDtype) && (self_complex == other_complex)) {
|
||||
if (alpha_has_value) {
|
||||
at::native::alpha_check(commonDtype, alpha);
|
||||
mps::binary_op_kernel((self_complex || other_complex) ? "complex_" + op_name : op_name,
|
||||
self,
|
||||
other,
|
||||
output,
|
||||
getMPSScalar(alpha, commonDtype));
|
||||
} else {
|
||||
mps::binary_op_kernel(op_name, self, other, output);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (!alpha_has_value && op_name == "lerp") {
|
||||
|
@ -165,6 +165,29 @@ kernel void binary_strided(
|
||||
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_strided(
|
||||
device void* output [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other [[buffer(2)]],
|
||||
constant T* alpha [[buffer(3)]],
|
||||
constant long* sizes [[buffer(4)]],
|
||||
constant long* output_strides [[buffer(5)]],
|
||||
constant long* input_strides [[buffer(6)]],
|
||||
constant long* other_strides [[buffer(7)]],
|
||||
constant uint3& ndim [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
int pos[max_ndim];
|
||||
pos_from_thread_index(int(index), pos, sizes, ndim.x);
|
||||
const auto input_offs = offset_from_coord(pos, input_strides, ndim.x);
|
||||
const auto other_offs = offset_from_coord(pos, other_strides, ndim.x);
|
||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim.x);
|
||||
const auto a = val_at_offs<T>(input, input_offs);
|
||||
const auto b = val_at_offs<T>(other, other_offs);
|
||||
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, *alpha);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_strided_cast(
|
||||
device void* output [[buffer(0)]],
|
||||
@ -189,6 +212,31 @@ kernel void binary_strided_cast(
|
||||
ref_at_offs<result_of<F, T, T>>(output, output_offs) = f(a, b);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_strided_cast(
|
||||
device void* output [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other [[buffer(2)]],
|
||||
constant T* alpha [[buffer(3)]],
|
||||
constant long* sizes [[buffer(4)]],
|
||||
constant long* output_strides [[buffer(5)]],
|
||||
constant long* input_strides [[buffer(6)]],
|
||||
constant long* other_strides [[buffer(7)]],
|
||||
constant uint3& ndim_types [[buffer(8)]],
|
||||
uint index [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
int pos[max_ndim];
|
||||
pos_from_thread_index(int(index), pos, sizes, ndim_types.x);
|
||||
const auto input_offs = offset_from_coord(pos, input_strides, ndim_types.x);
|
||||
const auto other_offs = offset_from_coord(pos, other_strides, ndim_types.x);
|
||||
const auto output_offs = offset_from_coord(pos, output_strides, ndim_types.x);
|
||||
const auto a =
|
||||
val_at_offs<T>(input, input_offs, static_cast<ScalarType>(ndim_types.y));
|
||||
const auto b =
|
||||
val_at_offs<T>(other, other_offs, static_cast<ScalarType>(ndim_types.z));
|
||||
ref_at_offs<result_of<F, T, T, T>>(output, output_offs) = f(a, b, *alpha);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_dense(
|
||||
device result_of<F, T, T>* out [[buffer(0)]],
|
||||
@ -199,6 +247,17 @@ kernel void binary_dense(
|
||||
out[tid] = f(input[tid], other[tid]);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_dense(
|
||||
device result_of<F, T, T, T>* out [[buffer(0)]],
|
||||
constant T* input [[buffer(1)]],
|
||||
constant T* other [[buffer(2)]],
|
||||
constant T* alpha [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
out[tid] = f(input[tid], other[tid], *alpha);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void binary_dense_cast(
|
||||
device result_of<F, T, T>* out [[buffer(0)]],
|
||||
@ -214,6 +273,22 @@ kernel void binary_dense_cast(
|
||||
out[tid] = f(a, b);
|
||||
}
|
||||
|
||||
template <typename T, typename F>
|
||||
kernel void alpha_binary_dense_cast(
|
||||
device result_of<F, T, T, T>* out [[buffer(0)]],
|
||||
constant void* input [[buffer(1)]],
|
||||
constant void* other [[buffer(2)]],
|
||||
constant T* alpha [[buffer(3)]],
|
||||
constant uint4& sizes_types [[buffer(4)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
F f;
|
||||
const auto a = val_at_offs<T>(
|
||||
input, tid * sizes_types.x, static_cast<ScalarType>(sizes_types.z));
|
||||
const auto b = val_at_offs<T>(
|
||||
other, tid * sizes_types.y, static_cast<ScalarType>(sizes_types.w));
|
||||
out[tid] = f(a, b, *alpha);
|
||||
}
|
||||
|
||||
#define REGISTER_BINARY_OP(NAME, DTYPEI, DTYPEO) \
|
||||
static_assert( \
|
||||
::metal::is_same_v< \
|
||||
@ -257,5 +332,55 @@ kernel void binary_dense_cast(
|
||||
constant void* other, \
|
||||
constant uint4& sizes_types, \
|
||||
uint tid)
|
||||
|
||||
#define REGISTER_BINARY_ALPHA_OP(NAME, DTYPEI, DTYPEO) \
|
||||
static_assert( \
|
||||
::metal::is_same_v< \
|
||||
DTYPEO, \
|
||||
::c10::metal::result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI>>, \
|
||||
"Output dtype mismatch for binary op " #NAME " and input " #DTYPEI); \
|
||||
template [[host_name(#NAME "_strided_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
|
||||
c10::metal::alpha_binary_strided<DTYPEI, NAME##_functor>( \
|
||||
device void* out, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEI* alpha, \
|
||||
constant long* sizes, \
|
||||
constant long* output_strides, \
|
||||
constant long* input_strides, \
|
||||
constant long* other_strides, \
|
||||
constant uint3& ndim, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_strided_cast_" #DTYPEI)]] kernel void ::c10:: \
|
||||
metal::alpha_binary_strided_cast<DTYPEI, NAME##_functor>( \
|
||||
device void* out, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEI* alpha, \
|
||||
constant long* sizes, \
|
||||
constant long* output_strides, \
|
||||
constant long* input_strides, \
|
||||
constant long* other_strides, \
|
||||
constant uint3& ndim_types, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_" #DTYPEO "_" #DTYPEI)]] kernel void :: \
|
||||
c10::metal::alpha_binary_dense<DTYPEI, NAME##_functor>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
|
||||
out_, \
|
||||
constant DTYPEI * input_, \
|
||||
constant DTYPEI * other_, \
|
||||
constant DTYPEI * alpha, \
|
||||
uint tid); \
|
||||
template [[host_name(#NAME "_dense_cast_" #DTYPEI)]] kernel void ::c10:: \
|
||||
metal::alpha_binary_dense_cast<DTYPEI, NAME##_functor>( \
|
||||
device ::c10::metal:: \
|
||||
result_of<NAME##_functor, DTYPEI, DTYPEI, DTYPEI> * \
|
||||
out_, \
|
||||
constant void* input, \
|
||||
constant void* other, \
|
||||
constant DTYPEI* alpha, \
|
||||
constant uint4& sizes_types, \
|
||||
uint tid)
|
||||
} // namespace metal
|
||||
} // namespace c10
|
||||
|
Reference in New Issue
Block a user