mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] More batch rules
This commit is contained in:
29
functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
Normal file
29
functorch/functorch/csrc/BatchRulesLinearAlgebra.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>>
|
||||
slogdet_batch_rule(const Tensor& self, optional<int64_t> self_bdim) {
|
||||
if (!self_bdim.has_value()) {
|
||||
auto result = at::slogdet(self);
|
||||
return {
|
||||
std::move(std::get<0>(result)), nullopt,
|
||||
std::move(std::get<1>(result)), nullopt
|
||||
};
|
||||
}
|
||||
|
||||
// slogdet supports arbitrary dims at the front
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto result = at::slogdet(self_);
|
||||
return {
|
||||
std::move(std::get<0>(result)), 0,
|
||||
std::move(std::get<1>(result)), 0
|
||||
};
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("slogdet", slogdet_batch_rule);
|
||||
}
|
||||
|
||||
}}
|
||||
|
51
functorch/functorch/csrc/BatchRulesReduceOps.cpp
Normal file
51
functorch/functorch/csrc/BatchRulesReduceOps.cpp
Normal file
@ -0,0 +1,51 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
// [start, start + 1, ..., stop - 1]
|
||||
static VmapDimVector range(int64_t start, int64_t stop) {
|
||||
TORCH_INTERNAL_ASSERT(stop > start);
|
||||
VmapDimVector dims;
|
||||
dims.reserve(stop - start);
|
||||
for (int64_t i = start; i < stop; i++) {
|
||||
dims.emplace_back(i);
|
||||
}
|
||||
return dims;
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> sum_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim, optional<ScalarType> dtype) {
|
||||
if (!self_bdim.has_value()) {
|
||||
return { self.sum(dtype), nullopt };
|
||||
}
|
||||
auto self_dim = self.dim();
|
||||
if (self_dim == 1) {
|
||||
return { self.clone(), 0 };
|
||||
}
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto dims = range(1, self_dim);
|
||||
auto result = at::sum(self_, dims, /*keepdim*/false, dtype);
|
||||
return { result, 0 };
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> mean_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim, optional<ScalarType> dtype) {
|
||||
if (!self_bdim.has_value()) {
|
||||
return { self.sum(dtype), nullopt };
|
||||
}
|
||||
auto self_dim = self.dim();
|
||||
if (self_dim == 1) {
|
||||
return { self.clone(), 0 };
|
||||
}
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto dims = range(1, self_dim);
|
||||
auto result = at::mean(self_, dims, /*keepdim*/false, dtype);
|
||||
return { result, 0 };
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("sum", sum_batch_rule);
|
||||
VMAP_SUPPORT("mean", mean_batch_rule);
|
||||
}
|
||||
|
||||
}}
|
@ -91,20 +91,6 @@ static bool participatesInCurrentLevel(TensorList self) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Tensor mean_batching_rule(const Tensor& self, optional<ScalarType> dtype) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
return self.mean(dtype);
|
||||
}
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
VmapDimVector dims;
|
||||
for (int64_t i = 1; i < self_physical.tensor().dim(); i++) {
|
||||
dims.push_back(i);
|
||||
}
|
||||
auto result = at::mean(self_physical.tensor(), dims, /*keepdim*/false, dtype);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor log_softmax_batching_rule(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
@ -1525,7 +1511,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("max_pool2d", at::native::max_pool2d); // composite
|
||||
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
|
||||
|
||||
m.impl("mean", mean_batching_rule);
|
||||
m.impl("mean.dim", mean_int_batching_rule);
|
||||
m.impl("sum.dim_IntList", sum_batching_rule);
|
||||
m.impl("log_softmax.int", log_softmax_batching_rule);
|
||||
|
@ -1935,6 +1935,27 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
def test_mean_dim(self):
|
||||
self._test_mean_sum_dim(torch.mean)
|
||||
|
||||
def _test_sum_mean(self, op):
|
||||
test = self._vmap_test
|
||||
B0, B1 = 5, 7
|
||||
|
||||
# Single vmap, various in_dims / out_dims
|
||||
test(op, [torch.randn([B0])])
|
||||
test(op, [torch.randn([B0, 3])])
|
||||
test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
|
||||
test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
|
||||
|
||||
# Doubly nested vmap
|
||||
test(vmap(op), [torch.randn([B0, B1])])
|
||||
test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])])
|
||||
test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2)
|
||||
|
||||
def test_sum(self):
|
||||
self._test_sum_mean(torch.sum)
|
||||
|
||||
def test_mean(self):
|
||||
self._test_sum_mean(torch.mean)
|
||||
|
||||
def test_repeat(self):
|
||||
test = self._vmap_test
|
||||
B0 = 7
|
||||
@ -1942,6 +1963,15 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
|
||||
test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
|
||||
|
||||
def test_slogdet(self):
|
||||
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
||||
B0 = 7
|
||||
op = torch.linalg.slogdet
|
||||
test(op, (torch.rand(B0, 1, 1),))
|
||||
test(op, (torch.rand(B0, 2, 2),))
|
||||
test(op, (torch.rand(B0, 3, 2, 2),))
|
||||
test(op, (torch.rand(3, 2, 2, B0),), in_dims=3)
|
||||
|
||||
def test_reshape(self):
|
||||
test = self._vmap_test
|
||||
B0, B1, B2 = 7, 11, 13
|
||||
|
Reference in New Issue
Block a user