[precompile] Fix guard serialization loading bugs. (#164490)

Summary: Added a set of fixes triggered by fm training job. Overall the theme here is that we should get rid of saved objects as much as possible when they are not used in guard reconstruction. Sometimes for objects that cannot be saved (like local functions) we still try our best to save their closures.

Test Plan:
test_guard_serialization.py
test_lazy_awatiable.py

Differential Revision: D83766926

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164490
Approved by: https://github.com/jamesjwu
This commit is contained in:
Zhengxu Chen
2025-10-03 19:20:07 +00:00
committed by PyTorch MergeBot
parent 3c59351c6e
commit 16f9bef642
5 changed files with 481 additions and 25 deletions

View File

@ -4,8 +4,10 @@ import dataclasses
import importlib
import pickle
import sys
import tempfile
import types
import unittest
import weakref
from collections.abc import Iterator
from unittest.mock import patch
@ -18,6 +20,7 @@ import torch.utils.cpp_extension
from torch._dynamo.bytecode_transformation import transform_code_object
from torch._dynamo.exc import PackageError
from torch._dynamo.guards import CheckFunctionManager, CompileId
from torch._dynamo.package import CompilePackage
from torch._dynamo.symbolic_convert import (
ExceptionStack,
InstructionTranslator,
@ -44,10 +47,33 @@ class GlobalModule(torch.nn.Module):
return x + 1
class GlobalNestedModule(torch.nn.Module):
def __init__(self, submodule=None):
super().__init__()
self.linear = torch.nn.Linear(10, 10)
self.param = torch.nn.Parameter(torch.randn(3, 2))
self.nested = submodule or GlobalModule()
def forward(self, x):
return self.linear(x) + 1
def global_func(x):
return x + 1
class ModuleNotSerializable(torch.nn.Module):
def __init__(self):
super().__init__()
self.param = torch.nn.Parameter(torch.randn(3, 2))
def __getstate__(self):
raise NotImplementedError("not serialzable")
def forward(self, x):
return x + self.param
class GlobalTorchFunctionMode(TorchFunctionMode):
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
@ -63,6 +89,39 @@ class MyClass:
return x + 1
class MyClassNotSerializable:
def __getstate__(self):
raise NotImplementedError
def add(self, x):
return x + 1
class Inputs:
def __init__(self, x, unused):
self.x = x
self.unused = unused
def _global_func_wrong_fqn(x):
return x + 1
global_func_wrong_fqn = _global_func_wrong_fqn
del _global_func_wrong_fqn
class FlatModule(torch.nn.Module):
def forward(self, x):
return x + 2
class ModWithDict(torch.nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
class SubclassWithMeta(torch.Tensor):
@staticmethod
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
@ -355,13 +414,14 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
self._cached_guards_state = guards_state
self._cached_f_code = self._frame_state.f_code
self.assertIsNotNone(guards_state)
guards_state = pickle.loads(guards_state)
guards_state = torch._dynamo.package.load_guards_state(guards_state)
check_fn_manager = CheckFunctionManager(
self._frame_state.f_code,
guards_state.output_graph,
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=self._frame_state.f_globals,
source_get_cache=guards_state.source_get_cache,
)
loaded_gm = check_fn_manager.guard_manager
@ -1372,10 +1432,289 @@ class TestGuardSerialization(TestGuardSerializationBase):
ref, loaded, {"self": m, "foo": MyClass().add, "x": torch.randn(3, 2)}, True
)
def test_bound_methods_missing(self):
class MyClass:
def __getstate__(self):
raise NotImplementedError
def add(self, x):
return x + 1
def foo(x: torch.Tensor, y: list[MyClass]):
assert len(y) == 1
return x + 1
ref, loaded = self._test_serialization(
"TYPE_MATCH", foo, torch.randn(3, 2), [MyClass()]
)
self._test_check_fn(
ref, loaded, {"x": torch.randn(3, 2), "y": [MyClass()]}, True
)
def test_bound_methods_empty(self):
def foo(x, y):
assert callable(y[0])
return x + 1
ref, loaded = self._test_serialization(
"TYPE_MATCH", foo, torch.randn(3, 2), [MyClassNotSerializable().add]
)
self._test_check_fn(
ref,
loaded,
{"x": torch.randn(3, 2), "y": [MyClassNotSerializable().add]},
True,
)
def test_ddp_module(self):
import torch.distributed as dist
if not dist.is_available():
self.skipTest("Torch distributed is not available")
from torch.nn.parallel import DistributedDataParallel as DDP
tmpfile = tempfile.NamedTemporaryFile()
dist.init_process_group(
backend="gloo", rank=0, world_size=1, init_method=f"file://{tmpfile.name}"
)
try:
ddp_model = DDP(GlobalNestedModule())
def foo(ddp, x):
return ddp(x)
x = torch.randn(10)
package = CompilePackage(foo)
torch._dynamo.optimize(
package=package,
guard_filter_fn=lambda gs: [
x.guard_type not in ("CLOSURE_MATCH", "ID_MATCH") for x in gs
],
)(foo)(ddp_model, x)
self.assertEqual(len(package._codes[foo.__code__].guarded_codes), 1)
torch._dynamo.package.load_guards_state(
package._codes[foo.__code__].guarded_codes[0].guards_state
)
finally:
dist.destroy_process_group()
def test_dict_keys_serialization(self):
d = {1: 2, 3: 4}
def foo(x, y):
for k in y:
x += k
return x
ref, loaded = self._test_serialization(
"TYPE_MATCH", foo, torch.randn(3, 2), d.keys()
)
self._test_check_fn(
ref,
loaded,
{"x": torch.randn(3, 2), "y": d.keys()},
True,
)
def test_unserializable_sharded_tensor(self):
import torch.distributed as dist
if not dist.is_available():
self.skipTest("Torch distributed is not available")
tmpfile = tempfile.NamedTemporaryFile()
dist.init_process_group(
backend="gloo", rank=0, world_size=1, init_method=f"file://{tmpfile.name}"
)
try:
ChunkShardingSpec = dist._shard.sharding_spec.ChunkShardingSpec
ShardedTensor = dist._shard.sharded_tensor.ShardedTensor
tensor = torch.arange(2, dtype=torch.int64)
local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]), 0)
sharding_dim = 0
sharding_spec = ChunkShardingSpec(
dim=sharding_dim,
placements=[
"rank:0/cpu",
],
)
st = ShardedTensor._init_from_local_tensor(
local_tensor, sharding_spec, [1, 4]
)
def foo(inputs):
return inputs.x + 1
ref, loaded = self._test_serialization(
"TENSOR_MATCH", foo, Inputs(torch.randn(3, 2), st)
)
self._test_check_fn(
ref, loaded, {"inputs": Inputs(torch.randn(3, 2), st)}, True
)
finally:
dist.destroy_process_group()
def test_function_with_wrong_fqn(self):
def foo(inputs):
return inputs.x + 1
x = torch.randn(3, 2)
ref, loaded = self._test_serialization(
"TENSOR_MATCH", foo, Inputs(x, global_func_wrong_fqn)
)
self._test_check_fn(
ref, loaded, {"inputs": Inputs(x, global_func_wrong_fqn)}, True
)
def test_c10d_work(self):
import torch.distributed as dist
if not dist.is_available():
self.skipTest("Torch distributed is not available")
Work = dist.distributed_c10d.Work
class DummyWork(Work):
def __init__(self, should_succeed=True):
super().__init__()
self._done = False
self._should_succeed = should_succeed
def is_completed(self):
return self._done
def is_success(self):
return self._should_succeed
def wait(self, timeout=None):
self._done = True
if not self._should_succeed:
raise RuntimeError("DummyWork failed")
return self
def result(self):
if not self._should_succeed:
raise RuntimeError("DummyWork failed")
return "dummy_result"
def foo(inputs):
return inputs.x + 1
x = torch.randn(3, 2)
ref, loaded = self._test_serialization(
"TENSOR_MATCH", foo, Inputs(x, DummyWork())
)
self._test_check_fn(ref, loaded, {"inputs": Inputs(x, DummyWork())}, True)
def test_unused_weakref(self):
def foo(inputs):
return inputs.x + 1
x = torch.randn(3, 2)
ref, loaded = self._test_serialization(
"TENSOR_MATCH", foo, Inputs(x, weakref.ref(x))
)
self._test_check_fn(ref, loaded, {"inputs": Inputs(x, weakref.ref(x))}, True)
def test_unused_stream(self):
if not torch.cuda.is_available():
self.skipTest("CUDA is not available")
def foo(inputs):
return inputs.x + 1
x = torch.randn(3, 2)
ref, loaded = self._test_serialization(
"TENSOR_MATCH", foo, Inputs(x, torch.cuda.Stream())
)
self._test_check_fn(
ref, loaded, {"inputs": Inputs(x, torch.cuda.Stream())}, True
)
def test_unused_process_group(self):
import torch.distributed as dist
if not dist.is_available():
self.skipTest("Torch distributed is not available")
def foo(inputs):
return inputs.x + 1
tmpfile = tempfile.NamedTemporaryFile()
dist.init_process_group(
backend="gloo",
init_method=f"file://{tmpfile.name}",
rank=0,
world_size=1,
)
try:
pg = dist.distributed_c10d._get_default_group()
x = torch.randn(3, 2)
ref, loaded = self._test_serialization("TENSOR_MATCH", foo, Inputs(x, pg))
self._test_check_fn(ref, loaded, {"inputs": Inputs(x, pg)}, True)
finally:
dist.destroy_process_group()
def test_unserializable_submodule(self):
def foo(mod, x):
return mod(x)
x = torch.randn(10, 10)
mod = GlobalNestedModule(ModuleNotSerializable())
ref, loaded = self._test_serialization("TENSOR_MATCH", foo, mod, x)
self._test_check_fn(ref, loaded, {"mod": mod, "x": x}, True)
def test_closure_var_missing(self):
captured = torch.randn(3, 2)
def bar(x):
return x + captured
def foo(f, x):
return f(x)
x = torch.randn(3, 2)
ref, loaded = self._test_serialization("TENSOR_MATCH", foo, bar, x)
self._test_check_fn(ref, loaded, {"f": bar, "x": x}, True)
def test_bound_method_patched_forward(self):
def forward(x):
return x + 1
m = FlatModule()
m_forward = m.forward
m.forward = forward
def foo(f, x):
assert callable(f)
return f(x)
x = torch.randn(3, 2)
ref, loaded = self._test_serialization("TYPE_MATCH", foo, m_forward, x)
self._test_check_fn(ref, loaded, {"f": m_forward, "x": x}, True)
def test_guard_on_key_order_with_cache(self):
def foo(x, mod):
for y in mod.d.values():
x *= y
return x
x = torch.randn(3, 2)
d = {"a": 1e9, "b": 1e-9}
ref, loaded = self._test_serialization(
"DICT_KEYS_MATCH", foo, x, ModWithDict(d)
)
self._test_check_fn(
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
)
class SimpleModule(torch.nn.Module):
def __init__(self):
def __init__(self, c):
super().__init__()
self.c = c
self.p = torch.nn.Parameter(torch.randn(3, 2))
def forward(self, x):
@ -1400,7 +1739,7 @@ if not IS_MACOS:
from torch.distributed.fsdp import fully_shard
mesh = init_device_mesh(str(torch.get_default_device()), (1,))
m = SimpleModule()
m = SimpleModule(42)
m = fully_shard(m, mesh=mesh)
inputs = distribute_tensor(torch.randn(3, 2), mesh, [Replicate()])
ref, loaded = self._test_serialization("TENSOR_MATCH", m, inputs)

