mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Refactory UnaryOps
This commit is contained in:
45
functorch/functorch/csrc/BatchRulesUnaryOps.cpp
Normal file
45
functorch/functorch/csrc/BatchRulesUnaryOps.cpp
Normal file
@ -0,0 +1,45 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
#define SINGLE_ARG(...) __VA_ARGS__
|
||||
#define UNARY_POINTWISE(op) \
|
||||
VMAP_SUPPORT(#op, SINGLE_ARG(basic_unary_batch_rule<decltype(&at::op), &at::op>));
|
||||
|
||||
UNARY_POINTWISE(abs);
|
||||
UNARY_POINTWISE(acos);
|
||||
UNARY_POINTWISE(asin);
|
||||
UNARY_POINTWISE(atan);
|
||||
UNARY_POINTWISE(ceil);
|
||||
UNARY_POINTWISE(cos);
|
||||
UNARY_POINTWISE(cosh);
|
||||
UNARY_POINTWISE(_conj);
|
||||
UNARY_POINTWISE(digamma);
|
||||
UNARY_POINTWISE(exp);
|
||||
UNARY_POINTWISE(expm1);
|
||||
UNARY_POINTWISE(floor);
|
||||
UNARY_POINTWISE(frac);
|
||||
UNARY_POINTWISE(lgamma);
|
||||
UNARY_POINTWISE(log);
|
||||
UNARY_POINTWISE(log10);
|
||||
UNARY_POINTWISE(log1p);
|
||||
UNARY_POINTWISE(log2);
|
||||
UNARY_POINTWISE(neg);
|
||||
UNARY_POINTWISE(reciprocal);
|
||||
UNARY_POINTWISE(relu);
|
||||
UNARY_POINTWISE(round);
|
||||
UNARY_POINTWISE(rsqrt);
|
||||
UNARY_POINTWISE(sigmoid);
|
||||
UNARY_POINTWISE(sign);
|
||||
UNARY_POINTWISE(sin);
|
||||
UNARY_POINTWISE(sinh);
|
||||
UNARY_POINTWISE(sqrt);
|
||||
UNARY_POINTWISE(tan);
|
||||
UNARY_POINTWISE(tanh);
|
||||
UNARY_POINTWISE(trunc);
|
||||
|
||||
#undef UNARY_POINTWISE
|
||||
}
|
||||
|
||||
}}
|
@ -1373,20 +1373,6 @@ TORCH_LIBRARY_IMPL(_, FT_BATCHED_KEY, m) {
|
||||
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
|
||||
}
|
||||
|
||||
// // debug_t<tail_t<tail_t<typelist<Tensor, optional<int64_t>>>>> dt;
|
||||
// debug_t<remove_batch_dim_after_tensor_t<typelist<Tensor, optional<int64_t>>>> dt;
|
||||
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim) {
|
||||
return {tensor.abs(), batch_dim};
|
||||
}
|
||||
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(
|
||||
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
|
||||
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
|
||||
}
|
||||
|
||||
Tensor matmul_decomposed(
|
||||
const Tensor& tensor1,
|
||||
const Tensor& tensor2) {
|
||||
@ -1563,41 +1549,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
// m.impl("clamp_max", clamp_max_batching_rule);
|
||||
|
||||
// unary pointwise, out-of-place, no additional arguments.
|
||||
#define UNARY_POINTWISE_BATCH_RULE(op) unwrap_and_call2<decltype(&op), &op>
|
||||
|
||||
#define UNARY_POINTWISE(op) VMAP_SUPPORT(#op, UNARY_POINTWISE_BATCH_RULE(at::op));
|
||||
UNARY_POINTWISE(abs);
|
||||
UNARY_POINTWISE(acos);
|
||||
UNARY_POINTWISE(asin);
|
||||
UNARY_POINTWISE(atan);
|
||||
UNARY_POINTWISE(ceil);
|
||||
UNARY_POINTWISE(cos);
|
||||
UNARY_POINTWISE(cosh);
|
||||
UNARY_POINTWISE(_conj);
|
||||
UNARY_POINTWISE(digamma);
|
||||
UNARY_POINTWISE(exp);
|
||||
UNARY_POINTWISE(expm1);
|
||||
UNARY_POINTWISE(floor);
|
||||
UNARY_POINTWISE(frac);
|
||||
UNARY_POINTWISE(lgamma);
|
||||
UNARY_POINTWISE(log);
|
||||
UNARY_POINTWISE(log10);
|
||||
UNARY_POINTWISE(log1p);
|
||||
UNARY_POINTWISE(log2);
|
||||
UNARY_POINTWISE(neg);
|
||||
UNARY_POINTWISE(reciprocal);
|
||||
UNARY_POINTWISE(relu);
|
||||
UNARY_POINTWISE(round);
|
||||
UNARY_POINTWISE(rsqrt);
|
||||
UNARY_POINTWISE(sigmoid);
|
||||
UNARY_POINTWISE(sign);
|
||||
UNARY_POINTWISE(sin);
|
||||
UNARY_POINTWISE(sinh);
|
||||
UNARY_POINTWISE(sqrt);
|
||||
UNARY_POINTWISE(tan);
|
||||
UNARY_POINTWISE(tanh);
|
||||
UNARY_POINTWISE(trunc);
|
||||
#undef UNARY_POINTWISE
|
||||
#define TO_BATCHING_RULE(name, ...) \
|
||||
{ \
|
||||
using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
|
||||
|
Reference in New Issue
Block a user