Revert "Enable dim=None for torch.sum (#75845)"

This reverts commit e79a51f7db181be2e6e196d6d9d90403022bc465.

Reverted https://github.com/pytorch/pytorch/pull/75845 on behalf of https://github.com/malfet due to Breaks MacOS builds, see e79a51f7db
This commit is contained in:
PyTorch MergeBot
2022-06-16 22:01:40 +00:00
parent f9656817df
commit ee6ebfc06b
23 changed files with 67 additions and 156 deletions

View File

@ -1 +1 @@
35f759fdd7eb585679df7c1e6db4569b1aba5475 de45c7c503f403be2c85066013b6a860f04f1152

View File

@ -56,9 +56,7 @@ static bool is_allowed_dim_on_scalar_tensor(int64_t dim) {
return dim == 0 || dim == -1; return dim == 0 || dim == -1;
} }
Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool keepdim, optional<ScalarType> dtype) { Tensor sum_batching_rule(const Tensor& self, IntArrayRef dims, bool keepdim, optional<ScalarType> dtype) {
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
// PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail // PyTorch has a special case where sum(scalar_tensor, dim=0) does not fail
// and instead returns a new scalar tensor (this also happens for dim=-1) // and instead returns a new scalar tensor (this also happens for dim=-1)
// If the following happens: // If the following happens:
@ -68,9 +66,8 @@ Tensor sum_batching_rule(const Tensor& self, OptionalIntArrayRef opt_dims, bool
if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) { if (/*logical*/self.dim() == 0 && (dims.size() == 0 || (dims.size() == 1 && is_allowed_dim_on_scalar_tensor(dims[0])))) {
return self.clone(); return self.clone();
} }
}
auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self); auto self_physical = MultiBatchVmapTransform::logicalToPhysical(self);
auto dims_physical = self_physical.getPhysicalDims(opt_dims); auto dims_physical = self_physical.getPhysicalDims(dims);
auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype); auto result = at::sum(self_physical.tensor(), dims_physical, keepdim, dtype);
return self_physical.getPhysicalToLogicalMap().apply(result); return self_physical.getPhysicalToLogicalMap().apply(result);
} }

View File

@ -55,21 +55,14 @@ int64_t VmapPhysicalView::numLogicalDims() const {
return /*physical*/tensor_.dim() - numBatchDims(); return /*physical*/tensor_.dim() - numBatchDims();
} }
VmapDimVector VmapPhysicalView::getPhysicalDims(OptionalIntArrayRef opt_logical_dims) const { VmapDimVector VmapPhysicalView::getPhysicalDims(IntArrayRef logical_dims) const {
auto logical_ndim = numLogicalDims(); auto logical_ndim = numLogicalDims();
// NB: fmap doesn't have a SmallVector variant, so we don't use it here. // NB: fmap doesn't have a SmallVector variant, so we don't use it here.
VmapDimVector result; VmapDimVector result;
result.reserve(logical_ndim); result.reserve(logical_ndim);
if (opt_logical_dims.has_value()) {
auto logical_dims = opt_logical_dims.value();
for (auto dim : logical_dims) { for (auto dim : logical_dims) {
result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims()); result.push_back(maybe_wrap_dim(dim, logical_ndim) + numBatchDims());
} }
} else {
for (int64_t dim = 0; dim < logical_ndim; dim++) {
result.push_back(dim + numBatchDims());
}
}
return result; return result;
} }

View File

