mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Almost there! Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913 Approved by: https://github.com/oulgen
431 lines
14 KiB
Python
431 lines
14 KiB
Python
# mypy: allow-untyped-defs
|
|
# Owner(s): ["module: complex"]
|
|
|
|
import torch
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
instantiate_device_type_tests,
|
|
onlyCPU,
|
|
)
|
|
from torch.testing._internal.common_dtype import complex_types
|
|
from torch.testing._internal.common_utils import run_tests, set_default_dtype, TestCase
|
|
|
|
|
|
devices = (torch.device("cpu"), torch.device("cuda:0"))
|
|
|
|
|
|
class TestComplexTensor(TestCase):
|
|
@dtypes(*complex_types())
|
|
def test_to_list(self, device, dtype):
|
|
# test that the complex float tensor has expected values and
|
|
# there's no garbage value in the resultant list
|
|
self.assertEqual(
|
|
torch.zeros((2, 2), device=device, dtype=dtype).tolist(),
|
|
[[0j, 0j], [0j, 0j]],
|
|
)
|
|
|
|
@dtypes(torch.float32, torch.float64, torch.float16)
|
|
def test_dtype_inference(self, device, dtype):
|
|
# issue: https://github.com/pytorch/pytorch/issues/36834
|
|
with set_default_dtype(dtype):
|
|
x = torch.tensor([3.0, 3.0 + 5.0j], device=device)
|
|
if dtype == torch.float16:
|
|
self.assertEqual(x.dtype, torch.chalf)
|
|
elif dtype == torch.float32:
|
|
self.assertEqual(x.dtype, torch.cfloat)
|
|
else:
|
|
self.assertEqual(x.dtype, torch.cdouble)
|
|
|
|
@dtypes(*complex_types())
|
|
def test_conj_copy(self, device, dtype):
|
|
# issue: https://github.com/pytorch/pytorch/issues/106051
|
|
x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype)
|
|
xc1 = torch.conj(x1)
|
|
x1.copy_(xc1)
|
|
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
|
|
|
|
@dtypes(*complex_types())
|
|
def test_all(self, device, dtype):
|
|
# issue: https://github.com/pytorch/pytorch/issues/120875
|
|
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
|
|
|
self.assertTrue(torch.all(x))
|
|
|
|
@dtypes(*complex_types())
|
|
def test_any(self, device, dtype):
|
|
# issue: https://github.com/pytorch/pytorch/issues/120875
|
|
x = torch.tensor(
|
|
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
|
)
|
|
|
|
self.assertFalse(torch.any(x))
|
|
|
|
@onlyCPU
|
|
@dtypes(*complex_types())
|
|
def test_eq(self, device, dtype):
|
|
"Test eq on complex types"
|
|
nan = float("nan")
|
|
# Non-vectorized operations
|
|
for a, b in (
|
|
(
|
|
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
|
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
|
|
),
|
|
(
|
|
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
|
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
|
|
),
|
|
(
|
|
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
|
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
|
|
),
|
|
):
|
|
actual = torch.eq(a, b)
|
|
expected = torch.tensor([False], device=device, dtype=torch.bool)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.eq(a, a)
|
|
expected = torch.tensor([True], device=device, dtype=torch.bool)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.eq(a, b, out=actual)
|
|
expected = torch.tensor([complex(0)], device=device, dtype=dtype)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.eq(a, a, out=actual)
|
|
expected = torch.tensor([complex(1)], device=device, dtype=dtype)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
# Vectorized operations
|
|
for a, b in (
|
|
(
|
|
torch.tensor(
|
|
[
|
|
-0.0610 - 2.1172j,
|
|
5.1576 + 5.4775j,
|
|
complex(2.8871, nan),
|
|
-6.6545 - 3.7655j,
|
|
-2.7036 - 1.4470j,
|
|
0.3712 + 7.989j,
|
|
-0.0610 - 2.1172j,
|
|
5.1576 + 5.4775j,
|
|
complex(nan, -3.2650),
|
|
-6.6545 - 3.7655j,
|
|
-2.7036 - 1.4470j,
|
|
0.3712 + 7.989j,
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
),
|
|
torch.tensor(
|
|
[
|
|
-6.1278 - 8.5019j,
|
|
0.5886 + 8.8816j,
|
|
complex(2.8871, nan),
|
|
6.3505 + 2.2683j,
|
|
0.3712 + 7.9659j,
|
|
0.3712 + 7.989j,
|
|
-6.1278 - 2.1172j,
|
|
5.1576 + 8.8816j,
|
|
complex(nan, -3.2650),
|
|
6.3505 + 2.2683j,
|
|
0.3712 + 7.9659j,
|
|
0.3712 + 7.989j,
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
),
|
|
),
|
|
):
|
|
actual = torch.eq(a, b)
|
|
expected = torch.tensor(
|
|
[
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
True,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
True,
|
|
],
|
|
device=device,
|
|
dtype=torch.bool,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.eq(a, a)
|
|
expected = torch.tensor(
|
|
[
|
|
True,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
True,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
],
|
|
device=device,
|
|
dtype=torch.bool,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.eq(a, b, out=actual)
|
|
expected = torch.tensor(
|
|
[
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(1),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(1),
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.eq(a, a, out=actual)
|
|
expected = torch.tensor(
|
|
[
|
|
complex(1),
|
|
complex(1),
|
|
complex(0),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(0),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
@onlyCPU
|
|
@dtypes(*complex_types())
|
|
def test_ne(self, device, dtype):
|
|
"Test ne on complex types"
|
|
nan = float("nan")
|
|
# Non-vectorized operations
|
|
for a, b in (
|
|
(
|
|
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
|
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype),
|
|
),
|
|
(
|
|
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
|
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype),
|
|
),
|
|
(
|
|
torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
|
|
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype),
|
|
),
|
|
):
|
|
actual = torch.ne(a, b)
|
|
expected = torch.tensor([True], device=device, dtype=torch.bool)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.ne(a, a)
|
|
expected = torch.tensor([False], device=device, dtype=torch.bool)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.ne(a, b, out=actual)
|
|
expected = torch.tensor([complex(1)], device=device, dtype=dtype)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.ne(a, a, out=actual)
|
|
expected = torch.tensor([complex(0)], device=device, dtype=dtype)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
# Vectorized operations
|
|
for a, b in (
|
|
(
|
|
torch.tensor(
|
|
[
|
|
-0.0610 - 2.1172j,
|
|
5.1576 + 5.4775j,
|
|
complex(2.8871, nan),
|
|
-6.6545 - 3.7655j,
|
|
-2.7036 - 1.4470j,
|
|
0.3712 + 7.989j,
|
|
-0.0610 - 2.1172j,
|
|
5.1576 + 5.4775j,
|
|
complex(nan, -3.2650),
|
|
-6.6545 - 3.7655j,
|
|
-2.7036 - 1.4470j,
|
|
0.3712 + 7.989j,
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
),
|
|
torch.tensor(
|
|
[
|
|
-6.1278 - 8.5019j,
|
|
0.5886 + 8.8816j,
|
|
complex(2.8871, nan),
|
|
6.3505 + 2.2683j,
|
|
0.3712 + 7.9659j,
|
|
0.3712 + 7.989j,
|
|
-6.1278 - 2.1172j,
|
|
5.1576 + 8.8816j,
|
|
complex(nan, -3.2650),
|
|
6.3505 + 2.2683j,
|
|
0.3712 + 7.9659j,
|
|
0.3712 + 7.989j,
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
),
|
|
),
|
|
):
|
|
actual = torch.ne(a, b)
|
|
expected = torch.tensor(
|
|
[
|
|
True,
|
|
True,
|
|
True,
|
|
True,
|
|
True,
|
|
False,
|
|
True,
|
|
True,
|
|
True,
|
|
True,
|
|
True,
|
|
False,
|
|
],
|
|
device=device,
|
|
dtype=torch.bool,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.ne(a, a)
|
|
expected = torch.tensor(
|
|
[
|
|
False,
|
|
False,
|
|
True,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
True,
|
|
False,
|
|
False,
|
|
False,
|
|
],
|
|
device=device,
|
|
dtype=torch.bool,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.ne(a, b, out=actual)
|
|
expected = torch.tensor(
|
|
[
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(0),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(1),
|
|
complex(0),
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
actual = torch.full_like(b, complex(2, 2))
|
|
torch.ne(a, a, out=actual)
|
|
expected = torch.tensor(
|
|
[
|
|
complex(0),
|
|
complex(0),
|
|
complex(1),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
complex(1),
|
|
complex(0),
|
|
complex(0),
|
|
complex(0),
|
|
],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.assertEqual(
|
|
actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}"
|
|
)
|
|
|
|
|
|
instantiate_device_type_tests(TestComplexTensor, globals())
|
|
|
|
if __name__ == "__main__":
|
|
TestCase._default_dtype_check_enabled = True
|
|
run_tests()
|