diff --git a/functorch/functorch/csrc/BatchRulesUnaryOps.cpp b/functorch/functorch/csrc/BatchRulesUnaryOps.cpp index 0457e7b7e9b0..2357581ce496 100644 --- a/functorch/functorch/csrc/BatchRulesUnaryOps.cpp +++ b/functorch/functorch/csrc/BatchRulesUnaryOps.cpp @@ -1,45 +1,98 @@ #include +#include namespace at { namespace functorch { +#define INVOKE(object,ptrToMember) ((object).*(ptrToMember)) + +template +static Tensor& unary_inplace_plumbing(Tensor& self) { + c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey); + auto maybe_layer = maybeCurrentDynamicLayer(); + TORCH_INTERNAL_ASSERT(maybe_layer.has_value()); + int64_t cur_level = maybe_layer->layerId(); + Tensor self_value; + optional self_bdim; + std::tie(self_value, self_bdim) = unwrapTensorAtLevel(self, cur_level); + BatchRule(self_value, self_bdim); + return self; +} + +template +static Tensor& unary_inplace_batch_rule(Tensor& self, optional) { + INVOKE(self, Method)(); + return self; +} + +template +static Tensor& unary_inplace_func_batch_rule(Tensor& self, optional) { + Func(self); + return self; +} + +Tensor& test(Tensor& self, optional) { + unary_inplace_batch_rule(self, {}); + return self; +} + +Tensor& test2(Tensor& self) { + return unary_inplace_plumbing(self); +} + + TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) { #define SINGLE_ARG(...) __VA_ARGS__ + + using UnaryInplaceBRType = Tensor& (*)(Tensor&, optional); +#define UNARY_POINTWISE_(op) \ + m.impl(#op, unary_inplace_plumbing>); +#define UNARY_POINTWISE_FUNC_(op) \ + m.impl(#op, unary_inplace_plumbing>); + #define UNARY_POINTWISE(op) \ VMAP_SUPPORT(#op, SINGLE_ARG(basic_unary_batch_rule)); - UNARY_POINTWISE(abs); - UNARY_POINTWISE(acos); - UNARY_POINTWISE(asin); - UNARY_POINTWISE(atan); - UNARY_POINTWISE(ceil); - UNARY_POINTWISE(cos); - UNARY_POINTWISE(cosh); +#define UNARY_POINTWISE_ALL(op) \ + UNARY_POINTWISE_(op ## _); \ + VMAP_SUPPORT(#op, SINGLE_ARG(basic_unary_batch_rule)); + + UNARY_POINTWISE_ALL(abs); + UNARY_POINTWISE_ALL(acos); + UNARY_POINTWISE_ALL(asin); + UNARY_POINTWISE_ALL(atan); + UNARY_POINTWISE_ALL(ceil); + UNARY_POINTWISE_ALL(cos); + UNARY_POINTWISE_ALL(cosh); UNARY_POINTWISE(_conj); - UNARY_POINTWISE(digamma); - UNARY_POINTWISE(exp); - UNARY_POINTWISE(expm1); - UNARY_POINTWISE(floor); - UNARY_POINTWISE(frac); - UNARY_POINTWISE(lgamma); - UNARY_POINTWISE(log); - UNARY_POINTWISE(log10); - UNARY_POINTWISE(log1p); - UNARY_POINTWISE(log2); - UNARY_POINTWISE(neg); - UNARY_POINTWISE(reciprocal); - UNARY_POINTWISE(relu); - UNARY_POINTWISE(round); - UNARY_POINTWISE(rsqrt); - UNARY_POINTWISE(sigmoid); - UNARY_POINTWISE(sign); - UNARY_POINTWISE(sin); - UNARY_POINTWISE(sinh); - UNARY_POINTWISE(sqrt); - UNARY_POINTWISE(tan); - UNARY_POINTWISE(tanh); - UNARY_POINTWISE(trunc); + UNARY_POINTWISE_ALL(digamma); + UNARY_POINTWISE_ALL(exp); + UNARY_POINTWISE_ALL(expm1); + UNARY_POINTWISE_ALL(floor); + UNARY_POINTWISE_ALL(frac); + UNARY_POINTWISE_ALL(lgamma); + UNARY_POINTWISE_ALL(log); + UNARY_POINTWISE_ALL(log10); + UNARY_POINTWISE_ALL(log1p); + UNARY_POINTWISE_ALL(log2); + UNARY_POINTWISE_ALL(neg); + UNARY_POINTWISE_ALL(reciprocal); + UNARY_POINTWISE_ALL(relu); + UNARY_POINTWISE_ALL(round); + UNARY_POINTWISE_ALL(rsqrt); + UNARY_POINTWISE_ALL(sigmoid); + UNARY_POINTWISE_ALL(sign); + UNARY_POINTWISE_ALL(sin); + UNARY_POINTWISE_ALL(sinh); + UNARY_POINTWISE_ALL(sqrt); + UNARY_POINTWISE_ALL(tan); + UNARY_POINTWISE_ALL(tanh); + UNARY_POINTWISE_ALL(trunc); #undef UNARY_POINTWISE +#undef UNARY_POINTWISE_ +#undef UNARY_POINTWISE_ALL + } +#undef INVOKE }} diff --git a/functorch/test/test_vmap.py b/functorch/test/test_vmap.py index a70f24c75ce1..a01900591114 100644 --- a/functorch/test/test_vmap.py +++ b/functorch/test/test_vmap.py @@ -1093,9 +1093,14 @@ class TestVmapOperators(Namespace.TestVmapBase): (torch.tanh, TensorFactory.rand), (torch.trunc, TensorFactory.randn), ] + for op, getter in cases: self._test_unary(op, getter, 'cpu') + # test in-place + method = getattr(Tensor, f'{op.__name__ + "_"}') + self._test_unary(method, getter, 'cpu', check_propagates_grad=False) + def test_clone(self): # Some basic tests self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu')