mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[MPS] Type-promote tensor-iterator common dtype (#160334)
Otherwise, `torch.add(FloatTensor, IntTensor, alpha=2)` and `torch.add(FloatTensor, IntTensor, alpha=2)` were dispatched to different kernels Fixes https://github.com/pytorch/pytorch/issues/160208 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160334 Approved by: https://github.com/Skylion007, https://github.com/dcci
This commit is contained in:
committed by
PyTorch MergeBot
parent
d0e2240f68
commit
d25c4f954d
@ -7736,6 +7736,8 @@ class TestMPS(TestCaseMPS):
|
||||
y = torch.arange(32, device='mps', dtype=torch.int32)
|
||||
self.assertEqual(torch.add(x, y, alpha=2).cpu(), torch.add(x.cpu(), y.cpu(), alpha=2))
|
||||
self.assertEqual(torch.add(x, 3, alpha=2).cpu(), torch.add(x.cpu(), 3, alpha=2))
|
||||
# Regression test for https://github.com/pytorch/pytorch/issues/160208
|
||||
self.assertEqual(torch.add(y, x, alpha=2).cpu(), torch.add(y.cpu(), x.cpu(), alpha=2))
|
||||
|
||||
# Test add
|
||||
def test_add_scalars(self):
|
||||
|
Reference in New Issue
Block a user