mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Relax DUPLICATED_INPUT to be serializable. (#157492)
Since we don't actually rely on any real data while building DUPLICATE_INPUT guard, we can safely serialize it with sources and it should be able to reconstruct the guard correctly in the new process. Therefore we don't really need to prevent serializing it. Differential Revision: [D77683302](https://our.internmc.facebook.com/intern/diff/D77683302/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157492 Approved by: https://github.com/jamesjwu, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
336f1e2d35
commit
7be862ab8f
@ -1049,10 +1049,10 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
return x + x_
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
with self.assertRaisesRegex(
|
||||
PackageError, "DUPLICATE_INPUT guard cannot be serialized"
|
||||
):
|
||||
self._test_serialization("DUPLICATE_INPUT", fn, x, x)
|
||||
ref, loaded = self._test_serialization("DUPLICATE_INPUT", fn, x, x)
|
||||
|
||||
self._test_check_fn(ref, loaded, {"x": x, "x_": x}, True)
|
||||
self._test_check_fn(ref, loaded, {"x": x, "x_": torch.randn(3, 2)}, False)
|
||||
|
||||
def test_weakref_alive(self):
|
||||
mod = torch.nn.Linear(10, 10, bias=False)
|
||||
|
@ -1943,9 +1943,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
# TODO(voz): Deduplicate w/ AOTAutograd dupe input guards
|
||||
def DUPLICATE_INPUT(self, guard, source_b):
|
||||
if self.serialization_mode == "save":
|
||||
raise torch._dynamo.exc.PackageError(
|
||||
"DUPLICATE_INPUT guard cannot be serialized yet."
|
||||
)
|
||||
if name := get_local_source_name(source_b):
|
||||
self.check_fn_manager.additional_used_local_vars.add(name)
|
||||
|
||||
ref_a = self.arg_ref(guard)
|
||||
ref_b = self.arg_ref(source_b.name())
|
||||
|
||||
@ -2821,6 +2821,7 @@ class CheckFunctionManager:
|
||||
)
|
||||
self.guards_serialization_mode = guards_serialization_mode
|
||||
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
|
||||
self.additional_used_local_vars: OrderedSet[str] = OrderedSet()
|
||||
if runtime_global_scope:
|
||||
assert self.guards_serialization_mode == "load"
|
||||
self.runtime_global_scope = runtime_global_scope
|
||||
@ -2999,7 +3000,7 @@ class CheckFunctionManager:
|
||||
local_scope={
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.local_scope.items()
|
||||
if k in used_local_vars
|
||||
if k in used_local_vars or k in self.additional_used_local_vars
|
||||
},
|
||||
global_scope=global_scope_state,
|
||||
_guards=torch._guards.GuardsSet(
|
||||
|
Reference in New Issue
Block a user