[dynamo] Support BUILTIN_MATCH serialization. (#157016)

Serialize BUILTIN_MATCH since they are all stored in __builtin__ dict.

Also fixed an issue that the wrong global scope is passed to CheckFunctionManager while loading guards. Previously we can always reuse the compile-time global scope for evaluating guards because the compile-time and runtime global scope are always the same.

For precompile, we need to serialize the compile-time global scope for loading only. We need to point the CheckFunctionManager to the new global scope after loading is finished for evaluating guards.

Differential Revision: [D77159313](https://our.internmc.facebook.com/intern/diff/D77159313/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157016
Approved by: https://github.com/jansel, https://github.com/jamesjwu
This commit is contained in:
zhxchen17
2025-07-02 08:48:53 -07:00
committed by PyTorch MergeBot
parent 172853547a
commit e20784f228
4 changed files with 110 additions and 26 deletions

View File

@ -254,7 +254,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
self._frame_state = _FrameState(
f_locals=dict(frame.f_locals),
f_globals=dict(frame.f_globals),
f_globals=frame.f_globals,
f_code=frame.f_code,
f_builtins=frame.f_builtins,
)
@ -336,13 +336,18 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
):
tracer.run()
ref_gm = CheckFunctionManager(
self._frame_state.f_code,
tracer.output,
guard_filter_fn=guard_filter_fn,
).guard_manager
check_fn_manager = CheckFunctionManager(
self._frame_state.f_code,
tracer.output,
guard_filter_fn=guard_filter_fn,
guards_serialization_mode="save",
)
ref_gm = check_fn_manager.guard_manager
guards_state = check_fn_manager.guards_state
self._cached_guards_state = guards_state
self._cached_f_code = self._frame_state.f_code
@ -354,6 +359,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
guards_state.output_graph,
guards_serialization_mode="load",
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=self._frame_state.f_globals,
)
loaded_gm = check_fn_manager.guard_manager
@ -1278,6 +1284,30 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 11, 2)}, False)
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 2, 2)}, False)
def test_builtin_match(self):
def fn(x):
# usage of getattr() here installs a BUILTIN_MATCH guard
s = getattr(x, "shape") # noqa: B009
return x + s[0]
x = torch.randn(3)
ref, loaded = self._test_serialization("BUILTIN_MATCH", fn, x)
self._test_check_fn(ref, loaded, {"x": x}, True)
getattr_original = getattr
def getattr_new(*args, **kwargs):
return getattr_original(*args, **kwargs)
builtins_dict = (
__builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__
)
builtins_dict["getattr"] = getattr_new
try:
self._test_check_fn(ref, loaded, {"x": x}, False)
finally:
builtins_dict["getattr"] = getattr_original
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -643,12 +643,14 @@ class GuardBuilder(GuardBuilderBase):
guard_manager: GuardManagerWrapper,
check_fn_manager: CheckFunctionManager,
serialization_mode: Optional[str] = None,
runtime_global_scope: Optional[dict[str, Any]] = None,
):
self.f_code = f_code
self.id_ref = id_ref
self.source_ref = source_ref
self.lookup_weakrefs = lookup_weakrefs
self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
self.runtime_global_scope = runtime_global_scope or global_scope
self.scope["__builtins__"] = builtins.__dict__.copy()
for (
name,
@ -953,7 +955,7 @@ class GuardBuilder(GuardBuilderBase):
def get_global_guard_manager(self):
return self.guard_manager.root.globals_dict_manager(
f_globals=self.scope["G"],
f_globals=self.runtime_global_scope,
source="G",
example_value=self.scope["G"],
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
@ -1537,6 +1539,9 @@ class GuardBuilder(GuardBuilderBase):
def ID_MATCH(self, guard: Guard):
if self.serialization_mode == "save":
raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.")
return self.id_match_unchecked(guard)
def id_match_unchecked(self, guard: Guard):
# ___check_obj_id is same as `id(x) == y`
if isinstance(guard.originating_source, TypeSource):
# optional optimization to produce cleaner/faster guard code
@ -1838,7 +1843,15 @@ class GuardBuilder(GuardBuilderBase):
self.FUNCTION_MATCH(guard)
def BUILTIN_MATCH(self, guard: Guard):
return self.FUNCTION_MATCH(guard)
if self.serialization_mode == "save":
# Record which builtin variables are used for pruning later.
if isinstance(guard.originating_source, DictGetItemSource):
self.check_fn_manager.used_builtin_vars.add(
guard.originating_source.index
)
return self.id_match_unchecked(guard)
return self.ID_MATCH(guard)
def SEQUENCE_LENGTH(self, guard):
# This guard is used to check length of PySequence objects like list,
@ -2761,6 +2774,7 @@ class CheckFunctionManager:
] = None,
guards_serialization_mode: Optional[str] = None,
shape_code_parts: Optional[ShapeCodeParts] = None,
runtime_global_scope: Optional[dict[str, Any]] = None,
):
guards = output_graph.guards if output_graph else None
self._weakrefs: dict[int, ReferenceType[object]] = {}
@ -2779,6 +2793,10 @@ class CheckFunctionManager:
output_graph.torch_function_mode_stack if output_graph else None
)
self.guards_serialization_mode = guards_serialization_mode
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
if runtime_global_scope:
assert self.guards_serialization_mode == "load"
self.runtime_global_scope = runtime_global_scope
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
log.warning("guard_nn_modules is turned off using justknobs killswitch")
@ -2893,6 +2911,7 @@ class CheckFunctionManager:
CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
self.guards_state: Optional[bytes] = None
builtins_dict_name = self.output_graph.name_of_builtins_dict_key_in_fglobals
if self.guards_serialization_mode == "save":
used_global_vars = set()
used_local_vars = set()
@ -2900,7 +2919,11 @@ class CheckFunctionManager:
def prune_variable(source):
if name := get_global_source_name(source):
assert isinstance(name, str)
used_global_vars.add(name)
# Leave out the builtins dict key, as we will special handle
# it later because the guarded code rarely use the entire
# builtin dict in the common case.
if name not in (builtins_dict_name,):
used_global_vars.add(name)
elif name := get_local_source_name(source):
assert isinstance(name, str)
used_local_vars.add(name)
@ -2932,6 +2955,18 @@ class CheckFunctionManager:
return x
global_scope_state = {
k: v
for k, v in output_graph_guards_state.global_scope.items()
if k in used_global_vars
}
global_scope_state[builtins_dict_name] = {
k: v
for k, v in output_graph_guards_state.global_scope[
builtins_dict_name
].items()
if k in self.used_builtin_vars
}
output_graph_guards_state = dataclasses.replace(
output_graph_guards_state,
local_scope={
@ -2939,11 +2974,7 @@ class CheckFunctionManager:
for k, v in output_graph_guards_state.local_scope.items()
if k in used_local_vars
},
global_scope={
k: v
for k, v in output_graph_guards_state.global_scope.items()
if k in used_global_vars
},
global_scope=global_scope_state,
_guards=torch._guards.GuardsSet(
{
dataclasses.replace(
@ -3015,6 +3046,7 @@ class CheckFunctionManager:
guard_manager,
self,
serialization_mode,
runtime_global_scope=self.runtime_global_scope,
)
# Break retain cycle. See test_release_scope_memory

View File

@ -311,6 +311,7 @@ class OutputGraphGuardsState:
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
current_device: Optional[torch.device]
global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
name_of_builtins_dict_key_in_fglobals: Optional[str] = None
export: bool = False
export_constraints: bool = False
@ -344,6 +345,26 @@ class StackLocalsMetadata:
locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
def get_builtins_dict(global_scope):
# f_globals["__builtins__"] can be a dict or a module. This is an
# implemenation detail -
# https://docs.python.org/3/library/builtins.html.
# This makes guarding on any builtin messy because the guard check_fn
# has to check if the __builtins__ is a module or dict, and then access
# by either using getattr or getitem respectively.
# To solve this problem, we insert a new entry in f_globals which points
# to the builtins __dict__ and then we guard any builtin on this dict.
# To avoid any collision with the pre-existing keys, we use the
# install_global to give us a unique dict key.
f_builtins = global_scope["__builtins__"]
if not isinstance(f_builtins, dict):
f_builtins = f_builtins.__dict__
return f_builtins
class OutputGraph(OutputGraphGuardsState):
"""
Wrapper class to hold outputs of InstructionTranslator. Mainly the
@ -566,22 +587,7 @@ class OutputGraph(OutputGraphGuardsState):
self.compiler_trace_stack.close()
def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an
# implementation detail -
# https://docs.python.org/3/library/builtins.html.
# This makes guarding on any builtin messy because the guard check_fn
# has to check if the __builtins__ is a module or dict, and then access
# by either using getattr or getitem respectively.
# To solve this problem, we insert a new entry in f_globals which points
# to the builtins __dict__ and then we guard any builtin on this dict.
# To avoid any collision with the pre-existing keys, we use the
# install_global to give us a unique dict key.
f_builtins = self.global_scope["__builtins__"]
if not isinstance(f_builtins, dict):
f_builtins = f_builtins.__dict__
f_builtins = get_builtins_dict(self.global_scope)
return self.install_global("__builtins_dict__", f_builtins)
def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
@ -680,6 +686,7 @@ class OutputGraph(OutputGraphGuardsState):
functorch_layers=self.functorch_layers,
current_device=self.current_device,
global_state_guard=self.global_state_guard,
name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals,
export=self.export,
export_constraints=self.export_constraints,
_guards=self.guards,

View File

@ -375,6 +375,8 @@ class CompilePackage:
"""
from torch._C._dynamo.eval_frame import _load_precompile_entry
from .output_graph import get_builtins_dict
self.uninstall()
for code, entry in self._codes.items():
@ -401,12 +403,25 @@ class CompilePackage:
for code, entry in self._codes.items():
for guarded_code in entry.guarded_codes:
guards_state = pickle.loads(guarded_code.guards_state)
runtime_global_scope = sys.modules[entry.python_module].__dict__
# The installed builtins dict might be absent from the runtime
# while loading guards. Populate it if it's missing.
if (
builtin_dict_name
:= guards_state.output_graph.name_of_builtins_dict_key_in_fglobals
):
builtins_dict = get_builtins_dict(runtime_global_scope)
if builtin_dict_name in runtime_global_scope:
assert runtime_global_scope[builtin_dict_name] is builtins_dict
else:
runtime_global_scope[builtin_dict_name] = builtins_dict
assert isinstance(guards_state, torch._dynamo.guards.GuardsState)
check_fn_manager = torch._dynamo.guards.CheckFunctionManager(
code,
guards_state.output_graph,
guards_serialization_mode="load",
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=runtime_global_scope,
)
_load_precompile_entry(
code,