Files
pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp
Yuanyuan Chen ef50c9b557 Remove unnecessary "static" for definitions in anonymous namespace (#165035)
This PR removes unnecessary "static" for C++ functions and variables in anonymous namespace as detected by clang-tidy. This enhances code readability. The related rules are planed to be enabled in follow-up PRs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165035
Approved by: https://github.com/Skylion007
2025-10-11 00:04:23 +00:00

768 lines
32 KiB
C++

// Copyright (c) Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
#include <torch/library.h>
#include <ATen/native/ResizeCommon.h>
#include <ATen/ATen.h>
#include <ATen/native/TensorShape.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/functorch/DynamicLayer.h>
#include <ATen/functorch/TensorWrapper.h>
#include <ATen/functorch/BatchingMetaprogramming.h>
#include <ATen/functorch/LegacyVmapTransforms.h>
#include <ATen/functorch/BatchedFallback.h>
#include <ATen/functorch/BatchRulesHelper.h>
#include <utility>
namespace at::functorch {
// NOTE: [What is a batching rule?]
//
// NB: the following description only applies to this file and is about
// the legacy (deprecated) batching rule API. Please see writing_batch_rules.md
// for how to write new-style batching rules.
//
// This files contains batching rules written with the legacy (now-deprecated)
// batching rule API.
// Please try to use the new-style batching rule API (see writing_batch_rules.md)
//
// A *batching rule* implements the logic of how to call an operator on inputs
// that have zero or more additional batch dimensions. When one does a vmap, the
// dimension(s) being vmap'ed over get recorded as batch dimensions.
//
// For example, vmap(torch.add)(x, y)
// 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
// 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
// 3. and then runs `torch.add(batched_x, batched_y)`.
// NOTE: [When should I add a batching rule?]
// When you are adding a new operator, you'll need to add a batching rule so
// that vmap can work efficiently with said operator. If you do not, we'll attempt
// to generate a slow fallback for the batching rule.
// NOTE: [How to write batching rules?]
// The signature of a batching rule should look like exactly like the C++ signature
// of its operator.
//
// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
//
// At a high level, what a batching rule does is the following:
// 1. Converts (logical) BatchedTensors to views on physical tensors.
// 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
// arguments that correspond to the physical tensors.
// 3. Calls at:: operations on the physical tensors and arguments to produce
// some physical results.
// 4. Converts physical results back to BatchedTensors.
//
// Steps 1, 2, and 4 differ for operators with different batching behaviors. When
// writing a new batching rule, please select a VmapTransform that matches the
// batching behavior of your operation. The VmapTransform provides helper functions
// to do steps (1), (2), and (4).
// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
namespace{
// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
return dim == 0 || dim == -1;
}
int64_t get_current_level() {
auto maybe_level = maybeCurrentDynamicLayer();
TORCH_INTERNAL_ASSERT(maybe_level.has_value());
return maybe_level->layerId();
}
// This check should probably go into the dispatcher...
bool participatesInCurrentLevel(const Tensor& self) {
auto current_level = get_current_level();
auto* maybe_batched_impl = maybeGetBatchedImpl(self);
if (!maybe_batched_impl) {
return false;
}
auto self_level = maybe_batched_impl->level();
TORCH_INTERNAL_ASSERT(self_level <= current_level);
return self_level == current_level;
}
bool participatesInCurrentLevel(ITensorListRef self) {
for (const Tensor& tensor : self) {
if (participatesInCurrentLevel(tensor)) {
return true;
}
}
return false;
}
Tensor& squeeze_dims__batching_rule(Tensor& self, IntArrayRef dims) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return self.squeeze_(dims);
}
auto* batched = maybeGetBatchedImpl(self);
const auto bdim = batched->bdim();
auto logical_dim = self.dim();
if (logical_dim == 0) {
TORCH_CHECK(
dims.empty() || (dims.size() == 1 && dims[0] == 0),
"Dimension is out of range (expected to be in range of [-1, 0], but got ", dims);
return self;
}
// Adjust any dimensions higher than the batch dimension
DimVector adjusted_dims(dims.begin(), dims.end());
int64_t updated_batch_idx = bdim;
for (auto &d : adjusted_dims) {
auto actual_dim = c10::maybe_wrap_dim(d, logical_dim);
if (actual_dim < bdim) {
d = actual_dim;
if (batched->value().sym_size(actual_dim) == 1) {
// A column before batch dimension will be dropped so adjust accordingly.
--updated_batch_idx;
}
} else {
// Since dimension to be squeezed is after the batch dimension adjust by one to account
// for the original batch dimension. In this case batch dimension won't move.
d = actual_dim + 1;
}
}
batched->value().squeeze_(adjusted_dims);
if (updated_batch_idx != bdim) {
batched->unsafe_set_bdim(updated_batch_idx);
}
batched->refreshTensorMetadata();
return self;
}
Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
return squeeze_dims__batching_rule(self, {dim});
}
Tensor& squeeze__batching_rule(Tensor& self) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return self.squeeze_();
}
auto* batched = maybeGetBatchedImpl(self);
// Need to find out how many dimensions of size 1 are before the bdim
const auto bdim = batched->bdim();
const auto physical_shape = batched->value().sizes();
auto how_many_dims_of_size_1_before_bdim = 0;
for (const auto i : c10::irange(0, physical_shape.size())) {
if ((int64_t)i == bdim) {
break;
}
if (physical_shape[i] == 1) {
how_many_dims_of_size_1_before_bdim++;
}
}
int64_t new_bdim = bdim - how_many_dims_of_size_1_before_bdim;
if (physical_shape[bdim] != 1) {
// if bdim is not 1, can just call squeeze_()
batched->value().squeeze_();
} else {
// otherwise, squeeze_() is going to get rid of the bdim too.
// We "fix it up" by calling unsqueeze_.
batched->value().squeeze_();
batched->value().unsqueeze(new_bdim);
}
// Refresh metadata
batched->unsafe_set_bdim(new_bdim);
batched->refreshTensorMetadata();
return self;
}
Tensor& unsqueeze__batching_rule(Tensor& self, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return self.unsqueeze_(dim);
}
auto* batched = maybeGetBatchedImpl(self);
auto logical_dim = self.dim();
int64_t dim_physical = maybe_wrap_dim(dim, logical_dim + 1);
if (dim_physical >= batched->bdim()) {
dim_physical = 1 + dim_physical;
} else {
batched->unsafe_set_bdim(batched->bdim() + 1);
}
batched->value().unsqueeze_(dim_physical);
// Also need to change some metadata...
batched->refreshTensorMetadata();
return self;
}
Tensor& transpose__batching_rule(Tensor& self, int64_t dim0, int64_t dim1) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return self.transpose_(dim0, dim1);
}
auto* batched = maybeGetBatchedImpl(self);
auto logical_dim = self.dim();
// PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
// for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
// >>> x = torch.randn(B0) # the per-examples are all scalars
// >>> vmap(lambda x: x.transpose_(0, -1), x)
// then we replicate this behavior.
if (logical_dim == 0 &&
is_allowed_dim_on_scalar_tensor(dim0) &&
is_allowed_dim_on_scalar_tensor(dim1)) {
// No transposing happened :P
return self;
}
dim0 = maybe_wrap_dim(dim0, logical_dim);
dim1 = maybe_wrap_dim(dim1, logical_dim);
dim0 = dim0 >= batched->bdim() ? dim0 + 1 : dim0;
dim1 = dim1 >= batched->bdim() ? dim1 + 1 : dim1;
batched->value().transpose_(dim0, dim1);
// Also need to change some metadata...
batched->refreshTensorMetadata();
return self;
}
std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::split(self, split_size, dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::split(self_physical.tensor(), split_size, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, SymIntArrayRef split_sizes, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return split_with_sizes_symint(self, split_sizes, dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = split_with_sizes_symint(self_physical.tensor(), split_sizes, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
std::vector<Tensor> split_with_sizes_copy_batching_rule(const Tensor& self, SymIntArrayRef split_sizes, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return split_with_sizes_copy_symint(self, split_sizes, dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = split_with_sizes_copy_symint(self_physical.tensor(), split_sizes, dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::unbind(self, dim);
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dim_physical = self_physical.getPhysicalDim(dim);
auto result = at::unbind(self_physical.tensor(), dim_physical);
self_physical.getPhysicalToLogicalMap().applyInplace(result);
return result;
}
// given (sizes, strides, storage_offset) returns the maximum location that
// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
// with zero-size dims).
std::optional<c10::SymInt> maximum_indexable_location(
c10::SymIntArrayRef sizes, c10::SymIntArrayRef strides, const c10::SymInt& storage_offset) {
auto result = native::storage_size_for(sizes, strides);
if (result == 0) {
return std::nullopt;
}
return result + storage_offset;
}
// Let x be the "first slice" of physical_tensor.
// This checks that the range of possible memory locations accessible by
// x.as_strided(sizes, strides, maybe_storage_offset)
// are within the bounds of possible memory locations accessible by x.
void checkBasicAsStridedValidForSlice(
const Tensor& physical_tensor,
int64_t num_batch_dims,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,
const std::optional<c10::SymInt>& maybe_storage_offset) {
auto slice_sizes = physical_tensor.sym_sizes().slice(num_batch_dims);
auto slice_strides = physical_tensor.sym_strides().slice(num_batch_dims);
auto base_offset = physical_tensor.sym_storage_offset();
auto storage_offset = maybe_storage_offset.value_or(base_offset);
auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
if (!max_as_strided_loc.has_value()) {
return;
}
if (!max_slice_loc.has_value()) {
TORCH_CHECK(false,
"result = tensor.as_strided(", sizes, ", ", strides, ", ", storage_offset, ") ",
"can access memory outside of `tensor`. `tensor` has no storage but the ",
"passed-in (size, stride, storage_offset) imply a result with some storage. ",
"This is not supported inside of vmap, please try to rewrite the ",
"`as_strided` call as a sequence of PyTorch view operations");
}
TORCH_CHECK(
*max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
"result = tensor.as_strided(", sizes, ", ", strides, ", ", storage_offset, ") ",
"can access memory outside of `tensor`. `result` can access some ",
"memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
"`tensor` can only access some memory in range [", base_offset, ", ",
*max_slice_loc, "]. This is not supported inside of vmap, please try to ",
"rewrite the `as_strided` call as a sequence of PyTorch view operations");
}
// What are the semantics of as_strided inside of vmap?
// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
// This returns a view on `x`, `y`, such that each y[i] has:
// - sizes: `sizes`
// - strides: `strides`
// - storage_offset: offset + i * x.stride(batch_dim)
//
// In other words, it is as if we had treated each x[i] as having storage
// offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
// (that is equivalent to x[i].as_strided(
// sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
//
// Note that this *may* be different from actually running as_strided
// in a for-loop. This is due to how as_strided takes in `offset` to be
// an *absolute* offset. As an example, consider:
// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
// However, we consider the above for-loop comprehension to be a user error:
// a user should have written the following if they wanted to use as_strided
// in a per-sample way:
// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
Tensor as_strided_batching_rule(
const Tensor& tensor,
c10::SymIntArrayRef sizes,
c10::SymIntArrayRef strides,
std::optional<c10::SymInt> storage_offset) {
if (!participatesInCurrentLevel(tensor)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::as_strided_symint(tensor, sizes, strides, std::move(storage_offset));
}
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
auto num_batch_dims = physical_view.numBatchDims();
auto physical_sizes = physical_view.getPhysicalShape(sizes);
const auto& physical_tensor = physical_view.tensor();
// We can't rely on the physical as_strided call to do this for us because
// we do some sanity checks on the size/strides before calling into as_strided.
TORCH_CHECK(sizes.size() == strides.size(),
"Tensor.as_strided(size, stride, ...): size and stride must have the ",
"same length! Got size ", sizes, " and stride ", strides);
// Sanity checks:
// 1. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
// is valid for a slice of the input tensor.
// See Note: [When will the as_strided batching rule fail?] for details.
checkBasicAsStridedValidForSlice(
physical_tensor, num_batch_dims, sizes, strides, storage_offset);
// physical_strides = physical tensor's batch strides + (logical) strides
auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
SymDimVector physical_strides;
physical_strides.reserve(num_batch_dims + strides.size());
physical_strides.insert(
physical_strides.end(), batch_strides.begin(), batch_strides.end());
physical_strides.insert(
physical_strides.end(), strides.begin(), strides.end());
// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
// is valid for all i, then it turns out that
// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
// and creates a tensor y such that each y[i] references the same memory
// locations as zi. See NOTE: [When will the as_strided batching rule fail?]
auto result = physical_view.tensor().as_strided_symint(
physical_sizes, physical_strides, std::move(storage_offset));
return physical_view.getPhysicalToLogicalMap().apply(result);
}
// NOTE: [When will the as_strided batching rule fail?]
// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
// is valid for all i, then it turns out that
// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
// creates a tensor y such that each y[i] refers to the same memory as zi.
//
// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
// Furthermore, let's say that as a part of being "valid" this as_strided call
// does not return a result that can index memory not indexable by xs[i].
//
// WLOG, assume that there's only one batch dim and it is at the front of the
// `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
// - If the batch dim isn't at the front of the tensor, then we can just move it
// to the front with movedim/permute. This is always valid because it just swaps
// some strides around.
// - This proof also works for tensors with multiple batch dims. We just have to
// do a little accounting:
// - instead of [B], we'd have [B0, B1, ..., Bk].
// - instead of [S], we'd have [S0, S1, ..., Sk].
// - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
// - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
//
// [Equation 1]
// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
// - sizes: sizes
// - strides: strides
// - offset: offset + S * i
//
// x.as_strided itself checks that:
// - (sizes, strides, offset) are in bounds for `x`'s storage.
// - strides are positive
// - offset is positive
//
// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
// is valid, then
// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
//
// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
// won't error out. So all we need to check is that the memory locations are
// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
//
// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
// xs.as_strided([B] + sizes, [S] + strides, offset)
//
// xs.as_strided([B] + sizes, [S] + strides, offset) has:
// - sizes: [B] + sizes
// - strides: [S] + strides
// - offset: offset
//
// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
// - sizes: sizes
// - strides: strides
// - offset: offset + S * i
// These memory locations are exactly the same as what we got for [Equation 1],
// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
//
// [Hand-wavy proof of Claim 1]
// Part of our definition of being valid is that xs[i].as_strided(...)
// must return a tensor that only uses memory indexable by xs[i].
// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
// (the largest-index memory location of xs[i].as_strided(...) must be \leq
// the largest-index memory location of xs[i])
//
// Fiddling that inequality gives us:
// offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
// <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
//
// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
// <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
//
// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
// <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
//
// offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
// <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
// (the largest-index memory location of xs.as_strided(size, stride, offset)
// is \leq than the largest-index memory location of xs)
// Under the assumptions we've made, the lower bound (lowest indexed memory)
// is trivially within the storage.
//
// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
// `xs`'s storage.
template <typename F, F Func, typename... ExtraArgs>
Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
if (!participatesInCurrentLevel(input)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return Func(input, args...);
}
// guard against the user passing in a batch of scalar tensors with batch
auto* input_batched = unsafeGetBatchedImpl(input);
auto output_physical = Func(input_batched->value(), args...);
return makeBatched(output_physical, input_batched->bdim(), input_batched->level());
}
template <typename F, F Func, typename... ExtraArgs>
Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
if (!participatesInCurrentLevel(input)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return (input.*Func)(extra_args...);
}
auto* input_batched = unsafeGetBatchedImpl(input);
auto output_physical = (input_batched->value().*Func)(extra_args...);
return makeBatched(output_physical, input_batched->bdim(), input_batched->level());
}
Tensor cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
if (!participatesInCurrentLevel(tensors)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::cat(tensors, dim);
}
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
// NB: Probably bad for perf that we're allocating std::vectors for each level, but
// what can you do.
auto materialized = tensors.materialize();
dim = at::legacy_cat_wrap_dim(dim, materialized);
// Strategy:
// we're going to unwrap tensors, move their batch dims to the front,
// and put them into `tensors_to_cat`. Tensors that don't have a batch dim
// will get one forced onto them.
//
// Then, we'll do at::cat(tensors_to_cat, ...).
//
// There's a special case where at::cat ignores tensors that have logical shape
// [0]. If we see a Tensor that has logical shape [0] (but physical shape [B, 0]),
// we'll just slice the tensor to get a Tensor of shape [0] to pass to at::cat.
std::vector<Tensor> tensors_to_cat;
tensors_to_cat.reserve(tensors.size());
std::optional<int64_t> bdim_size = std::nullopt;
// find the bdim size. Might not exist if all BatchedTensors should be skipped
// by cat's special case.
for (const auto& tensor : tensors) {
if (!participatesInCurrentLevel(tensor)) {
continue;
}
if (at::native::cat_should_skip_tensor(tensor)) {
continue;
}
const auto* batched = unsafeGetBatchedImpl(tensor);
bdim_size = batched->value().size(batched->bdim());
break;
}
// unwrap batchedtensors; expand out bdims
for (const auto& tensor : tensors) {
if (!participatesInCurrentLevel(tensor)) {
if (at::native::cat_should_skip_tensor(tensor) || !bdim_size.has_value()) {
tensors_to_cat.emplace_back(tensor);
continue;
}
tensors_to_cat.emplace_back(ensure_has_bdim(tensor, /*has_bdim*/false, *bdim_size));
continue;
}
const auto* batched = unsafeGetBatchedImpl(tensor);
if (at::native::cat_should_skip_tensor(tensor)) {
// Special case: slice the tensor to get something of shape [0] to pass to cat
// We slice instead of allocate a new tensor to propagate requires_gradness...
tensors_to_cat.emplace_back(batched->value().select(/*dim=*/batched->bdim(), /*index=*/0));
continue;
}
tensors_to_cat.emplace_back(moveBatchDimToFront(batched->value(), batched->bdim()));
}
auto new_dim = bdim_size.has_value() ? dim + 1 : dim;
std::optional<int64_t> new_bdim = bdim_size.has_value() ? std::make_optional((int64_t)0) : std::nullopt;
auto result = at::cat(tensors_to_cat, new_dim);
return makeBatched(result, new_bdim, get_current_level());
}
Tensor block_diag_batching_rule(TensorList tensors) {
if (!participatesInCurrentLevel(tensors)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::block_diag(tensors);
}
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
!tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
// Implementing this as a dummy for loop for now, since I'm not sure how to do it any better.
// I'm probably not accounting for potentially multiple batched dimensions?
auto bdim = physical_tensors[0].size(0);
std::vector<Tensor> batched_outputs;
batched_outputs.reserve(bdim);
for (const auto& i : c10::irange(bdim)) {
std::vector<Tensor> inputs_for_batch;
inputs_for_batch.reserve(physical_tensors.size());
for (const auto& t : physical_tensors) {
inputs_for_batch.push_back(t[i]);
}
auto out_for_batch = at::block_diag(inputs_for_batch);
batched_outputs.push_back(out_for_batch.unsqueeze(0));
}
auto result = at::cat(batched_outputs);
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}
Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
if (!participatesInCurrentLevel(tensors)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return at::stack(tensors, dim);
}
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
auto physical_tensors = fmap(
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
TORCH_INTERNAL_ASSERT(
!tensors.empty(), "The dispatcher should not have dispatched here otherwise.");
// NB: stack wraps the dimensionality to (logical dim + 1), so we have to
// manually handle that here.
auto dim_physical =
physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
auto result = at::stack(physical_tensors, dim_physical);
return physical_views[0].getPhysicalToLogicalMap().apply(result);
}
Tensor new_empty_strided_batching_rule(
const Tensor& self,
SymIntArrayRef sym_size,
SymIntArrayRef sym_stride,
std::optional<ScalarType> dtype,
std::optional<Layout> layout,
std::optional<Device> device,
std::optional<bool> pin_memory) {
auto size = C10_AS_INTARRAYREF_SLOW(sym_size);
auto stride = C10_AS_INTARRAYREF_SLOW(sym_stride);
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched);
return self.new_empty_strided(
size, stride, dtype, layout, device, pin_memory);
}
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
auto physical_size = physical_view.getPhysicalShape(size);
// Let [B0, B1, B2] be the shape of the batch dims. We're going to create
// the batch dimensions at the front of the tensor (in memory layout),
// irrespective of whether or not they are actually at the front (in memory layout)
// in the original `self` tensor. This is because when a user calls
// `new_empty_strided` in general, the `strides` they provide are for a new
// tensor and have no relation to the strides of the original tensor.
//
// So, the physical shape of the result should be ([B0, B1, B2] + size),
// but what about the physical strides?
//
// We're actually free to pick whatever stride we want:
// e.g., for size=[5, 3], stride=[0, 1], we could decide to
// use
// - physical size: [B0, B1, B2, 5, 3]
// - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
//
// Let's select some reasonable strides such that:
// - The batch dims are "contiguous" with respect to each other
// - if empty_strided(size, stride) would have created a contiguous Tensor,
// then this new physical Tensor (with batch dims) is also contiguous
//
// Let S be the size of the storage if one were to construct a tensor
// with `size` and `stride` via empty_strided(size, stride).
// Then the physical sizes/strides should be:
// - physical size: [B0, B1, B2, 5, 3]
// - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
auto batch_shape = IntArrayRef(
physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
// physical_strides = [B1 * B2 * S, B2 * S, S]
auto physical_strides = at::detail::defaultStrides(batch_shape);
TORCH_CHECK(size.size() == stride.size(),
"new_empty_strided(sizes, strides): dimensionality of sizes (",
size.size(), ") must match dimensionality of strides (",
stride.size(), ")");
auto storage_size = native::storage_size_for(size, stride);
for (auto& physical_stride : physical_strides) {
physical_stride *= storage_size;
}
// physical_strides = [B1 * B2 * S, B2 * S, S] + strides
physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
auto result = physical_view.tensor().new_empty_strided(
physical_size, physical_strides, dtype, layout, device, pin_memory);
return physical_view.getPhysicalToLogicalMap().apply(result);
}
Tensor nested_cat_batching_rule(const ITensorListRef& tensors, int64_t dim) {
TORCH_CHECK(!tensors.empty(), "cat() not supported on empty tensor list");
std::vector<std::vector<Tensor>> unbound;
for (const auto & tensor : tensors) {
auto* maybe_batched_impl = maybeGetBatchedImpl(tensor);
TORCH_CHECK(maybe_batched_impl, "Tried to run batching rule for cat() on a non-batched tensor");
auto nt = maybe_batched_impl->value();
TORCH_CHECK(nt.is_nested(), "Tried to run batching rule for cat() on a non-nested tensor");
c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::BatchedNestedTensor);
auto this_unbound = nt.unbind();
if (!unbound.empty()) {
TORCH_INTERNAL_ASSERT(unbound.front().size() == this_unbound.size(),
"cat() not supported for differently-sized nested arguments");
}
unbound.push_back(this_unbound);
}
// Do a cat for each set of zipped unbound components
const auto num_components = unbound.front().size();
std::vector<Tensor> outputs;
for (auto i : c10::irange(num_components)) {
std::vector<Tensor> arg_list;
for (auto j : c10::irange(unbound.size())) {
arg_list.push_back(unbound[j][i]);
}
outputs.push_back(at::cat(arg_list, dim));
}
// NB: NTs only support batching over dim 0
auto out_nt = at::_nested_tensor_from_tensor_list(outputs);
return makeBatched(out_nt, 0, get_current_level());
}
}
TORCH_LIBRARY_IMPL(_, FuncTorchBatched, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
}
TORCH_LIBRARY_IMPL(aten, FuncTorchBatched, m) {
// still legacy b/c teturns multiple tensors
m.impl("split.Tensor", split_batching_rule);
m.impl("split_with_sizes", split_with_sizes_batching_rule);
m.impl("split_with_sizes_copy", split_with_sizes_copy_batching_rule);
m.impl("unbind.int", unbind_batching_rule);
m.impl("cat", cat_batching_rule);
m.impl("block_diag", block_diag_batching_rule);
m.impl("stack", stack_batching_rule);
// still legacy b/c needs special inplace rules
m.impl("squeeze_", squeeze__batching_rule);
m.impl("squeeze_.dim", squeeze_dim__batching_rule);
m.impl("squeeze_.dims", squeeze_dims__batching_rule);
m.impl("unsqueeze_", unsqueeze__batching_rule);
m.impl("transpose_", transpose__batching_rule);
// still legacy because these are ridiculously complicated
m.impl("as_strided", as_strided_batching_rule);
m.impl("new_empty_strided", new_empty_strided_batching_rule);
}
TORCH_LIBRARY_IMPL(_, BatchedNestedTensor, m) {
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedNestedTensorForLoopFallback>());
}
// TODO: Move this somewhere better?
TORCH_LIBRARY_IMPL(aten, BatchedNestedTensor, m) {
m.impl("cat", nested_cat_batching_rule);
}
} // namespace at::functorch