Support numpy array in Tensor.__eq__ (#122249)

When the `other` arg of `Tensor.__eq__` is a numpy array, it is converted to a PyTorch tensor view of the numpy array, which is then given as the `other` arg to a `Tensor.eq` call

Fixes #119965
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122249
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2024-03-21 01:58:06 +00:00
committed by PyTorch MergeBot
parent bf18e967b4
commit b915877deb
2 changed files with 81 additions and 1 deletions

View File

@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import \
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing import make_tensor
# For testing handling NumPy objects and sending tensors to / accepting
@ -497,6 +498,49 @@ class TestNumPyInterop(TestCase):
else:
self.assertTrue(t == a)
@onlyCPU
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
def test___eq__(self, device, dtype):
a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
b = a.clone().detach()
b_np = b.numpy()
# Check all elements equal
res_check = torch.ones_like(a, dtype=torch.bool)
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
# Check one element unequal
if dtype == torch.bool:
b[1][3] = not b[1][3]
else:
b[1][3] += 1
res_check[1][3] = False
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
# Check random elements unequal
rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
res_check = rand.logical_not()
b.copy_(a)
if dtype == torch.bool:
b[rand] = b[rand].logical_not()
else:
b[rand] += 1
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
# Check all elements unequal
if dtype == torch.bool:
b.copy_(a.logical_not())
else:
b.copy_(a + 1)
res_check.fill_(False)
self.assertEqual(a == b_np, res_check)
self.assertEqual(b_np == a, res_check)
@onlyCPU
def test_empty_tensors_interop(self, device):
x = torch.rand((), dtype=torch.float16)