[dynamo] Relax DUPLICATED_INPUT to be serializable. (#157492)

Since we don't actually rely on any real data while building DUPLICATE_INPUT guard, we can safely serialize it with sources and it should be able to reconstruct the guard correctly in the new process. Therefore we don't really need to prevent serializing it.

Differential Revision: [D77683302](https://our.internmc.facebook.com/intern/diff/D77683302/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157492
Approved by: https://github.com/jamesjwu, https://github.com/jansel
This commit is contained in:
zhxchen17
2025-07-03 12:20:13 -07:00
committed by PyTorch MergeBot
parent 336f1e2d35
commit 7be862ab8f
2 changed files with 9 additions and 8 deletions

View File

@ -1049,10 +1049,10 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
return x + x_
x = torch.randn(3, 2)
with self.assertRaisesRegex(
PackageError, "DUPLICATE_INPUT guard cannot be serialized"
):
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
ref, loaded = self._test_serialization("DUPLICATE_INPUT", fn, x, x)
self._test_check_fn(ref, loaded, {"x": x, "x_": x}, True)
self._test_check_fn(ref, loaded, {"x": x, "x_": torch.randn(3, 2)}, False)
def test_weakref_alive(self):
mod = torch.nn.Linear(10, 10, bias=False)

View File

@ -1943,9 +1943,9 @@ class GuardBuilder(GuardBuilderBase):
# TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
def DUPLICATE_INPUT(self, guard, source_b):
if self.serialization_mode == "save":
raise torch._dynamo.exc.PackageError(
"DUPLICATE_INPUT guard cannot be serialized yet."
)
if name := get_local_source_name(source_b):
self.check_fn_manager.additional_used_local_vars.add(name)
ref_a = self.arg_ref(guard)
ref_b = self.arg_ref(source_b.name())
@ -2821,6 +2821,7 @@ class CheckFunctionManager:
)
self.guards_serialization_mode = guards_serialization_mode
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
if runtime_global_scope:
assert self.guards_serialization_mode == "load"
self.runtime_global_scope = runtime_global_scope
@ -2999,7 +3000,7 @@ class CheckFunctionManager:
local_scope={
k: v
for k, v in output_graph_guards_state.local_scope.items()
if k in used_local_vars
if k in used_local_vars or k in self.additional_used_local_vars
},
global_scope=global_scope_state,
_guards=torch._guards.GuardsSet(