View File

@ -77,6 +77,8 @@ class AOTCompiledFunction:
return self._artifacts.guard_manager.check(f_locals)
def __post_init__(self) -> None:
from .package import load_guards_state
self._artifacts.check_compatibility()
import_sources = {
@ -92,7 +94,7 @@ class AOTCompiledFunction:
)
if self._artifacts.guard_manager is None:
guards_state = pickle.loads(self._artifacts.guards_state)
guards_state = load_guards_state(self._artifacts.guards_state)
self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager(
self._artifacts.original_code,
guards_state.output_graph,

View File

@ -986,6 +986,7 @@ class GuardBuilder(GuardBuilderBase):
check_fn_manager: CheckFunctionManager,
save_guards: bool = False,
runtime_global_scope: Optional[dict[str, object]] = None,
source_get_cache: Optional[dict[str, Any]] = None,
) -> None:
self.f_code = f_code
self.id_ref = id_ref
@ -993,6 +994,7 @@ class GuardBuilder(GuardBuilderBase):
self.lookup_weakrefs = lookup_weakrefs
self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
self.runtime_global_scope = runtime_global_scope or global_scope
self.source_get_cache = source_get_cache or {}
self.scope["__builtins__"] = builtins.__dict__.copy()
for (
name,
@ -1021,6 +1023,9 @@ class GuardBuilder(GuardBuilderBase):
self.check_fn_manager: CheckFunctionManager = check_fn_manager
self.guard_tree_values: dict[int, Any] = {}
self.save_guards = save_guards
# Collect the ids of dicts which need key order guarding. source_name is
# not sufficient because for nn modules, we can have different sources
# to access the same object - self._module["param"] is same as
@ -1028,7 +1033,10 @@ class GuardBuilder(GuardBuilderBase):
self.key_order_guarded_dict_ids = set()
assert self.check_fn_manager.output_graph is not None
for source in self.check_fn_manager.output_graph.guard_on_key_order:
self.key_order_guarded_dict_ids.add(id(self.get(source.name())))
dict_obj = self.get(source.name())
if self.save_guards:
self.source_get_cache[source.name()] = dict_obj
self.key_order_guarded_dict_ids.add(id(dict_obj))
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and guard_manager and is used to
@ -1039,14 +1047,12 @@ class GuardBuilder(GuardBuilderBase):
self._cached_guard_managers: dict[str, GuardManager] = {}
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
self.object_aliasing_guard_codes: list[tuple[str, str]] = []
self.save_guards = save_guards
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
)
self.already_guarded_not_present_in_generic_dict: OrderedSet[
tuple[str, str]
] = OrderedSet()
self.guard_tree_values: dict[int, Any] = {}
def guard_on_dict_keys_and_ignore_order(
self, example_value: dict[Any, Any], guard: Guard
@ -1772,9 +1778,15 @@ class GuardBuilder(GuardBuilderBase):
# (like its type) which is what you permanently install into the
# guard code.
def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any:
if self.source_get_cache:
if name in self.source_get_cache:
return self.source_get_cache[name]
if closure_vars is None:
closure_vars = _get_closure_vars()
return eval(name, self.scope, closure_vars)
ret = eval(name, self.scope, closure_vars)
if self.save_guards and ".__closure__" in name:
self.source_get_cache[name] = ret
return ret
# Registers the usage of the source name referenced by the
# string (or stored in the Guard) as being guarded upon. It's important
@ -3071,10 +3083,23 @@ class ShapeCodeParts:
class GuardsState:
output_graph: OutputGraphGuardsState
shape_code_parts: Optional[ShapeCodeParts]
source_get_cache: Optional[dict[str, Any]] = None
class _Missing:
pass
def __init__(self, reason: Optional[str] = None) -> None:
self._reason = reason
def __repr__(self) -> str:
return f"_Missing({self._reason})"
def __str__(self) -> str:
return f"_Missing({self._reason})"
# Sometimes _Missing object is used as the callable with functools.partial,
# so we add a dummy __call__ here to bypass TypeError from partial().
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return _Missing()
@functools.cache
@ -3097,6 +3122,7 @@ class GuardsStatePickler(pickle.Pickler):
self,
guard_tree_values: dict[int, Any],
empty_values: dict[int, Any],
missing_values: dict[int, Any],
*args: Any,
**kwargs: Any,
) -> None:
@ -3105,6 +3131,7 @@ class GuardsStatePickler(pickle.Pickler):
self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
self.guard_tree_values = guard_tree_values
self.empty_values = empty_values
self.missing_values = missing_values
@classmethod
def _unpickle_module(cls, state: Any) -> torch.nn.Module:
@ -3188,10 +3215,31 @@ class GuardsStatePickler(pickle.Pickler):
original_type
]
@classmethod
def _unpickle_ddp_module(
cls, state: dict[str, Any]
) -> torch.nn.parallel.DistributedDataParallel:
ty = torch.nn.parallel.DistributedDataParallel
ddp = ty.__new__(ty)
torch.nn.Module.__setstate__(ddp, state)
return ddp
@classmethod
def _unpickle_c_op(cls, name: str) -> Any:
return getattr(torch.ops._C, name)
@classmethod
def _unpickle_bound_method(cls, func: Any, base: Any) -> Any:
return types.MethodType(func, base)
@classmethod
def _unpickle_cell(cls, val: Any) -> Any:
def _() -> Any:
return val
assert _.__closure__ is not None
return _.__closure__[0]
def reducer_override(
self, obj: Any
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
@ -3200,11 +3248,14 @@ class GuardsStatePickler(pickle.Pickler):
if id(obj) in self.empty_values:
return type(obj).__new__, (type(obj),)
if id(obj) in self.missing_values:
return _Missing, ("missing values",)
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if id(obj) not in self.guard_tree_values:
return _Missing, ()
return _Missing, ("tensor guard tree",)
if is_traceable_wrapper_subclass(obj):
# inner_data is a list of tuples of:
@ -3238,6 +3289,15 @@ class GuardsStatePickler(pickle.Pickler):
)
elif isinstance(obj, torch.nn.Module):
if id(obj) not in self.guard_tree_values:
return _Missing, ("module guard tree",)
# DDP module is a special case because it tries to restore unneeded
# data in custom __setstate__. We cannot skip ddp module because it
# is often a toplevel module.
if isinstance(obj, torch.nn.parallel.DistributedDataParallel):
return type(self)._unpickle_ddp_module, (obj.__getstate__(),)
if type(obj).__qualname__ == type(obj).__name__:
return NotImplemented
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
@ -3279,20 +3339,37 @@ class GuardsStatePickler(pickle.Pickler):
and obj.__class__.__name__ == "PyCapsule"
):
# Skipping PyCapsule since there isn't much to be guarded about them.
return _Missing, ()
return _Missing, ("capsule",)
elif isinstance(obj, _get_unsupported_types()):
return _Missing, ()
return _Missing, ("unsupported",)
elif inspect.isfunction(obj):
if obj.__code__.co_flags & inspect.CO_NESTED:
return _Missing, ()
return _Missing, ("nested function",)
if obj.__module__ in sys.modules:
f = sys.modules[obj.__module__]
for name in obj.__qualname__.split("."):
f = getattr(f, name, None) # type: ignore[assignment]
if f is not obj:
return _Missing, ()
return _Missing, ("fqn mismatch",)
elif inspect.ismethod(obj):
func = obj.__func__
method_self = obj.__self__
inner_func = getattr(method_self, func.__name__)
if inspect.ismethod(inner_func):
inner_func = inner_func.__func__
if func is not inner_func:
return type(self)._unpickle_bound_method, (func, method_self)
elif isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
return type(self)._unpickle_cell, (obj.cell_contents,)
if hasattr(torch.distributed, "distributed_c10d") and isinstance(
obj, torch.distributed.distributed_c10d.Work
):
if id(obj) not in self.guard_tree_values:
return _Missing, ("distributed_c10d.Work",)
if type(obj).__qualname__ != type(obj).__name__:
raise torch._dynamo.exc.PackageError(
@ -3301,12 +3378,6 @@ class GuardsStatePickler(pickle.Pickler):
+ "Please define the class at global scope (top level of a module)."
)
if hasattr(torch.distributed, "distributed_c10d") and isinstance(
obj, torch.distributed.distributed_c10d.Work
):
if id(obj) not in self.guard_tree_values:
return _Missing, ()
if (
inspect.isclass(obj)
and hasattr(torch.distributed, "fsdp")
@ -3327,6 +3398,7 @@ class GuardsStatePickler(pickle.Pickler):
def pickle_guards_state(state: GuardsState, guard_tree_values: dict[int, Any]) -> bytes:
buf = io.BytesIO()
empty_values = {}
missing_values = {}
leaves = pytree.tree_leaves(state.output_graph.local_scope)
for leaf in leaves:
@ -3338,7 +3410,11 @@ def pickle_guards_state(state: GuardsState, guard_tree_values: dict[int, Any]) -
empty_values[id(base)] = base
except: # noqa: E722, B001
pass
pickler = GuardsStatePickler(guard_tree_values, empty_values, buf)
elif id(leaf) not in guard_tree_values:
# TODO See if we have lift this branch as the first one.
# Prune more objects in pytree hierarchy.
missing_values[id(leaf)] = leaf
pickler = GuardsStatePickler(guard_tree_values, empty_values, missing_values, buf)
try:
pickler.dump(state)
except AttributeError as e:
@ -3365,6 +3441,7 @@ class CheckFunctionManager:
runtime_global_scope: Optional[dict[str, Any]] = None,
save_guards: bool = False,
strict_error: bool = False,
source_get_cache: Optional[dict[str, Any]] = None,
):
guards = output_graph.guards if output_graph else None
self._weakrefs: dict[int, ReferenceType[object]] = {}
@ -3427,7 +3504,12 @@ class CheckFunctionManager:
# If we're filtering guards, we need to build it an extra time first
# because filtering depends on the builder/guard_manager results
builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph, False
sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
False,
source_get_cache=source_get_cache,
)
def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry:
@ -3476,6 +3558,7 @@ class CheckFunctionManager:
f_code,
output_graph,
save_guards,
source_get_cache=source_get_cache,
)
self.guard_manager = guard_manager
@ -3696,6 +3779,7 @@ class CheckFunctionManager:
guards_state = GuardsState(
output_graph=output_graph_guards_state,
shape_code_parts=self.shape_code_parts,
source_get_cache=builder.source_get_cache,
)
return pickle_guards_state(guards_state, builder.guard_tree_values)
@ -3707,6 +3791,7 @@ class CheckFunctionManager:
f_code: types.CodeType,
output_graph: OutputGraphGuardsState,
save_guards: bool,
source_get_cache: Optional[dict[str, Any]] = None,
) -> tuple[GuardBuilder, GuardManagerWrapper]:
guard_manager = GuardManagerWrapper()
guard_manager.diff_guard_sources = existing_diff_guard_sources
@ -3734,6 +3819,7 @@ class CheckFunctionManager:
self,
save_guards,
runtime_global_scope=self.runtime_global_scope,
source_get_cache=source_get_cache,
)
# Break retain cycle. See test_release_scope_memory

