Files
pytorch/test/export/test_swap.py
Tugsbayasgalan Manlaibaatar 0e9f9c3a61 Fix inconsistent test and add new tracer as config (#162558)
It is better to have the new tracer as global config that can be manipulated easily. Also I believe dynamo-like config infra is useful instead of relying on custom way of patching stuff.

Differential Revision: [D82478649](https://our.internmc.facebook.com/intern/diff/D82478649)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162558
Approved by: https://github.com/zhxchen17
ghstack dependencies: #162557
2025-09-17 17:01:48 +00:00

448 lines
14 KiB
Python

# Owner(s): ["oncall: export"]
# flake8: noqa
import unittest
from dataclasses import dataclass
from typing import Any, List
from parameterized import parameterized_class
import torch
import torch._dynamo as torchdynamo
from torch import Tensor
from torch._export import config
from torch._export.utils import register_dataclass_as_pytree_node
from torch.export import export, register_dataclass
from torch.export._swap import _swap_modules
from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo isn't support")
@parameterized_class(
[
{"strict": False},
{"strict": True},
],
class_name_func=lambda cls,
_,
params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}",
)
class TestSwap(TestCase):
def test_unflatten_preserve_signature(self):
class NestedChild(torch.nn.Module):
def forward(self, zx, y):
return {"x": y["key"] + zx[1], "w": y["key"] * zx[1]}
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = NestedChild()
def forward(self, x, y):
z = torch.ones_like(x)
xw = self.nested((z, x), y={"key": y})
return xw["w"] + z - xw["x"]
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x - 1
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x = self.foo(x, y)
x = self.bar(x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
ep = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo.nested", "bar"),
strict=self.strict,
)
swapped_gm = _swap_modules(
ep,
{"foo.nested": NestedChild(), "bar": Child2()},
)
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
def test_unflatten_preserve_with_unused_input(self):
class M1(torch.nn.Module):
def forward(self, x, a, b):
return x + a, b
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M1()
def forward(self, x, y):
a, b = torch.topk(y, 2)
return self.m1(x, a, b)[0]
ep = torch.export.export(
M(),
(torch.randn(2), torch.randn(5)),
preserve_module_call_signature=("m1",),
strict=self.strict,
)
swapped_gm = _swap_modules(
ep,
{"m1": M1()},
)
inps = (torch.randn(2), torch.randn(5))
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
def test_nested_leaf(self):
class Leaf(torch.nn.Module):
def forward(self, x):
return x + 1
class Nested(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = Leaf()
def forward(self, x):
return self.leaf(x) + 2
class TopLevel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.nested = Nested()
def forward(self, x):
return self.nested(x) + 3
ep = torch.export.export(
TopLevel(),
(torch.randn(3),),
strict=self.strict,
preserve_module_call_signature=("nested",),
)
swapped_gm = _swap_modules(
ep,
{"nested": Nested()},
)
inps = (torch.randn(3),)
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
def test_dedup_sym_size(self):
# Here, sym_size & floor div are used in 3 subgraphs (top-level, m1, m2),
# but only one copy of sym_size is created in the initial export graph.
# For m1, sym_size & floordiv should be copied as recompute since we preserve the call signature,
# but for m2 floordiv should be passed in as a placeholder.
# Test that this is preserved, and the unflattened module runs correctly.
class M1(torch.nn.Module):
def forward(self, x, y):
d = x.size(0) // 2
return y[:d]
class M2(torch.nn.Module):
def forward(self, x, y):
d = x.size(0) // 2
return y[:d]
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.m1 = M1()
self.m2 = M2()
def forward(self, x, y):
d = x.size(0) // 2
m1_res = self.m1(x, y)
m2_res = self.m2(x, y)
return y[d:] + m1_res + m2_res
inputs = (torch.ones(10), torch.ones(10))
d_ = torch.export.Dim("foo", max=2048)
d = 2 * d_
ep = torch.export.export(
M(),
inputs,
dynamic_shapes=((d,), (d,)),
strict=self.strict,
preserve_module_call_signature=("m1",),
)
swapped_gm = _swap_modules(
ep,
{"m1": M1()},
)
inps = (torch.randn(10), torch.randn(10))
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
inps = (torch.randn(20), torch.randn(20))
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
def test_remove_duplicate_pytree_simple(self):
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
z = torch.ones_like(x)
w = y + z[1]
x = y * z[1]
return {"res1": x + y, "res2": x * y}
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x):
return x["res2"] + x["res1"] - 1
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x = self.foo(x, y)
x = self.bar(x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
ep = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo", "bar"),
strict=self.strict,
)
swapped_gm = _swap_modules(
ep,
{"foo": Child1(), "bar": Child2()},
)
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
self.assertExpectedInline(
swapped_gm.code.strip(),
"""\
def forward(self, x, y):
x_1 = x
y_1 = y
_spec_0 = self._spec_0
_spec_1 = self._spec_1
_spec_4 = self._spec_4
tree_flatten = torch.utils._pytree.tree_flatten((x_1, y_1)); x_1 = y_1 = None
getitem = tree_flatten[0]; tree_flatten = None
x = getitem[0]
y = getitem[1]; getitem = None
tree_unflatten_1 = torch.utils._pytree.tree_unflatten([x, y], _spec_1); x = y = _spec_1 = None
getitem_1 = tree_unflatten_1[0]; tree_unflatten_1 = None
getitem_2 = getitem_1[0]
getitem_3 = getitem_1[1]; getitem_1 = None
foo = self.foo(getitem_2, getitem_3); getitem_2 = getitem_3 = None
bar = self.bar(foo); foo = None
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_4); bar = _spec_4 = None
getitem_10 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
tree_unflatten = torch.utils._pytree.tree_unflatten((getitem_10,), _spec_0); getitem_10 = _spec_0 = None
return tree_unflatten""",
)
@unittest.expectedFailure
def test_remove_duplicate_pytree_different_order(self):
"""
This is not supported yet because module `foo`s outputs are not all
directly used in as inputs to `bar` in the same order as outputted from
`foo`. To support this, we would have to do some sort of ordering.
"""
class Child1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, x, y):
return {"res1": x + y}, {"res2": x * y, "res3": x * x}
class Child2(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(self, y, x):
y = y["res2"] * y["res3"]
x = x["res1"] + x["res1"]
return y - x
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Child1()
self.bar = Child2()
def forward(self, x, y):
x, y = self.foo(x, y)
x = self.bar(y, x)
return x
orig_eager = MyModule()
inps = torch.rand(2, 3), torch.rand(2, 3)
ep = export(
orig_eager,
inps,
{},
preserve_module_call_signature=("foo", "bar"),
strict=self.strict,
)
swapped_gm = _swap_modules(
ep,
{"foo": Child1(), "bar": Child2()},
)
self.assertTrue(torch.allclose(ep.module()(*inps), swapped_gm(*inps)))
self.assertExpectedInline(
swapped_gm.code.strip(),
"""\
def forward(self, x, y):
x, y, = fx_pytree.tree_flatten_spec(([x, y], {}), self._in_spec)
_spec_0 = self._spec_0
_spec_3 = self._spec_3
tree_unflatten = torch.utils._pytree.tree_unflatten([x, y], _spec_0); x = y = _spec_0 = None
getitem = tree_unflatten[0]; tree_unflatten = None
getitem_1 = getitem[0]
getitem_2 = getitem[1]; getitem = None
foo = self.foo(getitem_1, getitem_2); getitem_1 = getitem_2 = None
getitem_3 = foo[0]
getitem_4 = foo[1];
bar = self.bar(getitem_4, getitem_3); foo = None
tree_flatten_spec_1 = torch.fx._pytree.tree_flatten_spec(bar, _spec_3); bar = _spec_3 = None
getitem_9 = tree_flatten_spec_1[0]; tree_flatten_spec_1 = None
return pytree.tree_unflatten((getitem_9,), self._out_spec)""",
)
def test_custom_input_args(self):
@dataclass
class CustomInput:
a: Tensor
b: Tensor
register_dataclass_as_pytree_node(
CustomInput,
serialized_type_name="test_swap.test_custom_input.CustomInput",
)
class Foo(torch.nn.Module):
def forward(self, inputs):
return torch.matmul(inputs.a, inputs.b)
ep = export(
Foo(),
(CustomInput(torch.randn(2, 3), torch.randn(3, 2)),),
strict=self.strict,
)
swapped = _swap_modules(ep, {})
inp = (CustomInput(torch.randn(2, 3), torch.randn(3, 2)),)
res1 = torch.fx.Interpreter(swapped).run(*inp)
res2 = swapped(*inp)
self.assertTrue(torch.allclose(res1, res2))
def test_custom_input_kwargs(self):
@dataclass
class CustomInput:
a: Tensor
b: Tensor
register_dataclass(
CustomInput,
serialized_type_name="test_swap.test_custom_input.CustomInput",
)
class Foo(torch.nn.Module):
def forward(self, x, *, inputs):
return x + torch.matmul(inputs.a, inputs.b)
for use_new_tracer in [True, False]:
with config.patch(use_new_tracer_experimental=use_new_tracer):
ep = export(
Foo(),
(torch.randn(2, 2),),
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
strict=self.strict,
)
swapped = _swap_modules(ep, {})
inp_args = (torch.randn(2, 2),)
inp_kwargs = {"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))}
res1 = torch.fx.Interpreter(swapped).run(*(*inp_args, *inp_kwargs.values()))
res2 = swapped(*inp_args, **inp_kwargs)
self.assertTrue(torch.allclose(res1, res2))
def test_custom_input_kwargs_use_private(self):
@dataclass
class CustomInput:
a: Tensor
b: Tensor
register_dataclass_as_pytree_node(
CustomInput,
serialized_type_name="test_swap.test_custom_input.CustomInput",
)
class Foo(torch.nn.Module):
def forward(self, x, *, inputs):
return x + torch.matmul(inputs.a, inputs.b)
# shouldn't error
with config.patch(use_new_tracer_experimental=True):
_ = export(
Foo(),
(torch.randn(2, 2),),
{"inputs": CustomInput(torch.randn(2, 3), torch.randn(3, 2))},
strict=self.strict,
)
def test_custom_output(self):
@dataclass
class CustomOutput:
a: Tensor
b: Tensor
register_dataclass_as_pytree_node(
CustomOutput,
serialized_type_name="test_swap.test_custom_input.CustomInput",
)
class Foo(torch.nn.Module):
def forward(self, a, b):
return (CustomOutput(a * a, b * b), CustomOutput(a * b.T, a + b.T))
ep = export(Foo(), (torch.randn(2, 3), torch.randn(3, 2)), strict=True)
swapped = _swap_modules(ep, {})
inp = (torch.randn(2, 3), torch.randn(3, 2))
res1 = torch.fx.Interpreter(swapped).run(*inp)
res2 = swapped(*inp)
self.assertTrue(torch.allclose(res1[0].a, res2[0].a))
self.assertTrue(torch.allclose(res1[0].b, res2[0].b))
self.assertTrue(torch.allclose(res1[1].a, res2[1].a))
self.assertTrue(torch.allclose(res1[1].b, res2[1].b))
if __name__ == "__main__":
run_tests()