mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Raise error for int64 and bool dtypes in nanmean, even for empty tensors (#138745)
This PR ensures that the `nanmean()` function raises a `RuntimeError` when using `int64` or `bool` dtypes, even for empty tensors. Previously, non-empty tensors correctly raised errors for unsupported dtypes, while empty tensors did not. This change brings consistent error handling for both cases. addressing the need raised in an issue by @hyperkai (Issue [#131043](https://github.com/pytorch/pytorch/issues/131043)). ### Changes - Added checks in `nanmean_out()` to raise errors for `int64` and `bool` dtypes regardless of tensor size. Pull Request resolved: https://github.com/pytorch/pytorch/pull/138745 Approved by: https://github.com/ezyang
This commit is contained in:
@ -2252,6 +2252,33 @@ class TestReductions(TestCase):
|
||||
self.assertEqual(x[:, :2].amax().item(), 5)
|
||||
self.assertEqual(x[:, :2].argmax().item(), 2)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*integral_types_and(torch.bool))
|
||||
def test_nanmean_integral_types(self, device, dtype):
|
||||
|
||||
# List of tensor shapes to test
|
||||
shapes = [
|
||||
(),
|
||||
(0,),
|
||||
(1,),
|
||||
(3, 4, 5),
|
||||
(2, 0, 3),
|
||||
(10, 10, 10),
|
||||
(2, 3, 0, 4),
|
||||
(100,),
|
||||
(1, 1, 1),
|
||||
(5, 5, 5, 5, 5),
|
||||
]
|
||||
|
||||
for shape in shapes:
|
||||
# Tensor of the specified shape and dtype
|
||||
t = make_tensor(shape, dtype=dtype, device=device)
|
||||
# Attempt to call torch.nanmean and expect a RuntimeError
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
r"nanmean\(\): expected input to have floating point or complex dtype but got \w+"
|
||||
):
|
||||
torch.nanmean(t)
|
||||
|
||||
@precisionOverride({torch.float16: 1e-2, torch.bfloat16: 1e-2})
|
||||
@dtypes(*set(all_types_and(torch.half, torch.bfloat16)) - {torch.uint8})
|
||||
|
||||
Reference in New Issue
Block a user