mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
fix non-strict placeholder naming with kwargs (#144278)
Fixes https://github.com/pytorch/pytorch/issues/143732 Differential Revision: [D67872055](https://our.internmc.facebook.com/intern/diff/D67872055/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/144278 Approved by: https://github.com/yushangdi, https://github.com/pianpwk
This commit is contained in:
committed by
PyTorch MergeBot
parent
c3b28491c8
commit
12fdb93ebd
@ -9752,6 +9752,47 @@ def forward(self, x):
|
||||
return (foo_functional,)""",
|
||||
)
|
||||
|
||||
def test_placeholder_naming_order(self):
|
||||
# See https://github.com/pytorch/pytorch/issues/143732
|
||||
|
||||
class Mod(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.layer1 = torch.nn.Linear(3, 16)
|
||||
self.layer2 = torch.nn.Linear(3, 32)
|
||||
|
||||
def forward(self, x1, x2, flag=True):
|
||||
x1o = self.layer1(x1)
|
||||
x2o = self.layer2(x2)
|
||||
return torch.cat([x1o, x2o], dim=1)
|
||||
|
||||
mod = Mod()
|
||||
args = (torch.rand(1, 3),)
|
||||
kwargs = {"flag": False, "x2": torch.rand(1, 3)}
|
||||
ep = export(mod, args, kwargs)
|
||||
|
||||
# check that graph is behaviorally correct
|
||||
self.assertTrue(
|
||||
torch.allclose(ep.module()(*args, **kwargs), mod(*args, **kwargs))
|
||||
)
|
||||
|
||||
# check that graph input names are as expected
|
||||
self.assertEqual(ep.graph_signature.user_inputs, ("x1", False, "x2"))
|
||||
|
||||
def test_placeholder_naming_order_variadic(self):
|
||||
class Mod(torch.nn.Module):
|
||||
def forward(self, a, b, c, **kwargs):
|
||||
return a - b + c * kwargs["d"]
|
||||
|
||||
mod = Mod()
|
||||
args = (torch.randn(3),)
|
||||
kwargs = {"c": torch.randn(3), "b": torch.randn(3), "d": torch.randn(3)}
|
||||
ep = export(mod, args, kwargs)
|
||||
self.assertTrue(
|
||||
torch.allclose(ep.module()(*args, **kwargs), mod(*args, **kwargs))
|
||||
)
|
||||
self.assertEqual(ep.graph_signature.user_inputs, ("a", "c", "b", "d"))
|
||||
|
||||
def test_placeholder_naming_collisions(self):
|
||||
# test collisions between nested user inputs
|
||||
class Foo(torch.nn.Module):
|
||||
|
||||
@ -725,7 +725,11 @@ def _bind_signature_to_inputs(mod, fake_args, fake_kwargs):
|
||||
else:
|
||||
sig = inspect.signature(mod.forward)
|
||||
|
||||
return sig.bind(*fake_args, **fake_kwargs).arguments
|
||||
# Rather than binding both fake_args and fake_kwargs to sig names, we
|
||||
# (partially) bind only fake_args, while reusing fake_kwarg names. This
|
||||
# ensures that fake_kwargs do not get reordered, which is important to
|
||||
# match flattened user inputs.
|
||||
return {**sig.bind_partial(*fake_args).arguments, **fake_kwargs}
|
||||
|
||||
|
||||
def _name_hoo_subgraph_placeholders(gm: torch.fx.GraphModule) -> None:
|
||||
|
||||
Reference in New Issue
Block a user