[functorch] More batch rules

This commit is contained in:
Richard Zou
2021-04-28 13:35:09 -07:00
committed by Jon Janzen
parent d7d266f51e
commit c24314c09b
4 changed files with 110 additions and 15 deletions

View File

@ -0,0 +1,29 @@
#include <functorch/csrc/BatchRulesHelper.h>
namespace at { namespace functorch {
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
if (!self_bdim.has_value()) {
auto result = at::slogdet(self);
return {
std::move(std::get<0>(result)), nullopt,
std::move(std::get<1>(result)), nullopt
};
}
// slogdet supports arbitrary dims at the front
auto self_ = moveBatchDimToFront(self, self_bdim);
auto result = at::slogdet(self_);
return {
std::move(std::get<0>(result)), 0,
std::move(std::get<1>(result)), 0
};
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
}
}}

View File

@ -0,0 +1,51 @@
#include <functorch/csrc/BatchRulesHelper.h>
namespace at { namespace functorch {
// [start, start + 1, ..., stop - 1]
static VmapDimVector range(int64_t start, int64_t stop) {
TORCH_INTERNAL_ASSERT(stop > start);
VmapDimVector dims;
dims.reserve(stop - start);
for (int64_t i = start; i < stop; i++) {
dims.emplace_back(i);
}
return dims;
}
std::tuple<Tensor,optional<int64_t>> sum_batch_rule(
const Tensor& self, optional<int64_t> self_bdim, optional<ScalarType> dtype) {
if (!self_bdim.has_value()) {
return { self.sum(dtype), nullopt };
}
auto self_dim = self.dim();
if (self_dim == 1) {
return { self.clone(), 0 };
}
auto self_ = moveBatchDimToFront(self, self_bdim);
auto dims = range(1, self_dim);
auto result = at::sum(self_, dims, /*keepdim*/false, dtype);
return { result, 0 };
}
std::tuple<Tensor,optional<int64_t>> mean_batch_rule(
const Tensor& self, optional<int64_t> self_bdim, optional<ScalarType> dtype) {
if (!self_bdim.has_value()) {
return { self.sum(dtype), nullopt };
}
auto self_dim = self.dim();
if (self_dim == 1) {
return { self.clone(), 0 };
}
auto self_ = moveBatchDimToFront(self, self_bdim);
auto dims = range(1, self_dim);
auto result = at::mean(self_, dims, /*keepdim*/false, dtype);
return { result, 0 };
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("sum", sum_batch_rule);
VMAP_SUPPORT("mean", mean_batch_rule);
}
}}

View File

@ -91,20 +91,6 @@ static bool participatesInCurrentLevel(TensorList self) {
return false;
}
Tensor mean_batching_rule(const Tensor& self, optional<ScalarType> dtype) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return self.mean(dtype);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
VmapDimVector dims;
for (int64_t i = 1; i < self_physical.tensor().dim(); i++) {
dims.push_back(i);
}
auto result = at::mean(self_physical.tensor(), dims, /*keepdim*/false, dtype);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor log_softmax_batching_rule(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@ -1525,7 +1511,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("max_pool2d", at::native::max_pool2d); // composite
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
m.impl("mean", mean_batching_rule);
m.impl("mean.dim", mean_int_batching_rule);
m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("log_softmax.int", log_softmax_batching_rule);

View File

@ -1935,6 +1935,27 @@ class TestVmapOperators(Namespace.TestVmapBase):
def test_mean_dim(self):
self._test_mean_sum_dim(torch.mean)
def _test_sum_mean(self, op):
test = self._vmap_test
B0, B1 = 5, 7
# Single vmap, various in_dims / out_dims
test(op, [torch.randn([B0])])
test(op, [torch.randn([B0, 3])])
test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
# Doubly nested vmap
test(vmap(op), [torch.randn([B0, B1])])
test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])])
test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2)
def test_sum(self):
self._test_sum_mean(torch.sum)
def test_mean(self):
self._test_sum_mean(torch.mean)
def test_repeat(self):
test = self._vmap_test
B0 = 7
@ -1942,6 +1963,15 @@ class TestVmapOperators(Namespace.TestVmapBase):
test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
def test_slogdet(self):
test = functools.partial(self._vmap_test, check_propagates_grad=False)
B0 = 7
op = torch.linalg.slogdet
test(op, (torch.rand(B0, 1, 1),))
test(op, (torch.rand(B0, 2, 2),))
test(op, (torch.rand(B0, 3, 2, 2),))
test(op, (torch.rand(3, 2, 2, B0),), in_dims=3)
def test_reshape(self):
test = self._vmap_test
B0, B1, B2 = 7, 11, 13