[functorch] Refactory UnaryOps

This commit is contained in:
Richard Zou
2021-05-05 13:33:44 -07:00
committed by Jon Janzen
parent cb51b16e59
commit 66918c4a3e
2 changed files with 45 additions and 49 deletions

View File

@ -0,0 +1,45 @@
#include <functorch/csrc/BatchRulesHelper.h>
namespace at { namespace functorch {
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
#define SINGLE_ARG(...) __VA_ARGS__
#define UNARY_POINTWISE(op) \
VMAP_SUPPORT(#op, SINGLE_ARG(basic_unary_batch_rule<decltype(&at::op), &at::op>));
UNARY_POINTWISE(abs);
UNARY_POINTWISE(acos);
UNARY_POINTWISE(asin);
UNARY_POINTWISE(atan);
UNARY_POINTWISE(ceil);
UNARY_POINTWISE(cos);
UNARY_POINTWISE(cosh);
UNARY_POINTWISE(_conj);
UNARY_POINTWISE(digamma);
UNARY_POINTWISE(exp);
UNARY_POINTWISE(expm1);
UNARY_POINTWISE(floor);
UNARY_POINTWISE(frac);
UNARY_POINTWISE(lgamma);
UNARY_POINTWISE(log);
UNARY_POINTWISE(log10);
UNARY_POINTWISE(log1p);
UNARY_POINTWISE(log2);
UNARY_POINTWISE(neg);
UNARY_POINTWISE(reciprocal);
UNARY_POINTWISE(relu);
UNARY_POINTWISE(round);
UNARY_POINTWISE(rsqrt);
UNARY_POINTWISE(sigmoid);
UNARY_POINTWISE(sign);
UNARY_POINTWISE(sin);
UNARY_POINTWISE(sinh);
UNARY_POINTWISE(sqrt);
UNARY_POINTWISE(tan);
UNARY_POINTWISE(tanh);
UNARY_POINTWISE(trunc);
#undef UNARY_POINTWISE
}
}}

View File

@ -1373,20 +1373,6 @@ TORCH_LIBRARY_IMPL(_, FT_BATCHED_KEY, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
// // debug_t<tail_t<tail_t<typelist<Tensor, optional<int64_t>>>>> dt;
// debug_t<remove_batch_dim_after_tensor_t<typelist<Tensor, optional<int64_t>>>> dt;
std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim) {
return {tensor.abs(), batch_dim};
}
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
}
Tensor matmul_decomposed(
const Tensor& tensor1,
const Tensor& tensor2) {
@ -1563,41 +1549,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
// m.impl("clamp_max", clamp_max_batching_rule);
// unary pointwise, out-of-place, no additional arguments.
#define UNARY_POINTWISE_BATCH_RULE(op) unwrap_and_call2<decltype(&op), &op>
#define UNARY_POINTWISE(op) VMAP_SUPPORT(#op, UNARY_POINTWISE_BATCH_RULE(at::op));
UNARY_POINTWISE(abs);
UNARY_POINTWISE(acos);
UNARY_POINTWISE(asin);
UNARY_POINTWISE(atan);
UNARY_POINTWISE(ceil);
UNARY_POINTWISE(cos);
UNARY_POINTWISE(cosh);
UNARY_POINTWISE(_conj);
UNARY_POINTWISE(digamma);
UNARY_POINTWISE(exp);
UNARY_POINTWISE(expm1);
UNARY_POINTWISE(floor);
UNARY_POINTWISE(frac);
UNARY_POINTWISE(lgamma);
UNARY_POINTWISE(log);
UNARY_POINTWISE(log10);
UNARY_POINTWISE(log1p);
UNARY_POINTWISE(log2);
UNARY_POINTWISE(neg);
UNARY_POINTWISE(reciprocal);
UNARY_POINTWISE(relu);
UNARY_POINTWISE(round);
UNARY_POINTWISE(rsqrt);
UNARY_POINTWISE(sigmoid);
UNARY_POINTWISE(sign);
UNARY_POINTWISE(sin);
UNARY_POINTWISE(sinh);
UNARY_POINTWISE(sqrt);
UNARY_POINTWISE(tan);
UNARY_POINTWISE(tanh);
UNARY_POINTWISE(trunc);
#undef UNARY_POINTWISE
#define TO_BATCHING_RULE(name, ...) \
{ \
using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \