[MPS] Fix [nan]median output for empty tensors (#162846)

It should be `NaN` rather than 0

Added respective checks to `test_empty_tensor`

Fixes https://github.com/pytorch/pytorch/issues/162798
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162846
Approved by: https://github.com/dcci
This commit is contained in:
Nikita Shulga
2025-09-12 13:33:49 -07:00
committed by PyTorch MergeBot
parent ee53ad2dd0
commit d25c35d2b2
2 changed files with 4 additions and 0 deletions

View File

@ -617,6 +617,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
// we allocate 1 here due to MacOS13 bug for gather MPSGraph op, look below for the error // we allocate 1 here due to MacOS13 bug for gather MPSGraph op, look below for the error
Tensor output_t = at::empty({1}, input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt); Tensor output_t = at::empty({1}, input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
if (output_t.numel() == 0 || num_in_elements == 0) { if (output_t.numel() == 0 || num_in_elements == 0) {
output_t.fill_(std::numeric_limits<float>::quiet_NaN());
return output_t; return output_t;
} }

View File

@ -11628,6 +11628,9 @@ class TestAdvancedIndexing(TestCaseMPS):
def test_empty_reduce(self, device="mps"): def test_empty_reduce(self, device="mps"):
x = torch.rand(0, 3, device=device) x = torch.rand(0, 3, device=device)
self.assertTrue(x.mean().isnan()) self.assertTrue(x.mean().isnan())
self.assertTrue(x.nanmean().isnan())
self.assertTrue(x.median().isnan())
self.assertTrue(x.nanmedian().isnan())
self.assertEqual(x.count_nonzero(), 0) self.assertEqual(x.count_nonzero(), 0)
self.assertEqual(x.sum(), 0) self.assertEqual(x.sum(), 0)
self.assertEqual(x.nansum(), 0) self.assertEqual(x.nansum(), 0)