remove unneeded overload for nansum

Per title, fixes a few OpInfo skips

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76356
Approved by: https://github.com/albanD, https://github.com/mruberry
This commit is contained in:
Natalia Gimelshein
2022-04-27 16:02:54 +00:00
committed by PyTorch MergeBot
parent 122999919c
commit 04b3313379
6 changed files with 9 additions and 26 deletions

View File

@ -1082,10 +1082,6 @@ Tensor& nansum_out(const Tensor& self, IntArrayRef dim,
return result;
}
Tensor nansum(const Tensor &self, c10::optional<ScalarType> dtype) {
return at::native::nansum(self, std::vector<int64_t>{}, false, dtype);
}
Tensor nansum(const Tensor& self, IntArrayRef dim, bool keepdim, c10::optional<ScalarType> opt_dtype) {
ScalarType dtype = get_dtype_from_self(self, opt_dtype, true);
Tensor result = create_reduction_result(self, dim, keepdim, dtype);

View File

@ -4359,17 +4359,12 @@
- func: sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
device_check: NoCheck # TensorIterator
- func: nansum(Tensor self, *, ScalarType? dtype=None) -> Tensor
- func: nansum(Tensor self, int[1] dim=[], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: nansum
- func: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
variants: function, method
dispatch:
CPU, CUDA: nansum
- func: nansum.IntList_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
- func: nansum.out(Tensor self, int[1] dim=[], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)
dispatch:
CPU, CUDA: nansum_out

Binary file not shown.

View File

@ -116,6 +116,7 @@ ALLOW_LIST = [
("prim::infer_squeeze_size.dim", datetime.date(9999, 1, 1)),
("prim::infer_squeeze_size", datetime.date(9999, 1, 1)),
("aten::_cat", datetime.date(2022, 5, 15)),
("aten::nansum", datetime.date(2022, 5, 15)),
("aten::zero", datetime.date(2022, 5, 15)),
]

View File

@ -1495,10 +1495,7 @@
self: sum_backward(grad, self.sizes(), dim, keepdim)
result: auto_linear
- name: nansum(Tensor self, *, ScalarType? dtype=None) -> Tensor
self: grad.expand(self.sizes()).to(self.scalar_type()) * self.isnan().logical_not()
- name: nansum.dim_IntList(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
- name: nansum(Tensor self, int[1] dim=[], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor
self: nansum_backward(grad.to(self.scalar_type()), self, dim, keepdim)
# We never call _linalg_svd with compute_uv=False in an autograd context, so we don't even consider it here

View File

@ -16221,24 +16221,18 @@ op_db: List[OpInfo] = [
'nansum',
identity=0,
nan_policy='omit',
supports_out=False,
supports_out=True,
promotes_int_to_int64=True,
dtypes=all_types_and(torch.bool, torch.float16, torch.bfloat16),
ref=reference_reduction_numpy(np.nansum),
skips=(
# FIXME: nansum does not support passing keepdim without passing dim
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_default_keepdim'),
# FIXME: nansum reduces all dimensions when dim=[]
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_empty_keepdim'),
# FIXME: nansum does not support passing None to dim
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none'),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_dim_none_keepdim'),
# FIXME: improve precision
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty'),
DecorateInfo(unittest.expectedFailure, 'TestReductions', 'test_dim_empty_keepdim'),
# FIXME: flaky test so skipped instead of xfailed
# possibly bad low precision reference in numpy
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_small_input',
dtypes=[torch.float16]),
DecorateInfo(unittest.skip("Skipped!"), 'TestReductions', 'test_ref_duplicate_values',
dtypes=[torch.float16]),
),
),
ReductionOpInfo(