[functorch] Grab bag of batch rules

This commit is contained in:
Richard Zou
2021-04-28 11:04:58 -07:00
committed by Jon Janzen
parent 0abba43aa3
commit d7d266f51e
4 changed files with 181 additions and 115 deletions

View File

@ -0,0 +1,113 @@
#include <functorch/csrc/BatchRulesHelper.h>
namespace at { namespace functorch {
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> basic_unary_batch_rule(
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
}
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
if (logical_scalar_tensor.scalar_type() != result_type) {
logical_scalar_tensor = logical_scalar_tensor.to(result_type);
}
if (second.scalar_type() != result_type) {
second = second.to(result_type);
}
}
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
if (!has_bdim) {
return tensor;
}
auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
if (tensor_logical_rank >= logical_rank) {
return tensor;
}
VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
new_sizes.insert(new_sizes.begin() + 1, 1);
}
return tensor.view(new_sizes);
}
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim,
ExtraArgs... extra_args) {
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
auto other_ = moveBatchDimToFront(other, other_batch_dim);
// In the (0D, ND) case, type promotion semantics are different :/
auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
if (tensor_is_logical_scalar && !other_is_logical_scalar) {
handleScalarTypePromotion(tensor_, other_);
}
if (other_is_logical_scalar && !tensor_is_logical_scalar) {
handleScalarTypePromotion(other_, tensor_);
}
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
auto result = Func(tensor_, other_, std::forward<ExtraArgs>(extra_args)...);
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
? optional<int64_t>{0} : nullopt;
return { std::move(result), std::move(result_batch_dim) };
}
std::tuple<Tensor,optional<int64_t>> pow_scalar_tensor_batch_rule(
const Scalar& other,
const Tensor& tensor, optional<int64_t> tensor_batch_dim) {
return { at::pow(other, tensor), tensor_batch_dim };
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, const Scalar&);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
using TensorScalarType = Tensor (*)(const Tensor&, const Scalar&);
#define BINARY_POINTWISE_BATCH_RULE_SCALAR(op) \
binary_pointwise_batch_rule<TensorTensorScalarType, &op, const Scalar&>
#define BINARY_POINTWISE_WITH_SCALAR(op) \
VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE_SCALAR(at::op));
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
BINARY_POINTWISE_WITH_SCALAR(add);
BINARY_POINTWISE_WITH_SCALAR(sub);
BINARY_POINTWISE_WITH_SCALAR(rsub);
BINARY_POINTWISE(mul);
BINARY_POINTWISE(div);
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
// at::pow has three out-of-place overloads
#define POW_BATCH_RULE binary_pointwise_batch_rule<TensorTensorType, &at::pow>
VMAP_SUPPORT("pow.Tensor_Tensor", POW_BATCH_RULE);
#undef POW_BATCH_RULE
#define POW_BATCH_RULE basic_unary_batch_rule<TensorScalarType, &at::pow, const Scalar&>
VMAP_SUPPORT("pow.Tensor_Scalar", POW_BATCH_RULE);
#undef POW_BATCH_RULE
VMAP_SUPPORT("pow.Scalar", pow_scalar_tensor_batch_rule);
#undef BINARY_POINTWISE_BATCH_RULE_SCALAR
#undef BINARY_POINTWISE_BATCH_RULE
#undef BINARY_POINTWISE_WITH_SCALAR
#undef BINARY_POINTWISE
}
}}

View File

@ -95,9 +95,26 @@ std::tuple<Tensor,optional<int64_t>> unsqueeze_batch_rule(
return { self_.unsqueeze(dim), valIfNonempty(self_bdim, 0) };
}
// NB: repeat is not actually a view, but it is in this file
std::tuple<Tensor,optional<int64_t>> repeat_batch_rule(
const Tensor& self,
optional<int64_t> self_bdim,
IntArrayRef sizes) {
if (!self_bdim) {
return { self.repeat(sizes), nullopt };
}
auto self_ = moveBatchDimToFront(self, self_bdim);
VmapDimVector sizes_with_bdim = { sizes.begin(), sizes.end() };
sizes_with_bdim.insert(sizes_with_bdim.begin(), 1);
return { self_.repeat(sizes_with_bdim), 0 };
}
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
VMAP_SUPPORT("flatten.using_ints", flatten_batch_rule);
VMAP_SUPPORT("unsqueeze", unsqueeze_batch_rule);
VMAP_SUPPORT("repeat", repeat_batch_rule);
}
}}

View File

