mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Check FakeScriptObject in _resolve_name_collision (#157736)
Summary: Fix https://github.com/pytorch/pytorch/issues/157401 torch.equal cannot handle FakeScriptObject inputs. Test Plan: ``` buck run fbcode//mode/dev-nosan //caffe2/test/inductor:torchbind -- -r test_aoti_torchbind_name_collision ``` Rollback Plan: Differential Revision: D77894081 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157736 Approved by: https://github.com/angelayi
This commit is contained in:
committed by
PyTorch MergeBot
parent
44d0800d60
commit
5b4e0255d7
@ -410,6 +410,30 @@ class TestTorchbind(TestCase):
|
||||
):
|
||||
aot_compile(ep.module(), inputs, options={"aot_inductor.package": True})
|
||||
|
||||
def test_aoti_torchbind_name_collision(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self._torchbind_obj0 = torch.classes._TorchScriptTesting._Foo(2, 3)
|
||||
|
||||
def forward(self, x):
|
||||
a = self._torchbind_obj0.add_tensor(x)
|
||||
torchbind = torch.classes._TorchScriptTesting._Foo(4, 5)
|
||||
b = torchbind.add_tensor(x)
|
||||
return a + b
|
||||
|
||||
m = M()
|
||||
inputs = (torch.ones(2, 3),)
|
||||
orig_res = m(*inputs)
|
||||
|
||||
with enable_torchbind_tracing():
|
||||
ep = torch.export.export(m, inputs, strict=False)
|
||||
|
||||
pt2_path = torch._inductor.aoti_compile_and_package(ep)
|
||||
optimized = torch._inductor.aoti_load_package(pt2_path)
|
||||
result = optimized(*inputs)
|
||||
self.assertEqual(result, orig_res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -355,7 +355,13 @@ def _resolve_name_collision(mod: GraphModule, gm: GraphModule) -> None:
|
||||
continue
|
||||
gm_target = attrgetter(target_name)(gm)
|
||||
model_target = attrgetter(target_name)(mod)
|
||||
if (
|
||||
if isinstance(gm_target, FakeScriptObject):
|
||||
if (
|
||||
isinstance(model_target, FakeScriptObject)
|
||||
and gm_target.real_obj is model_target.real_obj
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
torch.equal(gm_target, model_target)
|
||||
and gm_target.dtype == model_target.dtype
|
||||
):
|
||||
|
Reference in New Issue
Block a user