mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[MTIA] Map names to operand indices when folding submodules (#150692)
When replacing placeholders with getattrs during constant folding, we can have an argument and parameter name mismatch. In fact, there is no guarantee that the parameter name is equivalent to the argument name used in the module call. Differential Revision: D72415970 Pull Request resolved: https://github.com/pytorch/pytorch/pull/150692 Approved by: https://github.com/jfix71
This commit is contained in:
committed by
PyTorch MergeBot
parent
15768cc34b
commit
2d98a1caf5
@ -252,13 +252,20 @@ def split_const_subgraphs(
|
||||
# %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {})
|
||||
# return add
|
||||
root_const_gm = torch.fx.GraphModule(split, const_gm.graph)
|
||||
|
||||
# The order of placeholders in the const_gm graph should match the order of
|
||||
# args in the outer module, so we can simply use an index for the
|
||||
# placeholder mapping
|
||||
ph_idx = 0
|
||||
for node in root_const_gm.graph.nodes:
|
||||
if node.op == "output":
|
||||
multiple_outputs = isinstance(node.args[0], tuple)
|
||||
continue
|
||||
if node.op != "placeholder":
|
||||
continue
|
||||
in_node = next(n for n in call_const_gm_args if n.name == node.target)
|
||||
assert ph_idx < len(call_const_gm_args)
|
||||
in_node = call_const_gm_args[ph_idx]
|
||||
ph_idx += 1
|
||||
assert in_node.op == "get_attr"
|
||||
with root_const_gm.graph.inserting_before(node):
|
||||
new_node = root_const_gm.graph.get_attr(in_node.target)
|
||||
|
Reference in New Issue
Block a user