mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MPS] fix empty input in posneg functions (#161824)
fix empty posneg function for mps: ```python import torch input_tensor = torch.empty(0, device="mps") out_pos = torch.isposinf(input_tensor) ``` Gives: ``` RuntimeError: [srcBuf length] > 0 INTERNAL ASSERT FAILED at "/Users/Irakli_Salia/Desktop/pytorch/aten/src/ATen/native/mps/OperationUtils.mm":551, please report a bug to PyTorch. Placeholder tensor is empty! ``` on main branch Pull Request resolved: https://github.com/pytorch/pytorch/pull/161824 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
3e459491b5
commit
3daf20f8e1
@ -8017,6 +8017,14 @@ class TestMPS(TestCaseMPS):
|
||||
x[::2].bitwise_not_()
|
||||
self.assertEqual(x_mps.cpu(), x_cpu)
|
||||
|
||||
def test_empty_posneginf(self):
|
||||
# just to check that it doesnt crash
|
||||
input_tensor = torch.empty(0, device="mps")
|
||||
out_pos = torch.isposinf(input_tensor)
|
||||
out_neg = torch.isposinf(input_tensor)
|
||||
self.assertEqual(out_pos.numel(), 0)
|
||||
self.assertEqual(out_neg.numel(), 0)
|
||||
|
||||
|
||||
class TestLargeTensors(TestCaseMPS):
|
||||
@serialTest()
|
||||
|
Reference in New Issue
Block a user