[BE] Define REGISTER_UNARY_TI_DISPATCH (#155081)

That creates _kernel_mps function that takes iterator and calls stub for
it
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155081
Approved by: https://github.com/dcci
ghstack dependencies: #154936, #155002
This commit is contained in:
Nikita Shulga
2025-06-03 17:49:00 -07:00
committed by PyTorch MergeBot
parent 50de6ae253
commit d8e4c1c363

View File

@ -13,89 +13,31 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
#include <ATen/native/mps/UnaryKernel_metallib.h>
#endif
static void erfinv_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "erfinv");
}
static void exp_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "exp");
}
static void sinc_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "sinc");
}
static void tanh_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "tanh");
}
static void sin_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "sin");
}
static void cos_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "cos");
}
static void tan_kernel(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "tan");
}
#define REGISTER_UNARY_TI_DISPATCH(NAME) \
static void NAME##_kernel_mps(TensorIteratorBase& iter) { \
lib.exec_unary_kernel(iter, #NAME); \
} \
REGISTER_DISPATCH(NAME##_stub, NAME##_kernel_mps)
static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
lib.exec_unary_kernel(iter, "round_decimals", decimals);
}
static void exp2_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "exp2");
}
static void sqrt_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "sqrt");
}
static void rsqrt_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "rsqrt");
}
static void neg_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "neg");
}
static void bitwise_not_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "bitwise_not");
}
static void log10_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "log10");
}
static void log2_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "log2");
}
static void log_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "log");
}
static void log1p_kernel_mps(TensorIteratorBase& iter) {
lib.exec_unary_kernel(iter, "log1p");
}
REGISTER_DISPATCH(exp_stub, exp_kernel);
REGISTER_DISPATCH(erfinv_stub, erfinv_kernel);
REGISTER_DISPATCH(sinc_stub, sinc_kernel);
REGISTER_DISPATCH(tanh_stub, tanh_kernel);
REGISTER_DISPATCH(sin_stub, sin_kernel);
REGISTER_DISPATCH(cos_stub, cos_kernel);
REGISTER_DISPATCH(tan_stub, tan_kernel);
REGISTER_UNARY_TI_DISPATCH(exp);
REGISTER_UNARY_TI_DISPATCH(erfinv);
REGISTER_UNARY_TI_DISPATCH(sinc);
REGISTER_UNARY_TI_DISPATCH(tanh);
REGISTER_UNARY_TI_DISPATCH(sin);
REGISTER_UNARY_TI_DISPATCH(cos);
REGISTER_UNARY_TI_DISPATCH(tan);
REGISTER_UNARY_TI_DISPATCH(sqrt);
REGISTER_UNARY_TI_DISPATCH(rsqrt);
REGISTER_UNARY_TI_DISPATCH(neg);
REGISTER_UNARY_TI_DISPATCH(exp2);
REGISTER_UNARY_TI_DISPATCH(log10);
REGISTER_UNARY_TI_DISPATCH(log2);
REGISTER_UNARY_TI_DISPATCH(log);
REGISTER_UNARY_TI_DISPATCH(log1p);
REGISTER_UNARY_TI_DISPATCH(bitwise_not);
REGISTER_DISPATCH(round_decimals_stub, round_decimals_kernel);
REGISTER_DISPATCH(sqrt_stub, sqrt_kernel_mps);
REGISTER_DISPATCH(rsqrt_stub, rsqrt_kernel_mps);
REGISTER_DISPATCH(exp2_stub, exp2_kernel_mps);
REGISTER_DISPATCH(neg_stub, neg_kernel_mps);
REGISTER_DISPATCH(bitwise_not_stub, bitwise_not_kernel_mps);
REGISTER_DISPATCH(log10_stub, log10_kernel_mps);
REGISTER_DISPATCH(log2_stub, log2_kernel_mps);
REGISTER_DISPATCH(log_stub, log_kernel_mps);
REGISTER_DISPATCH(log1p_stub, log1p_kernel_mps);
} // namespace at::native