mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Added maximum/minimum/clamp batching rules
This commit is contained in:
@ -123,28 +123,11 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT(#op".Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(op, Tensor)), &at::op>));
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(add);
|
||||
BINARY_POINTWISE_WITH_SCALAR(sub);
|
||||
BINARY_POINTWISE_WITH_SCALAR(rsub);
|
||||
BINARY_POINTWISE(mul);
|
||||
VMAP_SUPPORT("add.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(add, Scalar)), &at::add, const Scalar&, const Scalar&>));
|
||||
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(sub, Scalar)), &at::sub, const Scalar&, const Scalar&>));
|
||||
VMAP_SUPPORT("rsub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(rsub, Scalar)), &at::rsub, const Scalar&, const Scalar&>));
|
||||
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(mul, Scalar)), &at::mul, const Scalar&>));
|
||||
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(div, Scalar)), &at::div, const Scalar&>));
|
||||
BINARY_POINTWISE(div);
|
||||
VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
|
||||
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
|
||||
VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
|
||||
|
||||
VMAP_SUPPORT("atan2", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(atan2)), &at::atan2>));
|
||||
|
||||
// at::pow has three out-of-place overloads
|
||||
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
|
||||
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Scalar)), &at::pow, const Scalar&>));
|
||||
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
||||
|
||||
VMAP_SUPPORT("clamp",
|
||||
SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN(clamp)), &at::clamp, const optional<Scalar>&, const optional<Scalar>&>));
|
||||
VMAP_SUPPORT("clamp_min.Tensor",
|
||||
SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(clamp_min, Tensor)), &at::clamp_min>));
|
||||
VMAP_SUPPORT("clamp_min",
|
||||
@ -155,6 +138,32 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN(clamp_max)), &at::clamp_max, const Scalar&>));
|
||||
|
||||
|
||||
BINARY_POINTWISE(div);
|
||||
VMAP_SUPPORT("div.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(div, Scalar)), &at::div, const Scalar&>));
|
||||
|
||||
VMAP_SUPPORT("maximum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(maximum)), &at::maximum>));
|
||||
VMAP_SUPPORT("minimum", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN(minimum)), &at::minimum>));
|
||||
|
||||
BINARY_POINTWISE(mul);
|
||||
VMAP_SUPPORT("mul.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(mul, Scalar)), &at::mul, const Scalar&>));
|
||||
|
||||
// at::pow has three out-of-place overloads
|
||||
VMAP_SUPPORT("pow.Tensor_Tensor", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Tensor)), &at::pow>));
|
||||
VMAP_SUPPORT("pow.Tensor_Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(pow, Tensor_Scalar)), &at::pow, const Scalar&>));
|
||||
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(sub);
|
||||
VMAP_SUPPORT("sub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(sub, Scalar)), &at::sub, const Scalar&, const Scalar&>));
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(rsub);
|
||||
VMAP_SUPPORT("rsub.Scalar", SINGLE_ARG(basic_unary_batch_rule<decltype(&ATEN_FN2(rsub, Scalar)), &at::rsub, const Scalar&, const Scalar&>));
|
||||
|
||||
VMAP_SUPPORT("sigmoid_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::sigmoid_backward), &at::sigmoid_backward>));
|
||||
VMAP_SUPPORT("tanh_backward", SINGLE_ARG(binary_pointwise_batch_rule<decltype(&at::tanh_backward), &at::tanh_backward>));
|
||||
VMAP_SUPPORT("threshold_backward", SINGLE_ARG(
|
||||
binary_pointwise_batch_rule<decltype(&at::threshold_backward), &at::threshold_backward, const Scalar&>));
|
||||
|
||||
using TensorScalarInplaceT = Tensor& (Tensor::*)(const Tensor&, const Scalar&) const;
|
||||
using ScalarScalarInplaceT = Tensor& (Tensor::*)(const Scalar&, const Scalar&) const;
|
||||
using TensorInplaceT = Tensor& (Tensor::*)(const Tensor&) const;
|
||||
|
@ -540,7 +540,7 @@ class TestVmapAPI(TestCase):
|
||||
# NB: One day we will implement a batching rule for torch.atan2.
|
||||
# If/when we do, this test should be replaced to test the fallback
|
||||
# path on another operator to avoid bitrot.
|
||||
op = torch.maximum
|
||||
op = torch.copysign
|
||||
x = torch.randn(11)
|
||||
y = torch.randn(11)
|
||||
with warnings.catch_warnings(record=True) as wa:
|
||||
@ -575,7 +575,7 @@ class TestVmapAPI(TestCase):
|
||||
# NB: One day we will implement a batching rule for torch.atan2.
|
||||
# If/when we do, this test should be replaced to test the fallback
|
||||
# path on another operator to avoid bitrot.
|
||||
op = torch.maximum
|
||||
op = torch.copysign
|
||||
x = torch.randn(11)
|
||||
y = torch.randn(11)
|
||||
self._assert_uses_vmap_fallback((op,), (x, y))
|
||||
@ -606,7 +606,7 @@ class TestVmapAPI(TestCase):
|
||||
# NB: One day we will implement a batching rule for torch.atan2.
|
||||
# If/when we do, this test should be replaced to test the fallback
|
||||
# path on another operator to avoid bitrot.
|
||||
op = torch.maximum
|
||||
op = torch.copysign
|
||||
|
||||
x = torch.randn(5, 7, 11)
|
||||
y = torch.randn(5, 7, 11)
|
||||
|
Reference in New Issue
Block a user