[MPS] Fix error check for torch.var on scalar (#160889)

Fixes https://github.com/pytorch/pytorch/issues/160738
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160889
Approved by: https://github.com/Skylion007
ghstack dependencies: #160850
This commit is contained in:
Nikita Shulga
2025-08-18 08:30:40 -07:00
committed by PyTorch MergeBot
parent c6333f7dae
commit b0071c65e2
2 changed files with 4 additions and 1 deletions

View File

@ -5321,6 +5321,9 @@ class TestMPS(TestCaseMPS):
helper()
# Regression test for https://github.com/pytorch/pytorch/issues/160738
self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item())
# Test forward amax
def test_amax(self):
def helper(shape, dim, keepdim):