Raise error for int64 and bool dtypes in nanmean, even for empty tensors (#138745)

This PR ensures that the `nanmean()` function raises a `RuntimeError` when using `int64` or `bool` dtypes, even for empty tensors. Previously, non-empty tensors correctly raised errors for unsupported dtypes, while empty tensors did not. This change brings consistent error handling for both cases.

addressing the need raised in an issue by @hyperkai  (Issue [#131043](https://github.com/pytorch/pytorch/issues/131043)).

### Changes

- Added checks in `nanmean_out()` to raise errors for `int64` and `bool` dtypes regardless of tensor size.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138745
Approved by: https://github.com/ezyang
This commit is contained in:
axel
2024-11-02 22:52:38 +00:00
committed by PyTorch MergeBot
parent 232af152b5
commit f6e5d09682
2 changed files with 31 additions and 0 deletions

View File

@ -1434,6 +1434,10 @@ Tensor& nanmean_out(
bool keepdim,
std::optional<ScalarType> opt_dtype,
Tensor& result) {
// Check if dtype is an integral type or Bool and raise an error
TORCH_CHECK(
!at::isIntegralType(self.scalar_type(), /*includeBool=*/true),
"nanmean(): integral types and 'Bool' are not supported for nanmean, even for empty tensors.");
TORCH_CHECK(
self.is_floating_point() || self.is_complex(),
"nanmean(): expected input to have floating point or complex dtype but got ",

View File

@ -2252,6 +2252,33 @@ class TestReductions(TestCase):
self.assertEqual(x[:, :2].amax().item(), 5)
self.assertEqual(x[:, :2].argmax().item(), 2)
@onlyCPU
@dtypes(*integral_types_and(torch.bool))
def test_nanmean_integral_types(self, device, dtype):
# List of tensor shapes to test
shapes = [
(),
(0,),
(1,),
(3, 4, 5),
(2, 0, 3),
(10, 10, 10),
(2, 3, 0, 4),
(100,),
(1, 1, 1),
(5, 5, 5, 5, 5),
]
for shape in shapes:
# Tensor of the specified shape and dtype
t = make_tensor(shape, dtype=dtype, device=device)
# Attempt to call torch.nanmean and expect a RuntimeError
with self.assertRaisesRegex(
RuntimeError,
r"nanmean\(\): expected input to have floating point or complex dtype but got \w+"
):
torch.nanmean(t)
@precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
@dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})