graph module retracing without preserving MCS (#143676)

Retracing while preserving module call signatures used to be a problem because graph modules don't have submodules at given paths. This led to a number of failing retracebility tests. By not trying to wrap modules with export tracepoints we can pass most of these tests; the only exception is where you do module swapping on retraced programs, which is still not possible.

Differential Revision: [D67539304](https://our.internmc.facebook.com/intern/diff/D67539304/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143676
Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan
ghstack dependencies: #143664
This commit is contained in:
Avik Chaudhuri
2024-12-20 13:53:31 -08:00
committed by PyTorch MergeBot
parent d7e59c2f85
commit 51eacea8c4
3 changed files with 81 additions and 102 deletions

View File

@ -4038,7 +4038,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
return self.n(x + 1, True) + self.n(x + 1, False)
x = torch.zeros(4)
types = {} if is_retracebility_test(self._testMethodName) else {"n": N}
types = {"n": N}
ep = export(
M(),
(x,),
@ -7722,14 +7722,11 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -7787,14 +7784,11 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -7851,14 +7845,11 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -7934,15 +7925,12 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -8060,17 +8048,14 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -8248,20 +8233,17 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
"n1.n2.n3.n4.n5.n6.n7",
"n1.n2.n3.n4.n5.n6.n7.n8",
"n1.n2.n3.n4.n5.n6.n7.n8.n9",
)
fqns = (
"n1",
"n1.n2",
"n1.n2.n3",
"n1.n2.n3.n4",
"n1.n2.n3.n4.n5",
"n1.n2.n3.n4.n5.n6",
"n1.n2.n3.n4.n5.n6.n7",
"n1.n2.n3.n4.n5.n6.n7.n8",
"n1.n2.n3.n4.n5.n6.n7.n8.n9",
)
ep = export(
N0(),
inp,
@ -8307,13 +8289,10 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
)
fqns = (
"n1",
"n1.n2",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -8354,13 +8333,10 @@ graph():
inp = (torch.ones(1),)
eager = N0()(*inp)
if is_retracebility_test(self._testMethodName):
fqns = ()
else:
fqns = (
"n1",
"n1.n2",
)
fqns = (
"n1",
"n1.n2",
)
ep = export(N0(), inp, preserve_module_call_signature=fqns)
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -8451,6 +8427,7 @@ graph():
self.assertTrue(torch.allclose(unflattened_result, eager_result))
if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
test(
export(Mod(), inp, preserve_module_call_signature=(path_n,)),
swap={path_n: N()},
@ -8484,6 +8461,7 @@ graph():
eager_result = m(*inp)
if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
ep = export(M(), inp, preserve_module_call_signature=("n",))
epm = ep.module()
ufm = torch.export.unflatten(ep)
@ -8535,18 +8513,17 @@ graph():
unflattened_result = ufm(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
if not is_retracebility_test(self._testMethodName):
if is_training_ir_test(self._testMethodName):
test(
torch.export.export_for_training(
M(),
inp,
strict=not is_non_strict_test(self._testMethodName),
preserve_module_call_signature=("n",),
)
if is_training_ir_test(self._testMethodName):
test(
torch.export.export_for_training(
M(),
inp,
strict=not is_non_strict_test(self._testMethodName),
preserve_module_call_signature=("n",),
)
)
test(export(M(), inp, preserve_module_call_signature=("n",)))
test(export(M(), inp, preserve_module_call_signature=("n",)))
def test_unflatten_multiple_graphs_preserve_signature_no_error(self):
class N(torch.nn.Module):
@ -8590,6 +8567,7 @@ graph():
self.assertTrue(torch.allclose(unflattened_result, eager_result))
if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
test(
export(M(), inp, preserve_module_call_signature=("n",)),
swap={"n": N()},
@ -8646,6 +8624,7 @@ graph():
self.assertTrue(torch.allclose(unflattened_result, eager_result))
if not is_retracebility_test(self._testMethodName):
# swapping will not work with retrace
test(
export(M(), inp, preserve_module_call_signature=("n",)),
swap={"n": N()},
@ -8790,15 +8769,13 @@ graph():
id(getattr(unflattened, a)), id(getattr(unflattened, b))
)
if not is_retracebility_test(self._testMethodName):
# preserving module call signatures
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
exported_result = ep.module()(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
exported_result = ep.module()(*inp)
self.assertTrue(torch.allclose(exported_result, eager_result))
unflattened = torch.export.unflatten(ep)
unflattened_result = unflattened(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
unflattened = torch.export.unflatten(ep)
unflattened_result = unflattened(*inp)
self.assertTrue(torch.allclose(unflattened_result, eager_result))
test(
gen_m(n=True, n_1=False, p=False, p_1=False),

View File

@ -61,11 +61,7 @@ def export_tracepoint_cpu(*args, **kwargs):
def _wrap_submodule(mod, path, module_call_specs):
assert isinstance(mod, torch.nn.Module)
assert path != ""
submodule = mod
for name in path.split("."):
if not hasattr(submodule, name):
raise RuntimeError(f"Couldn't find submodule at path {path}")
submodule = getattr(submodule, name)
submodule = torch.fx.graph_module._get_attr(mod, path)
def update_module_call_signatures(path, in_spec, out_spec):
if path in module_call_specs:

View File

@ -658,9 +658,12 @@ def _export_to_torch_ir(
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
try:
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
with _wrap_submodules(
f, preserve_module_call_signature, module_call_specs
), _ignore_backend_decomps():
ctx = nullcontext()
if not isinstance(f, torch.fx.GraphModule):
ctx = _wrap_submodules( # type: ignore[assignment]
f, preserve_module_call_signature, module_call_specs
)
with ctx, _ignore_backend_decomps():
gm_torch_level, _ = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
@ -1684,9 +1687,12 @@ def _non_strict_export(
new_preserved_call_signatures = [
"_export_root." + i for i in preserve_module_call_signature
]
with _wrap_submodules(
wrapped_mod, new_preserved_call_signatures, module_call_specs
):
ctx = nullcontext()
if not isinstance(mod, torch.fx.GraphModule):
ctx = _wrap_submodules( # type: ignore[assignment]
wrapped_mod, new_preserved_call_signatures, module_call_specs
)
with ctx:
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
log.debug("Exported program from AOTAutograd:\n%s", gm)