mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This PR is part of a series attempting to re-submit https://github.com/pytorch/pytorch/pull/134592 as smaller PRs. In jit tests: - Add and use a common raise_on_run_directly method for when a user runs a test file directly which should not be run this way. Print the file which the user should have run. - Raise a RuntimeError on tests which have been disabled (not run) Pull Request resolved: https://github.com/pytorch/pytorch/pull/154725 Approved by: https://github.com/Skylion007
624 lines
15 KiB
Python
624 lines
15 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import cmath
|
|
import os
|
|
import sys
|
|
from itertools import product
|
|
from textwrap import dedent
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
from torch.testing._internal.common_utils import IS_MACOS, raise_on_run_directly
|
|
from torch.testing._internal.jit_utils import execWrapper, JitTestCase
|
|
|
|
|
|
# Make the helper files in test/ importable
|
|
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
|
sys.path.append(pytorch_test_dir)
|
|
|
|
|
|
class TestComplex(JitTestCase):
|
|
def test_script(self):
|
|
def fn(a: complex):
|
|
return a
|
|
|
|
self.checkScript(fn, (3 + 5j,))
|
|
|
|
def test_complexlist(self):
|
|
def fn(a: List[complex], idx: int):
|
|
return a[idx]
|
|
|
|
input = [1j, 2, 3 + 4j, -5, -7j]
|
|
self.checkScript(fn, (input, 2))
|
|
|
|
def test_complexdict(self):
|
|
def fn(a: Dict[complex, complex], key: complex) -> complex:
|
|
return a[key]
|
|
|
|
input = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
|
|
self.checkScript(fn, (input, -4.3 - 2j))
|
|
|
|
def test_pickle(self):
|
|
class ComplexModule(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = 3 + 5j
|
|
self.b = [2 + 3j, 3 + 4j, 0 - 3j, -4 + 0j]
|
|
self.c = {2 + 3j: 2 - 3j, -4.3 - 2j: 3j}
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, b: int):
|
|
return b + 2j
|
|
|
|
loaded = self.getExportImportCopy(ComplexModule())
|
|
self.assertEqual(loaded.a, 3 + 5j)
|
|
self.assertEqual(loaded.b, [2 + 3j, 3 + 4j, -3j, -4])
|
|
self.assertEqual(loaded.c, {2 + 3j: 2 - 3j, -4.3 - 2j: 3j})
|
|
self.assertEqual(loaded(2), 2 + 2j)
|
|
|
|
def test_complex_parse(self):
|
|
def fn(a: int, b: torch.Tensor, dim: int):
|
|
# verifies `emitValueToTensor()` 's behavior
|
|
b[dim] = 2.4 + 0.5j
|
|
return (3 * 2j) + a + 5j - 7.4j - 4
|
|
|
|
t1 = torch.tensor(1)
|
|
t2 = torch.tensor([0.4, 1.4j, 2.35])
|
|
|
|
self.checkScript(fn, (t1, t2, 2))
|
|
|
|
def test_complex_constants_and_ops(self):
|
|
vals = (
|
|
[0.0, 1.0, 2.2, -1.0, -0.0, -2.2, 1, 0, 2]
|
|
+ [10.0**i for i in range(2)]
|
|
+ [-(10.0**i) for i in range(2)]
|
|
)
|
|
complex_vals = tuple(complex(x, y) for x, y in product(vals, vals))
|
|
|
|
funcs_template = dedent(
|
|
"""
|
|
def func(a: complex):
|
|
return cmath.{func_or_const}(a)
|
|
"""
|
|
)
|
|
|
|
def checkCmath(func_name, funcs_template=funcs_template):
|
|
funcs_str = funcs_template.format(func_or_const=func_name)
|
|
scope = {}
|
|
execWrapper(funcs_str, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(funcs_str)
|
|
f_script = cu.func
|
|
f = scope["func"]
|
|
|
|
if func_name in ["isinf", "isnan", "isfinite"]:
|
|
new_vals = vals + ([float("inf"), float("nan"), -1 * float("inf")])
|
|
final_vals = tuple(
|
|
complex(x, y) for x, y in product(new_vals, new_vals)
|
|
)
|
|
else:
|
|
final_vals = complex_vals
|
|
|
|
for a in final_vals:
|
|
res_python = None
|
|
res_script = None
|
|
try:
|
|
res_python = f(a)
|
|
except Exception as e:
|
|
res_python = e
|
|
try:
|
|
res_script = f_script(a)
|
|
except Exception as e:
|
|
res_script = e
|
|
|
|
if res_python != res_script:
|
|
if isinstance(res_python, Exception):
|
|
continue
|
|
|
|
msg = f"Failed on {func_name} with input {a}. Python: {res_python}, Script: {res_script}"
|
|
self.assertEqual(res_python, res_script, msg=msg)
|
|
|
|
unary_ops = [
|
|
"log",
|
|
"log10",
|
|
"sqrt",
|
|
"exp",
|
|
"sin",
|
|
"cos",
|
|
"asin",
|
|
"acos",
|
|
"atan",
|
|
"sinh",
|
|
"cosh",
|
|
"tanh",
|
|
"asinh",
|
|
"acosh",
|
|
"atanh",
|
|
"phase",
|
|
"isinf",
|
|
"isnan",
|
|
"isfinite",
|
|
]
|
|
|
|
# --- Unary ops ---
|
|
for op in unary_ops:
|
|
checkCmath(op)
|
|
|
|
def fn(x: complex):
|
|
return abs(x)
|
|
|
|
for val in complex_vals:
|
|
self.checkScript(fn, (val,))
|
|
|
|
def pow_complex_float(x: complex, y: float):
|
|
return pow(x, y)
|
|
|
|
def pow_float_complex(x: float, y: complex):
|
|
return pow(x, y)
|
|
|
|
self.checkScript(pow_float_complex, (2, 3j))
|
|
self.checkScript(pow_complex_float, (3j, 2))
|
|
|
|
def pow_complex_complex(x: complex, y: complex):
|
|
return pow(x, y)
|
|
|
|
for x, y in zip(complex_vals, complex_vals):
|
|
# Reference: https://github.com/pytorch/pytorch/issues/54622
|
|
if x == 0:
|
|
continue
|
|
self.checkScript(pow_complex_complex, (x, y))
|
|
|
|
if not IS_MACOS:
|
|
# --- Binary op ---
|
|
def rect_fn(x: float, y: float):
|
|
return cmath.rect(x, y)
|
|
|
|
for x, y in product(vals, vals):
|
|
self.checkScript(
|
|
rect_fn,
|
|
(
|
|
x,
|
|
y,
|
|
),
|
|
)
|
|
|
|
func_constants_template = dedent(
|
|
"""
|
|
def func():
|
|
return cmath.{func_or_const}
|
|
"""
|
|
)
|
|
float_consts = ["pi", "e", "tau", "inf", "nan"]
|
|
complex_consts = ["infj", "nanj"]
|
|
for x in float_consts + complex_consts:
|
|
checkCmath(x, funcs_template=func_constants_template)
|
|
|
|
def test_infj_nanj_pickle(self):
|
|
class ComplexModule(torch.jit.ScriptModule):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = 3 + 5j
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, infj: int, nanj: int):
|
|
if infj == 2:
|
|
return infj + cmath.infj
|
|
else:
|
|
return nanj + cmath.nanj
|
|
|
|
loaded = self.getExportImportCopy(ComplexModule())
|
|
self.assertEqual(loaded(2, 3), 2 + cmath.infj)
|
|
self.assertEqual(loaded(3, 4), 4 + cmath.nanj)
|
|
|
|
def test_complex_constructor(self):
|
|
# Test all scalar types
|
|
def fn_int(real: int, img: int):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_int,
|
|
(
|
|
0,
|
|
0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int,
|
|
(
|
|
-1234,
|
|
0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int,
|
|
(
|
|
0,
|
|
-1256,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int,
|
|
(
|
|
-167,
|
|
-1256,
|
|
),
|
|
)
|
|
|
|
def fn_float(real: float, img: float):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_float,
|
|
(
|
|
0.0,
|
|
0.0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float,
|
|
(
|
|
-1234.78,
|
|
0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float,
|
|
(
|
|
0,
|
|
56.18,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float,
|
|
(
|
|
-1.9,
|
|
-19.8,
|
|
),
|
|
)
|
|
|
|
def fn_bool(real: bool, img: bool):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_bool,
|
|
(
|
|
True,
|
|
True,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool,
|
|
(
|
|
False,
|
|
False,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool,
|
|
(
|
|
False,
|
|
True,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool,
|
|
(
|
|
True,
|
|
False,
|
|
),
|
|
)
|
|
|
|
def fn_bool_int(real: bool, img: int):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_bool_int,
|
|
(
|
|
True,
|
|
0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool_int,
|
|
(
|
|
False,
|
|
0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool_int,
|
|
(
|
|
False,
|
|
-1,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool_int,
|
|
(
|
|
True,
|
|
3,
|
|
),
|
|
)
|
|
|
|
def fn_int_bool(real: int, img: bool):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_int_bool,
|
|
(
|
|
0,
|
|
True,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int_bool,
|
|
(
|
|
0,
|
|
False,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int_bool,
|
|
(
|
|
-3,
|
|
True,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int_bool,
|
|
(
|
|
6,
|
|
False,
|
|
),
|
|
)
|
|
|
|
def fn_bool_float(real: bool, img: float):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_bool_float,
|
|
(
|
|
True,
|
|
0.0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool_float,
|
|
(
|
|
False,
|
|
0.0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool_float,
|
|
(
|
|
False,
|
|
-1.0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_bool_float,
|
|
(
|
|
True,
|
|
3.0,
|
|
),
|
|
)
|
|
|
|
def fn_float_bool(real: float, img: bool):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_float_bool,
|
|
(
|
|
0.0,
|
|
True,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float_bool,
|
|
(
|
|
0.0,
|
|
False,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float_bool,
|
|
(
|
|
-3.0,
|
|
True,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float_bool,
|
|
(
|
|
6.0,
|
|
False,
|
|
),
|
|
)
|
|
|
|
def fn_float_int(real: float, img: int):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_float_int,
|
|
(
|
|
0.0,
|
|
1,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float_int,
|
|
(
|
|
0.0,
|
|
-1,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float_int,
|
|
(
|
|
1.8,
|
|
-3,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_float_int,
|
|
(
|
|
2.7,
|
|
8,
|
|
),
|
|
)
|
|
|
|
def fn_int_float(real: int, img: float):
|
|
return complex(real, img)
|
|
|
|
self.checkScript(
|
|
fn_int_float,
|
|
(
|
|
1,
|
|
0.0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int_float,
|
|
(
|
|
-1,
|
|
1.7,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int_float,
|
|
(
|
|
-3,
|
|
0.0,
|
|
),
|
|
)
|
|
self.checkScript(
|
|
fn_int_float,
|
|
(
|
|
2,
|
|
-8.9,
|
|
),
|
|
)
|
|
|
|
def test_torch_complex_constructor_with_tensor(self):
|
|
tensors = [torch.rand(1), torch.randint(-5, 5, (1,)), torch.tensor([False])]
|
|
|
|
def fn_tensor_float(real, img: float):
|
|
return complex(real, img)
|
|
|
|
def fn_tensor_int(real, img: int):
|
|
return complex(real, img)
|
|
|
|
def fn_tensor_bool(real, img: bool):
|
|
return complex(real, img)
|
|
|
|
def fn_float_tensor(real: float, img):
|
|
return complex(real, img)
|
|
|
|
def fn_int_tensor(real: int, img):
|
|
return complex(real, img)
|
|
|
|
def fn_bool_tensor(real: bool, img):
|
|
return complex(real, img)
|
|
|
|
for tensor in tensors:
|
|
self.checkScript(fn_tensor_float, (tensor, 1.2))
|
|
self.checkScript(fn_tensor_int, (tensor, 3))
|
|
self.checkScript(fn_tensor_bool, (tensor, True))
|
|
|
|
self.checkScript(fn_float_tensor, (1.2, tensor))
|
|
self.checkScript(fn_int_tensor, (3, tensor))
|
|
self.checkScript(fn_bool_tensor, (True, tensor))
|
|
|
|
def fn_tensor_tensor(real, img):
|
|
return complex(real, img) + complex(2)
|
|
|
|
for x, y in product(tensors, tensors):
|
|
self.checkScript(
|
|
fn_tensor_tensor,
|
|
(
|
|
x,
|
|
y,
|
|
),
|
|
)
|
|
|
|
def test_comparison_ops(self):
|
|
def fn1(a: complex, b: complex):
|
|
return a == b
|
|
|
|
def fn2(a: complex, b: complex):
|
|
return a != b
|
|
|
|
def fn3(a: complex, b: float):
|
|
return a == b
|
|
|
|
def fn4(a: complex, b: float):
|
|
return a != b
|
|
|
|
x, y = 2 - 3j, 4j
|
|
self.checkScript(fn1, (x, x))
|
|
self.checkScript(fn1, (x, y))
|
|
self.checkScript(fn2, (x, x))
|
|
self.checkScript(fn2, (x, y))
|
|
|
|
x1, y1 = 1 + 0j, 1.0
|
|
self.checkScript(fn3, (x1, y1))
|
|
self.checkScript(fn4, (x1, y1))
|
|
|
|
def test_div(self):
|
|
def fn1(a: complex, b: complex):
|
|
return a / b
|
|
|
|
x, y = 2 - 3j, 4j
|
|
self.checkScript(fn1, (x, y))
|
|
|
|
def test_complex_list_sum(self):
|
|
def fn(x: List[complex]):
|
|
return sum(x)
|
|
|
|
self.checkScript(fn, (torch.randn(4, dtype=torch.cdouble).tolist(),))
|
|
|
|
def test_tensor_attributes(self):
|
|
def tensor_real(x):
|
|
return x.real
|
|
|
|
def tensor_imag(x):
|
|
return x.imag
|
|
|
|
t = torch.randn(2, 3, dtype=torch.cdouble)
|
|
self.checkScript(tensor_real, (t,))
|
|
self.checkScript(tensor_imag, (t,))
|
|
|
|
def test_binary_op_complex_tensor(self):
|
|
def mul(x: complex, y: torch.Tensor):
|
|
return x * y
|
|
|
|
def add(x: complex, y: torch.Tensor):
|
|
return x + y
|
|
|
|
def eq(x: complex, y: torch.Tensor):
|
|
return x == y
|
|
|
|
def ne(x: complex, y: torch.Tensor):
|
|
return x != y
|
|
|
|
def sub(x: complex, y: torch.Tensor):
|
|
return x - y
|
|
|
|
def div(x: complex, y: torch.Tensor):
|
|
return x - y
|
|
|
|
ops = [mul, add, eq, ne, sub, div]
|
|
|
|
for shape in [(1,), (2, 2)]:
|
|
x = 0.71 + 0.71j
|
|
y = torch.randn(shape, dtype=torch.cfloat)
|
|
for op in ops:
|
|
eager_result = op(x, y)
|
|
scripted = torch.jit.script(op)
|
|
jit_result = scripted(x, y)
|
|
self.assertEqual(eager_result, jit_result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise_on_run_directly("test/test_jit.py")
|