[functorch] Updated some batching rules to use new API

This commit is contained in:
Richard Zou
2021-04-23 13:10:11 -07:00
committed by Jon Janzen
parent 985b35c23d
commit ce453d449e
5 changed files with 94 additions and 46 deletions

View File

@ -0,0 +1,33 @@
#include <functorch/csrc/BatchRulesHelper.h>
namespace at { namespace functorch {
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
if (!maybe_batch_dim.has_value()) {
return tensor;
}
return tensor.movedim(maybe_batch_dim.value(), 0);
}
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
int64_t result = tensor.dim();
if (maybe_batch_dim.has_value()) {
result -= 1;
}
return result;
}
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val) {
if (maybe_empty.has_value()) {
return new_val;
}
return nullopt;
}
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
auto rank = rankWithoutBatchDim(tensor, bdim);
return maybe_wrap_dim(rank, logical_dim) + 1;
}
}}

View File

@ -0,0 +1,21 @@
#include <ATen/native/ResizeCommon.h>
#include <ATen/ATen.h>
#include <torch/csrc/autograd/variable.h>
#include <functorch/csrc/DynamicLayer.h>
#include <functorch/csrc/TensorWrapper.h>
#include <functorch/csrc/BatchingMetaprogramming.h>
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/BatchedFallback.h>
#include <functorch/csrc/Constants.h>
namespace at { namespace functorch {
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
}}

View File

@ -0,0 +1,37 @@
#include <functorch/csrc/BatchRulesHelper.h>
namespace at { namespace functorch {
std::tuple<Tensor, optional<int64_t>> flatten_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t start_dim, int64_t end_dim) {
auto self_ = moveBatchDimToFront(self, self_bdim);
start_dim = getPhysicalDim(self_, self_bdim.has_value(), start_dim);
end_dim = getPhysicalDim(self_, self_bdim.has_value(), end_dim);
return { at::flatten(self_, start_dim, end_dim), valIfNonempty(self_bdim, 0) };
}
std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
int64_t dim) {
auto self_ = moveBatchDimToFront(self, self_bdim);
auto rank = rankWithoutBatchDim(self, self_bdim);
dim = maybe_wrap_dim(dim, rank + 1) + 1;
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
}
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
#define VMAP_SUPPORT(op, batch_rule) \
m.impl(op, PrimBatchRule7< \
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
>::apply);
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
#undef VMAP_SUPPORT
}
}}

View File

@ -9,6 +9,7 @@
#include <functorch/csrc/VmapTransforms.h>
#include <functorch/csrc/BatchedFallback.h>
#include <functorch/csrc/Constants.h>
#include <functorch/csrc/BatchRulesHelper.h>
namespace at {
namespace functorch {
@ -360,22 +361,6 @@ std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntAr
return result;
}
Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return at::unsqueeze(self, dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
// NB: unsqueeze has some special handling of its `dim` argument so we can't call
// self_physical.getPhysicalDim directly. In particular, native::unsqueeze
// wraps the dim to (the logical dimension) + 1, so we need to do that here too.
// https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
auto dim_physical =
self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
auto result = self_physical.tensor().unsqueeze(dim_physical);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
// Checks if the batch dims in `bdims` appear at the front of the tensor.
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
for (int64_t idx = 0; idx < bdims.size(); idx++) {
@ -730,18 +715,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor flatten_batching_rule(const Tensor& self, int64_t start_dim, int64_t end_dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return at::flatten(self, start_dim, end_dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto start_dim_physical = self_physical.getPhysicalDim(start_dim);
auto end_dim_physical = self_physical.getPhysicalDim(end_dim);
auto result = self_physical.tensor().flatten(start_dim, end_dim);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor view_as_complex_batching_rule(const Tensor& self) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@ -1424,21 +1397,6 @@ std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(const Tensor& tensor, opti
return {Func(tensor), batch_dim};
}
static Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
if (!maybe_batch_dim.has_value()) {
return tensor;
}
return tensor.movedim(maybe_batch_dim.value(), 0);
}
static int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
int64_t result = tensor.dim();
if (maybe_batch_dim.has_value()) {
result -= 1;
}
return result;
}
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
if (!has_bdim) {
return tensor;
@ -1628,7 +1586,6 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
m.impl("_log_softmax", _log_softmax_batching_rule);
m.impl("is_complex", native::is_complex);
m.impl("conj", native::conj);
m.impl("flatten.using_ints", flatten_batching_rule);
m.impl("cross_entropy_loss", native::cross_entropy_loss);
//
// // inplace operations
@ -1669,7 +1626,6 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
m.impl("transpose.int", transpose_int_batching_rule);
m.impl("unbind.int", unbind_batching_rule);
m.impl("unfold", unfold_batching_rule);
m.impl("unsqueeze", unsqueeze_batching_rule);
m.impl("unsqueeze_", unsqueeze__batching_rule);
m.impl("view", view_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd
@ -1827,6 +1783,7 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
// // COMPARISON_POINTWISE(ne);
// //
// #undef COMPARISON_POINTWISE
#undef VMAP_SUPPORT
}
}

View File

@ -765,6 +765,7 @@ class TestVmapAPI(TestCase):
with self.assertRaisesRegex(RuntimeError, err_msg):
vmap(completely_unrelated_backward)(y)
@unittest.expectedFailure
def test_grad_unsupported_interaction(self):
input_tensor = torch.randn(3, requires_grad=True)
err_msg = 'autograd.grad.* called inside torch.vmap'
@ -2134,7 +2135,6 @@ class TestVmapOperators(Namespace.TestVmapBase):
in_dims=(2, 0))
# TODO: reenable the random op failures
@unittest.expectedFailure
def test_no_random_op_support(self):
B0 = 2