View File

@ -25,6 +25,7 @@ import shutil
import sys
import types
from collections.abc import Generator, Iterator
from contextlib import nullcontext
from typing import Any, Callable, NewType, Optional
from typing_extensions import Never
@ -97,6 +98,17 @@ class _GuardedCodeCacheEntry:
dynamo_code: SerializedCode
def load_guards_state(guards_state: bytes) -> Any:
try:
import torch.distributed.fsdp._fully_shard._fully_shard as _fully_shard
ctx = _fully_shard.disable_fsdp_module_new_init()
except ImportError:
ctx = nullcontext() # type: ignore[assignment]
with ctx:
return pickle.loads(guards_state)
_BackendId = NewType("_BackendId", str) # __compiled_fn
_FunctionId = NewType("_FunctionId", str) # __resume_at
@ -784,7 +796,7 @@ class CompilePackage:
torch._dynamo.eval_frame.skip_code(target_code)
for guarded_code in entry.guarded_codes:
guards_state = pickle.loads(guarded_code.guards_state)
guards_state = load_guards_state(guarded_code.guards_state)
runtime_global_scope = sys.modules[entry.python_module].__dict__
# The installed builtins dict might be absent from the runtime
# while loading guards. Populate it if it's missing.
@ -805,6 +817,7 @@ class CompilePackage:
OutputGraphCommon(guards_state.output_graph),
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=runtime_global_scope,
source_get_cache=guards_state.source_get_cache,
)
_load_precompile_entry(
target_code,

View File

@ -4,6 +4,7 @@
from __future__ import annotations
import functools
from contextlib import contextmanager
from typing import Any, cast, NoReturn, Optional, overload, TYPE_CHECKING, Union
from typing_extensions import deprecated
@ -27,7 +28,7 @@ from ._fsdp_state import _get_module_fsdp_state, FSDPState
if TYPE_CHECKING:
from collections.abc import Callable, Iterable
from collections.abc import Callable, Iterable, Iterator
from torch.distributed.tensor import DeviceMesh, Shard
@ -37,6 +38,7 @@ __all__ = [
"UnshardHandle",
"register_fsdp_forward_method",
"get_cls_to_fsdp_cls",
"disable_fsdp_module_new_init",
]
@ -252,6 +254,19 @@ def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
)
_enable_fsdp_module_new_init: bool = True
@contextmanager
def disable_fsdp_module_new_init() -> Iterator[None]:
global _enable_fsdp_module_new_init
prev, _enable_fsdp_module_new_init = _enable_fsdp_module_new_init, False
try:
yield
finally:
_enable_fsdp_module_new_init = prev
class FSDPModule:
def __new__(cls, *args, **kwargs):
"""
@ -262,7 +277,8 @@ class FSDPModule:
# and index 1 is the `FSDPModule` class itself
orig_cls = cls.__mro__[2]
self = orig_cls.__new__(orig_cls, *args, **kwargs)
self.__init__(*args, **kwargs)
if _enable_fsdp_module_new_init:
self.__init__(*args, **kwargs)
return self
def reshard(self) -> None: