mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] add no-dim for batchrulesreduce and move trace from old batching rule to new decomposition
This commit is contained in:
@ -32,9 +32,6 @@ R ret(R(*)(A...));
|
||||
// Optional implies the weird case with 0-dim tensors i.e. torch.sum(torch.randn(()), 0)
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
optional<std::tuple<decltype(ret(Func)), optional<int64_t>>> reduction_dimarray_batch_rule_impl( const Tensor& self, optional<int64_t> self_bdim, IntArrayRef dims, ExtraArgs... extra_args) {
|
||||
if (!self_bdim.has_value()) {
|
||||
return std::make_tuple(Func(self, dims, std::forward<ExtraArgs>(extra_args)...), nullopt);
|
||||
}
|
||||
auto logical_dim = rankWithoutBatchDim(self, self_bdim);
|
||||
|
||||
// If the dim intlist is empty, that's equivalent to passing in a dim on all dimensions.
|
||||
@ -69,9 +66,6 @@ std::tuple<Tensor,optional<int64_t>> reduction_dimarray_batch_rule(
|
||||
// Optional implies the weird case with 0-dim tensors i.e. torch.sum(torch.randn(()), 0)
|
||||
template <typename F, F Func, typename... ExtraArgs>
|
||||
optional<std::tuple<decltype(ret(Func)), optional<int64_t>>> reduction_dim_batch_rule_impl(const Tensor& self, optional<int64_t> self_bdim, int64_t dim, ExtraArgs... extra_args) {
|
||||
if (!self_bdim.has_value()) {
|
||||
return std::make_tuple(Func(self, dim, std::forward<ExtraArgs>(extra_args)...), nullopt);
|
||||
}
|
||||
auto logical_dim = rankWithoutBatchDim(self, self_bdim);
|
||||
if (logical_dim == 0 && is_allowed_dim_on_scalar_tensor(dim)) {
|
||||
return nullopt;
|
||||
@ -107,9 +101,6 @@ std::tuple<Tensor,optional<int64_t>,Tensor,optional<int64_t>> reduction_dim_ret_
|
||||
template <typename F, F Func, typename G, G DimRule, typename... ExtraArgs>
|
||||
std::tuple<Tensor,optional<int64_t>> reduction_no_dim_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim, ExtraArgs... extra_args) {
|
||||
if (!self_bdim.has_value()) {
|
||||
return std::make_tuple(Func(self, std::forward<ExtraArgs>(extra_args)...), nullopt);
|
||||
}
|
||||
if (self.dim() == 1) {
|
||||
return std::make_tuple(self.clone(), 0);
|
||||
}
|
||||
@ -235,9 +226,6 @@ std::tuple<Tensor,optional<int64_t>> norm_scalar_batch_rule(
|
||||
template<typename F, F Func>
|
||||
std::tuple<Tensor,optional<int64_t>> argx_batch_rule(
|
||||
const Tensor& self, optional<int64_t> self_bdim, optional<int64_t> dim, bool keepdim) {
|
||||
if (!self_bdim.has_value()) {
|
||||
return std::make_tuple( Func(self, dim, keepdim), nullopt );
|
||||
}
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
if (!dim) {
|
||||
// If no dimension is given, then argmax gives you the flattened index of
|
||||
@ -269,9 +257,6 @@ _log_softmax_backward_data(
|
||||
const Tensor& self, optional<int64_t> self_bdim) {
|
||||
TORCH_INTERNAL_ASSERT(!(output_bdim.has_value() ^ self_bdim.has_value()),
|
||||
"output_bdim and self_bdim must be the same");
|
||||
if (!grad_output_bdim && !self_bdim) {
|
||||
return std::make_tuple( at::_log_softmax_backward_data(grad_output, output, dim, self), nullopt );
|
||||
}
|
||||
if (grad_output_bdim && self_bdim) {
|
||||
auto grad_output_ = moveBatchDimToFront(grad_output, grad_output_bdim);
|
||||
auto output_ = moveBatchDimToFront(output, output_bdim);
|
||||
|
@ -136,6 +136,10 @@ std::tuple<Tensor,optional<int64_t>> _unsafe_view_batch_rule(
|
||||
return std::make_tuple(at::_unsafe_view(self, view_size), self_bdim);
|
||||
}
|
||||
|
||||
Tensor trace_decomp(const Tensor& self) {
|
||||
return at::sum(at::diagonal(self));
|
||||
}
|
||||
|
||||
std::tuple<Tensor,optional<int64_t>> flip_batch_rule(const Tensor& self, optional<int64_t> self_bdim, IntArrayRef dims) {
|
||||
auto self_ = moveBatchDimToFront(self, self_bdim);
|
||||
VmapDimVector new_dims;
|
||||
@ -150,6 +154,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("flatten.using_ints", static_cast<decltype(&ATEN_FN2(flatten, using_ints))>(native::flatten));
|
||||
VMAP_SUPPORT("flip", flip_batch_rule);
|
||||
OP_DECOMPOSE(meshgrid);
|
||||
m.impl("trace", trace_decomp);
|
||||
VMAP_SUPPORT("tril", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(tril)), &at::tril, int64_t>));
|
||||
VMAP_SUPPORT("triu", SINGLE_ARG(variadic_bdims_batch_rule<decltype(&ATEN_FN(triu)), &at::triu, int64_t>));
|
||||
VMAP_SUPPORT("repeat", repeat_batch_rule);
|
||||
|
@ -1374,7 +1374,7 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
|
||||
m.impl("squeeze.dim", squeeze_dim_batching_rule);
|
||||
m.impl("squeeze_.dim", squeeze_dim__batching_rule);
|
||||
m.impl("t", native::t); // composite wrt autograd
|
||||
m.impl("trace", trace_batching_rule);
|
||||
// m.impl("trace", trace_batching_rule);
|
||||
m.impl("transpose.int", transpose_int_batching_rule);
|
||||
m.impl("unbind.int", unbind_batching_rule);
|
||||
m.impl("unfold", unfold_batching_rule);
|
||||
|
@ -2283,7 +2283,6 @@ class TestVmapOperators(Namespace.TestVmapBase):
|
||||
op = torch.trace
|
||||
test = self._vmap_test
|
||||
B0, B1, B2 = 7, 11, 13
|
||||
|
||||
test(op, (torch.rand(B0, 2, 5),))
|
||||
test(op, (torch.rand(2, B0, 5),), in_dims=1)
|
||||
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
||||
|
Reference in New Issue
Block a user