[MTIA] Map names to operand indices when folding submodules (#150692)

When replacing placeholders with getattrs during constant folding, we can have an argument and parameter name mismatch. In fact, there is no guarantee that the parameter name is equivalent to the argument name used in the module call.

Differential Revision: D72415970

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150692
Approved by: https://github.com/jfix71
This commit is contained in:
Klint Qinami
2025-04-06 03:11:11 +00:00
committed by PyTorch MergeBot
parent 15768cc34b
commit 2d98a1caf5

View File

@ -252,13 +252,20 @@ def split_const_subgraphs(
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
# return add
root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
# The order of placeholders in the const_gm graph should match the order of
# args in the outer module, so we can simply use an index for the
# placeholder mapping
ph_idx = 0
for node in root_const_gm.graph.nodes:
if node.op == "output":
multiple_outputs = isinstance(node.args[0], tuple)
continue
if node.op != "placeholder":
continue
in_node = next(n for n in call_const_gm_args if n.name == node.target)
assert ph_idx < len(call_const_gm_args)
in_node = call_const_gm_args[ph_idx]
ph_idx += 1
assert in_node.op == "get_attr"
with root_const_gm.graph.inserting_before(node):
new_node = root_const_gm.graph.get_attr(in_node.target)