mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] Fix error check for torch.var on scalar (#160889)
Fixes https://github.com/pytorch/pytorch/issues/160738 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160889 Approved by: https://github.com/Skylion007 ghstack dependencies: #160850
This commit is contained in:
committed by
PyTorch MergeBot
parent
c6333f7dae
commit
b0071c65e2
@ -456,7 +456,7 @@ static Tensor std_var_common_impl_mps(const Tensor& input_t,
|
||||
errMessage += ": reduction dim must be in the range of input shape";
|
||||
for (const auto dim : dim_value) {
|
||||
auto wrap_dim = maybe_wrap_dim(dim, num_input_dims);
|
||||
TORCH_CHECK(wrap_dim < static_cast<decltype(wrap_dim)>(input_shape.size()), errMessage.c_str())
|
||||
TORCH_CHECK(wrap_dim < (num_input_dims ? num_input_dims : 1), errMessage.c_str())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5321,6 +5321,9 @@ class TestMPS(TestCaseMPS):
|
||||
|
||||
helper()
|
||||
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/160738
|
||||
self.assertTrue(torch.var(torch.tensor(3.13, device='mps'), dim=0).isnan().item())
|
||||
|
||||
# Test forward amax
|
||||
def test_amax(self):
|
||||
def helper(shape, dim, keepdim):
|
||||
|
Reference in New Issue
Block a user