mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
[MPS] Fix [nan]median
output for empty tensors (#162846)
It should be `NaN` rather than 0 Added respective checks to `test_empty_tensor` Fixes https://github.com/pytorch/pytorch/issues/162798 Pull Request resolved: https://github.com/pytorch/pytorch/pull/162846 Approved by: https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
ee53ad2dd0
commit
d25c35d2b2
@ -617,6 +617,7 @@ static Tensor median_common_mps(const Tensor& input_t, bool nanmedian) {
|
|||||||
// we allocate 1 here due to MacOS13 bug for gather MPSGraph op, look below for the error
|
// we allocate 1 here due to MacOS13 bug for gather MPSGraph op, look below for the error
|
||||||
Tensor output_t = at::empty({1}, input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
|
Tensor output_t = at::empty({1}, input_t.scalar_type(), std::nullopt, kMPS, std::nullopt, std::nullopt);
|
||||||
if (output_t.numel() == 0 || num_in_elements == 0) {
|
if (output_t.numel() == 0 || num_in_elements == 0) {
|
||||||
|
output_t.fill_(std::numeric_limits<float>::quiet_NaN());
|
||||||
return output_t;
|
return output_t;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11628,6 +11628,9 @@ class TestAdvancedIndexing(TestCaseMPS):
|
|||||||
def test_empty_reduce(self, device="mps"):
|
def test_empty_reduce(self, device="mps"):
|
||||||
x = torch.rand(0, 3, device=device)
|
x = torch.rand(0, 3, device=device)
|
||||||
self.assertTrue(x.mean().isnan())
|
self.assertTrue(x.mean().isnan())
|
||||||
|
self.assertTrue(x.nanmean().isnan())
|
||||||
|
self.assertTrue(x.median().isnan())
|
||||||
|
self.assertTrue(x.nanmedian().isnan())
|
||||||
self.assertEqual(x.count_nonzero(), 0)
|
self.assertEqual(x.count_nonzero(), 0)
|
||||||
self.assertEqual(x.sum(), 0)
|
self.assertEqual(x.sum(), 0)
|
||||||
self.assertEqual(x.nansum(), 0)
|
self.assertEqual(x.nansum(), 0)
|
||||||
|
Reference in New Issue
Block a user