mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] Grab bag of batch rules
This commit is contained in:
113
functorch/functorch/csrc/BatchRulesBinaryOps.cpp
Normal file
113
functorch/functorch/csrc/BatchRulesBinaryOps.cpp
Normal file
@ -0,0 +1,113 @@
|
||||
#include <functorch/csrc/BatchRulesHelper.h>
|
||||
|
||||
namespace at { namespace functorch {
|
||||
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> basic_unary_batch_rule(
|
||||
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
|
||||
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
|
||||
}
|
||||
|
||||
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
|
||||
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
|
||||
if (logical_scalar_tensor.scalar_type() != result_type) {
|
||||
logical_scalar_tensor = logical_scalar_tensor.to(result_type);
|
||||
}
|
||||
if (second.scalar_type() != result_type) {
|
||||
second = second.to(result_type);
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
|
||||
if (!has_bdim) {
|
||||
return tensor;
|
||||
}
|
||||
auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
|
||||
if (tensor_logical_rank >= logical_rank) {
|
||||
return tensor;
|
||||
}
|
||||
VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
|
||||
for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
|
||||
new_sizes.insert(new_sizes.begin() + 1, 1);
|
||||
}
|
||||
return tensor.view(new_sizes);
|
||||
}
|
||||
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
||||
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
||||
const Tensor& other, optional<int64_t> other_batch_dim,
|
||||
ExtraArgs... extra_args) {
|
||||
// compute max logical rank
|
||||
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
|
||||
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
|
||||
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
|
||||
|
||||
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
|
||||
auto other_ = moveBatchDimToFront(other, other_batch_dim);
|
||||
|
||||
// In the (0D, ND) case, type promotion semantics are different :/
|
||||
auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
|
||||
auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
|
||||
if (tensor_is_logical_scalar && !other_is_logical_scalar) {
|
||||
handleScalarTypePromotion(tensor_, other_);
|
||||
}
|
||||
if (other_is_logical_scalar && !tensor_is_logical_scalar) {
|
||||
handleScalarTypePromotion(other_, tensor_);
|
||||
}
|
||||
|
||||
// If the dimensions aren't aligned, we need to line them up.
|
||||
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
|
||||
// Note that only tensors that have a batch dim need to be modified.
|
||||
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
|
||||
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
|
||||
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
|
||||
|
||||
auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
|
||||
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
|
||||
? optional<int64_t>{0} : nullopt;
|
||||
return { std::move(result), std::move(result_batch_dim) };
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> pow_scalar_tensor_batch_rule(
|
||||
const Scalar& other,
|
||||
const Tensor& tensor, optional<int64_t> tensor_batch_dim) {
|
||||
return { at::pow(other, tensor), tensor_batch_dim };
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
|
||||
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
|
||||
using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
|
||||
|
||||
#define BINARY_POINTWISE_BATCH_RULE_SCALAR(op) \
|
||||
binary_pointwise_batch_rule<TensorTensorScalarType, &op, const Scalar&>
|
||||
#define BINARY_POINTWISE_WITH_SCALAR(op) \
|
||||
VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE_SCALAR(at::op));
|
||||
|
||||
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
|
||||
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
|
||||
|
||||
BINARY_POINTWISE_WITH_SCALAR(add);
|
||||
BINARY_POINTWISE_WITH_SCALAR(sub);
|
||||
BINARY_POINTWISE_WITH_SCALAR(rsub);
|
||||
BINARY_POINTWISE(mul);
|
||||
BINARY_POINTWISE(div);
|
||||
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
|
||||
|
||||
// at::pow has three out-of-place overloads
|
||||
#define POW_BATCH_RULE binary_pointwise_batch_rule<TensorTensorType, &at::pow>
|
||||
VMAP_SUPPORT("pow.Tensor_Tensor", POW_BATCH_RULE);
|
||||
#undef POW_BATCH_RULE
|
||||
#define POW_BATCH_RULE basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>
|
||||
VMAP_SUPPORT("pow.Tensor_Scalar", POW_BATCH_RULE);
|
||||
#undef POW_BATCH_RULE
|
||||
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
|
||||
|
||||
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
|
||||
#undef BINARY_POINTWISE_BATCH_RULE
|
||||
#undef BINARY_POINTWISE_WITH_SCALAR
|
||||
#undef BINARY_POINTWISE
|
||||
}
|
||||
|
||||
}}
|
@ -95,9 +95,26 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
|
||||
return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
|
||||
}
|
||||
|
||||
// NB: repeat is not actually a view, but it is in this file
|
||||
std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
|
||||
const Tensor& self,
|
||||
optional<int64_t> self_bdim,
|
||||
IntArrayRef sizes) {
|
||||
if (!self_bdim) {
|
||||
return { self.repeat(sizes), nullopt };
|
||||
}
|
||||
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
VmapDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
|
||||
sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
|
||||
return { self_.repeat(sizes_with_bdim), 0 };
|
||||
}
|
||||
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
|
||||
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
|
||||
VMAP_SUPPORT("repeat", repeat_batch_rule);
|
||||
}
|
||||
|
||||
}}
|
||||
|
@ -173,6 +173,27 @@ Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, opt
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor mean_int_batching_rule(
|
||||
const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
return self.mean(dims, keepdim, dtype);
|
||||
}
|
||||
// PyTorch has a special case where mean(scalar_tensor, dim=0) does not fail
|
||||
// and instead returns a new scalar tensor (this also happens for dim=-1)
|
||||
// If the following happens:
|
||||
// >>> x = torch.randn(B0) # the per-examples are all scalars
|
||||
// >>> vmap(partial(torch.mean, dim=0), x)
|
||||
// then we replicate the behavior of mean(scalar_tensor, dim=0).
|
||||
if (/*logical*/self.dim() == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])) {
|
||||
return self.clone();
|
||||
}
|
||||
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
auto dims_physical = self_physical.getPhysicalDims(dims);
|
||||
auto result = at::mean(self_physical.tensor(), dims_physical, keepdim, dtype);
|
||||
return self_physical.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
|
||||
if (logical_tensor.dim() > 0) {
|
||||
return false;
|
||||
@ -982,17 +1003,6 @@ Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
|
||||
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
||||
}
|
||||
|
||||
Tensor pow_scalar_Tensor_batching_rule(Scalar other, const Tensor& self) {
|
||||
if (!participatesInCurrentLevel(self)) {
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
return at::pow(other, self);
|
||||
}
|
||||
auto* self_batched = unsafeGetBatchedImpl(self);
|
||||
auto output_physical = at::pow(other, self_batched->value());
|
||||
auto old_bdims = self_batched->bdims();
|
||||
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
|
||||
}
|
||||
|
||||
// Tensor ones_like_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
|
||||
// if (!participatesInCurrentLevel(self)) {
|
||||
// c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
|
||||
@ -1277,7 +1287,7 @@ Tensor new_empty_batching_rule(
|
||||
return physical_view.getPhysicalToLogicalMap().apply(result);
|
||||
}
|
||||
|
||||
Tensor addmm_batching_rule(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
|
||||
Tensor addmm_batching_rule(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
|
||||
// Decomposition that is probably not very fast...
|
||||
return at::add(self * beta, at::mm(mat1, mat2), alpha);
|
||||
}
|
||||
@ -1392,69 +1402,10 @@ std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, option
|
||||
return {tensor.abs(), batch_dim};
|
||||
}
|
||||
|
||||
template <typename F, F Func>
|
||||
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(const Tensor& tensor, optional<int64_t> batch_dim) {
|
||||
return {Func(tensor), batch_dim};
|
||||
}
|
||||
|
||||
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
|
||||
if (!has_bdim) {
|
||||
return tensor;
|
||||
}
|
||||
auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
|
||||
if (tensor_logical_rank >= logical_rank) {
|
||||
return tensor;
|
||||
}
|
||||
VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
|
||||
for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
|
||||
new_sizes.insert(new_sizes.begin() + 1, 1);
|
||||
}
|
||||
return tensor.view(new_sizes);
|
||||
}
|
||||
|
||||
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
|
||||
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
|
||||
if (logical_scalar_tensor.scalar_type() != result_type) {
|
||||
logical_scalar_tensor = logical_scalar_tensor.to(result_type);
|
||||
}
|
||||
if (second.scalar_type() != result_type) {
|
||||
second = second.to(result_type);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename F, F Func>
|
||||
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
|
||||
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
|
||||
const Tensor& other, optional<int64_t> other_batch_dim) {
|
||||
// compute max logical rank
|
||||
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
|
||||
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
|
||||
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
|
||||
|
||||
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
|
||||
auto other_ = moveBatchDimToFront(other, other_batch_dim);
|
||||
|
||||
// In the (0D, ND) case, type promotion semantics are different :/
|
||||
auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
|
||||
auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
|
||||
if (tensor_is_logical_scalar && !other_is_logical_scalar) {
|
||||
handleScalarTypePromotion(tensor_, other_);
|
||||
}
|
||||
if (other_is_logical_scalar && !tensor_is_logical_scalar) {
|
||||
handleScalarTypePromotion(other_, tensor_);
|
||||
}
|
||||
|
||||
// If the dimensions aren't aligned, we need to line them up.
|
||||
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
|
||||
// Note that only tensors that have a batch dim need to be modified.
|
||||
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
|
||||
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
|
||||
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
|
||||
|
||||
auto result = Func(tensor_, other_);
|
||||
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
|
||||
? optional<int64_t>{0} : nullopt;
|
||||
return { std::move(result), std::move(result_batch_dim) };
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(
|
||||
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
|
||||
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
|
||||
}
|
||||
|
||||
Tensor matmul_decomposed(
|
||||
@ -1575,6 +1526,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
|
||||
|
||||
m.impl("mean", mean_batching_rule);
|
||||
m.impl("mean.dim", mean_int_batching_rule);
|
||||
m.impl("sum.dim_IntList", sum_batching_rule);
|
||||
m.impl("log_softmax.int", log_softmax_batching_rule);
|
||||
m.impl("_log_softmax", _log_softmax_batching_rule);
|
||||
@ -1624,7 +1576,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("view", view_batching_rule);
|
||||
m.impl("view_as", native::view_as); // composite wrt autograd
|
||||
|
||||
// m.impl("addmm", addmm_batching_rule);
|
||||
m.impl("addmm", addmm_batching_rule);
|
||||
m.impl("matmul", matmul_decomposed);
|
||||
//
|
||||
// clamp operations
|
||||
@ -1682,35 +1634,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("clone", clone_batching_rule);
|
||||
// m.impl("ones_like", ones_like_batching_rule);
|
||||
|
||||
// using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
|
||||
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
|
||||
// using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
|
||||
//
|
||||
// #define BINARY_POINTWISE(op) \
|
||||
// m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
|
||||
// m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
|
||||
// #define BINARY_POINTWISE_VA(op, ...) \
|
||||
// { \
|
||||
// using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
|
||||
// using Unop = Tensor (*)(const Tensor&, Scalar, __VA_ARGS__); \
|
||||
// m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
|
||||
// m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, Scalar, __VA_ARGS__>); \
|
||||
// }
|
||||
|
||||
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
|
||||
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
|
||||
// BINARY_POINTWISE_VA(add, Scalar);
|
||||
// BINARY_POINTWISE_VA(sub, Scalar);
|
||||
// BINARY_POINTWISE_VA(rsub, Scalar);
|
||||
BINARY_POINTWISE(mul);
|
||||
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
|
||||
// BINARY_POINTWISE(div);
|
||||
//
|
||||
// // at::pow has three out-of-place overloads
|
||||
// m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
|
||||
// m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, Scalar>);
|
||||
// m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
|
||||
//
|
||||
// m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
|
||||
// m.impl(
|
||||
// "threshold_backward",
|
||||
|
@ -1911,24 +1911,37 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
test(vmap(op), (torch.rand(B0, B1, 1),))
|
||||
test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
|
||||
|
||||
def test_sum_dim(self):
|
||||
def _test_mean_sum_dim(self, op):
|
||||
test = self._vmap_test
|
||||
B0, B1 = 5, 7
|
||||
|
||||
# Single vmap, various in_dims / out_dims
|
||||
test(lambda x: x.sum(0), [torch.randn([B0])])
|
||||
test(lambda x: x.sum(-1), [torch.randn([B0])])
|
||||
test(lambda x: x.sum(0), [torch.randn([B0, 3])])
|
||||
test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2)
|
||||
test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
||||
test(lambda x: op(x, 0), [torch.randn([B0])])
|
||||
test(lambda x: op(x, -1), [torch.randn([B0])])
|
||||
test(lambda x: op(x, 0), [torch.randn([B0, 3])])
|
||||
test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2)
|
||||
test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
||||
|
||||
# Doubly nested vmap
|
||||
test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])])
|
||||
test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])])
|
||||
test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
|
||||
test(vmap(lambda x: x.sum(2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
|
||||
test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])])
|
||||
test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])])
|
||||
test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
|
||||
test(vmap(lambda x: op(x, 2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
|
||||
in_dims=2, out_dims=2)
|
||||
|
||||
def test_sum_dim(self):
|
||||
self._test_mean_sum_dim(torch.sum)
|
||||
|
||||
def test_mean_dim(self):
|
||||
self._test_mean_sum_dim(torch.mean)
|
||||
|
||||
def test_repeat(self):
|
||||
test = self._vmap_test
|
||||
B0 = 7
|
||||
op = Tensor.repeat
|
||||
test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
|
||||
test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
|
||||
|
||||
def test_reshape(self):
|
||||
test = self._vmap_test
|
||||
B0, B1, B2 = 7, 11, 13
|
||||
|
Reference in New Issue
Block a user