[Fix] Fixed behaviour for the conversion of complex tensors to bool (#121803)

Fixes #120875

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121803
Approved by: https://github.com/lezcano
This commit is contained in:
andoorve
2024-03-14 13:35:15 +00:00
committed by PyTorch MergeBot
parent 1251f0fa31
commit 956059fa2e
3 changed files with 39 additions and 1 deletions

View File

@ -33,6 +33,18 @@ class TestComplexTensor(TestCase):
x1.copy_(xc1)
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
@dtypes(*complex_types())
def test_all(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
self.assertTrue(torch.all(x))
@dtypes(*complex_types())
def test_any(self, device, dtype):
# issue: https://github.com/pytorch/pytorch/issues/120875
x = torch.tensor([0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype)
self.assertFalse(torch.any(x))
@onlyCPU
@dtypes(*complex_types())
def test_eq(self, device, dtype):