fix non-strict placeholder naming with kwargs (#144278)

Fixes https://github.com/pytorch/pytorch/issues/143732

Differential Revision: [D67872055](https://our.internmc.facebook.com/intern/diff/D67872055/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144278
Approved by: https://github.com/yushangdi, https://github.com/pianpwk
This commit is contained in:
Avik Chaudhuri
2025-01-06 14:21:37 -08:00
committed by PyTorch MergeBot
parent c3b28491c8
commit 12fdb93ebd
2 changed files with 46 additions and 1 deletions

View File

@ -9752,6 +9752,47 @@ def forward(self, x):
return (foo_functional,)""",
)
def test_placeholder_naming_order(self):
# See https://github.com/pytorch/pytorch/issues/143732
class Mod(torch.nn.Module):
def __init__(self):
super().__init__()
self.layer1 = torch.nn.Linear(3, 16)
self.layer2 = torch.nn.Linear(3, 32)
def forward(self, x1, x2, flag=True):
x1o = self.layer1(x1)
x2o = self.layer2(x2)
return torch.cat([x1o, x2o], dim=1)
mod = Mod()
args = (torch.rand(1, 3),)
kwargs = {"flag": False, "x2": torch.rand(1, 3)}
ep = export(mod, args, kwargs)
# check that graph is behaviorally correct
self.assertTrue(
torch.allclose(ep.module()(*args, **kwargs), mod(*args, **kwargs))
)
# check that graph input names are as expected
self.assertEqual(ep.graph_signature.user_inputs, ("x1", False, "x2"))
def test_placeholder_naming_order_variadic(self):
class Mod(torch.nn.Module):
def forward(self, a, b, c, **kwargs):
return a - b + c * kwargs["d"]
mod = Mod()
args = (torch.randn(3),)
kwargs = {"c": torch.randn(3), "b": torch.randn(3), "d": torch.randn(3)}
ep = export(mod, args, kwargs)
self.assertTrue(
torch.allclose(ep.module()(*args, **kwargs), mod(*args, **kwargs))
)
self.assertEqual(ep.graph_signature.user_inputs, ("a", "c", "b", "d"))
def test_placeholder_naming_collisions(self):
# test collisions between nested user inputs
class Foo(torch.nn.Module):

View File

@ -725,7 +725,11 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs):
else:
sig = inspect.signature(mod.forward)
return sig.bind(*fake_args, **fake_kwargs).arguments
# Rather than binding both fake_args and fake_kwargs to sig names, we
# (partially) bind only fake_args, while reusing fake_kwarg names. This
# ensures that fake_kwargs do not get reordered, which is important to
# match flattened user inputs.
return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs}
def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None: