mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Added batching rules for some in-place pointwise ops
This commit is contained in:
@ -1,45 +1,98 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
#include <functorch/csrc/OutOfPlacePlumbing.hpp>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
#define INVOKE(object,ptrToMember) ((object).*(ptrToMember))
|
||||
|
||||
template <typename F, F BatchRule>
|
||||
static Tensor& unary_inplace_plumbing(Tensor& self) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
auto maybe_layer = maybeCurrentDynamicLayer();
|
||||
TORCH_INTERNAL_ASSERT(maybe_layer.has_value());
|
||||
int64_t cur_level = maybe_layer->layerId();
|
||||
Tensor self_value;
|
||||
optional<int64_t> self_bdim;
|
||||
std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level);
|
||||
BatchRule(self_value, self_bdim);
|
||||
return self;
|
||||
}
|
||||
|
||||
template <typename F, F Method>
|
||||
static Tensor& unary_inplace_batch_rule(Tensor& self, optional<int64_t>) {
|
||||
INVOKE(self, Method)();
|
||||
return self;
|
||||
}
|
||||
|
||||
template <typename F, F Func>
|
||||
static Tensor& unary_inplace_func_batch_rule(Tensor& self, optional<int64_t>) {
|
||||
Func(self);
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& test(Tensor& self, optional<int64_t>) {
|
||||
unary_inplace_batch_rule<decltype(&Tensor::abs), &Tensor::abs>(self, {});
|
||||
return self;
|
||||
}
|
||||
|
||||
Tensor& test2(Tensor& self) {
|
||||
return unary_inplace_plumbing<decltype(&test), &test>(self);
|
||||
}
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
#define SINGLE_ARG(...) __VA_ARGS__
|
||||
|
||||
using UnaryInplaceBRType = Tensor& (*)(Tensor&, optional<int64_t>);
|
||||
#define UNARY_POINTWISE_(op) \
|
||||
m.impl(#op, unary_inplace_plumbing<UnaryInplaceBRType, &unary_inplace_batch_rule<decltype(&Tensor::op), &Tensor::op>>);
|
||||
#define UNARY_POINTWISE_FUNC_(op) \
|
||||
m.impl(#op, unary_inplace_plumbing<UnaryInplaceBRType, &unary_inplace_func_batch_rule<decltype(&at::op), &at::op>>);
|
||||
|
||||
#define UNARY_POINTWISE(op) \
|
||||
VMAP_SUPPORT(#op, SINGLE_ARG(basic_unary_batch_rule<decltype(&at::op), &at::op>));
|
||||
|
||||
UNARY_POINTWISE(abs);
|
||||
UNARY_POINTWISE(acos);
|
||||
UNARY_POINTWISE(asin);
|
||||
UNARY_POINTWISE(atan);
|
||||
UNARY_POINTWISE(ceil);
|
||||
UNARY_POINTWISE(cos);
|
||||
UNARY_POINTWISE(cosh);
|
||||
#define UNARY_POINTWISE_ALL(op) \
|
||||
UNARY_POINTWISE_(op ## _); \
|
||||
VMAP_SUPPORT(#op, SINGLE_ARG(basic_unary_batch_rule<decltype(&at::op), &at::op>));
|
||||
|
||||
UNARY_POINTWISE_ALL(abs);
|
||||
UNARY_POINTWISE_ALL(acos);
|
||||
UNARY_POINTWISE_ALL(asin);
|
||||
UNARY_POINTWISE_ALL(atan);
|
||||
UNARY_POINTWISE_ALL(ceil);
|
||||
UNARY_POINTWISE_ALL(cos);
|
||||
UNARY_POINTWISE_ALL(cosh);
|
||||
UNARY_POINTWISE(_conj);
|
||||
UNARY_POINTWISE(digamma);
|
||||
UNARY_POINTWISE(exp);
|
||||
UNARY_POINTWISE(expm1);
|
||||
UNARY_POINTWISE(floor);
|
||||
UNARY_POINTWISE(frac);
|
||||
UNARY_POINTWISE(lgamma);
|
||||
UNARY_POINTWISE(log);
|
||||
UNARY_POINTWISE(log10);
|
||||
UNARY_POINTWISE(log1p);
|
||||
UNARY_POINTWISE(log2);
|
||||
UNARY_POINTWISE(neg);
|
||||
UNARY_POINTWISE(reciprocal);
|
||||
UNARY_POINTWISE(relu);
|
||||
UNARY_POINTWISE(round);
|
||||
UNARY_POINTWISE(rsqrt);
|
||||
UNARY_POINTWISE(sigmoid);
|
||||
UNARY_POINTWISE(sign);
|
||||
UNARY_POINTWISE(sin);
|
||||
UNARY_POINTWISE(sinh);
|
||||
UNARY_POINTWISE(sqrt);
|
||||
UNARY_POINTWISE(tan);
|
||||
UNARY_POINTWISE(tanh);
|
||||
UNARY_POINTWISE(trunc);
|
||||
UNARY_POINTWISE_ALL(digamma);
|
||||
UNARY_POINTWISE_ALL(exp);
|
||||
UNARY_POINTWISE_ALL(expm1);
|
||||
UNARY_POINTWISE_ALL(floor);
|
||||
UNARY_POINTWISE_ALL(frac);
|
||||
UNARY_POINTWISE_ALL(lgamma);
|
||||
UNARY_POINTWISE_ALL(log);
|
||||
UNARY_POINTWISE_ALL(log10);
|
||||
UNARY_POINTWISE_ALL(log1p);
|
||||
UNARY_POINTWISE_ALL(log2);
|
||||
UNARY_POINTWISE_ALL(neg);
|
||||
UNARY_POINTWISE_ALL(reciprocal);
|
||||
UNARY_POINTWISE_ALL(relu);
|
||||
UNARY_POINTWISE_ALL(round);
|
||||
UNARY_POINTWISE_ALL(rsqrt);
|
||||
UNARY_POINTWISE_ALL(sigmoid);
|
||||
UNARY_POINTWISE_ALL(sign);
|
||||
UNARY_POINTWISE_ALL(sin);
|
||||
UNARY_POINTWISE_ALL(sinh);
|
||||
UNARY_POINTWISE_ALL(sqrt);
|
||||
UNARY_POINTWISE_ALL(tan);
|
||||
UNARY_POINTWISE_ALL(tanh);
|
||||
UNARY_POINTWISE_ALL(trunc);
|
||||
|
||||
#undef UNARY_POINTWISE
|
||||
#undef UNARY_POINTWISE_
|
||||
#undef UNARY_POINTWISE_ALL
|
||||
|
||||
}
|
||||
|
||||
#undef INVOKE
|
||||
}}
|
||||
|
@ -1093,9 +1093,14 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
(torch.tanh, TensorFactory.rand),
|
||||
(torch.trunc, TensorFactory.randn),
|
||||
]
|
||||
|
||||
for op, getter in cases:
|
||||
self._test_unary(op, getter, 'cpu')
|
||||
|
||||
# test in-place
|
||||
method = getattr(Tensor, f'{op.__name__ + "_"}')
|
||||
self._test_unary(method, getter, 'cpu', check_propagates_grad=False)
|
||||
|
||||
def test_clone(self):
|
||||
# Some basic tests
|
||||
self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu')
|
||||
|
Reference in New Issue
Block a user