[functorch] Added maximum/minimum/clamp batching rules

This commit is contained in:
Horace He
2021-06-24 23:50:27 -07:00
committed by Jon Janzen
parent 9b00f55a46
commit ec25616e0c
2 changed files with 31 additions and 22 deletions

View File

@ -123,28 +123,11 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>));
BINARY_POINTWISE_WITH_SCALAR(add);
BINARY_POINTWISE_WITH_SCALAR(sub);
BINARY_POINTWISE_WITH_SCALAR(rsub);
BINARY_POINTWISE(mul);
VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(add, Scalar)), &at::add, const Scalar&, const Scalar&>));
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(sub, Scalar)), &at::sub, const Scalar&, const Scalar&>));
VMAP_SUPPORT("rsub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(rsub, Scalar)), &at::rsub, const Scalar&, const Scalar&>));
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(mul, Scalar)), &at::mul, const Scalar&>));
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(div, Scalar)), &at::div, const Scalar&>));
BINARY_POINTWISE(div);
VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
VMAP_SUPPORT("atan2", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(atan2)), &at::atan2>));
// at::pow has three out-of-place overloads
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Scalar)), &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
VMAP_SUPPORT("clamp",
SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN(clamp)), &at::clamp, const optional<Scalar>&, const optional<Scalar>&>));
VMAP_SUPPORT("clamp_min.Tensor",
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_min, Tensor)), &at::clamp_min>));
VMAP_SUPPORT("clamp_min",
@ -155,6 +138,32 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN(clamp_max)), &at::clamp_max, const Scalar&>));
BINARY_POINTWISE(div);
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(div, Scalar)), &at::div, const Scalar&>));
VMAP_SUPPORT("maximum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(maximum)), &at::maximum>));
VMAP_SUPPORT("minimum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(minimum)), &at::minimum>));
BINARY_POINTWISE(mul);
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(mul, Scalar)), &at::mul, const Scalar&>));
// at::pow has three out-of-place overloads
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Scalar)), &at::pow, const Scalar&>));
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
BINARY_POINTWISE_WITH_SCALAR(sub);
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(sub, Scalar)), &at::sub, const Scalar&, const Scalar&>));
BINARY_POINTWISE_WITH_SCALAR(rsub);
VMAP_SUPPORT("rsub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(rsub, Scalar)), &at::rsub, const Scalar&, const Scalar&>));
VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;

View File

@ -540,7 +540,7 @@ class TestVmapAPI(TestCase):
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
op = torch.maximum
op = torch.copysign
x = torch.randn(11)
y = torch.randn(11)
with warnings.catch_warnings(record=True) as wa:
@ -575,7 +575,7 @@ class TestVmapAPI(TestCase):
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
op = torch.maximum
op = torch.copysign
x = torch.randn(11)
y = torch.randn(11)
self._assert_uses_vmap_fallback((op,), (x, y))
@ -606,7 +606,7 @@ class TestVmapAPI(TestCase):
# NB: One day we will implement a batching rule for torch.atan2.
# If/when we do, this test should be replaced to test the fallback
# path on another operator to avoid bitrot.
op = torch.maximum
op = torch.copysign
x = torch.randn(5, 7, 11)
y = torch.randn(5, 7, 11)