mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
It is better to have the new tracer as global config that can be manipulated easily. Also I believe dynamo-like config infra is useful instead of relying on custom way of patching stuff. Differential Revision: [D82478649](https://our.internmc.facebook.com/intern/diff/D82478649) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162558 Approved by: https://github.com/zhxchen17 ghstack dependencies: #162557
448 lines
14 KiB
Python
448 lines
14 KiB
Python
# Owner(s): ["oncall: export"]
|
|
# flake8: noqa
|
|
import unittest
|
|
from dataclasses import dataclass
|
|
from typing import Any, List
|
|
|
|
from parameterized import parameterized_class
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
from torch import Tensor
|
|
from torch._export import config
|
|
from torch._export.utils import register_dataclass_as_pytree_node
|
|
from torch.export import export, register_dataclass
|
|
from torch.export._swap import _swap_modules
|
|
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
|
|
@parameterized_class(
|
|
[
|
|
{"strict": False},
|
|
{"strict": True},
|
|
],
|
|
class_name_func=lambda cls,
|
|
_,
|
|
params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
|
|
)
|
|
class TestSwap(TestCase):
|
|
def test_unflatten_preserve_signature(self):
|
|
class NestedChild(torch.nn.Module):
|
|
def forward(self, zx, y):
|
|
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = NestedChild()
|
|
|
|
def forward(self, x, y):
|
|
z = torch.ones_like(x)
|
|
xw = self.nested((z, x), y={"key": y})
|
|
return xw["w"] + z - xw["x"]
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x - 1
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
|
|
def forward(self, x, y):
|
|
x = self.foo(x, y)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inps = torch.rand(2, 3), torch.rand(2, 3)
|
|
|
|
ep = export(
|
|
orig_eager,
|
|
inps,
|
|
{},
|
|
preserve_module_call_signature=("foo.nested", "bar"),
|
|
strict=self.strict,
|
|
)
|
|
|
|
swapped_gm = _swap_modules(
|
|
ep,
|
|
{"foo.nested": NestedChild(), "bar": Child2()},
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
|
|
def test_unflatten_preserve_with_unused_input(self):
|
|
class M1(torch.nn.Module):
|
|
def forward(self, x, a, b):
|
|
return x + a, b
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = M1()
|
|
|
|
def forward(self, x, y):
|
|
a, b = torch.topk(y, 2)
|
|
return self.m1(x, a, b)[0]
|
|
|
|
ep = torch.export.export(
|
|
M(),
|
|
(torch.randn(2), torch.randn(5)),
|
|
preserve_module_call_signature=("m1",),
|
|
strict=self.strict,
|
|
)
|
|
|
|
swapped_gm = _swap_modules(
|
|
ep,
|
|
{"m1": M1()},
|
|
)
|
|
|
|
inps = (torch.randn(2), torch.randn(5))
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
|
|
def test_nested_leaf(self):
|
|
class Leaf(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + 1
|
|
|
|
class Nested(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.leaf = Leaf()
|
|
|
|
def forward(self, x):
|
|
return self.leaf(x) + 2
|
|
|
|
class TopLevel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nested = Nested()
|
|
|
|
def forward(self, x):
|
|
return self.nested(x) + 3
|
|
|
|
ep = torch.export.export(
|
|
TopLevel(),
|
|
(torch.randn(3),),
|
|
strict=self.strict,
|
|
preserve_module_call_signature=("nested",),
|
|
)
|
|
|
|
swapped_gm = _swap_modules(
|
|
ep,
|
|
{"nested": Nested()},
|
|
)
|
|
|
|
inps = (torch.randn(3),)
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
|
|
def test_dedup_sym_size(self):
|
|
# Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
|
|
# but only one copy of sym_size is created in the initial export graph.
|
|
# For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
|
|
# but for m2 floordiv should be passed in as a placeholder.
|
|
# Test that this is preserved, and the unflattened module runs correctly.
|
|
class M1(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
d = x.size(0) // 2
|
|
return y[:d]
|
|
|
|
class M2(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
d = x.size(0) // 2
|
|
return y[:d]
|
|
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.m1 = M1()
|
|
self.m2 = M2()
|
|
|
|
def forward(self, x, y):
|
|
d = x.size(0) // 2
|
|
m1_res = self.m1(x, y)
|
|
m2_res = self.m2(x, y)
|
|
return y[d:] + m1_res + m2_res
|
|
|
|
inputs = (torch.ones(10), torch.ones(10))
|
|
d_ = torch.export.Dim("foo", max=2048)
|
|
d = 2 * d_
|
|
ep = torch.export.export(
|
|
M(),
|
|
inputs,
|
|
dynamic_shapes=((d,), (d,)),
|
|
strict=self.strict,
|
|
preserve_module_call_signature=("m1",),
|
|
)
|
|
|
|
swapped_gm = _swap_modules(
|
|
ep,
|
|
{"m1": M1()},
|
|
)
|
|
|
|
inps = (torch.randn(10), torch.randn(10))
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
|
|
inps = (torch.randn(20), torch.randn(20))
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
|
|
def test_remove_duplicate_pytree_simple(self):
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
z = torch.ones_like(x)
|
|
w = y + z[1]
|
|
x = y * z[1]
|
|
return {"res1": x + y, "res2": x * y}
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x["res2"] + x["res1"] - 1
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
|
|
def forward(self, x, y):
|
|
x = self.foo(x, y)
|
|
x = self.bar(x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inps = torch.rand(2, 3), torch.rand(2, 3)
|
|
|
|
ep = export(
|
|
orig_eager,
|
|
inps,
|
|
{},
|
|
preserve_module_call_signature=("foo", "bar"),
|
|
strict=self.strict,
|
|
)
|
|
|
|
swapped_gm = _swap_modules(
|
|
ep,
|
|
{"foo": Child1(), "bar": Child2()},
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
self.assertExpectedInline(
|
|
swapped_gm.code.strip(),
|
|
"""\
|
|
def forward(self, x, y):
|
|
x_1 = x
|
|
y_1 = y
|
|
_spec_0 = self._spec_0
|
|
_spec_1 = self._spec_1
|
|
_spec_4 = self._spec_4
|
|
tree_flatten = torch.utils._pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
|
|
getitem = tree_flatten[0]; tree_flatten = None
|
|
x = getitem[0]
|
|
y = getitem[1]; getitem = None
|
|
tree_unflatten_1 = torch.utils._pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
|
|
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
|
|
getitem_2 = getitem_1[0]
|
|
getitem_3 = getitem_1[1]; getitem_1 = None
|
|
foo = self.foo(getitem_2, getitem_3); getitem_2 = getitem_3 = None
|
|
bar = self.bar(foo); foo = None
|
|
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
|
|
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
|
|
tree_unflatten = torch.utils._pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
|
|
return tree_unflatten""",
|
|
)
|
|
|
|
@unittest.expectedFailure
|
|
def test_remove_duplicate_pytree_different_order(self):
|
|
"""
|
|
This is not supported yet because module `foo`s outputs are not all
|
|
directly used in as inputs to `bar` in the same order as outputted from
|
|
`foo`. To support this, we would have to do some sort of ordering.
|
|
"""
|
|
|
|
class Child1(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
return {"res1": x + y}, {"res2": x * y, "res3": x * x}
|
|
|
|
class Child2(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, y, x):
|
|
y = y["res2"] * y["res3"]
|
|
x = x["res1"] + x["res1"]
|
|
return y - x
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.foo = Child1()
|
|
self.bar = Child2()
|
|
|
|
def forward(self, x, y):
|
|
x, y = self.foo(x, y)
|
|
x = self.bar(y, x)
|
|
return x
|
|
|
|
orig_eager = MyModule()
|
|
inps = torch.rand(2, 3), torch.rand(2, 3)
|
|
|
|
ep = export(
|
|
orig_eager,
|
|
inps,
|
|
{},
|
|
preserve_module_call_signature=("foo", "bar"),
|
|
strict=self.strict,
|
|
)
|
|
|
|
swapped_gm = _swap_modules(
|
|
ep,
|
|
{"foo": Child1(), "bar": Child2()},
|
|
)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
|
|
self.assertExpectedInline(
|
|
swapped_gm.code.strip(),
|
|
"""\
|
|
def forward(self, x, y):
|
|
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
|
|
_spec_0 = self._spec_0
|
|
_spec_3 = self._spec_3
|
|
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
|
|
getitem = tree_unflatten[0]; tree_unflatten = None
|
|
getitem_1 = getitem[0]
|
|
getitem_2 = getitem[1]; getitem = None
|
|
foo = self.foo(getitem_1, getitem_2); getitem_1 = getitem_2 = None
|
|
getitem_3 = foo[0]
|
|
getitem_4 = foo[1];
|
|
bar = self.bar(getitem_4, getitem_3); foo = None
|
|
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_3); bar = _spec_3 = None
|
|
getitem_9 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
|
|
return pytree.tree_unflatten((getitem_9,), self._out_spec)""",
|
|
)
|
|
|
|
def test_custom_input_args(self):
|
|
@dataclass
|
|
class CustomInput:
|
|
a: Tensor
|
|
b: Tensor
|
|
|
|
register_dataclass_as_pytree_node(
|
|
CustomInput,
|
|
serialized_type_name="test_swap.test_custom_input.CustomInput",
|
|
)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, inputs):
|
|
return torch.matmul(inputs.a, inputs.b)
|
|
|
|
ep = export(
|
|
Foo(),
|
|
(CustomInput(torch.randn(2, 3), torch.randn(3, 2)),),
|
|
strict=self.strict,
|
|
)
|
|
swapped = _swap_modules(ep, {})
|
|
inp = (CustomInput(torch.randn(2, 3), torch.randn(3, 2)),)
|
|
res1 = torch.fx.Interpreter(swapped).run(*inp)
|
|
res2 = swapped(*inp)
|
|
self.assertTrue(torch.allclose(res1, res2))
|
|
|
|
def test_custom_input_kwargs(self):
|
|
@dataclass
|
|
class CustomInput:
|
|
a: Tensor
|
|
b: Tensor
|
|
|
|
register_dataclass(
|
|
CustomInput,
|
|
serialized_type_name="test_swap.test_custom_input.CustomInput",
|
|
)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, *, inputs):
|
|
return x + torch.matmul(inputs.a, inputs.b)
|
|
|
|
for use_new_tracer in [True, False]:
|
|
with config.patch(use_new_tracer_experimental=use_new_tracer):
|
|
ep = export(
|
|
Foo(),
|
|
(torch.randn(2, 2),),
|
|
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
|
|
strict=self.strict,
|
|
)
|
|
swapped = _swap_modules(ep, {})
|
|
inp_args = (torch.randn(2, 2),)
|
|
inp_kwargs = {"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))}
|
|
res1 = torch.fx.Interpreter(swapped).run(*(*inp_args, *inp_kwargs.values()))
|
|
res2 = swapped(*inp_args, **inp_kwargs)
|
|
self.assertTrue(torch.allclose(res1, res2))
|
|
|
|
def test_custom_input_kwargs_use_private(self):
|
|
@dataclass
|
|
class CustomInput:
|
|
a: Tensor
|
|
b: Tensor
|
|
|
|
register_dataclass_as_pytree_node(
|
|
CustomInput,
|
|
serialized_type_name="test_swap.test_custom_input.CustomInput",
|
|
)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, *, inputs):
|
|
return x + torch.matmul(inputs.a, inputs.b)
|
|
|
|
# shouldn't error
|
|
with config.patch(use_new_tracer_experimental=True):
|
|
_ = export(
|
|
Foo(),
|
|
(torch.randn(2, 2),),
|
|
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
|
|
strict=self.strict,
|
|
)
|
|
|
|
def test_custom_output(self):
|
|
@dataclass
|
|
class CustomOutput:
|
|
a: Tensor
|
|
b: Tensor
|
|
|
|
register_dataclass_as_pytree_node(
|
|
CustomOutput,
|
|
serialized_type_name="test_swap.test_custom_input.CustomInput",
|
|
)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, a, b):
|
|
return (CustomOutput(a * a, b * b), CustomOutput(a * b.T, a + b.T))
|
|
|
|
ep = export(Foo(), (torch.randn(2, 3), torch.randn(3, 2)), strict=True)
|
|
swapped = _swap_modules(ep, {})
|
|
inp = (torch.randn(2, 3), torch.randn(3, 2))
|
|
res1 = torch.fx.Interpreter(swapped).run(*inp)
|
|
res2 = swapped(*inp)
|
|
self.assertTrue(torch.allclose(res1[0].a, res2[0].a))
|
|
self.assertTrue(torch.allclose(res1[0].b, res2[0].b))
|
|
self.assertTrue(torch.allclose(res1[1].a, res2[1].a))
|
|
self.assertTrue(torch.allclose(res1[1].b, res2[1].b))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|