mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] Define REGISTER_UNARY_TI_DISPATCH
(#155081)
That creates _kernel_mps function that takes iterator and calls stub for it Pull Request resolved: https://github.com/pytorch/pytorch/pull/155081 Approved by: https://github.com/dcci ghstack dependencies: #154936, #155002
This commit is contained in:
committed by
PyTorch MergeBot
parent
50de6ae253
commit
d8e4c1c363
@ -13,89 +13,31 @@ static auto& lib = mps::MetalShaderLibrary::getBundledLibrary();
|
||||
#include <ATen/native/mps/UnaryKernel_metallib.h>
|
||||
#endif
|
||||
|
||||
static void erfinv_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "erfinv");
|
||||
}
|
||||
|
||||
static void exp_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "exp");
|
||||
}
|
||||
|
||||
static void sinc_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "sinc");
|
||||
}
|
||||
|
||||
static void tanh_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "tanh");
|
||||
}
|
||||
|
||||
static void sin_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "sin");
|
||||
}
|
||||
|
||||
static void cos_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "cos");
|
||||
}
|
||||
|
||||
static void tan_kernel(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "tan");
|
||||
}
|
||||
#define REGISTER_UNARY_TI_DISPATCH(NAME) \
|
||||
static void NAME##_kernel_mps(TensorIteratorBase& iter) { \
|
||||
lib.exec_unary_kernel(iter, #NAME); \
|
||||
} \
|
||||
REGISTER_DISPATCH(NAME##_stub, NAME##_kernel_mps)
|
||||
|
||||
static void round_decimals_kernel(TensorIteratorBase& iter, int64_t decimals) {
|
||||
lib.exec_unary_kernel(iter, "round_decimals", decimals);
|
||||
}
|
||||
|
||||
static void exp2_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "exp2");
|
||||
}
|
||||
|
||||
static void sqrt_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "sqrt");
|
||||
}
|
||||
|
||||
static void rsqrt_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "rsqrt");
|
||||
}
|
||||
|
||||
static void neg_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "neg");
|
||||
}
|
||||
|
||||
static void bitwise_not_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "bitwise_not");
|
||||
}
|
||||
|
||||
static void log10_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "log10");
|
||||
}
|
||||
|
||||
static void log2_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "log2");
|
||||
}
|
||||
|
||||
static void log_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "log");
|
||||
}
|
||||
|
||||
static void log1p_kernel_mps(TensorIteratorBase& iter) {
|
||||
lib.exec_unary_kernel(iter, "log1p");
|
||||
}
|
||||
|
||||
REGISTER_DISPATCH(exp_stub, exp_kernel);
|
||||
REGISTER_DISPATCH(erfinv_stub, erfinv_kernel);
|
||||
REGISTER_DISPATCH(sinc_stub, sinc_kernel);
|
||||
REGISTER_DISPATCH(tanh_stub, tanh_kernel);
|
||||
REGISTER_DISPATCH(sin_stub, sin_kernel);
|
||||
REGISTER_DISPATCH(cos_stub, cos_kernel);
|
||||
REGISTER_DISPATCH(tan_stub, tan_kernel);
|
||||
REGISTER_UNARY_TI_DISPATCH(exp);
|
||||
REGISTER_UNARY_TI_DISPATCH(erfinv);
|
||||
REGISTER_UNARY_TI_DISPATCH(sinc);
|
||||
REGISTER_UNARY_TI_DISPATCH(tanh);
|
||||
REGISTER_UNARY_TI_DISPATCH(sin);
|
||||
REGISTER_UNARY_TI_DISPATCH(cos);
|
||||
REGISTER_UNARY_TI_DISPATCH(tan);
|
||||
REGISTER_UNARY_TI_DISPATCH(sqrt);
|
||||
REGISTER_UNARY_TI_DISPATCH(rsqrt);
|
||||
REGISTER_UNARY_TI_DISPATCH(neg);
|
||||
REGISTER_UNARY_TI_DISPATCH(exp2);
|
||||
REGISTER_UNARY_TI_DISPATCH(log10);
|
||||
REGISTER_UNARY_TI_DISPATCH(log2);
|
||||
REGISTER_UNARY_TI_DISPATCH(log);
|
||||
REGISTER_UNARY_TI_DISPATCH(log1p);
|
||||
REGISTER_UNARY_TI_DISPATCH(bitwise_not);
|
||||
REGISTER_DISPATCH(round_decimals_stub, round_decimals_kernel);
|
||||
REGISTER_DISPATCH(sqrt_stub, sqrt_kernel_mps);
|
||||
REGISTER_DISPATCH(rsqrt_stub, rsqrt_kernel_mps);
|
||||
REGISTER_DISPATCH(exp2_stub, exp2_kernel_mps);
|
||||
REGISTER_DISPATCH(neg_stub, neg_kernel_mps);
|
||||
REGISTER_DISPATCH(bitwise_not_stub, bitwise_not_kernel_mps);
|
||||
REGISTER_DISPATCH(log10_stub, log10_kernel_mps);
|
||||
REGISTER_DISPATCH(log2_stub, log2_kernel_mps);
|
||||
REGISTER_DISPATCH(log_stub, log_kernel_mps);
|
||||
REGISTER_DISPATCH(log1p_stub, log1p_kernel_mps);
|
||||
} // namespace at::native
|
||||
|
Reference in New Issue
Block a user