[precompile] Fix guard serialization loading bugs. (#164490)

Summary: Added a set of fixes triggered by fm training job. Overall the theme here is that we should get rid of saved objects as much as possible when they are not used in guard reconstruction. Sometimes for objects that cannot be saved (like local functions) we still try our best to save their closures.

Test Plan:
test_guard_serialization.py
test_lazy_awatiable.py

Differential Revision: D83766926

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164490
Approved by: https://github.com/jamesjwu
This commit is contained in:
Zhengxu Chen
2025-10-03 19:20:07 +00:00
committed by PyTorch MergeBot
parent 3c59351c6e
commit 16f9bef642
5 changed files with 481 additions and 25 deletions

View File

@ -986,6 +986,7 @@ class GuardBuilder(GuardBuilderBase):
check_fn_manager: CheckFunctionManager,
save_guards: bool = False,
runtime_global_scope: Optional[dict[str, object]] = None,
source_get_cache: Optional[dict[str, Any]] = None,
) -> None:
self.f_code = f_code
self.id_ref = id_ref
@ -993,6 +994,7 @@ class GuardBuilder(GuardBuilderBase):
self.lookup_weakrefs = lookup_weakrefs
self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
self.runtime_global_scope = runtime_global_scope or global_scope
self.source_get_cache = source_get_cache or {}
self.scope["__builtins__"] = builtins.__dict__.copy()
for (
name,
@ -1021,6 +1023,9 @@ class GuardBuilder(GuardBuilderBase):
self.check_fn_manager: CheckFunctionManager = check_fn_manager
self.guard_tree_values: dict[int, Any] = {}
self.save_guards = save_guards
# Collect the ids of dicts which need key order guarding. source_name is
# not sufficient because for nn modules, we can have different sources
# to access the same object - self._module["param"] is same as
@ -1028,7 +1033,10 @@ class GuardBuilder(GuardBuilderBase):
self.key_order_guarded_dict_ids = set()
assert self.check_fn_manager.output_graph is not None
for source in self.check_fn_manager.output_graph.guard_on_key_order:
self.key_order_guarded_dict_ids.add(id(self.get(source.name())))
dict_obj = self.get(source.name())
if self.save_guards:
self.source_get_cache[source.name()] = dict_obj
self.key_order_guarded_dict_ids.add(id(dict_obj))
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and guard_manager and is used to
@ -1039,14 +1047,12 @@ class GuardBuilder(GuardBuilderBase):
self._cached_guard_managers: dict[str, GuardManager] = {}
self._cached_duplicate_input_guards: set[tuple[str, str]] = set()
self.object_aliasing_guard_codes: list[tuple[str, str]] = []
self.save_guards = save_guards
self.guard_nn_modules = config.guard_nn_modules and justknobs_check(
"pytorch/compiler:guard_nn_modules"
)
self.already_guarded_not_present_in_generic_dict: OrderedSet[
tuple[str, str]
] = OrderedSet()
self.guard_tree_values: dict[int, Any] = {}
def guard_on_dict_keys_and_ignore_order(
self, example_value: dict[Any, Any], guard: Guard
@ -1772,9 +1778,15 @@ class GuardBuilder(GuardBuilderBase):
# (like its type) which is what you permanently install into the
# guard code.
def get(self, name: str, closure_vars: Optional[dict[str, Any]] = None) -> Any:
if self.source_get_cache:
if name in self.source_get_cache:
return self.source_get_cache[name]
if closure_vars is None:
closure_vars = _get_closure_vars()
return eval(name, self.scope, closure_vars)
ret = eval(name, self.scope, closure_vars)
if self.save_guards and ".__closure__" in name:
self.source_get_cache[name] = ret
return ret
# Registers the usage of the source name referenced by the
# string (or stored in the Guard) as being guarded upon. It's important
@ -3071,10 +3083,23 @@ class ShapeCodeParts:
class GuardsState:
output_graph: OutputGraphGuardsState
shape_code_parts: Optional[ShapeCodeParts]
source_get_cache: Optional[dict[str, Any]] = None
class _Missing:
pass
def __init__(self, reason: Optional[str] = None) -> None:
self._reason = reason
def __repr__(self) -> str:
return f"_Missing({self._reason})"
def __str__(self) -> str:
return f"_Missing({self._reason})"
# Sometimes _Missing object is used as the callable with functools.partial,
# so we add a dummy __call__ here to bypass TypeError from partial().
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return _Missing()
@functools.cache
@ -3097,6 +3122,7 @@ class GuardsStatePickler(pickle.Pickler):
self,
guard_tree_values: dict[int, Any],
empty_values: dict[int, Any],
missing_values: dict[int, Any],
*args: Any,
**kwargs: Any,
) -> None:
@ -3105,6 +3131,7 @@ class GuardsStatePickler(pickle.Pickler):
self.tensor_converter = torch._subclasses.fake_tensor.FakeTensorConverter()
self.guard_tree_values = guard_tree_values
self.empty_values = empty_values
self.missing_values = missing_values
@classmethod
def _unpickle_module(cls, state: Any) -> torch.nn.Module:
@ -3188,10 +3215,31 @@ class GuardsStatePickler(pickle.Pickler):
original_type
]
@classmethod
def _unpickle_ddp_module(
cls, state: dict[str, Any]
) -> torch.nn.parallel.DistributedDataParallel:
ty = torch.nn.parallel.DistributedDataParallel
ddp = ty.__new__(ty)
torch.nn.Module.__setstate__(ddp, state)
return ddp
@classmethod
def _unpickle_c_op(cls, name: str) -> Any:
return getattr(torch.ops._C, name)
@classmethod
def _unpickle_bound_method(cls, func: Any, base: Any) -> Any:
return types.MethodType(func, base)
@classmethod
def _unpickle_cell(cls, val: Any) -> Any:
def _() -> Any:
return val
assert _.__closure__ is not None
return _.__closure__[0]
def reducer_override(
self, obj: Any
) -> Union[tuple[Callable[..., Any], tuple[Any, ...]], Any]:
@ -3200,11 +3248,14 @@ class GuardsStatePickler(pickle.Pickler):
if id(obj) in self.empty_values:
return type(obj).__new__, (type(obj),)
if id(obj) in self.missing_values:
return _Missing, ("missing values",)
if isinstance(obj, torch.Tensor) and obj.device.type != "meta":
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
if id(obj) not in self.guard_tree_values:
return _Missing, ()
return _Missing, ("tensor guard tree",)
if is_traceable_wrapper_subclass(obj):
# inner_data is a list of tuples of:
@ -3238,6 +3289,15 @@ class GuardsStatePickler(pickle.Pickler):
)
elif isinstance(obj, torch.nn.Module):
if id(obj) not in self.guard_tree_values:
return _Missing, ("module guard tree",)
# DDP module is a special case because it tries to restore unneeded
# data in custom __setstate__. We cannot skip ddp module because it
# is often a toplevel module.
if isinstance(obj, torch.nn.parallel.DistributedDataParallel):
return type(self)._unpickle_ddp_module, (obj.__getstate__(),)
if type(obj).__qualname__ == type(obj).__name__:
return NotImplemented
if obj.__class__.__getstate__ == torch.nn.Module.__getstate__:
@ -3279,20 +3339,37 @@ class GuardsStatePickler(pickle.Pickler):
and obj.__class__.__name__ == "PyCapsule"
):
# Skipping PyCapsule since there isn't much to be guarded about them.
return _Missing, ()
return _Missing, ("capsule",)
elif isinstance(obj, _get_unsupported_types()):
return _Missing, ()
return _Missing, ("unsupported",)
elif inspect.isfunction(obj):
if obj.__code__.co_flags & inspect.CO_NESTED:
return _Missing, ()
return _Missing, ("nested function",)
if obj.__module__ in sys.modules:
f = sys.modules[obj.__module__]
for name in obj.__qualname__.split("."):
f = getattr(f, name, None) # type: ignore[assignment]
if f is not obj:
return _Missing, ()
return _Missing, ("fqn mismatch",)
elif inspect.ismethod(obj):
func = obj.__func__
method_self = obj.__self__
inner_func = getattr(method_self, func.__name__)
if inspect.ismethod(inner_func):
inner_func = inner_func.__func__
if func is not inner_func:
return type(self)._unpickle_bound_method, (func, method_self)
elif isinstance(obj, type((lambda x: lambda: x)(0).__closure__[0])): # type: ignore[index] # noqa: PLC3002
return type(self)._unpickle_cell, (obj.cell_contents,)
if hasattr(torch.distributed, "distributed_c10d") and isinstance(
obj, torch.distributed.distributed_c10d.Work
):
if id(obj) not in self.guard_tree_values:
return _Missing, ("distributed_c10d.Work",)
if type(obj).__qualname__ != type(obj).__name__:
raise torch._dynamo.exc.PackageError(
@ -3301,12 +3378,6 @@ class GuardsStatePickler(pickle.Pickler):
+ "Please define the class at global scope (top level of a module)."
)
if hasattr(torch.distributed, "distributed_c10d") and isinstance(
obj, torch.distributed.distributed_c10d.Work
):
if id(obj) not in self.guard_tree_values:
return _Missing, ()
if (
inspect.isclass(obj)
and hasattr(torch.distributed, "fsdp")
@ -3327,6 +3398,7 @@ class GuardsStatePickler(pickle.Pickler):
def pickle_guards_state(state: GuardsState, guard_tree_values: dict[int, Any]) -> bytes:
buf = io.BytesIO()
empty_values = {}
missing_values = {}
leaves = pytree.tree_leaves(state.output_graph.local_scope)
for leaf in leaves:
@ -3338,7 +3410,11 @@ def pickle_guards_state(state: GuardsState, guard_tree_values: dict[int, Any]) -
empty_values[id(base)] = base
except: # noqa: E722, B001
pass
pickler = GuardsStatePickler(guard_tree_values, empty_values, buf)
elif id(leaf) not in guard_tree_values:
# TODO See if we have lift this branch as the first one.
# Prune more objects in pytree hierarchy.
missing_values[id(leaf)] = leaf
pickler = GuardsStatePickler(guard_tree_values, empty_values, missing_values, buf)
try:
pickler.dump(state)
except AttributeError as e:
@ -3365,6 +3441,7 @@ class CheckFunctionManager:
runtime_global_scope: Optional[dict[str, Any]] = None,
save_guards: bool = False,
strict_error: bool = False,
source_get_cache: Optional[dict[str, Any]] = None,
):
guards = output_graph.guards if output_graph else None
self._weakrefs: dict[int, ReferenceType[object]] = {}
@ -3427,7 +3504,12 @@ class CheckFunctionManager:
# If we're filtering guards, we need to build it an extra time first
# because filtering depends on the builder/guard_manager results
builder, guard_manager = self.build_guards(
sorted_guards, existing_diff_guard_sources, f_code, output_graph, False
sorted_guards,
existing_diff_guard_sources,
f_code,
output_graph,
False,
source_get_cache=source_get_cache,
)
def make_guard_filter_entry(guard: Guard) -> GuardFilterEntry:
@ -3476,6 +3558,7 @@ class CheckFunctionManager:
f_code,
output_graph,
save_guards,
source_get_cache=source_get_cache,
)
self.guard_manager = guard_manager
@ -3696,6 +3779,7 @@ class CheckFunctionManager:
guards_state = GuardsState(
output_graph=output_graph_guards_state,
shape_code_parts=self.shape_code_parts,
source_get_cache=builder.source_get_cache,
)
return pickle_guards_state(guards_state, builder.guard_tree_values)
@ -3707,6 +3791,7 @@ class CheckFunctionManager:
f_code: types.CodeType,
output_graph: OutputGraphGuardsState,
save_guards: bool,
source_get_cache: Optional[dict[str, Any]] = None,
) -> tuple[GuardBuilder, GuardManagerWrapper]:
guard_manager = GuardManagerWrapper()
guard_manager.diff_guard_sources = existing_diff_guard_sources
@ -3734,6 +3819,7 @@ class CheckFunctionManager:
self,
save_guards,
runtime_global_scope=self.runtime_global_scope,
source_get_cache=source_get_cache,
)
# Break retain cycle. See test_release_scope_memory