mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
graph module retracing without preserving MCS (#143676)
Retracing while preserving module call signatures used to be a problem because graph modules don't have submodules at given paths. This led to a number of failing retracebility tests. By not trying to wrap modules with export tracepoints we can pass most of these tests; the only exception is where you do module swapping on retraced programs, which is still not possible. Differential Revision: [D67539304](https://our.internmc.facebook.com/intern/diff/D67539304/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/143676 Approved by: https://github.com/zhxchen17, https://github.com/tugsbayasgalan ghstack dependencies: #143664
This commit is contained in:
committed by
PyTorch MergeBot
parent
d7e59c2f85
commit
51eacea8c4
@ -4038,7 +4038,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||
return self.n(x + 1, True) + self.n(x + 1, False)
|
||||
|
||||
x = torch.zeros(4)
|
||||
types = {} if is_retracebility_test(self._testMethodName) else {"n": N}
|
||||
types = {"n": N}
|
||||
ep = export(
|
||||
M(),
|
||||
(x,),
|
||||
@ -7722,14 +7722,11 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -7787,14 +7784,11 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -7851,14 +7845,11 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -7934,15 +7925,12 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
"n1.n2.n3.n4",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
"n1.n2.n3.n4",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -8060,17 +8048,14 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
"n1.n2.n3.n4",
|
||||
"n1.n2.n3.n4.n5",
|
||||
"n1.n2.n3.n4.n5.n6",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
"n1.n2.n3.n4",
|
||||
"n1.n2.n3.n4.n5",
|
||||
"n1.n2.n3.n4.n5.n6",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -8248,20 +8233,17 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
"n1.n2.n3.n4",
|
||||
"n1.n2.n3.n4.n5",
|
||||
"n1.n2.n3.n4.n5.n6",
|
||||
"n1.n2.n3.n4.n5.n6.n7",
|
||||
"n1.n2.n3.n4.n5.n6.n7.n8",
|
||||
"n1.n2.n3.n4.n5.n6.n7.n8.n9",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
"n1.n2.n3",
|
||||
"n1.n2.n3.n4",
|
||||
"n1.n2.n3.n4.n5",
|
||||
"n1.n2.n3.n4.n5.n6",
|
||||
"n1.n2.n3.n4.n5.n6.n7",
|
||||
"n1.n2.n3.n4.n5.n6.n7.n8",
|
||||
"n1.n2.n3.n4.n5.n6.n7.n8.n9",
|
||||
)
|
||||
ep = export(
|
||||
N0(),
|
||||
inp,
|
||||
@ -8307,13 +8289,10 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -8354,13 +8333,10 @@ graph():
|
||||
|
||||
inp = (torch.ones(1),)
|
||||
eager = N0()(*inp)
|
||||
if is_retracebility_test(self._testMethodName):
|
||||
fqns = ()
|
||||
else:
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
)
|
||||
fqns = (
|
||||
"n1",
|
||||
"n1.n2",
|
||||
)
|
||||
ep = export(N0(), inp, preserve_module_call_signature=fqns)
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -8451,6 +8427,7 @@ graph():
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
# swapping will not work with retrace
|
||||
test(
|
||||
export(Mod(), inp, preserve_module_call_signature=(path_n,)),
|
||||
swap={path_n: N()},
|
||||
@ -8484,6 +8461,7 @@ graph():
|
||||
eager_result = m(*inp)
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
# swapping will not work with retrace
|
||||
ep = export(M(), inp, preserve_module_call_signature=("n",))
|
||||
epm = ep.module()
|
||||
ufm = torch.export.unflatten(ep)
|
||||
@ -8535,18 +8513,17 @@ graph():
|
||||
unflattened_result = ufm(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
if is_training_ir_test(self._testMethodName):
|
||||
test(
|
||||
torch.export.export_for_training(
|
||||
M(),
|
||||
inp,
|
||||
strict=not is_non_strict_test(self._testMethodName),
|
||||
preserve_module_call_signature=("n",),
|
||||
)
|
||||
if is_training_ir_test(self._testMethodName):
|
||||
test(
|
||||
torch.export.export_for_training(
|
||||
M(),
|
||||
inp,
|
||||
strict=not is_non_strict_test(self._testMethodName),
|
||||
preserve_module_call_signature=("n",),
|
||||
)
|
||||
)
|
||||
|
||||
test(export(M(), inp, preserve_module_call_signature=("n",)))
|
||||
test(export(M(), inp, preserve_module_call_signature=("n",)))
|
||||
|
||||
def test_unflatten_multiple_graphs_preserve_signature_no_error(self):
|
||||
class N(torch.nn.Module):
|
||||
@ -8590,6 +8567,7 @@ graph():
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
# swapping will not work with retrace
|
||||
test(
|
||||
export(M(), inp, preserve_module_call_signature=("n",)),
|
||||
swap={"n": N()},
|
||||
@ -8646,6 +8624,7 @@ graph():
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
# swapping will not work with retrace
|
||||
test(
|
||||
export(M(), inp, preserve_module_call_signature=("n",)),
|
||||
swap={"n": N()},
|
||||
@ -8790,15 +8769,13 @@ graph():
|
||||
id(getattr(unflattened, a)), id(getattr(unflattened, b))
|
||||
)
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
# preserving module call signatures
|
||||
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
|
||||
exported_result = ep.module()(*inp)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
|
||||
exported_result = ep.module()(*inp)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
|
||||
unflattened = torch.export.unflatten(ep)
|
||||
unflattened_result = unflattened(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
unflattened = torch.export.unflatten(ep)
|
||||
unflattened_result = unflattened(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
test(
|
||||
gen_m(n=True, n_1=False, p=False, p_1=False),
|
||||
|
@ -61,11 +61,7 @@ def export_tracepoint_cpu(*args, **kwargs):
|
||||
def _wrap_submodule(mod, path, module_call_specs):
|
||||
assert isinstance(mod, torch.nn.Module)
|
||||
assert path != ""
|
||||
submodule = mod
|
||||
for name in path.split("."):
|
||||
if not hasattr(submodule, name):
|
||||
raise RuntimeError(f"Couldn't find submodule at path {path}")
|
||||
submodule = getattr(submodule, name)
|
||||
submodule = torch.fx.graph_module._get_attr(mod, path)
|
||||
|
||||
def update_module_call_signatures(path, in_spec, out_spec):
|
||||
if path in module_call_specs:
|
||||
|
@ -658,9 +658,12 @@ def _export_to_torch_ir(
|
||||
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
|
||||
try:
|
||||
module_call_specs: Dict[str, Dict[str, pytree.TreeSpec]] = {}
|
||||
with _wrap_submodules(
|
||||
f, preserve_module_call_signature, module_call_specs
|
||||
), _ignore_backend_decomps():
|
||||
ctx = nullcontext()
|
||||
if not isinstance(f, torch.fx.GraphModule):
|
||||
ctx = _wrap_submodules( # type: ignore[assignment]
|
||||
f, preserve_module_call_signature, module_call_specs
|
||||
)
|
||||
with ctx, _ignore_backend_decomps():
|
||||
gm_torch_level, _ = torch._dynamo.export(
|
||||
f,
|
||||
dynamic_shapes=dynamic_shapes, # type: ignore[arg-type]
|
||||
@ -1684,9 +1687,12 @@ def _non_strict_export(
|
||||
new_preserved_call_signatures = [
|
||||
"_export_root." + i for i in preserve_module_call_signature
|
||||
]
|
||||
with _wrap_submodules(
|
||||
wrapped_mod, new_preserved_call_signatures, module_call_specs
|
||||
):
|
||||
ctx = nullcontext()
|
||||
if not isinstance(mod, torch.fx.GraphModule):
|
||||
ctx = _wrap_submodules( # type: ignore[assignment]
|
||||
wrapped_mod, new_preserved_call_signatures, module_call_specs
|
||||
)
|
||||
with ctx:
|
||||
gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags)
|
||||
log.debug("Exported program from AOTAutograd:\n%s", gm)
|
||||
|
||||
|
Reference in New Issue
Block a user