mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix shared submodule module call signature (#139438)
Differential Revision: [D65308061](https://our.internmc.facebook.com/intern/diff/D65308061/) When a shared submodule is called multiple times with different aliases, e.g., `self.a` and `self.b` are both `C()` under the hood and we have calls to both `self.a(...)` and `self.b(...)`, we wrap `C()` to emit as many export tracepoints as there are aliases. This caused us to compute module call signatures that conflated information: we'd add inputs and outputs of one call to inputs and outputs of a different call. Overall preserving module call signatures in the presence of shared submodules was borked because of this bug. The fix is to pay attention to the nn module stack, which accurately tracks individual calls, thus allowing us to ignore some export tracepoints that get the module correct but not the alias through which the call was made. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139438 Approved by: https://github.com/zhxchen17
This commit is contained in:
committed by
PyTorch MergeBot
parent
a104b560d8
commit
9a5175e836
@ -7047,6 +7047,16 @@ graph():
|
||||
id(getattr(unflattened, a)), id(getattr(unflattened, b))
|
||||
)
|
||||
|
||||
if not is_retracebility_test(self._testMethodName):
|
||||
# preserving module call signatures
|
||||
ep = export(m, inp, preserve_module_call_signature=("n", "p"))
|
||||
exported_result = ep.module()(*inp)
|
||||
self.assertTrue(torch.allclose(exported_result, eager_result))
|
||||
|
||||
unflattened = torch.export.unflatten(ep)
|
||||
unflattened_result = unflattened(*inp)
|
||||
self.assertTrue(torch.allclose(unflattened_result, eager_result))
|
||||
|
||||
test(
|
||||
gen_m(n=True, n_1=False, p=False, p_1=False),
|
||||
# p should share n_1 graph, p_1 should be optimized away
|
||||
|
@ -85,21 +85,38 @@ class CollectTracepointsPass(PassBase):
|
||||
if node.op != "call_function":
|
||||
continue
|
||||
if node.target == torch.ops.higher_order._export_tracepoint:
|
||||
# There's some subtlety worth noting. Here fqn corresponds to
|
||||
# the call name, whereas path corresponds to the module name.
|
||||
# They are not necessarily the same! When a submodule is shared
|
||||
# through different aliases, there are as many _export_tracepoint
|
||||
# markers as there are aliases, since the shared submodule is
|
||||
# wrapped once for each alias.
|
||||
path = node.kwargs["path"]
|
||||
fqn, _ = next(reversed(node.meta["nn_module_stack"].values()))
|
||||
|
||||
module_key = next(reversed(node.meta["nn_module_stack"]))
|
||||
if "@" in module_key:
|
||||
call_path = f"{path}@{module_key.split('@')[-1]}"
|
||||
if call_path not in self.specs:
|
||||
self.specs[call_path] = copy_sig(self.specs[path])
|
||||
path = call_path
|
||||
suffix = module_key.split("@")[-1]
|
||||
path = f"{path}@{suffix}"
|
||||
|
||||
call_fqn = f"{fqn}@{suffix}"
|
||||
if call_fqn not in self.specs:
|
||||
self.specs[call_fqn] = copy_sig(self.specs[fqn])
|
||||
fqn = call_fqn
|
||||
|
||||
kind = node.kwargs["kind"]
|
||||
for i, arg in enumerate(node.args):
|
||||
if kind == "module_call_inputs":
|
||||
self.specs[path].inputs.append(get_arg_spec(arg))
|
||||
elif kind == "module_call_outputs":
|
||||
self.specs[path].outputs.append(get_arg_spec(arg))
|
||||
else:
|
||||
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
||||
# We only update the signature of the alias used to call
|
||||
# the submodule. Otherwise the signatures of all aliases
|
||||
# would get conflated; the inputs/outputs of every call
|
||||
# would be recorded in every other call as well.
|
||||
if fqn == path:
|
||||
if kind == "module_call_inputs":
|
||||
self.specs[path].inputs.append(get_arg_spec(arg))
|
||||
elif kind == "module_call_outputs":
|
||||
self.specs[path].outputs.append(get_arg_spec(arg))
|
||||
else:
|
||||
raise AssertionError(f"Unknown tracepoint kind: {kind}")
|
||||
if isinstance(arg, torch.fx.Node):
|
||||
for user in node.users:
|
||||
assert user.op == "call_function"
|
||||
|
Reference in New Issue
Block a user