mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Currently, running explain with TORCH_LOGS enabled will cause duplicate loggings because explain uses the exact same code path for covnersion. This PR just disables logging when it is running explain. And move all logging to convert() to prevent from logging from __init__ when we are just using explain. Test Plan: Manual testing with attached outputs. Reviewed By: SherlockNoMad, angelayi Differential Revision: D60199007 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132082 Approved by: https://github.com/ydwu4
1232 lines
41 KiB
Python
1232 lines
41 KiB
Python
# Owner(s): ["oncall: export"]
|
|
|
|
import unittest
|
|
from collections import OrderedDict
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
import torch
|
|
import torch.utils._pytree as pytree
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch._export.converter import TS2EPConverter
|
|
from torch.export import ExportedProgram
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
|
|
|
|
|
|
class TestConverter(TestCase):
|
|
def _check_equal_ts_ep_converter(
|
|
self,
|
|
M,
|
|
inp,
|
|
option: Union[List[str]] = None,
|
|
check_persistent=False,
|
|
lifted_tensor_constants=None,
|
|
) -> ExportedProgram:
|
|
# By default, it tests both jit.trace and jit.script.
|
|
if option is None:
|
|
option = ["trace", "script"]
|
|
|
|
if check_persistent:
|
|
num_iterations = 10
|
|
else:
|
|
num_iterations = 1
|
|
|
|
ep_list = []
|
|
for opt in option:
|
|
if opt == "script":
|
|
# Separate two models for testing non-functional effects
|
|
if check_persistent:
|
|
original_ts_model = torch.jit.script(M())
|
|
ts_model = torch.jit.script(M())
|
|
eager_model = M()
|
|
else:
|
|
original_ts_model = torch.jit.script(M)
|
|
ts_model = torch.jit.script(M)
|
|
eager_model = M
|
|
elif opt == "trace":
|
|
if check_persistent:
|
|
original_ts_model = torch.jit.trace(M(), inp)
|
|
ts_model = torch.jit.trace(M(), inp)
|
|
eager_model = M()
|
|
else:
|
|
original_ts_model = torch.jit.trace(M, inp)
|
|
ts_model = torch.jit.trace(M, inp)
|
|
eager_model = M
|
|
else:
|
|
raise RuntimeError(f"Unrecognized mode for torch.jit: {opt}")
|
|
|
|
ep = TS2EPConverter(ts_model, inp).convert()
|
|
ep_list.append(ep)
|
|
|
|
for _ in range(num_iterations):
|
|
orig_out, _ = pytree.tree_flatten(original_ts_model(*inp))
|
|
ep_out, _ = pytree.tree_flatten(ep.module()(*inp))
|
|
|
|
# Check module.
|
|
if isinstance(eager_model, torch.nn.Module):
|
|
expected_state_dict = OrderedDict()
|
|
expected_state_dict.update(ts_model.state_dict())
|
|
if lifted_tensor_constants:
|
|
expected_state_dict.update(lifted_tensor_constants)
|
|
self.assertEqual(
|
|
ep.state_dict.keys(),
|
|
expected_state_dict.keys(),
|
|
)
|
|
|
|
# Check results
|
|
self._check_tensor_list_equal(ep_out, orig_out)
|
|
return ep_list
|
|
|
|
def _check_tensor_list_equal(self, xs: List[torch.Tensor], ys: List[torch.Tensor]):
|
|
self.assertEqual(len(xs), len(ys))
|
|
for x, y in zip(xs, ys):
|
|
if isinstance(x, torch.Tensor) and isinstance(y, torch.Tensor):
|
|
self.assertEqual(x.shape, y.shape)
|
|
self.assertTrue(torch.allclose(x, y))
|
|
else:
|
|
self.assertEqual(type(x), type(y))
|
|
self.assertEqual(x, y)
|
|
|
|
def test_ts2ep_converter_basic(self):
|
|
class MSingle(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
class MMulti(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x = x.cos() + 1
|
|
y = y.sin() - 1
|
|
return x, y
|
|
|
|
inp = (torch.ones(1, 3), torch.ones(1, 3))
|
|
self._check_equal_ts_ep_converter(MSingle(), inp)
|
|
self._check_equal_ts_ep_converter(MMulti(), inp)
|
|
|
|
def test_ts2ep_converter_container_output(self):
|
|
# Output is a List.
|
|
class MOutputList(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
a = x * x
|
|
b = y + y
|
|
return [a, b]
|
|
|
|
# Output is a Tuple.
|
|
class MOutputTuple(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
a = x * x
|
|
b = y + y
|
|
return (a, b)
|
|
|
|
# Output is a Dict.
|
|
class MOutputDict(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
a = x * x
|
|
b = y + y
|
|
return {"data": {"mul": a, "add": b}}
|
|
|
|
inp = (torch.tensor(4), torch.tensor(4))
|
|
|
|
# Traced function must use immutable structure as output.
|
|
self._check_equal_ts_ep_converter(MOutputList(), inp, ["script"])
|
|
self._check_equal_ts_ep_converter(MOutputTuple(), inp)
|
|
self._check_equal_ts_ep_converter(MOutputDict(), inp, ["script"])
|
|
|
|
def test_aten_dim(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
num_dim = x.dim()
|
|
return torch.ones(num_dim)
|
|
|
|
inp = (torch.ones(1, 3),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_aten_len(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
length = len(x)
|
|
return torch.ones(length)
|
|
|
|
# aten::len.Tensor
|
|
inp = (torch.ones(2, 3),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: List[int]):
|
|
length = len(x)
|
|
return torch.ones(length)
|
|
|
|
# aten::len.t
|
|
inp = ([1, 2, 3],)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: Dict[int, str]):
|
|
length = len(x)
|
|
return torch.ones(length)
|
|
|
|
# aten::len.Dict_int
|
|
inp = ({1: "a", 2: "b", 3: "c"},)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: Dict[bool, str]):
|
|
length = len(x)
|
|
return torch.ones(length)
|
|
|
|
# aten::len.Dict_bool
|
|
inp = ({True: "a", False: "b"},)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: Dict[float, str]):
|
|
length = len(x)
|
|
return torch.ones(length)
|
|
|
|
# aten::len.Dict_float
|
|
inp = ({1.2: "a", 3.4: "b"},)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: Dict[torch.Tensor, str]):
|
|
length = len(x)
|
|
return torch.ones(length)
|
|
|
|
# aten::len.Dict_Tensor
|
|
inp = ({torch.zeros(2, 3): "a", torch.ones(2, 3): "b"},)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
# aten::len.str and aten::len.Dict_str are not supported
|
|
# since torch._C._jit_flatten does not support str
|
|
# inp = ("abcdefg",)
|
|
# self._check_equal_ts_ep_converter(Module(), inp)
|
|
# inp = ({"a": 1, "b": 2},)
|
|
# self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_prim_min(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
x_len = len(x)
|
|
y_len = len(y)
|
|
|
|
# prim::min.int
|
|
len_int = min(x_len, y_len)
|
|
|
|
# prim::min.float
|
|
len_float = int(min(x_len * 2.0, y_len * 2.0))
|
|
|
|
# prim::min.self_int
|
|
len_self_int = min([x_len, y_len])
|
|
|
|
# prim::min.self_float
|
|
len_self_float = int(min([x_len * 2.0, y_len * 2.0]))
|
|
|
|
# prim::min.float_int
|
|
len_float_int = int(min(x_len * 2.0, y_len))
|
|
|
|
# prim::min.int_float
|
|
len_int_float = int(min(x_len, y_len * 2.0))
|
|
|
|
return torch.ones(
|
|
len_int
|
|
+ len_float
|
|
+ len_self_int
|
|
+ len_self_float
|
|
+ len_float_int
|
|
+ len_int_float
|
|
)
|
|
|
|
inp = (torch.randn(10, 2), torch.randn(5))
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_prim_max(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
x_len = len(x)
|
|
y_len = len(y)
|
|
|
|
# prim::max.int
|
|
len_int = max(x_len, y_len)
|
|
|
|
# prim::max.float
|
|
len_float = int(max(x_len * 2.0, y_len * 2.0))
|
|
|
|
# prim::max.self_int
|
|
len_self_int = max([x_len, y_len])
|
|
|
|
# prim::max.self_float
|
|
len_self_float = int(max([x_len * 2.0, y_len * 2.0]))
|
|
|
|
# prim::max.float_int
|
|
len_float_int = int(max(x_len * 2.0, y_len))
|
|
|
|
# prim::max.int_float
|
|
len_int_float = int(max(x_len, y_len * 2.0))
|
|
|
|
return torch.ones(
|
|
len_int
|
|
+ len_float
|
|
+ len_self_int
|
|
+ len_self_float
|
|
+ len_float_int
|
|
+ len_int_float
|
|
)
|
|
|
|
inp = (torch.randn(10, 2), torch.randn(5))
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_aten___getitem___list(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.split(x, 2)
|
|
return y[0]
|
|
|
|
inp = (torch.rand((3, 2)),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_aten___getitem___dict(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.split(x, 2)
|
|
d_int = {0: y[0], 1: y[1]}
|
|
d_str = {"0": y[0], "1": y[1]}
|
|
d_bool = {True: y[0], False: y[1]}
|
|
d_float = {0.1: y[0], 2.3: y[1]}
|
|
return d_int[0], d_str["0"], d_bool[True], d_float[0.1]
|
|
|
|
inp = (torch.rand((3, 2)),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_prim_device(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
device = x.device
|
|
return torch.ones(2, 3, device=device)
|
|
|
|
inp = (torch.rand(3, 4),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
@requires_cuda
|
|
def test_prim_device_cuda(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
device = x.device
|
|
return torch.ones(2, 3, device=device)
|
|
|
|
inp = (torch.rand((3, 4), device="cuda:0"),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_prim_dtype(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
dtype = x.dtype
|
|
return torch.ones(2, 3, dtype=dtype)
|
|
|
|
for dtype in [
|
|
torch.float32,
|
|
torch.double,
|
|
]:
|
|
inp = (torch.rand((3, 4), dtype=dtype),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
for dtype in [
|
|
torch.uint8,
|
|
torch.int8,
|
|
torch.int32,
|
|
]:
|
|
inp = (torch.randint(high=128, size=(3, 4), dtype=dtype),)
|
|
self._check_equal_ts_ep_converter(Module(), inp)
|
|
|
|
def test_convert_if_basic(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
if x:
|
|
return y * y
|
|
else:
|
|
return y + y
|
|
|
|
inp = (torch.tensor(True), torch.tensor(4))
|
|
ep_list = self._check_equal_ts_ep_converter(M(), inp)
|
|
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(torch.tensor(False), torch.tensor(4)),
|
|
M()(torch.tensor(False), torch.tensor(4)),
|
|
)
|
|
|
|
def test_convert_if_tuple_out(self):
|
|
class M(torch.nn.Module):
|
|
def true_fn(self, y, z):
|
|
return (z * z, z + z)
|
|
|
|
def false_fn(self, y, z):
|
|
return (y * y * y, y + y)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
z = y * y
|
|
|
|
if x:
|
|
res = self.true_fn(y, z)
|
|
else:
|
|
res = self.false_fn(y, z)
|
|
|
|
return res[0] + res[1]
|
|
|
|
inp = (torch.tensor(True), torch.tensor(4))
|
|
ep_list = self._check_equal_ts_ep_converter(M(), inp)
|
|
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(torch.tensor(False), torch.tensor(4)),
|
|
M()(torch.tensor(False), torch.tensor(4)),
|
|
)
|
|
|
|
@unittest.skip("Wrong fx subgraph for cond, need to fix")
|
|
def test_convert_if_multiple_out(self):
|
|
class M(torch.nn.Module):
|
|
def true_fn(self, y, z):
|
|
return z * z
|
|
|
|
def false_fn(self, y, z):
|
|
return y * y * y
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
z = y * y
|
|
|
|
if x:
|
|
res1 = self.true_fn(y, z)
|
|
res2 = y
|
|
else:
|
|
res1 = z
|
|
res2 = self.false_fn(y, z)
|
|
|
|
return res1 + res2
|
|
|
|
inp = (torch.tensor(True), torch.tensor(4))
|
|
ep_list = self._check_equal_ts_ep_converter(M(), inp)
|
|
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(torch.tensor(False), torch.tensor(4)),
|
|
M()(torch.tensor(False), torch.tensor(4)),
|
|
)
|
|
|
|
def test_profiler__record_function(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
handle = torch.ops.profiler._record_function_enter_new("foo", None)
|
|
y = x * 2 + 4
|
|
torch.ops.profiler._record_function_exit(handle)
|
|
return y
|
|
|
|
x = torch.randn(10, 10)
|
|
self._check_equal_ts_ep_converter(Module(), (x,))
|
|
|
|
def test_aten_floordiv(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x // 2
|
|
|
|
x = torch.randn(10, 10)
|
|
self._check_equal_ts_ep_converter(Module(), (x,))
|
|
|
|
def test_aten___is__(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(
|
|
self, x: torch.Tensor, y: torch.Tensor
|
|
) -> Tuple[bool, torch.Tensor]:
|
|
z = x + 1
|
|
return x is y, z
|
|
|
|
# Traced function must return output that has tensors.
|
|
inp = (torch.randn(10, 10), torch.rand(10, 10))
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
def test_aten___isnot__(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(
|
|
self, x: torch.Tensor, y: torch.Tensor
|
|
) -> Tuple[bool, torch.Tensor]:
|
|
z = x + 1
|
|
return x is not y, z
|
|
|
|
# Traced function must return output that has tensors.
|
|
inp = (torch.randn(10, 10), torch.rand(10, 10))
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
def test_aten___not__(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(
|
|
self, x: torch.Tensor, y: torch.Tensor
|
|
) -> Tuple[bool, torch.Tensor]:
|
|
z = x + 1
|
|
return not (x is not y), z
|
|
|
|
# Traced function must return output that has tensors.
|
|
inp = (torch.randn(10, 10), torch.rand(10, 10))
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
def test_ts2ep_converter_unpack(self):
|
|
class MUnpackList(torch.nn.Module):
|
|
def forward(self, x):
|
|
x, y = torch.split(x, 2)
|
|
return x + y
|
|
|
|
class MUnpackTuple(torch.nn.Module):
|
|
def forward(self, x_tuple: Tuple[torch.Tensor, torch.Tensor]):
|
|
x, y = x_tuple
|
|
x = x.cos()
|
|
return x + y
|
|
|
|
inp = (torch.ones(4),)
|
|
self._check_equal_ts_ep_converter(MUnpackList(), inp)
|
|
inp = ((torch.zeros(1, 4), torch.ones(1, 4)),)
|
|
self._check_equal_ts_ep_converter(MUnpackTuple(), inp)
|
|
|
|
@unittest.skipIf(
|
|
IS_WINDOWS,
|
|
"torch.cond doesn't go through torch.compile on windows"
|
|
"causing output not normalized as list",
|
|
)
|
|
def test_convert_retrace_nested_scripted_modules(self):
|
|
class Wrapper(torch.nn.Module):
|
|
def __init__(self, mod) -> None:
|
|
super().__init__()
|
|
self.mod = mod
|
|
|
|
def forward(self, x, y):
|
|
return self.mod(x, y)
|
|
|
|
class LinearM(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
|
|
def forward(self, x, y):
|
|
return self.linear(y)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
m = LinearM(dim)
|
|
m = torch.jit.script(m)
|
|
self.mod1 = m
|
|
self.mod2 = Wrapper(m)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
if x:
|
|
return -self.mod1(x, y) - self.mod2(x, y)
|
|
else:
|
|
return -self.mod1(x, y) + self.mod2(x, y)
|
|
|
|
class NestedM(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
m = M(dim)
|
|
m = torch.jit.script(m)
|
|
self.mod1 = m
|
|
self.mod2 = Wrapper(m)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
if x:
|
|
return self.mod1(x, y) + self.mod2(x, y)
|
|
else:
|
|
return self.mod1(x, y) - self.mod2(x, y)
|
|
|
|
inp = (
|
|
torch.tensor(True),
|
|
torch.randn([3, 3]),
|
|
)
|
|
self._check_equal_ts_ep_converter(NestedM(3), inp)
|
|
|
|
def test_convert_nn_module_with_nested_param(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.linear(x)
|
|
|
|
class NestedM(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
self.m = M(dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.linear(self.m(x))
|
|
|
|
class SuperNestedM(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
self.m = NestedM(dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.linear(self.m(x))
|
|
|
|
inp = (torch.ones(3),)
|
|
orig_m = NestedM(3)
|
|
self._check_equal_ts_ep_converter(orig_m, inp)
|
|
orig_m = SuperNestedM(3)
|
|
self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
def test_convert_nn_module_with_nested_buffer(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_buffer("w", torch.randn(1))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.w + x
|
|
|
|
class NestedM(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m = M()
|
|
self.register_buffer("w", torch.randn(1))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.w + self.m(x)
|
|
|
|
class SuperNestedM(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m = NestedM()
|
|
self.register_buffer("w", torch.randn(1))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.w + self.m(x)
|
|
|
|
inp = (torch.ones(1),)
|
|
orig_m = NestedM()
|
|
self._check_equal_ts_ep_converter(orig_m, inp)
|
|
orig_m = SuperNestedM()
|
|
self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
def test_convert_nn_module_with_nested_if_and_buffer(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_buffer("w", torch.randn(1))
|
|
self.count = 1
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.w + x + self.count
|
|
|
|
class NestedM(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = M()
|
|
self.m2 = M()
|
|
self.register_buffer("w", torch.randn(1))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if torch.sum(x) > 1:
|
|
return self.w + self.m1(x)
|
|
else:
|
|
return self.w + self.m2(x)
|
|
|
|
# Super nested, parameters neeed to lifted
|
|
# multiple times.
|
|
class SuperNestedM(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = NestedM()
|
|
self.m2 = NestedM()
|
|
self.register_buffer("w", torch.randn(1))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if torch.max(x) > 1:
|
|
return self.w + self.m1(x)
|
|
else:
|
|
return self.w + self.m2(x)
|
|
|
|
# Super nested module testing.
|
|
inp = (torch.ones(1),)
|
|
orig_m = SuperNestedM()
|
|
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
t = inp[0]
|
|
t -= 1
|
|
for ep in ep_list:
|
|
torch.testing.assert_close(
|
|
ep.module()(*inp),
|
|
orig_m(*inp),
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
IS_WINDOWS,
|
|
"torch.cond doesn't go through torch.compile on windows"
|
|
"causing output not normalized as list",
|
|
)
|
|
def test_convert_nn_module_with_nested_if_and_param(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
return self.linear(x)
|
|
|
|
class NestedM(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.m1 = M(dim)
|
|
self.m2 = M(dim)
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if torch.sum(x) > 1:
|
|
return self.linear(self.m1(x))
|
|
else:
|
|
return self.linear(self.m2(x))
|
|
|
|
# Super nested, parameters neeed to lifted
|
|
# multiple times.
|
|
class SuperNestedM1(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.m1 = NestedM(dim)
|
|
self.m2 = NestedM(dim)
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if torch.max(x) > 1:
|
|
return self.linear(self.m1(x))
|
|
else:
|
|
return self.linear(self.m2(x))
|
|
|
|
# Super nested, even the input needs to be
|
|
# lifted recursively due to value propogation optimiztaion.
|
|
class SuperNestedM2(torch.nn.Module):
|
|
def __init__(self, dim: int) -> None:
|
|
super().__init__()
|
|
self.m1 = NestedM(dim)
|
|
self.m2 = NestedM(dim)
|
|
self.linear = torch.nn.Linear(dim, dim)
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
if torch.sum(x) > 1:
|
|
return self.linear(self.m1(x))
|
|
else:
|
|
return self.linear(self.m2(x))
|
|
|
|
# Basic module testing.
|
|
inp = (torch.ones(3),)
|
|
orig_m = M(3)
|
|
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
t = inp[0]
|
|
t -= 0.8
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(*inp),
|
|
orig_m(*inp),
|
|
)
|
|
|
|
# Nested module testing.
|
|
inp = (torch.ones(3),)
|
|
orig_m = NestedM(3)
|
|
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
t = inp[0]
|
|
t -= 0.8
|
|
# Skip jit.traced because it specializes on one path.
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(*inp),
|
|
orig_m(*inp),
|
|
)
|
|
|
|
# Super nested module testing.
|
|
inp = (torch.ones(3),)
|
|
orig_m = SuperNestedM1(3)
|
|
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
t = inp[0]
|
|
t -= 0.8
|
|
# Skip jit.traced because it specializes on one path.
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(*inp),
|
|
orig_m(*inp),
|
|
)
|
|
|
|
# Super nested module testing.
|
|
inp = (torch.ones(3),)
|
|
orig_m = SuperNestedM2(3)
|
|
ep_list = self._check_equal_ts_ep_converter(orig_m, inp)
|
|
|
|
t = inp[0]
|
|
t -= 0.8
|
|
# Skip jit.traced because it specializes on one path.
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(*inp),
|
|
orig_m(*inp),
|
|
)
|
|
|
|
def test_ts2ep_converter_contains(self):
|
|
class MIn(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x.dtype in [torch.float32, torch.float64]
|
|
|
|
class MNotIn(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor):
|
|
return x.dtype in [torch.int8]
|
|
|
|
class MTensorIn(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, x_dict: Dict[torch.Tensor, str]):
|
|
return x in x_dict
|
|
|
|
# Traced function must return output that has tensors.
|
|
inp = (torch.tensor(4),)
|
|
self._check_equal_ts_ep_converter(MIn(), inp, ["script"])
|
|
self._check_equal_ts_ep_converter(MNotIn(), inp, ["script"])
|
|
|
|
# TODO: update test to use reference for in.
|
|
inp = (torch.tensor(4), {torch.tensor(4): "foo"})
|
|
self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
|
|
inp = (torch.tensor(1), {torch.tensor(4): "foo"})
|
|
self._check_equal_ts_ep_converter(MTensorIn(), inp, ["script"])
|
|
|
|
def test_ts2ep_converter_custom_op(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch._dynamo.config.capture_scalar_outputs = True
|
|
torch._dynamo.config.capture_dynamic_output_shape_ops = True
|
|
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor x) -> Tensor",
|
|
lib=lib,
|
|
)
|
|
|
|
# PyTorch custorm op implementation
|
|
@torch.library.impl(
|
|
"mylib::foo",
|
|
"CompositeExplicitAutograd",
|
|
lib=lib,
|
|
)
|
|
def foo_impl(x):
|
|
return x + x
|
|
|
|
# Meta function of the custom op.
|
|
@torch.library.impl_abstract(
|
|
"mylib::foo",
|
|
lib=lib,
|
|
)
|
|
def foo_meta(x):
|
|
return x + x
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.ops.mylib.foo(x)
|
|
|
|
inp = (torch.randn(3, 3),)
|
|
m = M()
|
|
self._check_equal_ts_ep_converter(m, inp)
|
|
|
|
def test_convert_func_without_param(self):
|
|
def func1(x, y):
|
|
return x + y
|
|
|
|
def func2(x, y):
|
|
if x.sum() > 0:
|
|
return x + y
|
|
else:
|
|
return x - y
|
|
|
|
inp = (
|
|
torch.tensor(1),
|
|
torch.tensor(1),
|
|
)
|
|
self._check_equal_ts_ep_converter(func1, inp)
|
|
|
|
ep_list = self._check_equal_ts_ep_converter(func2, inp)
|
|
|
|
t = inp[0]
|
|
t -= 1
|
|
for ep in ep_list[1:]:
|
|
torch.testing.assert_close(
|
|
ep.module()(*inp),
|
|
func2(*inp),
|
|
)
|
|
|
|
def test_implicit_constant_to_tensor_handling(self):
|
|
def func1(x):
|
|
return x + 2
|
|
|
|
def func2(x, y):
|
|
return x * y / (x - 2 * y) + y
|
|
|
|
def func3(x):
|
|
return x + torch.tensor([3])
|
|
|
|
def func4():
|
|
val = torch.tensor(float("inf"))
|
|
return torch.full((10, 10), val)
|
|
|
|
def func5():
|
|
x = -1
|
|
return x * torch.ones(1, dtype=torch.float), torch.zeros(
|
|
1, dtype=torch.float
|
|
)
|
|
|
|
def func6(x1, x2, x3, x4):
|
|
return (
|
|
x1.numel(),
|
|
x1.size(),
|
|
x2.numel(),
|
|
x2.size(),
|
|
x3.numel(),
|
|
x3.size(),
|
|
x4.numel(),
|
|
x4.size(),
|
|
torch.ones(x1.numel()), # Just make sure downstream ops still work.
|
|
torch.ones(x1.size()), # Just make sure downstream ops still work.
|
|
)
|
|
|
|
class M1(torch.nn.Module):
|
|
def __init__(self, value):
|
|
super().__init__()
|
|
self.x = torch.tensor(value)
|
|
|
|
def forward(self):
|
|
return self.x.clone()
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.tensor(4) + x
|
|
|
|
inp = (torch.randn([2, 2]),)
|
|
self._check_equal_ts_ep_converter(func1, inp)
|
|
inp = (torch.randn([2, 2]), torch.randn([2, 2]))
|
|
self._check_equal_ts_ep_converter(func2, inp)
|
|
|
|
inp = (torch.randn([2, 2]),)
|
|
self._check_equal_ts_ep_converter(func3, inp)
|
|
|
|
self._check_equal_ts_ep_converter(func4, ())
|
|
self._check_equal_ts_ep_converter(M1(5), ())
|
|
|
|
inp = (torch.randn(2),)
|
|
self._check_equal_ts_ep_converter(M2(), inp)
|
|
|
|
self._check_equal_ts_ep_converter(func5, ())
|
|
inp = (
|
|
torch.randn([2, 3, 4]).to(torch.int8),
|
|
torch.randn([2, 3, 4]).to(torch.int32),
|
|
torch.randn([2, 3, 4]).to(torch.float32),
|
|
torch.randn([2, 3, 4]).to(torch.float64),
|
|
)
|
|
ep_list = self._check_equal_ts_ep_converter(func6, inp)
|
|
|
|
# TODO: Additional check once dynamic shape is supported.
|
|
# for ep in ep_list:
|
|
# self.assertEqual(
|
|
# ep.module()(
|
|
# torch.randn([1, 1, 1]).to(torch.int8),
|
|
# torch.randn([1, 1, 1]).to(torch.int32),
|
|
# torch.randn([1, 1, 1]).to(torch.float32),
|
|
# torch.randn([1, 1, 1]).to(torch.float64),
|
|
# )[0], 1
|
|
# )
|
|
|
|
def test_aten_tensor_dtype_int(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.tensor(1, dtype=torch.int32)
|
|
return y + x
|
|
|
|
ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
|
|
for ep in ep_list:
|
|
self.assertEqual(len(ep.constants), 1)
|
|
|
|
def test_aten_tensor_prim_dtype(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = torch.tensor(1, dtype=x.dtype)
|
|
return y + x
|
|
|
|
ep_list = self._check_equal_ts_ep_converter(M(), (torch.tensor(1),))
|
|
for ep in ep_list:
|
|
self.assertEqual(len(ep.constants), 1)
|
|
|
|
def test_aten_tensor_dynamic(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
s = x.shape[0]
|
|
y = torch.tensor(s)
|
|
return y
|
|
|
|
ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
|
|
for ep in ep_list:
|
|
self.assertEqual(len(ep.constants), 0)
|
|
|
|
# TODO: Additional check once dynamic shape is supported.
|
|
# for ep in ep_list:
|
|
# torch.testing.assert_close(
|
|
# ep.module()(torch.ones(4)),
|
|
# M()(torch.ones(4)),
|
|
# )
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
s = x.shape[0]
|
|
y = torch.tensor([s, s * 2, 1])
|
|
return y
|
|
|
|
ep_list = self._check_equal_ts_ep_converter(M(), (torch.ones(3),))
|
|
# Trace directly inline a tensor constant.
|
|
for ep in ep_list[1:]:
|
|
self.assertEqual(len(ep.constants), 0)
|
|
|
|
# TODO: Additional check once dynamic shape is supported.
|
|
# for ep in ep_list:
|
|
# torch.testing.assert_close(
|
|
# ep.module()(torch.ones(4)),
|
|
# M()(torch.ones(4)),
|
|
# )
|
|
|
|
def test_prim_tolist(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> List[int]:
|
|
return x.tolist()
|
|
|
|
inp = (torch.tensor([1, 2, 3]),)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> List[List[int]]:
|
|
return x.tolist()
|
|
|
|
inp = (torch.tensor([[1, 2, 3], [4, 5, 6]]),)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
def test_get_tensor_constants(self):
|
|
# Since self.data is only read but not written, it is lifted as
|
|
# constant tensors.
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.data = torch.randn(3, 2)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.data
|
|
|
|
class Goo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.data = torch.randn(3, 2)
|
|
self.foo = Foo()
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
return x + self.data + self.foo.data + self.foo(x)
|
|
|
|
inp = (torch.randn(3, 2),)
|
|
goo = Goo()
|
|
self._check_equal_ts_ep_converter(goo, inp)
|
|
|
|
def test_prim_SetAttr(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("data", torch.ones(3, 2))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self.data = self.data + x
|
|
return x + x
|
|
|
|
inp = (torch.ones(3, 2),)
|
|
self._check_equal_ts_ep_converter(
|
|
Module, inp, ["script"], check_persistent=True
|
|
)
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("data", torch.ones(3, 2))
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self.data = self.data + x
|
|
return x + self.data
|
|
|
|
inp = (torch.ones(3, 2),)
|
|
self._check_equal_ts_ep_converter(
|
|
Module, inp, ["script"], check_persistent=True
|
|
)
|
|
|
|
# export lifts a tensor constant (self.data) as an input if it is not assigned.
|
|
# If it is assigned, export will error and ask users to register it as a buffer.
|
|
# In converter, we change tensor constants that are assigned as a buffer automatically,
|
|
# since it might be hard to manually register them as buffers.
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.data = torch.ones(3, 2)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self.data = self.data + x
|
|
return x + self.data
|
|
|
|
inp = (torch.ones(3, 2),)
|
|
self._check_equal_ts_ep_converter(
|
|
Module,
|
|
inp,
|
|
["script"],
|
|
check_persistent=True,
|
|
lifted_tensor_constants=OrderedDict([("data", torch.ones(3, 2))]),
|
|
)
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.count = 0
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
self.count += 1
|
|
return x + self.count
|
|
|
|
# check_persistent is False since export specializes on non-tensor constants
|
|
inp = (torch.ones(3, 2),)
|
|
self._check_equal_ts_ep_converter(
|
|
Module(), inp, ["script"], check_persistent=False
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.count = 0
|
|
|
|
def forward(self, x):
|
|
count1 = self.count
|
|
self.count += 1
|
|
count2 = self.count
|
|
self.count += 1
|
|
count3 = self.count
|
|
return x + count1 + count2 + count3
|
|
|
|
inp = (torch.ones(1),)
|
|
self._check_equal_ts_ep_converter(M(), inp, ["script"], check_persistent=False)
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_buffer("w2", torch.ones(1))
|
|
|
|
def forward(self, x: torch.Tensor):
|
|
self.w2 += 1
|
|
return self.w2
|
|
|
|
inp = (torch.ones(1),)
|
|
self._check_equal_ts_ep_converter(M, inp, ["script"], check_persistent=True)
|
|
|
|
def test_raise_exception(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
|
|
if y > 0:
|
|
raise RuntimeError("test")
|
|
return x + y
|
|
|
|
# match non-strict export behavior that errors when the given input leads to
|
|
# RaiseException.
|
|
with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
|
|
inp = (torch.randn(3, 2), 1)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
# Matching non-strict export behavior that only executes 1 if-branch according
|
|
# to the given input.
|
|
inp = (torch.randn(3, 2), 0)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: int) -> torch.Tensor:
|
|
z = x
|
|
if y > 0:
|
|
raise RuntimeError("test")
|
|
# z = x
|
|
else:
|
|
z = x + y
|
|
return x + y + z
|
|
|
|
# match non-strict export behavior that errors when the given input leads to
|
|
# RaiseException.
|
|
with self.assertRaisesRegex(torch.jit.Error, "builtins.RuntimeError"):
|
|
inp = (torch.randn(3, 2), 1)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
# Matching non-strict export behavior that only executes 1 if-branch according
|
|
# to the given input.
|
|
inp = (torch.randn(3, 2), 0)
|
|
self._check_equal_ts_ep_converter(Module(), inp, ["script"])
|
|
|
|
def test_context_manager(self):
|
|
class ContextManager:
|
|
def __init__(self):
|
|
self.count = 0
|
|
return
|
|
|
|
def __enter__(self):
|
|
self.count += 1
|
|
return
|
|
|
|
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
|
self.count -= 1
|
|
return
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
with ContextManager():
|
|
res = x + y
|
|
return res
|
|
|
|
inp = (torch.ones(3, 3), torch.ones(3, 3))
|
|
self._check_equal_ts_ep_converter(M(), inp)
|
|
|
|
def test_hidden_input_name(self):
|
|
@torch.jit.script
|
|
def func1(x):
|
|
return x + 1
|
|
|
|
def func2(*args):
|
|
v = torch.cat(args, dim=1)
|
|
return v * v
|
|
|
|
inp = (torch.randn([1, 1]),)
|
|
self._check_equal_ts_ep_converter(func1, inp)
|
|
|
|
inp = (torch.ones(5, 5),)
|
|
# Cannot script again.
|
|
self._check_equal_ts_ep_converter(torch.ops.aten.relu, inp, ["trace"])
|
|
|
|
M = 2
|
|
Ns = [4, 2, 1]
|
|
empty = torch.tensor([], dtype=torch.double)
|
|
values = [empty] + [torch.randn(M, N) for N in Ns]
|
|
# Cannot script variable length inputs.
|
|
self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"])
|
|
|
|
def test_ts2ep_multi_outputs_on_call_ops(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True)
|
|
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
|
return (
|
|
torch.max(x, dim=0),
|
|
torch.topk(x, 3),
|
|
torch.sort(x, dim=0),
|
|
self.pool(y),
|
|
)
|
|
|
|
inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10]))
|
|
self._check_equal_ts_ep_converter(M(), inp)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|