mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support numpy array in Tensor.__eq__
(#122249)
When the `other` arg of `Tensor.__eq__` is a numpy array, it is converted to a PyTorch tensor view of the numpy array, which is then given as the `other` arg to a `Tensor.eq` call Fixes #119965 Pull Request resolved: https://github.com/pytorch/pytorch/pull/122249 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
bf18e967b4
commit
b915877deb
@ -13,6 +13,7 @@ from torch.testing._internal.common_utils import \
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes, skipMeta)
|
||||
from torch.testing._internal.common_dtype import all_types_and_complex_and
|
||||
from torch.testing import make_tensor
|
||||
|
||||
|
||||
# For testing handling NumPy objects and sending tensors to / accepting
|
||||
@ -497,6 +498,49 @@ class TestNumPyInterop(TestCase):
|
||||
else:
|
||||
self.assertTrue(t == a)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
def test___eq__(self, device, dtype):
|
||||
a = make_tensor((5, 7), dtype=dtype, device=device, low=-9, high=9)
|
||||
b = a.clone().detach()
|
||||
b_np = b.numpy()
|
||||
|
||||
# Check all elements equal
|
||||
res_check = torch.ones_like(a, dtype=torch.bool)
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
# Check one element unequal
|
||||
if dtype == torch.bool:
|
||||
b[1][3] = not b[1][3]
|
||||
else:
|
||||
b[1][3] += 1
|
||||
res_check[1][3] = False
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
# Check random elements unequal
|
||||
rand = torch.randint(0, 2, a.shape, dtype=torch.bool)
|
||||
res_check = rand.logical_not()
|
||||
b.copy_(a)
|
||||
|
||||
if dtype == torch.bool:
|
||||
b[rand] = b[rand].logical_not()
|
||||
else:
|
||||
b[rand] += 1
|
||||
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
# Check all elements unequal
|
||||
if dtype == torch.bool:
|
||||
b.copy_(a.logical_not())
|
||||
else:
|
||||
b.copy_(a + 1)
|
||||
res_check.fill_(False)
|
||||
self.assertEqual(a == b_np, res_check)
|
||||
self.assertEqual(b_np == a, res_check)
|
||||
|
||||
@onlyCPU
|
||||
def test_empty_tensors_interop(self, device):
|
||||
x = torch.rand((), dtype=torch.float16)
|
||||
|
Reference in New Issue
Block a user