@ -173,6 +173,27 @@ Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, opt
return self_physical.getPhysicalToLogicalMap().apply(result);
}
Tensor mean_int_batching_rule(
const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return self.mean(dims, keepdim, dtype);
}
// PyTorch has a special case where mean(scalar_tensor, dim=0) does not fail
// and instead returns a new scalar tensor (this also happens for dim=-1)
// If the following happens:
// >>> x = torch.randn(B0) # the per-examples are all scalars
// >>> vmap(partial(torch.mean, dim=0), x)
// then we replicate the behavior of mean(scalar_tensor, dim=0).
if (/*logical*/self.dim() == 0 && dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])) {
return self.clone();
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dims_physical = self_physical.getPhysicalDims(dims);
auto result = at::mean(self_physical.tensor(), dims_physical, keepdim, dtype);
return self_physical.getPhysicalToLogicalMap().apply(result);
}
bool isPhysicalScalarTensor(const Tensor& logical_tensor) {
if (logical_tensor.dim() > 0) {
return false;
@ -982,17 +1003,6 @@ Tensor unwrap_and_call_method(const Tensor& input, ExtraArgs... extra_args) {
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
Tensor pow_scalar_Tensor_batching_rule(Scalar other, const Tensor& self) {
if (!participatesInCurrentLevel(self)) {
c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
return at::pow(other, self);
}
auto* self_batched = unsafeGetBatchedImpl(self);
auto output_physical = at::pow(other, self_batched->value());
auto old_bdims = self_batched->bdims();
return makeBatched(output_physical, BatchDims(old_bdims.begin(), old_bdims.end()));
}
// Tensor ones_like_batching_rule(const Tensor& self, optional<MemoryFormat> memory_format) {
// if (!participatesInCurrentLevel(self)) {
// c10::impl::ExcludeDispatchKeyGuard guard(kBatchedKey);
@ -1277,7 +1287,7 @@ Tensor new_empty_batching_rule(
return physical_view.getPhysicalToLogicalMap().apply(result);
}
Tensor addmm_batching_rule(const Tensor& self, const Tensor& mat1, const Tensor& mat2, Scalar beta, Scalar alpha) {
Tensor addmm_batching_rule(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
// Decomposition that is probably not very fast...
return at::add(self * beta, at::mm(mat1, mat2), alpha);
}
@ -1392,69 +1402,10 @@ std::tuple<Tensor,optional<int64_t>> abs_batch_rule(const Tensor& tensor, option
return {tensor.abs(), batch_dim};
}
template <typename F, F Func>
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(const Tensor& tensor, optional<int64_t> batch_dim) {
return {Func(tensor), batch_dim};
}
static Tensor maybePadToLogicalRank(const Tensor& tensor, optional<int64_t> has_bdim, int64_t logical_rank) {
if (!has_bdim) {
return tensor;
}
auto tensor_logical_rank = rankWithoutBatchDim(tensor, has_bdim);
if (tensor_logical_rank >= logical_rank) {
return tensor;
}
VmapDimVector new_sizes(tensor.sizes().begin(), tensor.sizes().end());
for (int64_t i = 0; i < logical_rank - tensor_logical_rank; i++) {
new_sizes.insert(new_sizes.begin() + 1, 1);
}
return tensor.view(new_sizes);
}
static void handleScalarTypePromotion(Tensor& logical_scalar_tensor, Tensor& second) {
auto result_type = at::native::result_type(logical_scalar_tensor[0], second);
if (logical_scalar_tensor.scalar_type() != result_type) {
logical_scalar_tensor = logical_scalar_tensor.to(result_type);
}
if (second.scalar_type() != result_type) {
second = second.to(result_type);
}
}
template <typename F, F Func>
std::tuple<Tensor,optional<int64_t>> binary_pointwise_batch_rule(
const Tensor& tensor, optional<int64_t> tensor_batch_dim,
const Tensor& other, optional<int64_t> other_batch_dim) {
// compute max logical rank
auto tensor_logical_rank = rankWithoutBatchDim(tensor, tensor_batch_dim);
auto other_logical_rank = rankWithoutBatchDim(other, other_batch_dim);
auto max_logical_rank = std::max(tensor_logical_rank, other_logical_rank);
auto tensor_ = moveBatchDimToFront(tensor, tensor_batch_dim);
auto other_ = moveBatchDimToFront(other, other_batch_dim);
// In the (0D, ND) case, type promotion semantics are different :/
auto tensor_is_logical_scalar = (tensor_logical_rank == 0 && tensor_batch_dim.has_value());
auto other_is_logical_scalar = (other_logical_rank == 0 && other_batch_dim.has_value());
if (tensor_is_logical_scalar && !other_is_logical_scalar) {
handleScalarTypePromotion(tensor_, other_);
}
if (other_is_logical_scalar && !tensor_is_logical_scalar) {
handleScalarTypePromotion(other_, tensor_);
}
// If the dimensions aren't aligned, we need to line them up.
// Tensor[B, 3] + Tensor[2, 5, 3] -> Tensor[B, 1, 1, 3] + Tensor[2, 5, 3]
// Note that only tensors that have a batch dim need to be modified.
// Tensor[B, 2, 3, 5] + Tensor[5] -> no changes needed
tensor_ = maybePadToLogicalRank(tensor_, tensor_batch_dim, max_logical_rank);
other_ = maybePadToLogicalRank(other_, other_batch_dim, max_logical_rank);
auto result = Func(tensor_, other_);
auto result_batch_dim = tensor_batch_dim.has_value() || other_batch_dim.has_value()
? optional<int64_t>{0} : nullopt;
return { std::move(result), std::move(result_batch_dim) };
template <typename F, F Func, typename... ExtraArgs>
std::tuple<Tensor,optional<int64_t>> unwrap_and_call2(
const Tensor& tensor, optional<int64_t> batch_dim, ExtraArgs... extra_args) {
return {Func(tensor, std::forward<ExtraArgs>(extra_args)...), batch_dim};
}
Tensor matmul_decomposed(
@ -1575,6 +1526,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("max_pool2d_with_indices", max_pool2d_with_indices_batching_rule);
m.impl("mean", mean_batching_rule);
m.impl("mean.dim", mean_int_batching_rule);
m.impl("sum.dim_IntList", sum_batching_rule);
m.impl("log_softmax.int", log_softmax_batching_rule);
m.impl("_log_softmax", _log_softmax_batching_rule);
@ -1624,7 +1576,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("view", view_batching_rule);
m.impl("view_as", native::view_as); // composite wrt autograd
// m.impl("addmm", addmm_batching_rule);
m.impl("addmm", addmm_batching_rule);
m.impl("matmul", matmul_decomposed);
//
// clamp operations
@ -1682,35 +1634,6 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
m.impl("clone", clone_batching_rule);
// m.impl("ones_like", ones_like_batching_rule);
// using TensorTensorScalarType = Tensor (*)(const Tensor&, const Tensor&, Scalar);
using TensorTensorType = Tensor (*)(const Tensor&, const Tensor&);
// using TensorScalarType = Tensor (*)(const Tensor&, Scalar);
//
// #define BINARY_POINTWISE(op) \
// m.impl(#op".Tensor", binary_pointwise_batching_rule<TensorTensorType, at::op>); \
// m.impl(#op".Scalar", unwrap_and_call<TensorScalarType, at::op, Scalar>);
// #define BINARY_POINTWISE_VA(op, ...) \
// { \
// using Binop = Tensor (*)(const Tensor&, const Tensor&, __VA_ARGS__); \
// using Unop = Tensor (*)(const Tensor&, Scalar, __VA_ARGS__); \
// m.impl(#op".Tensor", binary_pointwise_batching_rule<Binop, at::op, __VA_ARGS__>); \
// m.impl(#op".Scalar", unwrap_and_call<Unop, at::op, Scalar, __VA_ARGS__>); \
// }
#define BINARY_POINTWISE_BATCH_RULE(op) binary_pointwise_batch_rule<TensorTensorType, &op>
#define BINARY_POINTWISE(op) VMAP_SUPPORT(#op".Tensor", BINARY_POINTWISE_BATCH_RULE(at::op));
// BINARY_POINTWISE_VA(add, Scalar);
// BINARY_POINTWISE_VA(sub, Scalar);
// BINARY_POINTWISE_VA(rsub, Scalar);
BINARY_POINTWISE(mul);
VMAP_SUPPORT("tanh_backward", BINARY_POINTWISE_BATCH_RULE(at::tanh_backward));
// BINARY_POINTWISE(div);
//
// // at::pow has three out-of-place overloads
// m.impl("pow.Tensor_Tensor", binary_pointwise_batching_rule<TensorTensorType, at::pow>);
// m.impl("pow.Tensor_Scalar", unwrap_and_call<TensorScalarType, at::pow, Scalar>);
// m.impl("pow.Scalar", pow_scalar_Tensor_batching_rule);
//
// m.impl("sigmoid_backward", binary_pointwise_batching_rule<TensorTensorType, at::sigmoid_backward>);
// m.impl(
// "threshold_backward",

View File

@ -1911,24 +1911,37 @@ class TestVmapOperators(Namespace.TestVmapBase):
test(vmap(op), (torch.rand(B0, B1, 1),))
test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
def test_sum_dim(self):
def _test_mean_sum_dim(self, op):
test = self._vmap_test
B0, B1 = 5, 7
# Single vmap, various in_dims / out_dims
test(lambda x: x.sum(0), [torch.randn([B0])])
test(lambda x: x.sum(-1), [torch.randn([B0])])
test(lambda x: x.sum(0), [torch.randn([B0, 3])])
test(lambda x: x.sum(-1), [torch.randn([2, 5, B0, 3])], in_dims=2)
test(lambda x: x.sum(2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
test(lambda x: op(x, 0), [torch.randn([B0])])
test(lambda x: op(x, -1), [torch.randn([B0])])
test(lambda x: op(x, 0), [torch.randn([B0, 3])])
test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2)
test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
# Doubly nested vmap
test(vmap(lambda x: x.sum(0)), [torch.randn([B0, B1])])
test(vmap(lambda x: x.sum(-1)), [torch.randn([B0, B1])])
test(vmap(lambda x: x.sum(-2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
test(vmap(lambda x: x.sum(2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])])
test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])])
test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
test(vmap(lambda x: op(x, 2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
in_dims=2, out_dims=2)
def test_sum_dim(self):
self._test_mean_sum_dim(torch.sum)
def test_mean_dim(self):
self._test_mean_sum_dim(torch.mean)
def test_repeat(self):
test = self._vmap_test
B0 = 7
op = Tensor.repeat
test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
def test_reshape(self):
test = self._vmap_test
B0, B1, B2 = 7, 11, 13