mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix error check for torch.var on scalar (#160889)
Fixes https://github.com/pytorch/pytorch/issues/160738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160889 Approved by: https://github.com/Skylion007 ghstack dependencies: #160850
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6333f7dae
commit
b0071c65e2
@ -5321,6 +5321,9 @@ class TestMPS(TestCaseMPS):
|
||||
|
||||
helper()
|
||||
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/160738
|
||||
self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item())
|
||||
|
||||
# Test forward amax
|
||||
def test_amax(self):
|
||||
def helper(shape, dim, keepdim):
|
||||
|
Reference in New Issue
Block a user