mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
1834 lines
80 KiB
C++
1834 lines
80 KiB
C++
#include <torch/library.h>
|
|
#include <ATen/native/ResizeCommon.h>
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/autograd/variable.h>
|
|
|
|
#include <functorch/csrc/DynamicLayer.h>
|
|
#include <functorch/csrc/TensorWrapper.h>
|
|
#include <functorch/csrc/BatchingMetaprogramming.h>
|
|
#include <functorch/csrc/VmapTransforms.h>
|
|
#include <functorch/csrc/BatchedFallback.h>
|
|
#include <functorch/csrc/Constants.h>
|
|
|
|
namespace at {
|
|
namespace functorch {
|
|
|
|
|
|
// NOTE: [What is a batching rule?]
|
|
//
|
|
// A *batching rule* implements the logic of how to call an operator on inputs
|
|
// that have zero or more additional batch dimensions. When one does a vmap, the
|
|
// dimension(s) being vmap'ed over get recorded as batch dimensions.
|
|
//
|
|
// For example, vmap(torch.add)(x, y)
|
|
// 1. wraps `x` into batched_x = BatchedTensor(x, bdims=[(lvl=1, dim=0)];
|
|
// 2. wraps `y` into batched_y = BatchedTensor(y, bdims=[(lvl=1, dim=0)];
|
|
// 3. and then runs `torch.add(batched_x, batched_y)`.
|
|
|
|
// NOTE: [When should I add a batching rule?]
|
|
// When you are adding a new operator, you'll need to add a batching rule so
|
|
// that vmap can work efficiently with said operator. If you do not, we'll attempt
|
|
// to generate a slow fallback for the batching rule.
|
|
|
|
// NOTE: [How to write batching rules?]
|
|
// The signature of a batching rule should look like exactly like the C++ signature
|
|
// of its operator.
|
|
//
|
|
// First, see NOTE: [Logical vs physical args] in VmapTransforms.h for terminology.
|
|
//
|
|
// At a high level, what a batching rule does is the following:
|
|
// 1. Converts (logical) BatchedTensors to views on physical tensors.
|
|
// 2. Converts logical arguments (e.g. dimension indexes, shapes) to physical
|
|
// arguments that correspond to the physical tensors.
|
|
// 3. Calls at:: operations on the physical tensors and arguments to produce
|
|
// some physical results.
|
|
// 4. Converts physical results back to BatchedTensors.
|
|
//
|
|
// Steps 1, 2, and 4 differ for operators with different batching behaviors. When
|
|
// writing a new batching rule, please select a VmapTransform that matches the
|
|
// batching behavior of your operation. The VmapTransform provides helper functions
|
|
// to do steps (1), (2), and (4).
|
|
// (see NOTE: [What is an VmapTransform?] in VmapTransforms.h)
|
|
|
|
// Note: [Future plans]
|
|
// The API for writing a batching rule isn't stable. In the future, we'd like
|
|
// to think about the problem of translating these batching rules to TorchScript.
|
|
// Ideally batching rules in eager mode vs TorchScript would look pretty similar,
|
|
// if not use the same mechanism. In order to accomplish that we might have to
|
|
// do some refactoring.
|
|
|
|
// PyTorch allows operations to specify dim 0 and dim -1 on a scalar tensor.
|
|
static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
|
|
return dim == 0 || dim == -1;
|
|
}
|
|
|
|
// This check should probably go into the dispatcher...
|
|
static bool participatesInCurrentLevel(const Tensor& self) {
|
|
auto maybe_level = maybeCurrentDynamicLayer();
|
|
TORCH_INTERNAL_ASSERT(maybe_level.has_value());
|
|
auto current_level = maybe_level->layerId();
|
|
auto* maybe_batched_impl = maybeGetBatchedImpl(self);
|
|
if (!maybe_batched_impl) {
|
|
return false;
|
|
}
|
|
const auto& bdims = maybe_batched_impl->bdims();
|
|
TORCH_INTERNAL_ASSERT(bdims.size() == 1);
|
|
auto self_level = bdims.back().level();
|
|
TORCH_INTERNAL_ASSERT(self_level <= current_level);
|
|
return self_level == current_level;
|
|
}
|
|
static bool participatesInCurrentLevel(const Tensor& self, const Tensor& other) {
|
|
return participatesInCurrentLevel(self) || participatesInCurrentLevel(other);
|
|
}
|
|
|
|
static bool participatesInCurrentLevel(TensorList self) {
|
|
for (const Tensor& tensor : self) {
|
|
if (participatesInCurrentLevel(tensor)) {
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
Tensor mean_batching_rule(const Tensor& self, optional<ScalarType> dtype) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.mean(dtype);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
VmapDimVector dims;
|
|
for (int64_t i = 1; i < self_physical.tensor().dim(); i++) {
|
|
dims.push_back(i);
|
|
}
|
|
auto result = at::mean(self_physical.tensor(), dims, /*keepdim*/false, dtype);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor log_softmax_batching_rule(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::log_softmax(self, dim, dtype);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::log_softmax(self_physical.tensor(), dim_physical, dtype);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor _log_softmax_batching_rule(const Tensor& self, int64_t dim, bool half_to_float) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::_log_softmax(self, dim, half_to_float);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::_log_softmax(self_physical.tensor(), dim_physical, half_to_float);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
std::tuple<Tensor,Tensor> max_pool2d_with_indices_batching_rule(
|
|
const Tensor & self, IntArrayRef kernel_size, IntArrayRef stride,
|
|
IntArrayRef padding, IntArrayRef dilation, bool ceil_mode) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::max_pool2d_with_indices(
|
|
self, kernel_size, stride, padding, dilation, ceil_mode);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
TORCH_INTERNAL_ASSERT(self_physical.tensor().dim() == 5);
|
|
|
|
auto N = self_physical.tensor().size(0);
|
|
auto M = self_physical.tensor().size(1);
|
|
auto physical = self_physical.tensor().flatten(0, 1);
|
|
|
|
auto result = max_pool2d_with_indices_batching_rule(physical,
|
|
kernel_size, stride, padding, dilation, ceil_mode);
|
|
|
|
auto first = std::get<0>(result).unflatten(0, {N, M});
|
|
auto second = std::get<1>(result).unflatten(0, {N, M});
|
|
|
|
first = self_physical.getPhysicalToLogicalMap().apply(first);
|
|
second = self_physical.getPhysicalToLogicalMap().apply(second);
|
|
return std::make_tuple<Tensor, Tensor>(std::move(first), std::move(second));
|
|
}
|
|
|
|
Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.sum(dims, keepdim, dtype);
|
|
}
|
|
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
|
|
// and instead returns a new scalar tensor (this also happens for dim=-1)
|
|
// If the following happens:
|
|
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
|
// >>> vmap(partial(torch.sum, dim=0), x)
|
|
// then we replicate the behavior of sum(scalar_tensor, dim=0).
|
|
if (/*logical*/self.dim() == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])) {
|
|
return self.clone();
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dims_physical = self_physical.getPhysicalDims(dims);
|
|
auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
|
|
if (logical_tensor.dim() > 0) {
|
|
return false;
|
|
}
|
|
auto* batched = maybeGetBatchedImpl(logical_tensor);
|
|
if (batched) {
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
template <typename F, F Func, typename... ExtraArgs>
|
|
Tensor binary_pointwise_batching_rule(
|
|
const Tensor& self, const Tensor& other, ExtraArgs... args) {
|
|
if (!participatesInCurrentLevel(self, other)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return Func(self, other, std::forward<ExtraArgs>(args)...);
|
|
}
|
|
if (self.dim() > 0 && other.dim() > 0) {
|
|
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
|
|
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
|
|
return physical_args[0].getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
if (isPhysicalScalarTensor(self)) {
|
|
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
|
|
auto result = Func(self, other_physical.tensor(), args...);
|
|
return other_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
if (isPhysicalScalarTensor(other)) {
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = Func(self_physical.tensor(), other, args...);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
// At this point, we know at least one of the operands is a logical Scalar tensor.
|
|
// Here we must emulate TensorIterator's special behavior on Scalars.
|
|
//
|
|
// As a motivating example, consider the following:
|
|
// x = torch.randn(3, 10)
|
|
// y = torch.randn(3, dtype=torch.double)
|
|
// vmap(torch.mul)(torch.randn(3, 10), torch.randn(3, dtype=torch.double))
|
|
//
|
|
// At a per-example level, we are adding FloatTensor[10] and DoubleTensor[];
|
|
// Type Promotion dictates that the result should be FloatTensor[10].
|
|
// This means we cannot directly pass the physical tensors (x and y) to
|
|
// TensorIterator (if we did, it would promote them to DoubleTensor).
|
|
//
|
|
// FIXME(rzou): I didn't want to go down the slippery slope of emulating
|
|
// everything TensorIterator does (it would be better to refactor out the
|
|
// TensorIterator logic). The one thing that this code doesn't handle
|
|
// is cross-device logical scalar tensors.
|
|
// cpu_tensor = torch.randn(3)
|
|
// cuda_tensor = torch.randn(3, 10, device='cuda')
|
|
// vmap(torch.mul)(cpu_tensor, cuda_tensor)
|
|
//
|
|
// At a per-example level, we are adding CPUTensor[] and CUDATensor[10].
|
|
// TensorIterator allows for this cross-device operation because one of the
|
|
// tensors is a Scalar CPU tensor. However, the following code will throw an
|
|
// error in that case. I don't expect to see many use cases for this, so
|
|
// this is probably fine as-is.
|
|
auto logical_self = self;
|
|
auto logical_other = other;
|
|
auto result_type = at::native::result_type(logical_self, logical_other);
|
|
if (logical_self.scalar_type() != result_type) {
|
|
logical_self = logical_self.to(result_type);
|
|
}
|
|
if (logical_other.scalar_type() != result_type) {
|
|
logical_other = logical_other.to(result_type);
|
|
}
|
|
auto physical_args = BroadcastingVmapTransform::logicalToPhysical(
|
|
{logical_self, logical_other});
|
|
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor(), args...);
|
|
return physical_args[0].getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor expand_batching_rule(const Tensor& self, IntArrayRef size, bool implicit) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.expand(size, implicit);
|
|
}
|
|
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto size_physical = self_physical.getPhysicalShape(size);
|
|
auto self_physical_dim = self_physical.tensor().dim();
|
|
|
|
TORCH_CHECK(self_physical_dim <= size_physical.size(),
|
|
"expand: the number of sizes provided (", /*logical*/size.size(), ") ",
|
|
"must be greater or equal to the number of dimensions in the tensor (",
|
|
/*logical dim*/self.dim(), ")");
|
|
|
|
if (self_physical_dim == size_physical.size()) {
|
|
auto result = self_physical.tensor().expand(size_physical, implicit);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(self_physical_dim < size_physical.size());
|
|
// Here, we know we are expanding a (logical) tensor to a larger number
|
|
// of dimensions. We have to be careful because we can't call expand directly
|
|
// due to the presence of batch dimensions.
|
|
//
|
|
// As an example, let B0 be a batch dimension and consider expand(Tensor[B0, 3], [2, 3]).
|
|
// The result should be a tensor of size [B0, 2, 3].
|
|
// A physical view of size [B0, 3] can't directly be expanded to size [B0, 2, 3]
|
|
// so the strategy here is to view it first as a tensor of size [B0, 1, 3] and
|
|
// then expand.
|
|
auto self_physical_size = self_physical.tensor().sizes();
|
|
auto extra_dims = size_physical.size() - self_physical_dim;
|
|
VmapDimVector view_shape(size_physical.size(), 1);
|
|
std::copy(self_physical_size.begin(),
|
|
self_physical_size.begin() + self_physical.numBatchDims(),
|
|
view_shape.begin());
|
|
std::copy(self_physical_size.begin() + self_physical.numBatchDims(),
|
|
self_physical_size.end(),
|
|
view_shape.begin() + self_physical.numBatchDims() + extra_dims);
|
|
auto result = self_physical.tensor().view(view_shape).expand(size_physical, implicit);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
std::vector<Tensor> chunk_batching_rule(const Tensor& self, int64_t chunks, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.chunk(chunks, dim);
|
|
}
|
|
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::chunk(self_physical.tensor(), chunks, dim_physical);
|
|
self_physical.getPhysicalToLogicalMap().applyInplace(result);
|
|
return result;
|
|
}
|
|
|
|
Tensor clamp_batching_rule(const Tensor& self, optional<Scalar> min, optional<Scalar> max) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.clamp(min, max);
|
|
}
|
|
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = at::clamp(self_physical.tensor(), min, max);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor clamp_min_batching_rule(const Tensor& self, Scalar min) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::clamp_min(self, min);
|
|
}
|
|
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = at::clamp_min(self_physical.tensor(), min);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor clamp_max_batching_rule(const Tensor& self, Scalar max) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::clamp_max(self, max);
|
|
}
|
|
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = at::clamp_max(self_physical.tensor(), max);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
std::vector<Tensor> tensor_split_sections_batching_rule(const Tensor& self, int64_t sections, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::tensor_split(self, sections, dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::tensor_split(self_physical.tensor(), sections, dim_physical);
|
|
self_physical.getPhysicalToLogicalMap().applyInplace(result);
|
|
return result;
|
|
}
|
|
|
|
std::vector<Tensor> tensor_split_indices_batching_rule(const Tensor& self, IntArrayRef indices, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::tensor_split(self, indices, dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::tensor_split(self_physical.tensor(), indices, dim_physical);
|
|
self_physical.getPhysicalToLogicalMap().applyInplace(result);
|
|
return result;
|
|
}
|
|
|
|
Tensor unsqueeze_batching_rule(const Tensor& self, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::unsqueeze(self, dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
// NB: unsqueeze has some special handling of its `dim` argument so we can't call
|
|
// self_physical.getPhysicalDim directly. In particular, native::unsqueeze
|
|
// wraps the dim to (the logical dimension) + 1, so we need to do that here too.
|
|
// https://github.com/pytorch/pytorch/blob/b623bdeabb0aa8da44285d303246e7f8ac06c2a9/aten/src/ATen/native/TensorShape.cpp#L1413
|
|
auto dim_physical =
|
|
self_physical.numBatchDims() + maybe_wrap_dim(dim, /*logical_dim*/self.dim() + 1);
|
|
auto result = self_physical.tensor().unsqueeze(dim_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
// Checks if the batch dims in `bdims` appear at the front of the tensor.
|
|
static bool areBdimsAtFrontInOrder(BatchDimsRef bdims) {
|
|
for (int64_t idx = 0; idx < bdims.size(); idx++) {
|
|
if (bdims[idx].dim() != idx) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
Tensor& squeeze_dim__batching_rule(Tensor& self, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.squeeze_(dim);
|
|
}
|
|
auto* batched = maybeGetBatchedImpl(self);
|
|
TORCH_CHECK(areBdimsAtFrontInOrder(batched->bdims()), "NYI: squeeze_ with bdims not at front");
|
|
auto num_bdims = batched->bdims().size();
|
|
auto logical_dim = self.dim();
|
|
auto dim_physical = num_bdims + maybe_wrap_dim(dim, logical_dim);
|
|
batched->value().squeeze_(dim_physical);
|
|
|
|
// Also need to change some metadata...
|
|
batched->refreshSizesAndStrides();
|
|
return self;
|
|
}
|
|
|
|
Tensor& unsqueeze__batching_rule(Tensor& self, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.unsqueeze_(dim);
|
|
}
|
|
auto* batched = maybeGetBatchedImpl(self);
|
|
TORCH_CHECK(areBdimsAtFrontInOrder(batched->bdims()), "NYI: unsqueeze_ with bdims not at front");
|
|
auto num_bdims = batched->bdims().size();
|
|
auto logical_dim = self.dim();
|
|
auto dim_physical = num_bdims + maybe_wrap_dim(dim, logical_dim + 1);
|
|
batched->value().unsqueeze_(dim_physical);
|
|
|
|
// Also need to change some metadata...
|
|
batched->refreshSizesAndStrides();
|
|
return self;
|
|
}
|
|
|
|
Tensor& fill_inplace_scalar_batching_rule(Tensor& self, Scalar value) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.fill_(value);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
self_physical.tensor().fill_(value);
|
|
return self;
|
|
}
|
|
|
|
Tensor& fill_inplace_tensor_batching_rule(Tensor& self, const Tensor& value) {
|
|
auto value_batched = isBatchedTensor(value);
|
|
|
|
if (value_batched) {
|
|
auto physical_args =
|
|
BroadcastingVmapTransform::logicalToPhysical({self, value});
|
|
physical_args[0].tensor().copy_(physical_args[1].tensor());
|
|
} else {
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
self_physical.tensor().fill_(value);
|
|
}
|
|
return self;
|
|
}
|
|
|
|
Tensor& zero_inplace_batching_rule(Tensor &self) {
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
self_physical.tensor().zero_();
|
|
return self;
|
|
}
|
|
|
|
Tensor squeeze_batching_rule(const Tensor& self) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.squeeze();
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto physical_sizes = self_physical.tensor().sizes();
|
|
|
|
// Don't squeeze the batch dims!
|
|
VmapDimVector squeezed_sizes;
|
|
int64_t num_batch_dims = self_physical.numBatchDims();
|
|
squeezed_sizes.insert(
|
|
squeezed_sizes.end(),
|
|
physical_sizes.begin(),
|
|
physical_sizes.begin() + num_batch_dims);
|
|
for (auto it = physical_sizes.begin() + num_batch_dims; it != physical_sizes.end(); ++it) {
|
|
if (*it != 1) {
|
|
squeezed_sizes.push_back(*it);
|
|
}
|
|
}
|
|
|
|
auto result = self_physical.tensor().view(squeezed_sizes);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor squeeze_dim_batching_rule(const Tensor& self, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.squeeze(dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = self_physical.tensor().squeeze(dim_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor trace_batching_rule(const Tensor& self) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.trace();
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
// Batched Diagonal View
|
|
auto self_diag = at::diagonal(self_physical.tensor(), /*offset*/0, /*dim1*/-2, /*dim2*/-1);
|
|
auto result = at::sum(self_diag, -1);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor trace_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes) {
|
|
if (!participatesInCurrentLevel(grad)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::trace_backward(grad, input_sizes);
|
|
}
|
|
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
|
|
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
|
|
// Batched Diagonal View
|
|
auto grad_input_diag = at::diagonal(grad_input, /*offset*/0, /*dim1*/-2, /*dim2*/-1);
|
|
// Append a dimension of size one to the grad output
|
|
auto grad_physical_tensor = grad_physical.tensor().unsqueeze(-1);
|
|
grad_input_diag.copy_(grad_physical_tensor);
|
|
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
|
|
}
|
|
|
|
Tensor transpose_int_batching_rule(const Tensor& self, int64_t dim0, int64_t dim1) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::transpose(self, dim0, dim1);
|
|
}
|
|
// PyTorch has a special case where scalar_tensor.transpose(dim0, dim1) works
|
|
// for dim0, dim1 in {0, -1} and returns the scalar tensor. If the following happens:
|
|
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
|
// >>> vmap(lambda x: x.transpose(0, -1), x)
|
|
// then we replicate this behavior.
|
|
if (/*logical*/self.dim() == 0 && is_allowed_dim_on_scalar_tensor(dim0) &&
|
|
is_allowed_dim_on_scalar_tensor(dim1)) {
|
|
return self;
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim0_physical = self_physical.getPhysicalDim(dim0);
|
|
auto dim1_physical = self_physical.getPhysicalDim(dim1);
|
|
auto result = self_physical.tensor().transpose(dim0_physical, dim1_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor permute_batching_rule(const Tensor& self, IntArrayRef dims) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.permute(dims);
|
|
}
|
|
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dims_physical = self_physical.getPhysicalDims(dims);
|
|
|
|
VmapDimVector all_dims_physical;
|
|
all_dims_physical.reserve(self_physical.tensor().dim());
|
|
for (int64_t bdim = 0; bdim < self_physical.numBatchDims(); bdim++) {
|
|
all_dims_physical.push_back(bdim);
|
|
}
|
|
all_dims_physical.insert(
|
|
all_dims_physical.end(),
|
|
dims_physical.begin(),
|
|
dims_physical.end());
|
|
auto result = self_physical.tensor().permute(all_dims_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor select_batching_rule(const Tensor& self, int64_t dim, int64_t index) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::select(self, dim, index);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = self_physical.tensor().select(dim_physical, index);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
static int64_t getGradInputPhysicalDim(int64_t dim, IntArrayRef input_sizes, int64_t num_batch_dims) {
|
|
return maybe_wrap_dim(dim, input_sizes.size()) + num_batch_dims;
|
|
}
|
|
|
|
Tensor select_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t index) {
|
|
if (!participatesInCurrentLevel(grad)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::select_backward(grad, input_sizes, dim, index);
|
|
}
|
|
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
|
|
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
|
|
auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
|
|
grad_input.select(physical_dim, index).copy_(grad_physical.tensor());
|
|
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
|
|
}
|
|
|
|
Tensor slice_batching_rule(
|
|
const Tensor& self,
|
|
int64_t dim,
|
|
c10::optional<int64_t> start,
|
|
c10::optional<int64_t> end,
|
|
int64_t step) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::slice(self, dim, start, end, step);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = self_physical.tensor().slice(dim_physical, start, end, step);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor slice_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) {
|
|
if (!participatesInCurrentLevel(grad)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::slice_backward(grad, input_sizes, dim, start, end, step);
|
|
}
|
|
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
|
|
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
|
|
auto physical_dim = getGradInputPhysicalDim(dim, input_sizes, grad_physical.numBatchDims());
|
|
grad_input.slice(physical_dim, start, end, step).copy_(grad_physical.tensor());
|
|
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
|
|
}
|
|
|
|
Tensor diagonal_batching_rule(const Tensor& self, int64_t offset, int64_t dim1, int64_t dim2) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::diagonal(self, offset, dim1, dim2);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim1_physical = self_physical.getPhysicalDim(dim1);
|
|
auto dim2_physical = self_physical.getPhysicalDim(dim2);
|
|
auto result = at::diagonal(self_physical.tensor(), offset, dim1_physical, dim2_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor diagonal_backward_batching_rule(const Tensor& grad, IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) {
|
|
if (!participatesInCurrentLevel(grad)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::diagonal_backward(grad, input_sizes, offset, dim1, dim2);
|
|
}
|
|
auto grad_physical = MultiBatchVmapTransform::logicalToPhysical(grad);
|
|
auto grad_input = at::zeros(grad_physical.getPhysicalShape(input_sizes), grad.options());
|
|
auto dim1_physical = getGradInputPhysicalDim(dim1, input_sizes, grad_physical.numBatchDims());
|
|
auto dim2_physical = getGradInputPhysicalDim(dim2, input_sizes, grad_physical.numBatchDims());
|
|
grad_input.diagonal(offset, dim1_physical, dim2_physical).copy_(grad_physical.tensor());
|
|
return grad_physical.getPhysicalToLogicalMap().apply(grad_input);
|
|
}
|
|
|
|
Tensor movedim_batching_rule(const Tensor& self, IntArrayRef source, IntArrayRef destination) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::movedim(self, source, destination);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto source_physical = self_physical.getPhysicalDims(source);
|
|
auto destination_physical = self_physical.getPhysicalDims(destination);
|
|
auto result = at::movedim(self_physical.tensor(), source_physical, destination_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor reshape_batching_rule(const Tensor& self, IntArrayRef shape) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::reshape(self, shape);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto shape_physical = self_physical.getPhysicalShape(shape);
|
|
auto result = self_physical.tensor().reshape(shape_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
std::vector<Tensor> split_batching_rule(const Tensor& self, int64_t split_size, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::split(self, split_size, dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::split(self_physical.tensor(), split_size, dim_physical);
|
|
self_physical.getPhysicalToLogicalMap().applyInplace(result);
|
|
return result;
|
|
}
|
|
|
|
std::vector<Tensor> split_with_sizes_batching_rule(const Tensor& self, IntArrayRef split_sizes, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::split_with_sizes(self, split_sizes, dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::split_with_sizes(self_physical.tensor(), split_sizes, dim_physical);
|
|
self_physical.getPhysicalToLogicalMap().applyInplace(result);
|
|
return result;
|
|
}
|
|
|
|
std::vector<Tensor> unbind_batching_rule(const Tensor& self, int64_t dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::unbind(self, dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = at::unbind(self_physical.tensor(), dim_physical);
|
|
self_physical.getPhysicalToLogicalMap().applyInplace(result);
|
|
return result;
|
|
}
|
|
|
|
Tensor unfold_batching_rule(const Tensor& self, int64_t dim, int64_t size, int64_t step) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.unfold(dim, size, step);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto dim_physical = self_physical.getPhysicalDim(dim);
|
|
auto result = self_physical.tensor().unfold(dim_physical, size, step);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor contiguous_batching_rule(const Tensor& self, MemoryFormat memory_format) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.contiguous(memory_format);
|
|
}
|
|
TORCH_CHECK(memory_format == MemoryFormat::Contiguous,
|
|
"NYI: Tensor.contiguous(...) inside of vmap for memory_format other ",
|
|
"than torch.contiguous_format");
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = physical_view.tensor().contiguous(memory_format);
|
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor view_batching_rule(const Tensor& self, IntArrayRef size) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return self.view(size);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto size_physical = self_physical.getPhysicalShape(size);
|
|
auto result = self_physical.tensor().view(size_physical);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor flatten_batching_rule(const Tensor& self, int64_t start_dim, int64_t end_dim) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::flatten(self, start_dim, end_dim);
|
|
}
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto start_dim_physical = self_physical.getPhysicalDim(start_dim);
|
|
auto end_dim_physical = self_physical.getPhysicalDim(end_dim);
|
|
auto result = self_physical.tensor().flatten(start_dim, end_dim);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor view_as_complex_batching_rule(const Tensor& self) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::view_as_complex(self);
|
|
}
|
|
// guard against the user passing in a batch of scalar tensors with batch
|
|
// size equal to 2.
|
|
TORCH_CHECK(self.sizes().size() != 0, "Input tensor must have one or more dimensions");
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = at::view_as_complex(self_physical.tensor());
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
// Checks that the smallest batch stride is greater than the largest example
|
|
// stride. This is something we can support but we choose not to because it's
|
|
// potentially error prone.
|
|
static void checkBatchDimsAtFrontInLayout(IntArrayRef physical_strides, int64_t num_batch_dims) {
|
|
auto smallest_batch_stride = std::min_element(
|
|
physical_strides.begin(), physical_strides.begin() + num_batch_dims);
|
|
auto largest_example_stride = std::max_element(
|
|
physical_strides.begin() + num_batch_dims, physical_strides.end());
|
|
if (largest_example_stride == physical_strides.end()) {
|
|
// No example dimensions
|
|
return;
|
|
}
|
|
TORCH_CHECK(*smallest_batch_stride >= *largest_example_stride,
|
|
"vmap: Calling Tensor.as_strided is not supported unless the batch dims being ",
|
|
"vmapped over are at the front of the tensor (in memory layout). When they are ",
|
|
"not at the front of the tensor this operation can be error prone so we "
|
|
"actively discourage it; please file us a bug report and/or try to ",
|
|
"express the as_strided operation in terms of PyTorch view operations");
|
|
}
|
|
|
|
// given (sizes, strides, storage_offset) returns the maximum location that
|
|
// can be indexed (or nullopt if such a location doesn't exist, e.g., tensors
|
|
// with zero-size dims).
|
|
static optional<int64_t> maximum_indexable_location(
|
|
IntArrayRef sizes, IntArrayRef strides, int64_t storage_offset) {
|
|
auto result = native::storage_size_for(sizes, strides);
|
|
if (result == 0) {
|
|
return nullopt;
|
|
}
|
|
return result + storage_offset;
|
|
}
|
|
|
|
// Let x be the "first slice" of physical_tensor.
|
|
// This checks that the range of possible memory locations accessible by
|
|
// x.as_strided(sizes, strides, maybe_storage_offset)
|
|
// are within the bounds of possible memory locations accessible by x.
|
|
static void checkBasicAsStridedValidForSlice(
|
|
const Tensor& physical_tensor,
|
|
int64_t num_batch_dims,
|
|
IntArrayRef sizes,
|
|
IntArrayRef strides,
|
|
optional<int64_t> maybe_storage_offset) {
|
|
auto slice_sizes = physical_tensor.sizes().slice(num_batch_dims);
|
|
auto slice_strides = physical_tensor.strides().slice(num_batch_dims);
|
|
auto base_offset = physical_tensor.storage_offset();
|
|
|
|
auto storage_offset = maybe_storage_offset.value_or(base_offset);
|
|
|
|
auto max_as_strided_loc = maximum_indexable_location(sizes, strides, storage_offset);
|
|
auto max_slice_loc = maximum_indexable_location(slice_sizes, slice_strides, base_offset);
|
|
|
|
if (!max_as_strided_loc.has_value()) {
|
|
return;
|
|
}
|
|
if (!max_slice_loc.has_value()) {
|
|
TORCH_CHECK(false,
|
|
"result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
|
|
"can access memory outside of `tensor`. `tensor` has no storage but the ",
|
|
"passed-in (size, stride, storage_offset) imply a result with some storage. ",
|
|
"This is not supported inside of vmap, please try to rewrite the ",
|
|
"`as_strided` call as a sequence of PyTorch view operations");
|
|
}
|
|
|
|
TORCH_CHECK(
|
|
*max_as_strided_loc <= *max_slice_loc && base_offset <= storage_offset,
|
|
"result = tensor.as_strided(", sizes, ",", strides, ",", storage_offset, ")",
|
|
"can access memory outside of `tensor`. `result` can access some",
|
|
"memory in range [", storage_offset, ", ", *max_as_strided_loc, "], but ",
|
|
"`tensor` can only access some memory in range [", base_offset, ", ",
|
|
*max_slice_loc, "]. This is not supported inside of vmap, please try to",
|
|
"rewrite the `as_strided` call as a sequence of PyTorch view operations");
|
|
}
|
|
|
|
// What are the semantics of as_strided inside of vmap?
|
|
// y = vmap(lambda x: x.as_strided(sizes, strides, offset))(xs)
|
|
// This returns a view on `x`, `y`, such that each y[i] has:
|
|
// - sizes: `sizes`
|
|
// - strides: `strides`
|
|
// - storage_offset: offset + i * x.stride(batch_dim)
|
|
//
|
|
// In other words, it is as if we had treated each x[i] as having storage
|
|
// offset equal to xs.offset() and called as_strided(sizes, sizes, offset).
|
|
// (that is equivalent to x[i].as_strided(
|
|
// sizes, sizes, offset + x[i].storage_offset() - xs.offset()) for all i)
|
|
//
|
|
// Note that this *may* be different from actually running as_strided
|
|
// in a for-loop. This is due to how as_strided takes in `offset` to be
|
|
// an *absolute* offset. As an example, consider:
|
|
// >>> x = torch.tensor([0., 1., 2., 3., 4.]).as_strided([4], [1], 1)
|
|
// >>> z = [x[i].as_strided([1], [1], 1) for i in range(4)]
|
|
// Each z[i] is actually the same view on x (z[i] == torch.tensor([1.]))!
|
|
// However, we consider the above for-loop comprehension to be a user error:
|
|
// a user should have written the following if they wanted to use as_strided
|
|
// in a per-sample way:
|
|
// >>> z = [x[i].as_strided([1], [1], 1 + x[i].storage_offset() - 1) for i in range(4)]
|
|
Tensor as_strided_batching_rule(
|
|
const Tensor& tensor,
|
|
IntArrayRef sizes,
|
|
IntArrayRef strides,
|
|
optional<int64_t> storage_offset) {
|
|
if (!participatesInCurrentLevel(tensor)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::as_strided(tensor, sizes, strides, storage_offset);
|
|
}
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(tensor);
|
|
auto num_batch_dims = physical_view.numBatchDims();
|
|
auto physical_sizes = physical_view.getPhysicalShape(sizes);
|
|
const auto& physical_tensor = physical_view.tensor();
|
|
|
|
// We can't rely on the physical as_strided call to do this for us because
|
|
// we do some sanity checks on the size/strides before calling into as_strided.
|
|
TORCH_CHECK(sizes.size() == strides.size(),
|
|
"Tensor.as_strided(size, stride, ...): size and stride must have the ",
|
|
"same length! Got size ", sizes, " and stride ", strides);
|
|
|
|
// Sanity checks:
|
|
// 1. All batch dims are at the front in memory layout (not necessary for
|
|
// correctness, but we are worried the user might be doing crazy things)
|
|
// 2. as_strided(sizes, strides, storage_offset + tensor[i].offset() - tensor.offset())
|
|
// is valid for a slice of the input tensor.
|
|
// See Note: [When will the as_strided batching rule fail?] for details.
|
|
checkBatchDimsAtFrontInLayout(physical_tensor.strides(), num_batch_dims);
|
|
checkBasicAsStridedValidForSlice(
|
|
physical_tensor, num_batch_dims, sizes, strides, storage_offset);
|
|
|
|
// physical_strides = physical tensor's batch strides + (logical) strides
|
|
auto batch_strides = physical_tensor.strides().slice(0, num_batch_dims);
|
|
VmapDimVector physical_strides;
|
|
physical_strides.reserve(num_batch_dims + strides.size());
|
|
physical_strides.insert(
|
|
physical_strides.end(), batch_strides.begin(), batch_strides.end());
|
|
physical_strides.insert(
|
|
physical_strides.end(), strides.begin(), strides.end());
|
|
|
|
// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
// is valid for all i, then it turns out that
|
|
// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds
|
|
// and creates a tensor y such that each y[i] references the same memory
|
|
// locations as zi. See NOTE: [When will the as_strided batching rule fail?]
|
|
auto result = physical_view.tensor().as_strided(
|
|
physical_sizes, physical_strides, storage_offset);
|
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
// NOTE: [When will the as_strided batching rule fail?]
|
|
// If zi = xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
// is valid for all i, then it turns out that
|
|
// xs.as_strided(physical_sizes, physical_strides, offset) always succeeds and
|
|
// creates a tensor y such that each y[i] refers to the same memory as zi.
|
|
//
|
|
// Let's say we have xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()).
|
|
// Furthermore, let's say that as a part of being "valid" this as_strided call
|
|
// does not return a result that can index memory not indexable by xs[i].
|
|
//
|
|
// WLOG, assume that there's only one batch dim and it is at the front of the
|
|
// `xs` tensor. Let B be the batch size and S be the stride of the batch dim.
|
|
// - If the batch dim isn't at the front of the tensor, then we can just move it
|
|
// to the front with movedim/permute. This is always valid because it just swaps
|
|
// some strides around.
|
|
// - This proof also works for tensors with multiple batch dims. We just have to
|
|
// do a little accounting:
|
|
// - instead of [B], we'd have [B0, B1, ..., Bk].
|
|
// - instead of [S], we'd have [S0, S1, ..., Sk].
|
|
// - instead of i, we'd have a list of indices [I0, I1, ..., Ik]
|
|
// - instead of S * I, we'd have \sum_{i=0}^k S_i * I_i
|
|
//
|
|
// [Equation 1]
|
|
// xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset()) has:
|
|
// - sizes: sizes
|
|
// - strides: strides
|
|
// - offset: offset + S * i
|
|
//
|
|
// x.as_strided itself checks that:
|
|
// - (sizes, strides, offset) are in bounds for `x`'s storage.
|
|
// - strides are positive
|
|
// - offset is positive
|
|
//
|
|
// Claim 1: if xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
// is valid, then
|
|
// ([B] + sizes, [S] + strides, offset + xs.offset()) are in bounds for `xs`'s storage.
|
|
//
|
|
// If we have the claim, then xs.as_strided([B] + sizes, [S] + strides, offset)
|
|
// won't error out. So all we need to check is that the memory locations are
|
|
// what we expected. See [Hand-wavy proof of Claim 1] for proof (it's not very important)
|
|
//
|
|
// xs.as_strided(physical_sizes, physical_strides, offset) is equivalent to
|
|
// xs.as_strided([B] + sizes, [S] + strides, offset)
|
|
//
|
|
// xs.as_strided([B] + sizes, [S] + strides, offset) has:
|
|
// - sizes: [B] + sizes
|
|
// - strides: [S] + strides
|
|
// - offset: offset
|
|
//
|
|
// xs.as_strided([B] + sizes, [S] + strides, offset)[i] has:
|
|
// - sizes: sizes
|
|
// - strides: strides
|
|
// - offset: offset + S * i
|
|
// These memory locations are exactly the same as what we got for [Equation 1],
|
|
// so the xs.as_strided([B] + sizes, [S] + strides, offset) is valid.
|
|
//
|
|
// [Hand-wavy proof of Claim 1]
|
|
// Part of our definition of being valid is that xs[i].as_strided(...)
|
|
// must return a tensor that only uses memory indexable by xs[i].
|
|
// This means that (sizes, strides, offset + xs[i].offset() - xs.offset()) satisfies:
|
|
// offset + xs[i].offset() - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
|
|
// <= xs[i].offset() + 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
|
|
// (the largest-index memory location of xs[i].as_strided(...) must be \leq
|
|
// the largest-index memory location of xs[i])
|
|
//
|
|
// Fiddling that inequality gives us:
|
|
// offset - xs.offset() + 1 + \sum_j (sizes[j] - 1) * strides[j]
|
|
// <= 1 + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
|
|
//
|
|
// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
|
|
// <= 1 + (B-1)*S + \sum_j (xs[i].size(j) - 1) * xs[i].stride(j)
|
|
//
|
|
// offset - xs.offset() + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
|
|
// <= 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
|
|
//
|
|
// offset + 1 + (B-1)*S + \sum_j (sizes[j] - 1) * strides[j]
|
|
// <= xs.offset() + 1 + \sum_j (xs.size(j) - 1) * xs.stride(j)
|
|
// (the largest-index memory location of xs.as_strided(size, stride, offset)
|
|
// is \leq than the largest-index memory location of xs)
|
|
// Under the assumptions we've made, the lower bound (lowest indexed memory)
|
|
// is trivially within the storage.
|
|
//
|
|
// Therefore ([B] + sizes, [S] + strides, offset) are in bounds for
|
|
// `xs`'s storage.
|
|
|
|
template <typename F, F Func, typename... ExtraArgs>
|
|
Tensor unwrap_and_call(const Tensor& input, ExtraArgs... args) {
|
|
if (!participatesInCurrentLevel(input)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return Func(input, args...);
|
|
}
|
|
// guard against the user passing in a batch of scalar tensors with batch
|
|
auto* input_batched = unsafeGetBatchedImpl(input);
|
|
auto output_physical = Func(input_batched->value(), args...);
|
|
auto old_bdims = input_batched->bdims();
|
|
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
|
}
|
|
|
|
template <typename F, F Func, typename... ExtraArgs>
|
|
Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
|
|
if (!participatesInCurrentLevel(input)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return (input.*Func)(extra_args...);
|
|
}
|
|
auto* input_batched = unsafeGetBatchedImpl(input);
|
|
auto output_physical = (input_batched->value().*Func)(extra_args...);
|
|
auto old_bdims = input_batched->bdims();
|
|
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
|
}
|
|
|
|
Tensor pow_scalar_Tensor_batching_rule(Scalar other, const Tensor& self) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::pow(other, self);
|
|
}
|
|
auto* self_batched = unsafeGetBatchedImpl(self);
|
|
auto output_physical = at::pow(other, self_batched->value());
|
|
auto old_bdims = self_batched->bdims();
|
|
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
|
}
|
|
|
|
// Tensor ones_like_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
|
|
// if (!participatesInCurrentLevel(self)) {
|
|
// c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
// return at::ones_like(self, memory_format);
|
|
// }
|
|
//
|
|
// TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
|
|
// || memory_format == MemoryFormat::Contiguous,
|
|
// "NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
|
|
// "memory_format torch.preserve_format or torch.contiguous_format (got ",
|
|
// *memory_format, ")");
|
|
//
|
|
// if (memory_format == MemoryFormat::Contiguous) {
|
|
// auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
// auto output_physical = at::clone(physical_view.tensor(), memory_format);
|
|
// return physical_view.getPhysicalToLogicalMap().apply(output_physical);
|
|
// }
|
|
//
|
|
// TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
|
|
// auto* self_batched = unsafeGetBatchedImpl(self);
|
|
// auto output_physical = at::clone(self_batched->value(), memory_format);
|
|
// auto old_bdims = self_batched->bdims();
|
|
// return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
|
// }
|
|
|
|
Tensor clone_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::clone(self, memory_format);
|
|
}
|
|
// Memory format support is a little tricky because vmap is allowed to move
|
|
// around batch dimensions and some memory formats are rank-dependent.
|
|
// Another weird case is:
|
|
// - a tensor with MemoryFormat::ChannelsLast MUST have 4 dimensions. Do we
|
|
// allow the user to clone a Tensor with 3 logical dimensions and 1 batch
|
|
// dim into a ChannelsLast Tensor? What about a Tensor with 3 logical dims
|
|
// and N>1 batch dims?
|
|
TORCH_CHECK(!memory_format.has_value() || memory_format == MemoryFormat::Preserve
|
|
|| memory_format == MemoryFormat::Contiguous,
|
|
"NYI: Tensor.clone(memory_format) inside vmap is only supported with ",
|
|
"memory_format torch.preserve_format or torch.contiguous_format (got ",
|
|
*memory_format, ")");
|
|
|
|
if (memory_format == MemoryFormat::Contiguous) {
|
|
// There is an ambiguity here when the batch dims are not at the front of
|
|
// the tensor.
|
|
// >>> x = torch.randn(3, B0, 5)
|
|
// >>> y = vmap(lambda x: x.clone(torch.contiguous_format), in_dims=1, out_dims=0)(x)
|
|
// >>> y[0].is_contiguous()
|
|
// ???
|
|
// Should we make the whole tensor contiguous, or should we
|
|
// make the non-batch dims contiguous? We've chosen the latter because
|
|
// philosophically vmap hides the batch dims and operates on a per-sample level.
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto output_physical = at::clone(physical_view.tensor(), memory_format);
|
|
return physical_view.getPhysicalToLogicalMap().apply(output_physical);
|
|
}
|
|
|
|
TORCH_INTERNAL_ASSERT(!memory_format.has_value() || memory_format == MemoryFormat::Preserve);
|
|
auto* self_batched = unsafeGetBatchedImpl(self);
|
|
auto output_physical = at::clone(self_batched->value(), memory_format);
|
|
auto old_bdims = self_batched->bdims();
|
|
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
|
}
|
|
|
|
// Note [Batching rules for matmul-like operators]
|
|
// at::matmul doesn't "de-expand" arguments to get better performance (maybe
|
|
// it should). In the batching rules for matmul-like operators (dot, mv, mm),
|
|
// we should be careful not to expand any unnecessary dimensions. e.g., if
|
|
// only one of the two arguments is a BatchedTensor, then we should try
|
|
// not to expand batch dimensions onto the other arg.
|
|
Tensor mv_batching_rule(const Tensor& self, const Tensor& other) {
|
|
auto self_batched = isBatchedTensor(self);
|
|
auto other_batched = isBatchedTensor(other);
|
|
|
|
// A shape checking API would be nice...
|
|
TORCH_CHECK(self.dim() == 2 && other.dim() == 1,
|
|
"mv(self, other): Shape mismatch: expected matrix "
|
|
"(got `self` of size ", self.sizes(), ") ",
|
|
"and vector (got `other` of size ", other.sizes(), ")");
|
|
|
|
// See Note [Batching rules for matmul-like operators] for why we have cases
|
|
if (self_batched && !other_batched) {
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = at::matmul(self_physical.tensor(), other);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
if (!self_batched && other_batched) {
|
|
// self_physical: [L, K], other_physical: [..., K]
|
|
// We view the tensors as [L, K], [..., K, 1], perform matmul to get
|
|
// a tensor of size [..., L, 1], and unsqueeze the last dim.
|
|
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
|
|
auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
|
|
return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
|
|
}
|
|
if (self_batched && other_batched) {
|
|
// self_physical: [..., L, K], other_physical: [..., K]
|
|
// We view the tensors as [..., L, K], [..., K, 1], perform matmul to get
|
|
// a tensor of size [..., L, 1], and unsqueeze the last dim.
|
|
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
|
|
auto result = at::matmul(
|
|
physical_args[0].tensor(),
|
|
physical_args[1].tensor().unsqueeze(-1));
|
|
return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1));
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
|
|
}
|
|
|
|
Tensor dot_batching_rule(const Tensor& self, const Tensor& other) {
|
|
auto self_batched = isBatchedTensor(self);
|
|
auto other_batched = isBatchedTensor(other);
|
|
|
|
TORCH_CHECK(/*logical*/self.dim() == 1 && /*logical*/other.dim() == 1,
|
|
"dot(self, other): Shape mismatch: vector "
|
|
"(got `self` of size ", self.sizes(), ") ",
|
|
"and vector (got `other` of size ", other.sizes(), ")");
|
|
|
|
// See Note [Batching rules for matmul-like operators] for why we have cases
|
|
if (self_batched && !other_batched) {
|
|
// self_physical: [..., K], other_physical: [K]
|
|
// View the tensors as [..., 1, K] and [K], perform matmul, and unsqueeze.
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto result = at::matmul(self_physical.tensor().unsqueeze(-2), other);
|
|
return self_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
|
|
}
|
|
if (!self_batched && other_batched) {
|
|
// self_physical: [K], other_physical: [..., K]
|
|
// View the tensors as [K] and [..., K, 1], perform matmul, and unsqueeze.
|
|
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
|
|
auto result = at::matmul(self, other_physical.tensor().unsqueeze(-1));
|
|
return other_physical.getPhysicalToLogicalMap().apply(result.squeeze(-1));
|
|
}
|
|
if (self_batched && other_batched) {
|
|
// self_physical: [..., K], other_physical: [..., K]
|
|
// View the tensors as [..., 1, K] and [..., K, 1], perform matmul, and unsqueeze.
|
|
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
|
|
auto result = at::matmul(
|
|
physical_args[0].tensor().unsqueeze(-2),
|
|
physical_args[1].tensor().unsqueeze(-1));
|
|
return physical_args[0].getPhysicalToLogicalMap().apply(result.squeeze(-1).squeeze(-1));
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
|
|
}
|
|
|
|
Tensor bmm_batching_rule(const Tensor& self, const Tensor& other) {
|
|
TORCH_CHECK(/*logical*/self.dim() == 3 && /*logical*/other.dim() == 3,
|
|
"bmm(self, other): Shape mismatch: expected 3D `self` "
|
|
"(got `self` of size ", self.sizes(), ") ",
|
|
"and 3D `other` (got `other` of size ", other.sizes(), ")");
|
|
|
|
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
|
|
auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
|
|
return physical_args[0].getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor mm_batching_rule(const Tensor& self, const Tensor& other) {
|
|
if (!participatesInCurrentLevel(self, other)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::mm(self, other);
|
|
}
|
|
|
|
auto self_batched = participatesInCurrentLevel(self);
|
|
auto other_batched = participatesInCurrentLevel(other);
|
|
|
|
TORCH_CHECK(/*logical*/self.dim() == 2 && /*logical*/other.dim() == 2,
|
|
"mm(self, other): Shape mismatch: expected matrix "
|
|
"(got `self` of size ", self.sizes(), ") ",
|
|
"and matrix (got `other` of size ", other.sizes(), ")");
|
|
|
|
// See Note [Batching rules for matmul-like operators] for why we have cases
|
|
if (self_batched && !other_batched) {
|
|
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
auto result = at::matmul(self_physical.tensor(), other);
|
|
result = self_physical.getPhysicalToLogicalMap().apply(result);
|
|
TORCH_INTERNAL_ASSERT(result.dim() == 2);
|
|
return result;
|
|
}
|
|
if (!self_batched && other_batched) {
|
|
auto other_physical = MultiBatchVmapTransform::logicalToPhysical(other);
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
auto result = at::matmul(self, other_physical.tensor());
|
|
result = other_physical.getPhysicalToLogicalMap().apply(result);
|
|
TORCH_INTERNAL_ASSERT(result.dim() == 2);
|
|
return result;
|
|
}
|
|
if (self_batched && other_batched) {
|
|
auto physical_args = MultiBatchVmapTransform::logicalToPhysical({self, other});
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
auto result = at::matmul(physical_args[0].tensor(), physical_args[1].tensor());
|
|
TORCH_INTERNAL_ASSERT(result.dim() == 3);
|
|
result = physical_args[0].getPhysicalToLogicalMap().apply(result);
|
|
TORCH_INTERNAL_ASSERT(result.dim() == 2);
|
|
return result;
|
|
}
|
|
TORCH_INTERNAL_ASSERT(false, "either self or other must be a BatchedTensor");
|
|
}
|
|
|
|
Tensor cat_batching_rule(TensorList tensors, int64_t dim) {
|
|
if (!participatesInCurrentLevel(tensors)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::cat(tensors, dim);
|
|
}
|
|
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
|
|
auto physical_tensors = fmap(
|
|
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
|
|
TORCH_INTERNAL_ASSERT(
|
|
tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
|
|
auto result = at::cat(physical_tensors, physical_views[0].getPhysicalDim(dim));
|
|
return physical_views[0].getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor stack_batching_rule(TensorList tensors, int64_t dim) {
|
|
if (!participatesInCurrentLevel(tensors)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
return at::stack(tensors, dim);
|
|
}
|
|
auto physical_views = MultiBatchVmapTransform::logicalToPhysical(tensors);
|
|
auto physical_tensors = fmap(
|
|
physical_views, [](const VmapPhysicalView& view) -> Tensor { return view.tensor(); });
|
|
TORCH_INTERNAL_ASSERT(
|
|
tensors.size() > 0, "The dispatcher should not have dispatched here otherwise.");
|
|
// NB: stack wraps the dimensionality to (logical dim + 1), so we have to
|
|
// manually handle that here.
|
|
auto dim_physical =
|
|
physical_views[0].numBatchDims() + maybe_wrap_dim(dim, /*logical*/tensors[0].dim() + 1);
|
|
auto result = at::stack(physical_tensors, dim_physical);
|
|
return physical_views[0].getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
// I am quite sad that we need to register operators with exploded TensorOptions,
|
|
// even though the native:: implementations can use TensorOptions&.
|
|
// This also makes it hard to metaprogram: i.e., we can't use
|
|
// unwrap_and_call<..., at::to> because at::to takes TensorOptions& (!!)
|
|
Tensor to_dtype_layout_batching_rule(
|
|
const Tensor& self,
|
|
optional<ScalarType> dtype,
|
|
optional<Layout> layout,
|
|
optional<Device> device,
|
|
optional<bool> pin_memory,
|
|
bool non_blocking, bool copy,
|
|
optional<MemoryFormat> memory_format) {
|
|
auto options = TensorOptions()
|
|
.dtype(dtype)
|
|
.layout(layout)
|
|
.device(device)
|
|
.pinned_memory(pin_memory);
|
|
auto* input_batched = unsafeGetBatchedImpl(self);
|
|
auto output_physical = input_batched->value().to(options, non_blocking, copy, memory_format);
|
|
auto old_bdims = input_batched->bdims();
|
|
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
|
}
|
|
|
|
Tensor new_zeros_batching_rule(
|
|
const Tensor& self,
|
|
IntArrayRef size,
|
|
optional<ScalarType> dtype,
|
|
optional<Layout> layout,
|
|
optional<Device> device,
|
|
optional<bool> pin_memory) {
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto physical_size = physical_view.getPhysicalShape(size);
|
|
auto options = TensorOptions()
|
|
.dtype(dtype)
|
|
.layout(layout)
|
|
.device(device)
|
|
.pinned_memory(pin_memory);
|
|
auto result = physical_view.tensor().new_zeros(physical_size, options);
|
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor new_empty_batching_rule(
|
|
const Tensor& self,
|
|
IntArrayRef size,
|
|
c10::optional<ScalarType> dtype,
|
|
c10::optional<Layout> layout,
|
|
c10::optional<Device> device,
|
|
c10::optional<bool> pin_memory) {
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto physical_size = physical_view.getPhysicalShape(size);
|
|
auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
|
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor addmm_batching_rule(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
|
|
// Decomposition that is probably not very fast...
|
|
return at::add(self * beta, at::mm(mat1, mat2), alpha);
|
|
}
|
|
|
|
Tensor ones_like_batching_rule(
|
|
const Tensor& self,
|
|
optional<ScalarType> dtype,
|
|
optional<Layout> layout,
|
|
optional<Device> device,
|
|
optional<bool> pin_memory,
|
|
optional<MemoryFormat> memory_format) {
|
|
if (!participatesInCurrentLevel(self)) {
|
|
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
|
auto options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
|
return at::ones_like(self, options, memory_format);
|
|
}
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
|
|
auto result = at::ones_like(physical_view.tensor(), options, memory_format);
|
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
Tensor new_empty_strided_batching_rule(
|
|
const Tensor& self,
|
|
IntArrayRef size,
|
|
IntArrayRef stride,
|
|
optional<ScalarType> dtype,
|
|
optional<Layout> layout,
|
|
optional<Device> device,
|
|
optional<bool> pin_memory) {
|
|
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
|
auto physical_size = physical_view.getPhysicalShape(size);
|
|
|
|
// Let [B0, B1, B2] be the shape of the batch dims. We're going to create
|
|
// the batch dimensions at the front of the tensor (in memory layout),
|
|
// irrespective of whether or not they are actually at the front (in memory layout)
|
|
// in the original `self` tensor. This is because when a user calls
|
|
// `new_empty_strided` in general, the `strides` they provide are for a new
|
|
// tensor and have no relation to the strides of the original tensor.
|
|
//
|
|
// So, the physical shape of the result should be ([B0, B1, B2] + size),
|
|
// but what about the physical strides?
|
|
//
|
|
// We're actually free to pick whatever stride we want:
|
|
// e.g., for size=[5, 3], stride=[0, 1], we could decide to
|
|
// use
|
|
// - physical size: [B0, B1, B2, 5, 3]
|
|
// - physical stride: [9999*B1*B2, 9999*B2, 9999, 0, 1]
|
|
//
|
|
// Let's select some reasonable strides such that:
|
|
// - The batch dims are "contiguous" with respect to each other
|
|
// - if empty_strided(size, stride) would have created a contiguous Tensor,
|
|
// then this new physical Tensor (with batch dims) is also contiguous
|
|
//
|
|
// Let S be the size of the storage if one were to construct a tensor
|
|
// with `size` and `stride` via empty_strided(size, stride).
|
|
// Then the physical sizes/strides should be:
|
|
// - physical size: [B0, B1, B2, 5, 3]
|
|
// - physical stride: [B1 * B2 * S, B2 * S, S, 0, 1]
|
|
auto batch_shape = IntArrayRef(
|
|
physical_view.tensor().sizes().begin(), physical_view.numBatchDims());
|
|
|
|
// physical_strides = [B1 * B2 * S, B2 * S, S]
|
|
auto physical_strides = at::detail::defaultStrides(batch_shape);
|
|
TORCH_CHECK(size.size() == stride.size(),
|
|
"new_empty_strided(sizes, strides): dimensionality of sizes (",
|
|
size.size(), ") must match dimensionality of strides (",
|
|
stride.size(), ")");
|
|
auto storage_size = native::storage_size_for(size, stride);
|
|
for (auto& physical_stride : physical_strides) {
|
|
physical_stride *= storage_size;
|
|
}
|
|
|
|
// physical_strides = [B1 * B2 * S, B2 * S, S] + strides
|
|
physical_strides.insert(physical_strides.end(), stride.begin(), stride.end());
|
|
|
|
auto result = physical_view.tensor().new_empty_strided(
|
|
physical_size, physical_strides, dtype, layout, device, pin_memory);
|
|
return physical_view.getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
template <typename F, F Func>
|
|
Tensor comparison_pointwise_batching_rule(const Tensor& self, const Tensor& other) {
|
|
auto physical_args = BroadcastingVmapTransform::logicalToPhysical({self, other});
|
|
auto result = Func(physical_args[0].tensor(), physical_args[1].tensor());
|
|
return physical_args[0].getPhysicalToLogicalMap().apply(result);
|
|
}
|
|
|
|
bool BatchedTensor_is_leaf(const Tensor& self) {
|
|
if (torch::autograd::impl::get_autograd_meta(self)) {
|
|
return torch::autograd::impl::get_autograd_meta(self)->grad_fn_ == nullptr;
|
|
} else {
|
|
return true;
|
|
}
|
|
}
|
|
|
|
Tensor& BatchedTensor_requires_grad_(Tensor& self, bool requires_grad) {
|
|
self.set_requires_grad(requires_grad);
|
|
return self;
|
|
}
|
|
|
|
|
|
TORCH_LIBRARY_IMPL(_, BatchedOutOfTree, m) {
|
|
m.fallback(torch::CppFunction::makeFromBoxedFunction<&batchedTensorForLoopFallback>());
|
|
}
|
|
|
|
// // debug_t<tail_t<tail_t<typelist<Tensor, optional<int64_t>>>>> dt;
|
|
// debug_t<remove_batch_dim_after_tensor_t<typelist<Tensor, optional<int64_t>>>> dt;
|
|
|
|
|
|
std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, optional<int64_t> batch_dim) {
|
|
return {tensor.abs(), batch_dim};
|
|
}
|
|
|
|
template <typename F, F Func>
|
|
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(const Tensor& tensor, optional<int64_t> batch_dim) {
|
|
return {Func(tensor), batch_dim};
|
|
}
|
|
|
|
static Tensor moveBatchDimToFront(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
|
|
if (!maybe_batch_dim.has_value()) {
|
|
return tensor;
|
|
}
|
|
return tensor.movedim(maybe_batch_dim.value(), 0);
|
|
}
|
|
|
|
static int64_t rankWithoutBatchDim(const Tensor& tensor, optional<int64_t> maybe_batch_dim) {
|
|
int64_t result = tensor.dim();
|
|
if (maybe_batch_dim.has_value()) {
|
|
result -= 1;
|
|
}
|
|
return result;
|
|
}
|
|
|
|
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
|
|
if (!has_bdim) {
|
|
return tensor;
|
|
}
|
|
auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
|
|
if (tensor_logical_rank >= logical_rank) {
|
|
return tensor;
|
|
}
|
|
VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
|
|
for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
|
|
new_sizes.insert(new_sizes.begin() + 1, 1);
|
|
}
|
|
return tensor.view(new_sizes);
|
|
}
|
|
|
|
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
|
|
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
|
|
if (logical_scalar_tensor.scalar_type() != result_type) {
|
|
logical_scalar_tensor = logical_scalar_tensor.to(result_type);
|
|
}
|
|
if (second.scalar_type() != result_type) {
|
|
second = second.to(result_type);
|
|
}
|
|
}
|
|
|
|
template <typename F, F Func>
|
|
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
|
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
|
const Tensor& other, optional<int64_t> other_batch_dim) {
|
|
// compute max logical rank
|
|
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
|
|
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
|
|
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
|
|
|
|
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
|
|
auto other_ = moveBatchDimToFront(other, other_batch_dim);
|
|
|
|
// In the (0D, ND) case, type promotion semantics are different :/
|
|
auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
|
|
auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
|
|
if (tensor_is_logical_scalar && !other_is_logical_scalar) {
|
|
handleScalarTypePromotion(tensor_, other_);
|
|
}
|
|
if (other_is_logical_scalar && !tensor_is_logical_scalar) {
|
|
handleScalarTypePromotion(other_, tensor_);
|
|
}
|
|
|
|
// If the dimensions aren't aligned, we need to line them up.
|
|
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
|
|
// Note that only tensors that have a batch dim need to be modified.
|
|
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
|
|
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
|
|
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
|
|
|
|
auto result = Func(tensor_, other_);
|
|
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
|
|
? optional<int64_t>{0} : nullopt;
|
|
return { std::move(result), std::move(result_batch_dim) };
|
|
}
|
|
|
|
Tensor matmul_decomposed(
|
|
const Tensor& tensor1,
|
|
const Tensor& tensor2) {
|
|
auto dim_tensor1 = tensor1.dim();
|
|
auto dim_tensor2 = tensor2.dim();
|
|
|
|
if (dim_tensor1 == 1 && dim_tensor2 == 1) {
|
|
return tensor1.dot(tensor2);
|
|
} else if (dim_tensor1 == 2 && dim_tensor2 == 1) {
|
|
return tensor1.mv(tensor2);
|
|
} else if (dim_tensor1 == 1 && dim_tensor2 == 2) {
|
|
return tensor1.unsqueeze(0).mm(tensor2).squeeze_(0);
|
|
} else if (dim_tensor1 == 2 && dim_tensor2 == 2) {
|
|
return tensor1.mm(tensor2);
|
|
} else if (dim_tensor1 >= 3 && (dim_tensor2 == 1 || dim_tensor2 == 2)) {
|
|
// optimization: use mm instead of bmm by folding tensor1's batch into
|
|
// its leading matrix dimension.
|
|
|
|
Tensor t2 = dim_tensor2 == 1 ? tensor2.unsqueeze(-1) : tensor2;
|
|
auto size1 = tensor1.sizes();
|
|
auto size2 = t2.sizes();
|
|
std::vector<int64_t> output_size;
|
|
output_size.insert(output_size.end(), size1.begin(), size1.end() - 1);
|
|
if (dim_tensor2 > 1) {
|
|
output_size.push_back(size2[dim_tensor2 - 1]);
|
|
}
|
|
|
|
// fold the batch into the first dimension
|
|
Tensor t1 = tensor1.reshape({-1, size1[size1.size() - 1]});
|
|
Tensor output = t1.mm(t2).view(output_size);
|
|
return output;
|
|
} else if ((dim_tensor1 == 1 || dim_tensor1 == 2) && dim_tensor2 >= 3) {
|
|
// optimization: transpose the inner dimensions of the arguments, call
|
|
// matmul on the swapped arguments, then transpose the inner dimensions
|
|
// of the result.
|
|
const int64_t n = dim_tensor1 == 2 ? tensor1.size(-2) : 1;
|
|
const int64_t m = tensor1.size(-1);
|
|
const int64_t p = tensor2.size(-1);
|
|
|
|
const Tensor t2_T = tensor2.transpose(-1, -2);
|
|
const Tensor t1_T = dim_tensor1 == 2 ? tensor1.t() : tensor1.reshape({n, m}).t();
|
|
const Tensor res_T = at::matmul(t2_T, t1_T);
|
|
|
|
if (dim_tensor1 == 2) {
|
|
return res_T.transpose(-1, -2);
|
|
}
|
|
else {
|
|
std::vector<int64_t> shape = tensor2.sizes().slice(0, dim_tensor2 - 2).vec();
|
|
shape.push_back(p);
|
|
|
|
Tensor res = res_T.reshape(shape);
|
|
return res;
|
|
}
|
|
} else if ((dim_tensor1 >= 1 && dim_tensor2 >= 1) && (dim_tensor1 >= 3 || dim_tensor2 >= 3)) {
|
|
// We are multiplying b1 x n x m1 by x2 x m2 x p (where b1 can be a list);
|
|
// we track m1 vs m2 separately even though they must match for nicer error messages
|
|
int64_t n = dim_tensor1 > 1 ? tensor1.size(-2) : 1;
|
|
int64_t m1 = tensor1.size(-1);
|
|
IntArrayRef batch_tensor1(tensor1.sizes().data(), std::max<int64_t>(dim_tensor1 - 2, 0));
|
|
int64_t m2 = dim_tensor2 > 1 ? tensor2.size(-2) : 1;
|
|
int64_t p = tensor2.size(-1);
|
|
IntArrayRef batch_tensor2(tensor2.sizes().data(), std::max<int64_t>(dim_tensor2 - 2, 0));
|
|
|
|
// expand the batch portion (i.e. cut off matrix dimensions and expand rest)
|
|
std::vector<int64_t> expand_batch_portion = infer_size(batch_tensor1, batch_tensor2);
|
|
|
|
std::vector<int64_t> tensor1_expand_size(expand_batch_portion);
|
|
tensor1_expand_size.insert(tensor1_expand_size.end(), {n, m1});
|
|
|
|
std::vector<int64_t> tensor2_expand_size(expand_batch_portion);
|
|
tensor2_expand_size.insert(tensor2_expand_size.end(), {m2, p});
|
|
|
|
const int64_t expand_batch_product =
|
|
c10::multiply_integers(expand_batch_portion);
|
|
|
|
std::vector<int64_t> tensor1_bmm_view({expand_batch_product});
|
|
tensor1_bmm_view.insert(tensor1_bmm_view.end(), {n, m1});
|
|
|
|
std::vector<int64_t> tensor2_bmm_view({expand_batch_product});
|
|
tensor2_bmm_view.insert(tensor2_bmm_view.end(), {m2, p});
|
|
|
|
// flatten expanded batches
|
|
Tensor tensor1_expanded = tensor1.expand(tensor1_expand_size).reshape(tensor1_bmm_view);
|
|
Tensor tensor2_expanded = tensor2.expand(tensor2_expand_size).reshape(tensor2_bmm_view);
|
|
|
|
// reshape batches back into result
|
|
std::vector<int64_t> output_shape(expand_batch_portion);
|
|
if (dim_tensor1 > 1) {
|
|
output_shape.push_back(n);
|
|
}
|
|
if (dim_tensor2 > 1) {
|
|
output_shape.push_back(p);
|
|
}
|
|
|
|
Tensor output = tensor1_expanded.bmm(tensor2_expanded).view(output_shape);
|
|
return output;
|
|
}
|
|
|
|
AT_ERROR("both arguments to matmul need to be at least 1D, but they are ",
|
|
dim_tensor1, "D and ", dim_tensor2, "D");
|
|
}
|
|
|
|
TORCH_LIBRARY_IMPL(aten, BatchedOutOfTree, m) {
|
|
|
|
#define VMAP_SUPPORT(op, batch_rule) \
|
|
m.impl(op, PrimBatchRule7< \
|
|
decltype(&batch_rule), &batch_rule, to_operator_t<decltype(batch_rule)> \
|
|
>::apply);
|
|
|
|
// VMAP_OUTPLACE_OP("abs", abs_batch_rule);
|
|
// m.impl("abs", PrimBatchRule7<decltype(&abs_batch_rule), &abs_batch_rule, to_operator_t<decltype(abs_batch_rule)>>::apply);
|
|
|
|
// NB: Ideally we would like some operators, like size.int, to "fallthrough"
|
|
// to the underlying implementation. However, because a BatchedTensor is a
|
|
// Tensor wrapper, it only has one dispatch key (Batched) on it. The resolution
|
|
// here is to just directly call the underlying implementation.
|
|
m.impl("size.int", static_cast<int64_t (*)(const Tensor&, int64_t)>(native::size));
|
|
// m.impl("_add_batch_dim", native::_add_batch_dim);
|
|
// m.impl("_remove_batch_dim", native::_remove_batch_dim);
|
|
|
|
m.impl("max_pool2d", at::native::max_pool2d); // composite
|
|
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
|
|
|
|
m.impl("mean", mean_batching_rule);
|
|
m.impl("sum.dim_IntList", sum_batching_rule);
|
|
m.impl("log_softmax.int", log_softmax_batching_rule);
|
|
m.impl("_log_softmax", _log_softmax_batching_rule);
|
|
m.impl("is_complex", native::is_complex);
|
|
m.impl("conj", native::conj);
|
|
m.impl("flatten.using_ints", flatten_batching_rule);
|
|
m.impl("cross_entropy_loss", native::cross_entropy_loss);
|
|
//
|
|
// // inplace operations
|
|
// m.impl("fill_.Scalar", fill_inplace_scalar_batching_rule);
|
|
// m.impl("fill_.Tensor", fill_inplace_tensor_batching_rule);
|
|
// m.impl("zero_", zero_inplace_batching_rule);
|
|
|
|
// // autograd things...
|
|
// m.impl("is_leaf", BatchedTensor_is_leaf);
|
|
// m.impl("requires_grad_", BatchedTensor_requires_grad_);
|
|
|
|
// view operations
|
|
m.impl("as_strided", as_strided_batching_rule);
|
|
m.impl("chunk", chunk_batching_rule);
|
|
m.impl("tensor_split.sections", tensor_split_sections_batching_rule);
|
|
m.impl("tensor_split.indices", tensor_split_indices_batching_rule);
|
|
m.impl("diagonal", diagonal_batching_rule);
|
|
m.impl("expand", expand_batching_rule);
|
|
m.impl("expand_as", native::expand_as); // composite wrt autograd
|
|
m.impl("movedim.intlist", movedim_batching_rule);
|
|
m.impl("movedim.int", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t)>(native::movedim)); // composite wrt autograd
|
|
// NB: static_cast because there's another variant of narrow. However, we don't
|
|
// want to support the other variant yet bc it isn't documented...
|
|
m.impl("narrow", static_cast<Tensor(*)(const Tensor&,int64_t,int64_t,int64_t)>(native::narrow)); // composite wrt autograd
|
|
m.impl("numpy_T", native::numpy_T); // composite wrt autograd
|
|
m.impl("permute", permute_batching_rule);
|
|
m.impl("reshape", reshape_batching_rule);
|
|
m.impl("reshape_as", native::reshape_as); // composite wrt autograd
|
|
m.impl("select.int", select_batching_rule);
|
|
m.impl("slice.Tensor", slice_batching_rule);
|
|
m.impl("split.Tensor", split_batching_rule);
|
|
m.impl("split_with_sizes", split_with_sizes_batching_rule);
|
|
m.impl("squeeze", squeeze_batching_rule);
|
|
m.impl("squeeze.dim", squeeze_dim_batching_rule);
|
|
m.impl("squeeze_.dim", squeeze_dim__batching_rule);
|
|
m.impl("t", native::t); // composite wrt autograd
|
|
m.impl("trace", trace_batching_rule);
|
|
m.impl("transpose.int", transpose_int_batching_rule);
|
|
m.impl("unbind.int", unbind_batching_rule);
|
|
m.impl("unfold", unfold_batching_rule);
|
|
m.impl("unsqueeze", unsqueeze_batching_rule);
|
|
m.impl("unsqueeze_", unsqueeze__batching_rule);
|
|
m.impl("view", view_batching_rule);
|
|
m.impl("view_as", native::view_as); // composite wrt autograd
|
|
|
|
// m.impl("addmm", addmm_batching_rule);
|
|
m.impl("matmul", matmul_decomposed);
|
|
//
|
|
// clamp operations
|
|
// m.impl("clamp", clamp_batching_rule);
|
|
// m.impl("clamp_min", clamp_min_batching_rule);
|
|
// m.impl("clamp_max", clamp_max_batching_rule);
|
|
|
|
// unary pointwise, out-of-place, no additional arguments.
|
|
#define UNARY_POINTWISE_BATCH_RULE(op) unwrap_and_call2<decltype(&op), &op>
|
|
|
|
#define UNARY_POINTWISE(op) VMAP_SUPPORT(#op, UNARY_POINTWISE_BATCH_RULE(at::op));
|
|
UNARY_POINTWISE(abs);
|
|
UNARY_POINTWISE(acos);
|
|
UNARY_POINTWISE(asin);
|
|
UNARY_POINTWISE(atan);
|
|
UNARY_POINTWISE(ceil);
|
|
UNARY_POINTWISE(cos);
|
|
UNARY_POINTWISE(cosh);
|
|
UNARY_POINTWISE(_conj);
|
|
UNARY_POINTWISE(digamma);
|
|
UNARY_POINTWISE(exp);
|
|
UNARY_POINTWISE(expm1);
|
|
UNARY_POINTWISE(floor);
|
|
UNARY_POINTWISE(frac);
|
|
UNARY_POINTWISE(lgamma);
|
|
UNARY_POINTWISE(log);
|
|
UNARY_POINTWISE(log10);
|
|
UNARY_POINTWISE(log1p);
|
|
UNARY_POINTWISE(log2);
|
|
UNARY_POINTWISE(neg);
|
|
UNARY_POINTWISE(reciprocal);
|
|
UNARY_POINTWISE(relu);
|
|
UNARY_POINTWISE(round);
|
|
UNARY_POINTWISE(rsqrt);
|
|
UNARY_POINTWISE(sigmoid);
|
|
UNARY_POINTWISE(sign);
|
|
UNARY_POINTWISE(sin);
|
|
UNARY_POINTWISE(sinh);
|
|
UNARY_POINTWISE(sqrt);
|
|
UNARY_POINTWISE(tan);
|
|
UNARY_POINTWISE(tanh);
|
|
UNARY_POINTWISE(trunc);
|
|
#undef UNARY_POINTWISE
|
|
#define TO_BATCHING_RULE(name, ...) \
|
|
{ \
|
|
using to_type = Tensor(Tensor::*)(__VA_ARGS__) const; \
|
|
m.impl(name, unwrap_and_call_method< \
|
|
to_type, &Tensor::to, __VA_ARGS__>);\
|
|
}
|
|
TO_BATCHING_RULE("to.device", Device, ScalarType, bool, bool, optional<MemoryFormat>)
|
|
TO_BATCHING_RULE("to.dtype", ScalarType, bool, bool, optional<MemoryFormat>)
|
|
TO_BATCHING_RULE("to.other", const Tensor&, bool, bool, optional<MemoryFormat>)
|
|
m.impl("to.dtype_layout", to_dtype_layout_batching_rule);
|
|
#undef TO_BATCHING_RULE
|
|
m.impl("clone", clone_batching_rule);
|
|
// m.impl("ones_like", ones_like_batching_rule);
|
|
|
|
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
|
|
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
|
|
using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
|
|
|
|
// #define BINARY_POINTWISE(op) \
|
|
// m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
|
|
// m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
|
|
// #define BINARY_POINTWISE_VA(op, ...) \
|
|
// { \
|
|
// using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
|
|
// using Unop = Tensor (*)(const Tensor&, Scalar, __VA_ARGS__); \
|
|
// m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
|
|
// m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, Scalar, __VA_ARGS__>); \
|
|
// }
|
|
|
|
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
|
|
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
|
|
// BINARY_POINTWISE_VA(add, Scalar);
|
|
// BINARY_POINTWISE_VA(sub, Scalar);
|
|
// BINARY_POINTWISE_VA(rsub, Scalar);
|
|
BINARY_POINTWISE(mul);
|
|
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
|
|
// BINARY_POINTWISE(div);
|
|
//
|
|
// // at::pow has three out-of-place overloads
|
|
// m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
|
|
// m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, Scalar>);
|
|
// m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
|
|
//
|
|
// m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
|
|
// m.impl(
|
|
// "threshold_backward",
|
|
// binary_pointwise_batching_rule<
|
|
// TensorTensorScalarType,
|
|
// at::threshold_backward,
|
|
// Scalar>);
|
|
//
|
|
// for at::result_type, call the native::result_type implementation.
|
|
// We don't have to do anything special because native::result_type operates
|
|
// on the logical shape of the tensors.
|
|
m.impl("result_type.Tensor", static_cast<ScalarType (*)(const Tensor&, const Tensor&)>(native::result_type));
|
|
m.impl("result_type.Scalar", static_cast<ScalarType (*)(const Tensor&, const Scalar&)>(native::result_type));
|
|
m.impl("result_type.Scalar_Tensor", static_cast<ScalarType (*)(const Scalar&, const Tensor&)>(native::result_type));
|
|
m.impl("result_type.Scalar_Scalar", static_cast<ScalarType (*)(const Scalar&, const Scalar&)>(native::result_type));
|
|
//
|
|
// #undef BINARY_POINTWISE_VA
|
|
// #undef BINARY_POINTWISE
|
|
//
|
|
//
|
|
#define TRIVIAL_OP(op) m.impl(#op, \
|
|
unwrap_and_call<Tensor (*)(const Tensor&), at::op>);
|
|
// complex number view operators
|
|
TRIVIAL_OP(imag)
|
|
TRIVIAL_OP(real);
|
|
TRIVIAL_OP(view_as_real);
|
|
m.impl("view_as_complex", view_as_complex_batching_rule);
|
|
// #undef TRIVIAL
|
|
// //
|
|
// // // matmul-like operators
|
|
// // m.impl("mv", mv_batching_rule);
|
|
// // m.impl("dot", dot_batching_rule);
|
|
// // m.impl("bmm", bmm_batching_rule);
|
|
m.impl("mm", mm_batching_rule);
|
|
// //
|
|
// cat/stack
|
|
m.impl("cat", cat_batching_rule);
|
|
m.impl("stack", stack_batching_rule);
|
|
// //
|
|
// // // backward operators
|
|
// // m.impl("select_backward", select_backward_batching_rule);
|
|
// // m.impl("slice_backward", slice_backward_batching_rule);
|
|
// // m.impl("trace_backward", trace_backward_batching_rule);
|
|
// // m.impl("diagonal_backward", diagonal_backward_batching_rule);
|
|
// //
|
|
// // // Tensor.new_* operators
|
|
// m.impl("ones_like", ones_like_batching_rule);
|
|
// // m.impl("new_empty", new_empty_batching_rule);
|
|
// m.impl("new_empty_strided", new_empty_strided_batching_rule);
|
|
// // m.impl("new_zeros", new_zeros_batching_rule);
|
|
// //
|
|
// // m.impl("contiguous", contiguous_batching_rule);
|
|
// //
|
|
// // // Comparison ops
|
|
// // #define COMPARISON_POINTWISE(op) \
|
|
// // m.impl(#op".Tensor", comparison_pointwise_batching_rule<TensorTensorType, at::op>); \
|
|
// // m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
|
|
// //
|
|
// // COMPARISON_POINTWISE(eq);
|
|
// // COMPARISON_POINTWISE(gt);
|
|
// // COMPARISON_POINTWISE(ge);
|
|
// // COMPARISON_POINTWISE(le);
|
|
// // COMPARISON_POINTWISE(lt);
|
|
// // COMPARISON_POINTWISE(ne);
|
|
// //
|
|
// #undef COMPARISON_POINTWISE
|
|
}
|
|
|
|
}
|
|
} // namespace at
|