Enable dim=None for torch.sum (#75845)

Part of #29137

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75845
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2022-06-16 20:17:07 +00:00
committed by PyTorch MergeBot
parent c8b073c5c5
commit e79a51f7db
23 changed files with 157 additions and 68 deletions

View File

@ -1004,7 +1004,7 @@ add_shape_compute_mapping("aten::view(Tensor(a) self, int[] size) -> Tensor(a)",
add_shape_compute_mapping("aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", expand)
add_shape_compute_mapping("aten::expand(Tensor(a) self, int[] size, *, bool implicit=False) -> Tensor(a)", expand_one_unused)
add_shape_compute_mapping("aten::mean.dim(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", mean_dim)
add_shape_compute_mapping("aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", max_dim)
add_shape_compute_mapping("aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)
add_shape_compute_mapping("aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", zero_dim_tensor)