Check FakeScriptObject in _resolve_name_collision (#157736)

Summary:
Fix https://github.com/pytorch/pytorch/issues/157401

torch.equal cannot handle FakeScriptObject inputs.

Test Plan:
```
buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r  test_aoti_torchbind_name_collision
```

Rollback Plan:

Differential Revision: D77894081

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157736
Approved by: https://github.com/angelayi
This commit is contained in:
Shangdi Yu
2025-07-08 17:51:43 +00:00
committed by PyTorch MergeBot
parent 44d0800d60
commit 5b4e0255d7
2 changed files with 31 additions and 1 deletions

View File

@ -410,6 +410,30 @@ class TestTorchbind(TestCase):
):
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
def test_aoti_torchbind_name_collision(self):
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self._torchbind_obj0 = torch.classes._TorchScriptTesting._Foo(2, 3)
def forward(self, x):
a = self._torchbind_obj0.add_tensor(x)
torchbind = torch.classes._TorchScriptTesting._Foo(4, 5)
b = torchbind.add_tensor(x)
return a + b
m = M()
inputs = (torch.ones(2, 3),)
orig_res = m(*inputs)
with enable_torchbind_tracing():
ep = torch.export.export(m, inputs, strict=False)
pt2_path = torch._inductor.aoti_compile_and_package(ep)
optimized = torch._inductor.aoti_load_package(pt2_path)
result = optimized(*inputs)
self.assertEqual(result, orig_res)
if __name__ == "__main__":
run_tests()

View File

@ -355,7 +355,13 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
continue
gm_target = attrgetter(target_name)(gm)
model_target = attrgetter(target_name)(mod)
if (
if isinstance(gm_target, FakeScriptObject):
if (
isinstance(model_target, FakeScriptObject)
and gm_target.real_obj is model_target.real_obj
):
continue
elif (
torch.equal(gm_target, model_target)
and gm_target.dtype == model_target.dtype
):