@ -131,7 +131,7 @@ struct TORCH_API VmapPhysicalView {
// This is because the size of levels tell us that the first two dimensions // This is because the size of levels tell us that the first two dimensions
// of `tensor_` are batch dimensions, so a logical dim of `n` is actually // of `tensor_` are batch dimensions, so a logical dim of `n` is actually
// a physical dim of `n + 2`. // a physical dim of `n + 2`.
VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; VmapDimVector getPhysicalDims(IntArrayRef logical_dims) const;
int64_t getPhysicalDim(int64_t logical_dim) const; int64_t getPhysicalDim(int64_t logical_dim) const;
// Returns a VmapPhysicalToLogicalMap object. This can be used for // Returns a VmapPhysicalToLogicalMap object. This can be used for

View File

@ -14,7 +14,7 @@ namespace at {
constexpr size_t dim_bitset_size = 64; constexpr size_t dim_bitset_size = 64;
static inline std::bitset<dim_bitset_size> dim_list_to_bitset( static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
OptionalIntArrayRef opt_dims, IntArrayRef dims,
int64_t ndims) { int64_t ndims) {
TORCH_CHECK( TORCH_CHECK(
ndims <= (int64_t)dim_bitset_size, ndims <= (int64_t)dim_bitset_size,
@ -22,22 +22,12 @@ static inline std::bitset<dim_bitset_size> dim_list_to_bitset(
dim_bitset_size, dim_bitset_size,
" dims are supported"); " dims are supported");
std::bitset<dim_bitset_size> seen; std::bitset<dim_bitset_size> seen;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
for (const auto i : c10::irange(dims.size())) { for (const auto i : c10::irange(dims.size())) {
size_t dim = maybe_wrap_dim(dims[i], ndims); size_t dim = maybe_wrap_dim(dims[i], ndims);
TORCH_CHECK( TORCH_CHECK(
!seen[dim], !seen[dim], "dim ", dim, " appears multiple times in the list of dims");
"dim ",
dim,
" appears multiple times in the list of dims");
seen[dim] = true; seen[dim] = true;
} }
} else {
for (int64_t dim = 0; dim < ndims; dim++) {
seen[dim] = true;
}
}
return seen; return seen;
} }

View File

@ -455,7 +455,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
// KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype) // KERNEL(ADD_NS(norm), "norm.ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, IntArrayRef, bool, ScalarType), fp32_set_opt_dtype)
// KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_set_opt_dtype) // KERNEL(ADD_NS(norm), "norm.names_ScalarOpt_dim_dtype", Tensor (const Tensor &, c10::optional<Scalar>, DimnameList, bool, ScalarType), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum", Tensor (const Tensor &, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, OptionalIntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum.dim_IntList", Tensor (const Tensor &, IntArrayRef, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype) KERNEL(ADD_NS(sum), "sum.dim_DimnameList", Tensor (const Tensor &, DimnameList, bool, c10::optional<ScalarType>), fp32_set_opt_dtype)
// fp32_append_dtype // fp32_append_dtype
// The fp32_append_dtype wrapper overrides implicit promotion behavior. // The fp32_append_dtype wrapper overrides implicit promotion behavior.

View File

@ -52,6 +52,8 @@ namespace meta {
static ScalarType infer_dtype_from_optional( static ScalarType infer_dtype_from_optional(
const Tensor& self, const Tensor& self,
IntArrayRef dim,
bool keepdim,
const optional<ScalarType>& opt_dtype, const optional<ScalarType>& opt_dtype,
const Tensor& result) { const Tensor& result) {
// 'opt_dtype' has the priority for both cases. // 'opt_dtype' has the priority for both cases.
@ -185,9 +187,9 @@ TORCH_META_FUNC(cumprod)
} }
TORCH_META_FUNC2(sum, dim_IntList) TORCH_META_FUNC2(sum, dim_IntList)
(const Tensor& self, OptionalIntArrayRef opt_dim, bool keepdim, optional<ScalarType> opt_dtype) { (const Tensor& self, IntArrayRef dim, bool keepdim, optional<ScalarType> opt_dtype) {
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output()); auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
resize_reduction(*this, self, opt_dim, keepdim, out_dtype); resize_reduction(*this, self, dim, keepdim, out_dtype);
} }
TORCH_META_FUNC2(prod, dim_int) TORCH_META_FUNC2(prod, dim_int)
@ -195,7 +197,7 @@ TORCH_META_FUNC2(prod, dim_int)
int64_t dim, int64_t dim,
bool keepdim, bool keepdim,
c10::optional<ScalarType> dtype) { c10::optional<ScalarType> dtype) {
auto out_dtype = infer_dtype_from_optional(self, dtype, maybe_get_output()); auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, dtype, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype); resize_reduction(*this, self, dim, keepdim, out_dtype);
} }
@ -219,7 +221,7 @@ TORCH_META_FUNC2(mean, dim)
"Got: ", dtype); "Got: ", dtype);
} }
auto out_dtype = infer_dtype_from_optional(self, opt_dtype, maybe_get_output()); auto out_dtype = infer_dtype_from_optional(self, dim, keepdim, opt_dtype, maybe_get_output());
resize_reduction(*this, self, dim, keepdim, out_dtype); resize_reduction(*this, self, dim, keepdim, out_dtype);
} }
@ -1059,11 +1061,11 @@ inline ScalarType get_dtype_from_result(Tensor& result, optional<ScalarType> dty
TORCH_IMPL_FUNC(sum_out) TORCH_IMPL_FUNC(sum_out)
(const Tensor& self, (const Tensor& self,
OptionalIntArrayRef opt_dim, IntArrayRef dim,
bool keepdim, bool keepdim,
optional<ScalarType> opt_dtype, optional<ScalarType> opt_dtype,
const Tensor& result) { const Tensor& result) {
auto iter = meta::make_reduction_from_out_ty(self, result, opt_dim, keepdim, result.scalar_type()); auto iter = meta::make_reduction_from_out_ty(self, result, dim, keepdim, result.scalar_type());
if (iter.numel() == 0) { if (iter.numel() == 0) {
result.zero_(); result.zero_();
} else { } else {

View File

@ -110,28 +110,13 @@ static inline Tensor integer_upcast(const Tensor& self, optional<ScalarType> dty
using DimMask = TensorIterator::DimMask; using DimMask = TensorIterator::DimMask;
static DimVector make_dim_vector(OptionalIntArrayRef opt_dims, int64_t ndim) { static DimMask make_dim_mask(IntArrayRef dims, int64_t ndim) {
if (opt_dims.has_value()) {
return DimVector(opt_dims.value());
} else {
std::vector<int64_t> all_dims(ndim);
std::iota(all_dims.begin(), all_dims.end(), 0);
return DimVector(all_dims);
}
}
static DimMask make_dim_mask(OptionalIntArrayRef opt_dims, int64_t ndim) {
DimMask mask; DimMask mask;
if (opt_dims.has_value()) {
auto dims = opt_dims.value();
if (dims.empty()) { if (dims.empty()) {
mask = DimMask().flip(); mask = DimMask().flip();
} else { } else {
mask = at::dim_list_to_bitset(dims, ndim); mask = at::dim_list_to_bitset(dims, ndim);
} }
} else {
mask = DimMask().flip();
}
return mask; return mask;
} }
@ -335,10 +320,10 @@ static C10_UNUSED DimVector get_reduction_shape(
static void resize_reduction( static void resize_reduction(
impl::MetaBase& meta, impl::MetaBase& meta,
const Tensor& self, const Tensor& self,
OptionalIntArrayRef opt_dims, IntArrayRef dims,
bool keepdim, bool keepdim,
ScalarType out_dtype) { ScalarType out_dtype) {
DimVector dims_ = at::native::make_dim_vector(opt_dims, self.dim()); DimVector dims_(dims);
maybe_wrap_dims(dims_, self.dim()); maybe_wrap_dims(dims_, self.dim());
auto shape = get_reduction_shape(self, dims_, keepdim); auto shape = get_reduction_shape(self, dims_, keepdim);
meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype)); meta.set_output_raw_strided(0, shape, {}, self.options().dtype(out_dtype));
@ -366,11 +351,11 @@ static void resize_reduction_with_indices(
static TensorIterator make_reduction( static TensorIterator make_reduction(
const Tensor& self, const Tensor& self,
const Tensor& result, const Tensor& result,
OptionalIntArrayRef opt_dims, IntArrayRef dims,
bool keepdim, bool keepdim,
ScalarType in_dtype) { ScalarType in_dtype) {
int64_t ndim = self.dim(); int64_t ndim = self.dim();
auto mask = at::native::make_dim_mask(opt_dims, ndim); auto mask = at::native::make_dim_mask(dims, ndim);
auto viewed_result = auto viewed_result =
at::native::review_reduce_result(result, ndim, mask, keepdim); at::native::review_reduce_result(result, ndim, mask, keepdim);
if (self.scalar_type() == in_dtype) { if (self.scalar_type() == in_dtype) {
@ -404,7 +389,7 @@ static TensorIterator make_reduction(
static C10_UNUSED TensorIterator make_reduction_from_out_ty( static C10_UNUSED TensorIterator make_reduction_from_out_ty(
const Tensor& self, const Tensor& self,
const Tensor& result, const Tensor& result,
OptionalIntArrayRef opt_dims, IntArrayRef dims,
bool keepdim, bool keepdim,
ScalarType out_dtype) { ScalarType out_dtype) {
// special case for type promotion in mixed precision, improves computational // special case for type promotion in mixed precision, improves computational
@ -416,7 +401,7 @@ static C10_UNUSED TensorIterator make_reduction_from_out_ty(
(self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) && (self.scalar_type() == kHalf || self.scalar_type() == kBFloat16) &&
out_dtype == kFloat); out_dtype == kFloat);
auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype; auto in_dtype = gpu_lowp_to_f32 ? self.scalar_type() : out_dtype;
return make_reduction(self, result, opt_dims, keepdim, in_dtype); return make_reduction(self, result, dims, keepdim, in_dtype);
} }
} // namespace meta } // namespace meta

View File

@ -4534,7 +4534,7 @@
CompositeExplicitAutograd: sum CompositeExplicitAutograd: sum
SparseCsrCPU, SparseCsrCUDA: sum_csr SparseCsrCPU, SparseCsrCUDA: sum_csr
- func: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - func: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
structured_delegate: sum.IntList_out structured_delegate: sum.IntList_out
device_check: NoCheck # TensorIterator device_check: NoCheck # TensorIterator
variants: function, method variants: function, method
@ -4543,7 +4543,7 @@
device_check: NoCheck # TensorIterator device_check: NoCheck # TensorIterator
variants: function, method variants: function, method
- func: sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) - func: sum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
structured: True structured: True
device_check: NoCheck # TensorIterator device_check: NoCheck # TensorIterator
dispatch: dispatch:

View File

@ -74,9 +74,6 @@ class OptionalArrayRef final {
Args&&... args) Args&&... args)
: wrapped_opt_array_ref(ip, il, args...) {} : wrapped_opt_array_ref(ip, il, args...) {}
constexpr OptionalArrayRef(const std::initializer_list<T>& Vec)
: wrapped_opt_array_ref(ArrayRef<T>(Vec)) {}
// Destructor // Destructor
~OptionalArrayRef() = default; ~OptionalArrayRef() = default;

View File

@ -5365,7 +5365,7 @@ a")
def func2(x): def func2(x):
return x.sum(dim=4) return x.sum(dim=4)
# test that shape analysis is written correctly for sum with OptionalIntArrayRef[1] dim argument # test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
self.run_pass('constant_propagation', func.graph) self.run_pass('constant_propagation', func.graph)
self.run_pass('constant_propagation', func2.graph) self.run_pass('constant_propagation', func2.graph)
g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False) g = _propagate_shapes(func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)

View File

@ -1195,9 +1195,6 @@ class TestNamedTensor(TestCase):
check_output(op(t, 1), ['N', 'L']) check_output(op(t, 1), ['N', 'L'])
check_output(op(t, -1), ['N', 'C']) check_output(op(t, -1), ['N', 'C'])
check_output(op(t, 'C'), ['N', 'L']) check_output(op(t, 'C'), ['N', 'L'])
if op.__name__ in ['sum']:
check_output(op(t, None), [])
else:
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
op(t, None) op(t, None)
with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'): with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):

View File

@ -1517,7 +1517,7 @@
self: grad.expand(self.sizes()) self: grad.expand(self.sizes())
result: auto_linear result: auto_linear
- name: sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor - name: sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: sum_backward(grad, self.sizes(), dim, keepdim) self: sum_backward(grad, self.sizes(), dim, keepdim)
result: auto_linear result: auto_linear

View File

@ -20,7 +20,6 @@
#include <ATen/native/IndexingUtils.h> #include <ATen/native/IndexingUtils.h>
#include <ATen/native/LinearAlgebraUtils.h> #include <ATen/native/LinearAlgebraUtils.h>
#include <c10/core/TensorOptions.h> #include <c10/core/TensorOptions.h>
#include <c10/util/OptionalArrayRef.h>
#include <c10/util/SmallBuffer.h> #include <c10/util/SmallBuffer.h>
#include <c10/util/accumulate.h> #include <c10/util/accumulate.h>
#include <c10/util/irange.h> #include <c10/util/irange.h>
@ -39,7 +38,6 @@ namespace details {
using at::areAnyTensorSubclassLike; using at::areAnyTensorSubclassLike;
using at::IntArrayRef; using at::IntArrayRef;
using at::OptionalIntArrayRef;
using at::Scalar; using at::Scalar;
using at::Tensor; using at::Tensor;
using at::TensorList; using at::TensorList;
@ -537,11 +535,8 @@ Tensor deg2rad_backward(const Tensor& grad) {
return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180))); return at::mul(grad, at::native::wrapped_scalar_tensor(Scalar(M_PI_180)));
} }
Tensor unsqueeze_multiple( Tensor unsqueeze_multiple(const Tensor& t, IntArrayRef dim, size_t n_dims) {
const Tensor& t, auto dims_to_unsqueeze = at::dim_list_to_bitset(dim, n_dims);
OptionalIntArrayRef opt_dim,
size_t n_dims) {
auto dims_to_unsqueeze = at::dim_list_to_bitset(opt_dim, n_dims);
Tensor res = t; Tensor res = t;
for (const auto i : c10::irange(n_dims)) { for (const auto i : c10::irange(n_dims)) {
if (dims_to_unsqueeze[i]) { if (dims_to_unsqueeze[i]) {
@ -554,13 +549,13 @@ Tensor unsqueeze_multiple(
Tensor sum_backward( Tensor sum_backward(
const Tensor& grad, const Tensor& grad,
IntArrayRef sizes, IntArrayRef sizes,
OptionalIntArrayRef opt_dims, IntArrayRef dims,
bool keepdim) { bool keepdim) {
if (!keepdim && sizes.size() > 0) { if (!keepdim && sizes.size() > 0) {
if (opt_dims.has_value() && opt_dims.value().size() == 1) { if (dims.size() == 1) {
return grad.unsqueeze(opt_dims.value()[0]).expand(sizes); return grad.unsqueeze(dims[0]).expand(sizes);
} else { } else {
Tensor res = unsqueeze_multiple(grad, opt_dims, sizes.size()); Tensor res = unsqueeze_multiple(grad, dims, sizes.size());
return res.expand(sizes); return res.expand(sizes);
} }
} else { } else {

View File

@ -146,12 +146,12 @@ at::Tensor rad2deg_backward(const at::Tensor& grad);
at::Tensor deg2rad_backward(const at::Tensor& grad); at::Tensor deg2rad_backward(const at::Tensor& grad);
at::Tensor unsqueeze_multiple( at::Tensor unsqueeze_multiple(
const at::Tensor& t, const at::Tensor& t,
at::OptionalIntArrayRef opt_dim, at::IntArrayRef dim,
size_t n_dims); size_t n_dims);
at::Tensor sum_backward( at::Tensor sum_backward(
const at::Tensor& grad, const at::Tensor& grad,
at::IntArrayRef sizes, at::IntArrayRef sizes,
at::OptionalIntArrayRef opt_dims, at::IntArrayRef dims,
bool keepdim); bool keepdim);
at::Tensor nansum_backward( at::Tensor nansum_backward(
const at::Tensor& grad, const at::Tensor& grad,

View File

@ -2478,7 +2478,7 @@ class IrParser {
{ {
auto ptr_op = getOperatorForLiteral( auto ptr_op = getOperatorForLiteral(
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"); "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)");
REGISTER_PARSE_RULE( REGISTER_PARSE_RULE(
ptr_op, ptr_op,
{ {
@ -3855,7 +3855,7 @@ bool insertProfileIValue(ProfilingRecord* pr, Node* node, size_t offset) {
static auto reduction_operator_schema = static auto reduction_operator_schema =
getOperatorForLiteral( getOperatorForLiteral(
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)") "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)")
->schema(); ->schema();
if (node->matches(reduction_operator_schema)) { if (node->matches(reduction_operator_schema)) {
switch (offset) { switch (offset) {

View File

@ -1980,7 +1980,7 @@ class ShapePropagator : public PropertyPropBase {
return true; return true;
} else if ( } else if (
node->matches( node->matches(
"aten::sum(Tensor self, int[]? dim, bool keepdim, *, int? dtype) -> Tensor", "aten::sum(Tensor self, int[] dim, bool keepdim, *, int? dtype) -> Tensor",
/*const_inputs=*/{attr::dim, attr::keepdim})) { /*const_inputs=*/{attr::dim, attr::keepdim})) {
auto& tp = tensor_types.at(0); auto& tp = tensor_types.at(0);
auto sizes = tp->sizes().concrete_sizes().value(); auto sizes = tp->sizes().concrete_sizes().value();

View File

@ -94,7 +94,7 @@ bool isSupported(Node* node) {
static const OperatorSet supported_reduction_set{ static const OperatorSet supported_reduction_set{
"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor",
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor",
"aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor", "aten::softmax.int(Tensor self, int dim , ScalarType? dtype=None) -> Tensor",
"aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor",
}; };

View File

@ -2158,54 +2158,6 @@ def transpose(self: List[int],
_4 = torch.append(out, self[idx]) _4 = torch.append(out, self[idx])
return out return out
)=====")
+ std::string(R"=====(def sum_dim(self: List[int],
opt_dims: Optional[List[int]],
keep_dim: bool,
dt: Any) -> List[int]:
out = annotate(List[int], [])
if opt_dims is None:
dims:List[int] = []
else:
dims = opt_dims
for idx in range(torch.len(self)):
is_mean_dim = False
for _0 in range(torch.len(dims)):
reduce_dim = dims[_0]
_1 = torch.len(self)
if torch.le(_1, 0):
dim_post_expr = 1
else:
dim_post_expr = _1
min = torch.neg(dim_post_expr)
max = torch.sub(dim_post_expr, 1)
if torch.lt(reduce_dim, min):
_2 = True
else:
_2 = torch.gt(reduce_dim, max)
if torch.__not__(_2):
pass
else:
ops.prim.RaiseException("AssertionError: ")
if torch.lt(reduce_dim, 0):
dim0 = torch.add(reduce_dim, dim_post_expr)
dim = dim0
else:
dim = reduce_dim
if torch.eq(idx, dim):
is_mean_dim0 = True
else:
is_mean_dim0 = is_mean_dim
is_mean_dim = is_mean_dim0
if is_mean_dim:
if keep_dim:
_3 = torch.append(out, 1)
else:
pass
else:
_4 = torch.append(out, self[idx])
return out
)=====") )=====")
+ std::string(R"=====(def max_dim(self: List[int], + std::string(R"=====(def max_dim(self: List[int],
dim: int, dim: int,
@ -2797,7 +2749,7 @@ const OperatorMap<std::string>& GetShapeFunctionMappings() {
{"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"}, {"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "expand"},
{"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"}, {"aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", "expand_one_unused"},
{"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"}, {"aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
{"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "sum_dim"}, {"aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "mean_dim"},
{"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"}, {"aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "max_dim"},
{"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"}, {"aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},
{"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"}, {"aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "zero_dim_tensor"},

View File

@ -1691,10 +1691,10 @@ REGISTER_OPERATOR_FUNCTOR(aten::sum, aten_sum, [](Node* n) -> SROperator {
}; };
} }
if (n->matches(torch::schema( if (n->matches(torch::schema(
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) { "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor"))) {
return [](ProcessedNode* p_node) { return [](ProcessedNode* p_node) {
const at::Tensor& self = p_node->Input(0).toTensor(); const at::Tensor& self = p_node->Input(0).toTensor();
auto dim = p_node->Input(1).toDimVector(); auto dim = p_node->Input(1).toIntList().vec();
auto keepdim = p_node->Input(2).toBool(); auto keepdim = p_node->Input(2).toBool();
auto dtype = p_node->Input(3).toOptional<at::ScalarType>(); auto dtype = p_node->Input(3).toOptional<at::ScalarType>();
if (p_node->Output(0).isNone()) { if (p_node->Output(0).isNone()) {

View File

@ -1767,7 +1767,7 @@ int nnc_lowerings_lazy_registration() {
RegisterNNCLoweringsFunction aten_sum( RegisterNNCLoweringsFunction aten_sum(
{"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)", {"aten::sum(Tensor self, *, int? dtype=None) -> (Tensor)",
"aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"}, "aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, int? dtype=None) -> (Tensor)"},
computeSum); computeSum);
RegisterNNCLoweringsFunction aten_softmax( RegisterNNCLoweringsFunction aten_softmax(

View File

@ -1004,7 +1004,7 @@ add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand) add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand)
add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused) add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused)
add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim) add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim) add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim)
add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor) add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)

View File

@ -18703,6 +18703,9 @@ op_db: List[OpInfo] = [
# FIXME: sum reduces all dimensions when dim=[] # FIXME: sum reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'), DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'), DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
# FIXME: sum does not support passing None to dim
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
# FIXME: improve precision # FIXME: improve precision
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input', DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
dtypes=[torch.float16]), dtypes=[torch.float16]),