mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[precompile] Fix guard serialization loading bugs. (#164490)
Summary: Added a set of fixes triggered by fm training job. Overall the theme here is that we should get rid of saved objects as much as possible when they are not used in guard reconstruction. Sometimes for objects that cannot be saved (like local functions) we still try our best to save their closures. Test Plan: test_guard_serialization.py test_lazy_awatiable.py Differential Revision: D83766926 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164490 Approved by: https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
3c59351c6e
commit
16f9bef642
@ -4,8 +4,10 @@ import dataclasses
|
||||
import importlib
|
||||
import pickle
|
||||
import sys
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
import weakref
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -18,6 +20,7 @@ import torch.utils.cpp_extension
|
||||
from torch._dynamo.bytecode_transformation import transform_code_object
|
||||
from torch._dynamo.exc import PackageError
|
||||
from torch._dynamo.guards import CheckFunctionManager, CompileId
|
||||
from torch._dynamo.package import CompilePackage
|
||||
from torch._dynamo.symbolic_convert import (
|
||||
ExceptionStack,
|
||||
InstructionTranslator,
|
||||
@ -44,10 +47,33 @@ class GlobalModule(torch.nn.Module):
|
||||
return x + 1
|
||||
|
||||
|
||||
class GlobalNestedModule(torch.nn.Module):
|
||||
def __init__(self, submodule=None):
|
||||
super().__init__()
|
||||
self.linear = torch.nn.Linear(10, 10)
|
||||
self.param = torch.nn.Parameter(torch.randn(3, 2))
|
||||
self.nested = submodule or GlobalModule()
|
||||
|
||||
def forward(self, x):
|
||||
return self.linear(x) + 1
|
||||
|
||||
|
||||
def global_func(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
class ModuleNotSerializable(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.param = torch.nn.Parameter(torch.randn(3, 2))
|
||||
|
||||
def __getstate__(self):
|
||||
raise NotImplementedError("not serialzable")
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.param
|
||||
|
||||
|
||||
class GlobalTorchFunctionMode(TorchFunctionMode):
|
||||
def __torch_function__(self, func, types, args=(), kwargs=None):
|
||||
if kwargs is None:
|
||||
@ -63,6 +89,39 @@ class MyClass:
|
||||
return x + 1
|
||||
|
||||
|
||||
class MyClassNotSerializable:
|
||||
def __getstate__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, x):
|
||||
return x + 1
|
||||
|
||||
|
||||
class Inputs:
|
||||
def __init__(self, x, unused):
|
||||
self.x = x
|
||||
self.unused = unused
|
||||
|
||||
|
||||
def _global_func_wrong_fqn(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
global_func_wrong_fqn = _global_func_wrong_fqn
|
||||
del _global_func_wrong_fqn
|
||||
|
||||
|
||||
class FlatModule(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x + 2
|
||||
|
||||
|
||||
class ModWithDict(torch.nn.Module):
|
||||
def __init__(self, d):
|
||||
super().__init__()
|
||||
self.d = d
|
||||
|
||||
|
||||
class SubclassWithMeta(torch.Tensor):
|
||||
@staticmethod
|
||||
def __new__(cls, a, extra, outer_size=None, outer_stride=None):
|
||||
@ -355,13 +414,14 @@ class TestGuardSerializationBase(torch._inductor.test_case.TestCase):
|
||||
self._cached_guards_state = guards_state
|
||||
self._cached_f_code = self._frame_state.f_code
|
||||
self.assertIsNotNone(guards_state)
|
||||
guards_state = pickle.loads(guards_state)
|
||||
guards_state = torch._dynamo.package.load_guards_state(guards_state)
|
||||
|
||||
check_fn_manager = CheckFunctionManager(
|
||||
self._frame_state.f_code,
|
||||
guards_state.output_graph,
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
runtime_global_scope=self._frame_state.f_globals,
|
||||
source_get_cache=guards_state.source_get_cache,
|
||||
)
|
||||
loaded_gm = check_fn_manager.guard_manager
|
||||
|
||||
@ -1372,10 +1432,289 @@ class TestGuardSerialization(TestGuardSerializationBase):
|
||||
ref, loaded, {"self": m, "foo": MyClass().add, "x": torch.randn(3, 2)}, True
|
||||
)
|
||||
|
||||
def test_bound_methods_missing(self):
|
||||
class MyClass:
|
||||
def __getstate__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, x):
|
||||
return x + 1
|
||||
|
||||
def foo(x: torch.Tensor, y: list[MyClass]):
|
||||
assert len(y) == 1
|
||||
return x + 1
|
||||
|
||||
ref, loaded = self._test_serialization(
|
||||
"TYPE_MATCH", foo, torch.randn(3, 2), [MyClass()]
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"x": torch.randn(3, 2), "y": [MyClass()]}, True
|
||||
)
|
||||
|
||||
def test_bound_methods_empty(self):
|
||||
def foo(x, y):
|
||||
assert callable(y[0])
|
||||
return x + 1
|
||||
|
||||
ref, loaded = self._test_serialization(
|
||||
"TYPE_MATCH", foo, torch.randn(3, 2), [MyClassNotSerializable().add]
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref,
|
||||
loaded,
|
||||
{"x": torch.randn(3, 2), "y": [MyClassNotSerializable().add]},
|
||||
True,
|
||||
)
|
||||
|
||||
def test_ddp_module(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
self.skipTest("Torch distributed is not available")
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
tmpfile = tempfile.NamedTemporaryFile()
|
||||
dist.init_process_group(
|
||||
backend="gloo", rank=0, world_size=1, init_method=f"file://{tmpfile.name}"
|
||||
)
|
||||
try:
|
||||
ddp_model = DDP(GlobalNestedModule())
|
||||
|
||||
def foo(ddp, x):
|
||||
return ddp(x)
|
||||
|
||||
x = torch.randn(10)
|
||||
package = CompilePackage(foo)
|
||||
torch._dynamo.optimize(
|
||||
package=package,
|
||||
guard_filter_fn=lambda gs: [
|
||||
x.guard_type not in ("CLOSURE_MATCH", "ID_MATCH") for x in gs
|
||||
],
|
||||
)(foo)(ddp_model, x)
|
||||
self.assertEqual(len(package._codes[foo.__code__].guarded_codes), 1)
|
||||
torch._dynamo.package.load_guards_state(
|
||||
package._codes[foo.__code__].guarded_codes[0].guards_state
|
||||
)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_dict_keys_serialization(self):
|
||||
d = {1: 2, 3: 4}
|
||||
|
||||
def foo(x, y):
|
||||
for k in y:
|
||||
x += k
|
||||
return x
|
||||
|
||||
ref, loaded = self._test_serialization(
|
||||
"TYPE_MATCH", foo, torch.randn(3, 2), d.keys()
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref,
|
||||
loaded,
|
||||
{"x": torch.randn(3, 2), "y": d.keys()},
|
||||
True,
|
||||
)
|
||||
|
||||
def test_unserializable_sharded_tensor(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
self.skipTest("Torch distributed is not available")
|
||||
|
||||
tmpfile = tempfile.NamedTemporaryFile()
|
||||
dist.init_process_group(
|
||||
backend="gloo", rank=0, world_size=1, init_method=f"file://{tmpfile.name}"
|
||||
)
|
||||
try:
|
||||
ChunkShardingSpec = dist._shard.sharding_spec.ChunkShardingSpec
|
||||
ShardedTensor = dist._shard.sharded_tensor.ShardedTensor
|
||||
tensor = torch.arange(2, dtype=torch.int64)
|
||||
local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]), 0)
|
||||
|
||||
sharding_dim = 0
|
||||
sharding_spec = ChunkShardingSpec(
|
||||
dim=sharding_dim,
|
||||
placements=[
|
||||
"rank:0/cpu",
|
||||
],
|
||||
)
|
||||
st = ShardedTensor._init_from_local_tensor(
|
||||
local_tensor, sharding_spec, [1, 4]
|
||||
)
|
||||
|
||||
def foo(inputs):
|
||||
return inputs.x + 1
|
||||
|
||||
ref, loaded = self._test_serialization(
|
||||
"TENSOR_MATCH", foo, Inputs(torch.randn(3, 2), st)
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"inputs": Inputs(torch.randn(3, 2), st)}, True
|
||||
)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_function_with_wrong_fqn(self):
|
||||
def foo(inputs):
|
||||
return inputs.x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization(
|
||||
"TENSOR_MATCH", foo, Inputs(x, global_func_wrong_fqn)
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"inputs": Inputs(x, global_func_wrong_fqn)}, True
|
||||
)
|
||||
|
||||
def test_c10d_work(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
self.skipTest("Torch distributed is not available")
|
||||
|
||||
Work = dist.distributed_c10d.Work
|
||||
|
||||
class DummyWork(Work):
|
||||
def __init__(self, should_succeed=True):
|
||||
super().__init__()
|
||||
self._done = False
|
||||
self._should_succeed = should_succeed
|
||||
|
||||
def is_completed(self):
|
||||
return self._done
|
||||
|
||||
def is_success(self):
|
||||
return self._should_succeed
|
||||
|
||||
def wait(self, timeout=None):
|
||||
self._done = True
|
||||
if not self._should_succeed:
|
||||
raise RuntimeError("DummyWork failed")
|
||||
return self
|
||||
|
||||
def result(self):
|
||||
if not self._should_succeed:
|
||||
raise RuntimeError("DummyWork failed")
|
||||
return "dummy_result"
|
||||
|
||||
def foo(inputs):
|
||||
return inputs.x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization(
|
||||
"TENSOR_MATCH", foo, Inputs(x, DummyWork())
|
||||
)
|
||||
self._test_check_fn(ref, loaded, {"inputs": Inputs(x, DummyWork())}, True)
|
||||
|
||||
def test_unused_weakref(self):
|
||||
def foo(inputs):
|
||||
return inputs.x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization(
|
||||
"TENSOR_MATCH", foo, Inputs(x, weakref.ref(x))
|
||||
)
|
||||
self._test_check_fn(ref, loaded, {"inputs": Inputs(x, weakref.ref(x))}, True)
|
||||
|
||||
def test_unused_stream(self):
|
||||
if not torch.cuda.is_available():
|
||||
self.skipTest("CUDA is not available")
|
||||
|
||||
def foo(inputs):
|
||||
return inputs.x + 1
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization(
|
||||
"TENSOR_MATCH", foo, Inputs(x, torch.cuda.Stream())
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"inputs": Inputs(x, torch.cuda.Stream())}, True
|
||||
)
|
||||
|
||||
def test_unused_process_group(self):
|
||||
import torch.distributed as dist
|
||||
|
||||
if not dist.is_available():
|
||||
self.skipTest("Torch distributed is not available")
|
||||
|
||||
def foo(inputs):
|
||||
return inputs.x + 1
|
||||
|
||||
tmpfile = tempfile.NamedTemporaryFile()
|
||||
dist.init_process_group(
|
||||
backend="gloo",
|
||||
init_method=f"file://{tmpfile.name}",
|
||||
rank=0,
|
||||
world_size=1,
|
||||
)
|
||||
|
||||
try:
|
||||
pg = dist.distributed_c10d._get_default_group()
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization("TENSOR_MATCH", foo, Inputs(x, pg))
|
||||
self._test_check_fn(ref, loaded, {"inputs": Inputs(x, pg)}, True)
|
||||
finally:
|
||||
dist.destroy_process_group()
|
||||
|
||||
def test_unserializable_submodule(self):
|
||||
def foo(mod, x):
|
||||
return mod(x)
|
||||
|
||||
x = torch.randn(10, 10)
|
||||
mod = GlobalNestedModule(ModuleNotSerializable())
|
||||
ref, loaded = self._test_serialization("TENSOR_MATCH", foo, mod, x)
|
||||
self._test_check_fn(ref, loaded, {"mod": mod, "x": x}, True)
|
||||
|
||||
def test_closure_var_missing(self):
|
||||
captured = torch.randn(3, 2)
|
||||
|
||||
def bar(x):
|
||||
return x + captured
|
||||
|
||||
def foo(f, x):
|
||||
return f(x)
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization("TENSOR_MATCH", foo, bar, x)
|
||||
self._test_check_fn(ref, loaded, {"f": bar, "x": x}, True)
|
||||
|
||||
def test_bound_method_patched_forward(self):
|
||||
def forward(x):
|
||||
return x + 1
|
||||
|
||||
m = FlatModule()
|
||||
m_forward = m.forward
|
||||
m.forward = forward
|
||||
|
||||
def foo(f, x):
|
||||
assert callable(f)
|
||||
return f(x)
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
ref, loaded = self._test_serialization("TYPE_MATCH", foo, m_forward, x)
|
||||
self._test_check_fn(ref, loaded, {"f": m_forward, "x": x}, True)
|
||||
|
||||
def test_guard_on_key_order_with_cache(self):
|
||||
def foo(x, mod):
|
||||
for y in mod.d.values():
|
||||
x *= y
|
||||
return x
|
||||
|
||||
x = torch.randn(3, 2)
|
||||
d = {"a": 1e9, "b": 1e-9}
|
||||
ref, loaded = self._test_serialization(
|
||||
"DICT_KEYS_MATCH", foo, x, ModWithDict(d)
|
||||
)
|
||||
self._test_check_fn(
|
||||
ref, loaded, {"x": x, "d": ModWithDict({"b": 1e-9, "a": 1e9})}, False
|
||||
)
|
||||
|
||||
|
||||
class SimpleModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
def __init__(self, c):
|
||||
super().__init__()
|
||||
self.c = c
|
||||
self.p = torch.nn.Parameter(torch.randn(3, 2))
|
||||
|
||||
def forward(self, x):
|
||||
@ -1400,7 +1739,7 @@ if not IS_MACOS:
|
||||
from torch.distributed.fsdp import fully_shard
|
||||
|
||||
mesh = init_device_mesh(str(torch.get_default_device()), (1,))
|
||||
m = SimpleModule()
|
||||
m = SimpleModule(42)
|
||||
m = fully_shard(m, mesh=mesh)
|
||||
inputs = distribute_tensor(torch.randn(3, 2), mesh, [Replicate()])
|
||||
ref, loaded = self._test_serialization("TENSOR_MATCH", m, inputs)
|
||||
|
@ -77,6 +77,8 @@ class AOTCompiledFunction:
|
||||
return self._artifacts.guard_manager.check(f_locals)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from .package import load_guards_state
|
||||
|
||||
self._artifacts.check_compatibility()
|
||||
|
||||
import_sources = {
|
||||
@ -92,7 +94,7 @@ class AOTCompiledFunction:
|
||||
)
|
||||
|
||||
if self._artifacts.guard_manager is None:
|
||||
guards_state = pickle.loads(self._artifacts.guards_state)
|
||||
guards_state = load_guards_state(self._artifacts.guards_state)
|
||||
self._artifacts.guard_manager = torch._dynamo.guards.CheckFunctionManager(
|
||||
self._artifacts.original_code,
|
||||
guards_state.output_graph,
|
||||
|
@ -986,6 +986,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
check_fn_manager: CheckFunctionManager,
|
||||
save_guards: bool = False,
|
||||
runtime_global_scope: Optional[dict[str, object]] = None,
|
||||
source_get_cache: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
self.f_code = f_code
|
||||
self.id_ref = id_ref
|
||||
@ -993,6 +994,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.lookup_weakrefs = lookup_weakrefs
|
||||
self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
|
||||
self.runtime_global_scope = runtime_global_scope or global_scope
|
||||
self.source_get_cache = source_get_cache or {}
|
||||
self.scope["__builtins__"] = builtins.__dict__.copy()
|
||||
for (
|
||||
name,
|
||||
@ -1021,6 +1023,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
|
||||
self.check_fn_manager: CheckFunctionManager = check_fn_manager
|
||||
|
||||
self.guard_tree_values: dict[int, Any] = {}
|
||||
self.save_guards = save_guards
|
||||
|
||||
# Collect the ids of dicts which need key order guarding. source_name is
|
||||
# not sufficient because for nn modules, we can have different sources
|
||||
# to access the same object - self._module["param"] is same as
|
||||
@ -1028,7 +1033,10 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.key_order_guarded_dict_ids = set()
|
||||
assert self.check_fn_manager.output_graph is not None
|
||||
for source in self.check_fn_manager.output_graph.guard_on_key_order:
|
||||
self.key_order_guarded_dict_ids.add(id(self.get(source.name())))
|
||||
dict_obj = self.get(source.name())
|
||||
if self.save_guards:
|
||||
self.source_get_cache[source.name()] = dict_obj
|
||||
self.key_order_guarded_dict_ids.add(id(dict_obj))
|
||||
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
# info is stored alongside optimized_code and guard_manager and is used to
|
||||
@ -1039,14 +1047,12 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self._cached_guard_managers: dict[str, GuardManager] = {}
|
||||
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
|
||||
self.object_aliasing_guard_codes: list[tuple[str, str]] = []
|
||||
self.save_guards = save_guards
|
||||
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
|
||||
"pytorch/compiler:guard_nn_modules"
|
||||
)
|
||||
self.already_guarded_not_present_in_generic_dict: OrderedSet[
|
||||
tuple[str, str]
|
||||
] = OrderedSet()
|
||||
self.guard_tree_values: dict[int, Any] = {}
|
||||
|
||||
def guard_on_dict_keys_and_ignore_order(
|
||||
self, example_value: dict[Any, Any], guard: Guard
|
||||
@ -1772,9 +1778,15 @@ class GuardBuilder(GuardBuilderBase):
|
||||
# (like its type) which is what you permanently install into the
|
||||
# guard code.
|
||||
def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any:
|
||||
if self.source_get_cache:
|
||||
if name in self.source_get_cache:
|
||||
return self.source_get_cache[name]
|
||||
if closure_vars is None:
|
||||
closure_vars = _get_closure_vars()
|
||||
return eval(name, self.scope, closure_vars)
|
||||
ret = eval(name, self.scope, closure_vars)
|
||||
if self.save_guards and ".__closure__" in name:
|
||||
self.source_get_cache[name] = ret
|
||||
return ret
|
||||
|
||||
# Registers the usage of the source name referenced by the
|
||||
# string (or stored in the Guard) as being guarded upon. It's important
|
||||
@ -3071,10 +3083,23 @@ class ShapeCodeParts:
|
||||
class GuardsState:
|
||||
output_graph: OutputGraphGuardsState
|
||||
shape_code_parts: Optional[ShapeCodeParts]
|
||||
source_get_cache: Optional[dict[str, Any]] = None
|
||||
|
||||
|
||||
class _Missing:
|
||||
pass
|
||||
def __init__(self, reason: Optional[str] = None) -> None:
|
||||
self._reason = reason
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"_Missing({self._reason})"
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"_Missing({self._reason})"
|
||||
|
||||
# Sometimes _Missing object is used as the callable with functools.partial,
|
||||
# so we add a dummy __call__ here to bypass TypeError from partial().
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return _Missing()
|
||||
|
||||
|
||||
@functools.cache
|
||||
@ -3097,6 +3122,7 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
self,
|
||||
guard_tree_values: dict[int, Any],
|
||||
empty_values: dict[int, Any],
|
||||
missing_values: dict[int, Any],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
@ -3105,6 +3131,7 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
|
||||
self.guard_tree_values = guard_tree_values
|
||||
self.empty_values = empty_values
|
||||
self.missing_values = missing_values
|
||||
|
||||
@classmethod
|
||||
def _unpickle_module(cls, state: Any) -> torch.nn.Module:
|
||||
@ -3188,10 +3215,31 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
original_type
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _unpickle_ddp_module(
|
||||
cls, state: dict[str, Any]
|
||||
) -> torch.nn.parallel.DistributedDataParallel:
|
||||
ty = torch.nn.parallel.DistributedDataParallel
|
||||
ddp = ty.__new__(ty)
|
||||
torch.nn.Module.__setstate__(ddp, state)
|
||||
return ddp
|
||||
|
||||
@classmethod
|
||||
def _unpickle_c_op(cls, name: str) -> Any:
|
||||
return getattr(torch.ops._C, name)
|
||||
|
||||
@classmethod
|
||||
def _unpickle_bound_method(cls, func: Any, base: Any) -> Any:
|
||||
return types.MethodType(func, base)
|
||||
|
||||
@classmethod
|
||||
def _unpickle_cell(cls, val: Any) -> Any:
|
||||
def _() -> Any:
|
||||
return val
|
||||
|
||||
assert _.__closure__ is not None
|
||||
return _.__closure__[0]
|
||||
|
||||
def reducer_override(
|
||||
self, obj: Any
|
||||
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
|
||||
@ -3200,11 +3248,14 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
if id(obj) in self.empty_values:
|
||||
return type(obj).__new__, (type(obj),)
|
||||
|
||||
if id(obj) in self.missing_values:
|
||||
return _Missing, ("missing values",)
|
||||
|
||||
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
if id(obj) not in self.guard_tree_values:
|
||||
return _Missing, ()
|
||||
return _Missing, ("tensor guard tree",)
|
||||
|
||||
if is_traceable_wrapper_subclass(obj):
|
||||
# inner_data is a list of tuples of:
|
||||
@ -3238,6 +3289,15 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
)
|
||||
|
||||
elif isinstance(obj, torch.nn.Module):
|
||||
if id(obj) not in self.guard_tree_values:
|
||||
return _Missing, ("module guard tree",)
|
||||
|
||||
# DDP module is a special case because it tries to restore unneeded
|
||||
# data in custom __setstate__. We cannot skip ddp module because it
|
||||
# is often a toplevel module.
|
||||
if isinstance(obj, torch.nn.parallel.DistributedDataParallel):
|
||||
return type(self)._unpickle_ddp_module, (obj.__getstate__(),)
|
||||
|
||||
if type(obj).__qualname__ == type(obj).__name__:
|
||||
return NotImplemented
|
||||
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
|
||||
@ -3279,20 +3339,37 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
and obj.__class__.__name__ == "PyCapsule"
|
||||
):
|
||||
# Skipping PyCapsule since there isn't much to be guarded about them.
|
||||
return _Missing, ()
|
||||
return _Missing, ("capsule",)
|
||||
|
||||
elif isinstance(obj, _get_unsupported_types()):
|
||||
return _Missing, ()
|
||||
return _Missing, ("unsupported",)
|
||||
|
||||
elif inspect.isfunction(obj):
|
||||
if obj.__code__.co_flags & inspect.CO_NESTED:
|
||||
return _Missing, ()
|
||||
return _Missing, ("nested function",)
|
||||
if obj.__module__ in sys.modules:
|
||||
f = sys.modules[obj.__module__]
|
||||
for name in obj.__qualname__.split("."):
|
||||
f = getattr(f, name, None) # type: ignore[assignment]
|
||||
if f is not obj:
|
||||
return _Missing, ()
|
||||
return _Missing, ("fqn mismatch",)
|
||||
elif inspect.ismethod(obj):
|
||||
func = obj.__func__
|
||||
method_self = obj.__self__
|
||||
inner_func = getattr(method_self, func.__name__)
|
||||
if inspect.ismethod(inner_func):
|
||||
inner_func = inner_func.__func__
|
||||
if func is not inner_func:
|
||||
return type(self)._unpickle_bound_method, (func, method_self)
|
||||
|
||||
elif isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
|
||||
return type(self)._unpickle_cell, (obj.cell_contents,)
|
||||
|
||||
if hasattr(torch.distributed, "distributed_c10d") and isinstance(
|
||||
obj, torch.distributed.distributed_c10d.Work
|
||||
):
|
||||
if id(obj) not in self.guard_tree_values:
|
||||
return _Missing, ("distributed_c10d.Work",)
|
||||
|
||||
if type(obj).__qualname__ != type(obj).__name__:
|
||||
raise torch._dynamo.exc.PackageError(
|
||||
@ -3301,12 +3378,6 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
+ "Please define the class at global scope (top level of a module)."
|
||||
)
|
||||
|
||||
if hasattr(torch.distributed, "distributed_c10d") and isinstance(
|
||||
obj, torch.distributed.distributed_c10d.Work
|
||||
):
|
||||
if id(obj) not in self.guard_tree_values:
|
||||
return _Missing, ()
|
||||
|
||||
if (
|
||||
inspect.isclass(obj)
|
||||
and hasattr(torch.distributed, "fsdp")
|
||||
@ -3327,6 +3398,7 @@ class GuardsStatePickler(pickle.Pickler):
|
||||
def pickle_guards_state(state: GuardsState, guard_tree_values: dict[int, Any]) -> bytes:
|
||||
buf = io.BytesIO()
|
||||
empty_values = {}
|
||||
missing_values = {}
|
||||
|
||||
leaves = pytree.tree_leaves(state.output_graph.local_scope)
|
||||
for leaf in leaves:
|
||||
@ -3338,7 +3410,11 @@ def pickle_guards_state(state: GuardsState, guard_tree_values: dict[int, Any]) -
|
||||
empty_values[id(base)] = base
|
||||
except: # noqa: E722, B001
|
||||
pass
|
||||
pickler = GuardsStatePickler(guard_tree_values, empty_values, buf)
|
||||
elif id(leaf) not in guard_tree_values:
|
||||
# TODO See if we have lift this branch as the first one.
|
||||
# Prune more objects in pytree hierarchy.
|
||||
missing_values[id(leaf)] = leaf
|
||||
pickler = GuardsStatePickler(guard_tree_values, empty_values, missing_values, buf)
|
||||
try:
|
||||
pickler.dump(state)
|
||||
except AttributeError as e:
|
||||
@ -3365,6 +3441,7 @@ class CheckFunctionManager:
|
||||
runtime_global_scope: Optional[dict[str, Any]] = None,
|
||||
save_guards: bool = False,
|
||||
strict_error: bool = False,
|
||||
source_get_cache: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self._weakrefs: dict[int, ReferenceType[object]] = {}
|
||||
@ -3427,7 +3504,12 @@ class CheckFunctionManager:
|
||||
# If we're filtering guards, we need to build it an extra time first
|
||||
# because filtering depends on the builder/guard_manager results
|
||||
builder, guard_manager = self.build_guards(
|
||||
sorted_guards, existing_diff_guard_sources, f_code, output_graph, False
|
||||
sorted_guards,
|
||||
existing_diff_guard_sources,
|
||||
f_code,
|
||||
output_graph,
|
||||
False,
|
||||
source_get_cache=source_get_cache,
|
||||
)
|
||||
|
||||
def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry:
|
||||
@ -3476,6 +3558,7 @@ class CheckFunctionManager:
|
||||
f_code,
|
||||
output_graph,
|
||||
save_guards,
|
||||
source_get_cache=source_get_cache,
|
||||
)
|
||||
|
||||
self.guard_manager = guard_manager
|
||||
@ -3696,6 +3779,7 @@ class CheckFunctionManager:
|
||||
guards_state = GuardsState(
|
||||
output_graph=output_graph_guards_state,
|
||||
shape_code_parts=self.shape_code_parts,
|
||||
source_get_cache=builder.source_get_cache,
|
||||
)
|
||||
|
||||
return pickle_guards_state(guards_state, builder.guard_tree_values)
|
||||
@ -3707,6 +3791,7 @@ class CheckFunctionManager:
|
||||
f_code: types.CodeType,
|
||||
output_graph: OutputGraphGuardsState,
|
||||
save_guards: bool,
|
||||
source_get_cache: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[GuardBuilder, GuardManagerWrapper]:
|
||||
guard_manager = GuardManagerWrapper()
|
||||
guard_manager.diff_guard_sources = existing_diff_guard_sources
|
||||
@ -3734,6 +3819,7 @@ class CheckFunctionManager:
|
||||
self,
|
||||
save_guards,
|
||||
runtime_global_scope=self.runtime_global_scope,
|
||||
source_get_cache=source_get_cache,
|
||||
)
|
||||
|
||||
# Break retain cycle. See test_release_scope_memory
|
||||
|
@ -25,6 +25,7 @@ import shutil
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Generator, Iterator
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, NewType, Optional
|
||||
from typing_extensions import Never
|
||||
|
||||
@ -97,6 +98,17 @@ class _GuardedCodeCacheEntry:
|
||||
dynamo_code: SerializedCode
|
||||
|
||||
|
||||
def load_guards_state(guards_state: bytes) -> Any:
|
||||
try:
|
||||
import torch.distributed.fsdp._fully_shard._fully_shard as _fully_shard
|
||||
|
||||
ctx = _fully_shard.disable_fsdp_module_new_init()
|
||||
except ImportError:
|
||||
ctx = nullcontext() # type: ignore[assignment]
|
||||
with ctx:
|
||||
return pickle.loads(guards_state)
|
||||
|
||||
|
||||
_BackendId = NewType("_BackendId", str) # __compiled_fn
|
||||
_FunctionId = NewType("_FunctionId", str) # __resume_at
|
||||
|
||||
@ -784,7 +796,7 @@ class CompilePackage:
|
||||
torch._dynamo.eval_frame.skip_code(target_code)
|
||||
|
||||
for guarded_code in entry.guarded_codes:
|
||||
guards_state = pickle.loads(guarded_code.guards_state)
|
||||
guards_state = load_guards_state(guarded_code.guards_state)
|
||||
runtime_global_scope = sys.modules[entry.python_module].__dict__
|
||||
# The installed builtins dict might be absent from the runtime
|
||||
# while loading guards. Populate it if it's missing.
|
||||
@ -805,6 +817,7 @@ class CompilePackage:
|
||||
OutputGraphCommon(guards_state.output_graph),
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
runtime_global_scope=runtime_global_scope,
|
||||
source_get_cache=guards_state.source_get_cache,
|
||||
)
|
||||
_load_precompile_entry(
|
||||
target_code,
|
||||
|
@ -4,6 +4,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, cast, NoReturn, Optional, overload, TYPE_CHECKING, Union
|
||||
from typing_extensions import deprecated
|
||||
|
||||
@ -27,7 +28,7 @@ from ._fsdp_state import _get_module_fsdp_state, FSDPState
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable
|
||||
from collections.abc import Callable, Iterable, Iterator
|
||||
|
||||
from torch.distributed.tensor import DeviceMesh, Shard
|
||||
|
||||
@ -37,6 +38,7 @@ __all__ = [
|
||||
"UnshardHandle",
|
||||
"register_fsdp_forward_method",
|
||||
"get_cls_to_fsdp_cls",
|
||||
"disable_fsdp_module_new_init",
|
||||
]
|
||||
|
||||
|
||||
@ -252,6 +254,19 @@ def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn:
|
||||
)
|
||||
|
||||
|
||||
_enable_fsdp_module_new_init: bool = True
|
||||
|
||||
|
||||
@contextmanager
|
||||
def disable_fsdp_module_new_init() -> Iterator[None]:
|
||||
global _enable_fsdp_module_new_init
|
||||
prev, _enable_fsdp_module_new_init = _enable_fsdp_module_new_init, False
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_enable_fsdp_module_new_init = prev
|
||||
|
||||
|
||||
class FSDPModule:
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""
|
||||
@ -262,7 +277,8 @@ class FSDPModule:
|
||||
# and index 1 is the `FSDPModule` class itself
|
||||
orig_cls = cls.__mro__[2]
|
||||
self = orig_cls.__new__(orig_cls, *args, **kwargs)
|
||||
self.__init__(*args, **kwargs)
|
||||
if _enable_fsdp_module_new_init:
|
||||
self.__init__(*args, **kwargs)
|
||||
return self
|
||||
|
||||
def reshard(self) -> None:
|
||||
|
Reference in New Issue
Block a user