mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[functorch] Updated some batching rules to use new API
This commit is contained in:
33
functorch/functorch/csrc/BatchRulesHelper.cpp
Normal file
33
functorch/functorch/csrc/BatchRulesHelper.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
|
||||
if (!maybe_batch_dim.has_value()) {
|
||||
return tensor;
|
||||
}
|
||||
return tensor.movedim(maybe_batch_dim.value(), 0);
|
||||
}
|
||||
|
||||
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
|
||||
int64_t result = tensor.dim();
|
||||
if (maybe_batch_dim.has_value()) {
|
||||
result -= 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val) {
|
||||
if (maybe_empty.has_value()) {
|
||||
return new_val;
|
||||
}
|
||||
return nullopt;
|
||||
}
|
||||
|
||||
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim) {
|
||||
optional<int64_t> bdim = has_batch_dim ? optional<int64_t>(0) : nullopt;
|
||||
auto rank = rankWithoutBatchDim(tensor, bdim);
|
||||
return maybe_wrap_dim(rank, logical_dim) + 1;
|
||||
}
|
||||
|
||||
}}
|
21
functorch/functorch/csrc/BatchRulesHelper.h
Normal file
21
functorch/functorch/csrc/BatchRulesHelper.h
Normal file
@ -0,0 +1,21 @@
|
||||
#include <ATen/native/ResizeCommon.h>
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/csrc/autograd/variable.h>
|
||||
|
||||
#include <functorch/csrc/DynamicLayer.h>
|
||||
#include <functorch/csrc/TensorWrapper.h>
|
||||
#include <functorch/csrc/BatchingMetaprogramming.h>
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/BatchedFallback.h>
|
||||
#include <functorch/csrc/Constants.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
|
||||
int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim);
|
||||
optional<int64_t> valIfNonempty(optional<int64_t> maybe_empty, int64_t new_val);
|
||||
int64_t getPhysicalDim(const Tensor& tensor, bool has_batch_dim, int64_t logical_dim);
|
||||
|
||||
|
||||
}}
|
||||
|
37
functorch/functorch/csrc/BatchRulesViews.cpp
Normal file
37
functorch/functorch/csrc/BatchRulesViews.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
std::tuple<Tensor, optional<int64_t>> flatten_batch_rule(
|
||||
const Tensor& self,
|
||||
optional<int64_t> self_bdim,
|
||||
int64_t start_dim, int64_t end_dim) {
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
start_dim = getPhysicalDim(self_, self_bdim.has_value(), start_dim);
|
||||
end_dim = getPhysicalDim(self_, self_bdim.has_value(), end_dim);
|
||||
return { at::flatten(self_, start_dim, end_dim), valIfNonempty(self_bdim, 0) };
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
||||
const Tensor& self,
|
||||
optional<int64_t> self_bdim,
|
||||
int64_t dim) {
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
auto rank = rankWithoutBatchDim(self, self_bdim);
|
||||
dim = maybe_wrap_dim(dim, rank + 1) + 1;
|
||||
return { self.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
#define VMAP_SUPPORT(op, batch_rule) \
|
||||
m.impl(op, PrimBatchRule7< \
|
||||
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
|
||||
>::apply);
|
||||
|
||||
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
|
||||
#undef VMAP_SUPPORT
|
||||
}
|
||||
|
||||
}}
|
@ -9,6 +9,7 @@
|
||||
#include <functorch/csrc/VmapTransforms.h>
|
||||
#include <functorch/csrc/BatchedFallback.h>
|
||||
#include <functorch/csrc/Constants.h>
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at {
|
||||
namespace functorch {
|
||||
@ -360,22 +361,6 @@ std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntAr
|
||||
return result;
|
||||
}
|
||||
|
||||
Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
return at::unsqueeze(self, dim);
|
||||
}
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
// NB: unsqueeze has some special handling of its `dim` argument so we can't call
|
||||
// self_physical.getPhysicalDim directly. In particular, native::unsqueeze
|
||||
// wraps the dim to (the logical dimension) + 1, so we need to do that here too.
|
||||
// https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
|
||||
auto dim_physical =
|
||||
self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
|
||||
auto result = self_physical.tensor().unsqueeze(dim_physical);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
// Checks if the batch dims in `bdims` appear at the front of the tensor.
|
||||
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
|
||||
for (int64_t idx = 0; idx < bdims.size(); idx++) {
|
||||
@ -730,18 +715,6 @@ Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor flatten_batching_rule(const Tensor& self, int64_t start_dim, int64_t end_dim) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
return at::flatten(self, start_dim, end_dim);
|
||||
}
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
auto start_dim_physical = self_physical.getPhysicalDim(start_dim);
|
||||
auto end_dim_physical = self_physical.getPhysicalDim(end_dim);
|
||||
auto result = self_physical.tensor().flatten(start_dim, end_dim);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor view_as_complex_batching_rule(const Tensor& self) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
@ -1424,21 +1397,6 @@ std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(const Tensor& tensor, opti
|
||||
return {Func(tensor), batch_dim};
|
||||
}
|
||||
|
||||
static Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
|
||||
if (!maybe_batch_dim.has_value()) {
|
||||
return tensor;
|
||||
}
|
||||
return tensor.movedim(maybe_batch_dim.value(), 0);
|
||||
}
|
||||
|
||||
static int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
|
||||
int64_t result = tensor.dim();
|
||||
if (maybe_batch_dim.has_value()) {
|
||||
result -= 1;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
|
||||
if (!has_bdim) {
|
||||
return tensor;
|
||||
@ -1628,7 +1586,6 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
m.impl("_log_softmax", _log_softmax_batching_rule);
|
||||
m.impl("is_complex", native::is_complex);
|
||||
m.impl("conj", native::conj);
|
||||
m.impl("flatten.using_ints", flatten_batching_rule);
|
||||
m.impl("cross_entropy_loss", native::cross_entropy_loss);
|
||||
//
|
||||
// // inplace operations
|
||||
@ -1669,7 +1626,6 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
m.impl("transpose.int", transpose_int_batching_rule);
|
||||
m.impl("unbind.int", unbind_batching_rule);
|
||||
m.impl("unfold", unfold_batching_rule);
|
||||
m.impl("unsqueeze", unsqueeze_batching_rule);
|
||||
m.impl("unsqueeze_", unsqueeze__batching_rule);
|
||||
m.impl("view", view_batching_rule);
|
||||
m.impl("view_as", native::view_as); // composite wrt autograd
|
||||
@ -1827,6 +1783,7 @@ TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
||||
// // COMPARISON_POINTWISE(ne);
|
||||
// //
|
||||
// #undef COMPARISON_POINTWISE
|
||||
#undef VMAP_SUPPORT
|
||||
}
|
||||
|
||||
}
|
||||
|
@ -765,6 +765,7 @@ class TestVmapAPI(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
vmap(completely_unrelated_backward)(y)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_grad_unsupported_interaction(self):
|
||||
input_tensor = torch.randn(3, requires_grad=True)
|
||||
err_msg = 'autograd.grad.* called inside torch.vmap'
|
||||
@ -2134,7 +2135,6 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
in_dims=(2, 0))
|
||||
|
||||
# TODO: reenable the random op failures
|
||||
@unittest.expectedFailure
|
||||
def test_no_random_op_support(self):
|
||||
B0 = 2
|
||||
|
||||
|
Reference in New Issue
Block a user