mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support BUILTIN_MATCH serialization. (#157016)
Serialize BUILTIN_MATCH since they are all stored in __builtin__ dict. Also fixed an issue that the wrong global scope is passed to CheckFunctionManager while loading guards. Previously we can always reuse the compile-time global scope for evaluating guards because the compile-time and runtime global scope are always the same. For precompile, we need to serialize the compile-time global scope for loading only. We need to point the CheckFunctionManager to the new global scope after loading is finished for evaluating guards. Differential Revision: [D77159313](https://our.internmc.facebook.com/intern/diff/D77159313/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/157016 Approved by: https://github.com/jansel, https://github.com/jamesjwu
This commit is contained in:
committed by
PyTorch MergeBot
parent
172853547a
commit
e20784f228
@ -254,7 +254,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
|
||||
self._frame_state = _FrameState(
|
||||
f_locals=dict(frame.f_locals),
|
||||
f_globals=dict(frame.f_globals),
|
||||
f_globals=frame.f_globals,
|
||||
f_code=frame.f_code,
|
||||
f_builtins=frame.f_builtins,
|
||||
)
|
||||
@ -336,13 +336,18 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
):
|
||||
tracer.run()
|
||||
|
||||
ref_gm = CheckFunctionManager(
|
||||
self._frame_state.f_code,
|
||||
tracer.output,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
).guard_manager
|
||||
|
||||
check_fn_manager = CheckFunctionManager(
|
||||
self._frame_state.f_code,
|
||||
tracer.output,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
guards_serialization_mode="save",
|
||||
)
|
||||
ref_gm = check_fn_manager.guard_manager
|
||||
guards_state = check_fn_manager.guards_state
|
||||
self._cached_guards_state = guards_state
|
||||
self._cached_f_code = self._frame_state.f_code
|
||||
@ -354,6 +359,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
guards_state.output_graph,
|
||||
guards_serialization_mode="load",
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
runtime_global_scope=self._frame_state.f_globals,
|
||||
)
|
||||
loaded_gm = check_fn_manager.guard_manager
|
||||
|
||||
@ -1278,6 +1284,30 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 11, 2)}, False)
|
||||
self._test_check_fn(ref, loaded, {"x": torch.randn(3, 2, 2)}, False)
|
||||
|
||||
def test_builtin_match(self):
|
||||
def fn(x):
|
||||
# usage of getattr() here installs a BUILTIN_MATCH guard
|
||||
s = getattr(x, "shape") # noqa: B009
|
||||
return x + s[0]
|
||||
|
||||
x = torch.randn(3)
|
||||
|
||||
ref, loaded = self._test_serialization("BUILTIN_MATCH", fn, x)
|
||||
self._test_check_fn(ref, loaded, {"x": x}, True)
|
||||
getattr_original = getattr
|
||||
|
||||
def getattr_new(*args, **kwargs):
|
||||
return getattr_original(*args, **kwargs)
|
||||
|
||||
builtins_dict = (
|
||||
__builtins__ if isinstance(__builtins__, dict) else __builtins__.__dict__
|
||||
)
|
||||
builtins_dict["getattr"] = getattr_new
|
||||
try:
|
||||
self._test_check_fn(ref, loaded, {"x": x}, False)
|
||||
finally:
|
||||
builtins_dict["getattr"] = getattr_original
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -643,12 +643,14 @@ class GuardBuilder(GuardBuilderBase):
|
||||
guard_manager: GuardManagerWrapper,
|
||||
check_fn_manager: CheckFunctionManager,
|
||||
serialization_mode: Optional[str] = None,
|
||||
runtime_global_scope: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
self.f_code = f_code
|
||||
self.id_ref = id_ref
|
||||
self.source_ref = source_ref
|
||||
self.lookup_weakrefs = lookup_weakrefs
|
||||
self.scope: dict[str, dict[str, object]] = {"L": local_scope, "G": global_scope}
|
||||
self.runtime_global_scope = runtime_global_scope or global_scope
|
||||
self.scope["__builtins__"] = builtins.__dict__.copy()
|
||||
for (
|
||||
name,
|
||||
@ -953,7 +955,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
|
||||
def get_global_guard_manager(self):
|
||||
return self.guard_manager.root.globals_dict_manager(
|
||||
f_globals=self.scope["G"],
|
||||
f_globals=self.runtime_global_scope,
|
||||
source="G",
|
||||
example_value=self.scope["G"],
|
||||
guard_manager_enum=GuardManagerType.GUARD_MANAGER,
|
||||
@ -1537,6 +1539,9 @@ class GuardBuilder(GuardBuilderBase):
|
||||
def ID_MATCH(self, guard: Guard):
|
||||
if self.serialization_mode == "save":
|
||||
raise torch._dynamo.exc.PackageError("ID_MATCH guard cannot be serialized.")
|
||||
return self.id_match_unchecked(guard)
|
||||
|
||||
def id_match_unchecked(self, guard: Guard):
|
||||
# ___check_obj_id is same as `id(x) == y`
|
||||
if isinstance(guard.originating_source, TypeSource):
|
||||
# optional optimization to produce cleaner/faster guard code
|
||||
@ -1838,7 +1843,15 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.FUNCTION_MATCH(guard)
|
||||
|
||||
def BUILTIN_MATCH(self, guard: Guard):
|
||||
return self.FUNCTION_MATCH(guard)
|
||||
if self.serialization_mode == "save":
|
||||
# Record which builtin variables are used for pruning later.
|
||||
if isinstance(guard.originating_source, DictGetItemSource):
|
||||
self.check_fn_manager.used_builtin_vars.add(
|
||||
guard.originating_source.index
|
||||
)
|
||||
return self.id_match_unchecked(guard)
|
||||
|
||||
return self.ID_MATCH(guard)
|
||||
|
||||
def SEQUENCE_LENGTH(self, guard):
|
||||
# This guard is used to check length of PySequence objects like list,
|
||||
@ -2761,6 +2774,7 @@ class CheckFunctionManager:
|
||||
] = None,
|
||||
guards_serialization_mode: Optional[str] = None,
|
||||
shape_code_parts: Optional[ShapeCodeParts] = None,
|
||||
runtime_global_scope: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self._weakrefs: dict[int, ReferenceType[object]] = {}
|
||||
@ -2779,6 +2793,10 @@ class CheckFunctionManager:
|
||||
output_graph.torch_function_mode_stack if output_graph else None
|
||||
)
|
||||
self.guards_serialization_mode = guards_serialization_mode
|
||||
self.used_builtin_vars: OrderedSet[str] = OrderedSet()
|
||||
if runtime_global_scope:
|
||||
assert self.guards_serialization_mode == "load"
|
||||
self.runtime_global_scope = runtime_global_scope
|
||||
|
||||
if not justknobs_check("pytorch/compiler:guard_nn_modules"):
|
||||
log.warning("guard_nn_modules is turned off using justknobs killswitch")
|
||||
@ -2893,6 +2911,7 @@ class CheckFunctionManager:
|
||||
CompileEventLogger.increment_toplevel("guard_latency_us", int(latency))
|
||||
|
||||
self.guards_state: Optional[bytes] = None
|
||||
builtins_dict_name = self.output_graph.name_of_builtins_dict_key_in_fglobals
|
||||
if self.guards_serialization_mode == "save":
|
||||
used_global_vars = set()
|
||||
used_local_vars = set()
|
||||
@ -2900,7 +2919,11 @@ class CheckFunctionManager:
|
||||
def prune_variable(source):
|
||||
if name := get_global_source_name(source):
|
||||
assert isinstance(name, str)
|
||||
used_global_vars.add(name)
|
||||
# Leave out the builtins dict key, as we will special handle
|
||||
# it later because the guarded code rarely use the entire
|
||||
# builtin dict in the common case.
|
||||
if name not in (builtins_dict_name,):
|
||||
used_global_vars.add(name)
|
||||
elif name := get_local_source_name(source):
|
||||
assert isinstance(name, str)
|
||||
used_local_vars.add(name)
|
||||
@ -2932,6 +2955,18 @@ class CheckFunctionManager:
|
||||
|
||||
return x
|
||||
|
||||
global_scope_state = {
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope.items()
|
||||
if k in used_global_vars
|
||||
}
|
||||
global_scope_state[builtins_dict_name] = {
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope[
|
||||
builtins_dict_name
|
||||
].items()
|
||||
if k in self.used_builtin_vars
|
||||
}
|
||||
output_graph_guards_state = dataclasses.replace(
|
||||
output_graph_guards_state,
|
||||
local_scope={
|
||||
@ -2939,11 +2974,7 @@ class CheckFunctionManager:
|
||||
for k, v in output_graph_guards_state.local_scope.items()
|
||||
if k in used_local_vars
|
||||
},
|
||||
global_scope={
|
||||
k: v
|
||||
for k, v in output_graph_guards_state.global_scope.items()
|
||||
if k in used_global_vars
|
||||
},
|
||||
global_scope=global_scope_state,
|
||||
_guards=torch._guards.GuardsSet(
|
||||
{
|
||||
dataclasses.replace(
|
||||
@ -3015,6 +3046,7 @@ class CheckFunctionManager:
|
||||
guard_manager,
|
||||
self,
|
||||
serialization_mode,
|
||||
runtime_global_scope=self.runtime_global_scope,
|
||||
)
|
||||
|
||||
# Break retain cycle. See test_release_scope_memory
|
||||
|
@ -311,6 +311,7 @@ class OutputGraphGuardsState:
|
||||
functorch_layers: list[torch._functorch.pyfunctorch.FuncTorchInterpreter]
|
||||
current_device: Optional[torch.device]
|
||||
global_state_guard: torch._C._dynamo.guards.GlobalStateGuard
|
||||
name_of_builtins_dict_key_in_fglobals: Optional[str] = None
|
||||
|
||||
export: bool = False
|
||||
export_constraints: bool = False
|
||||
@ -344,6 +345,26 @@ class StackLocalsMetadata:
|
||||
locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
|
||||
|
||||
|
||||
def get_builtins_dict(global_scope):
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
# implemenation detail -
|
||||
# https://docs.python.org/3/library/builtins.html.
|
||||
|
||||
# This makes guarding on any builtin messy because the guard check_fn
|
||||
# has to check if the __builtins__ is a module or dict, and then access
|
||||
# by either using getattr or getitem respectively.
|
||||
|
||||
# To solve this problem, we insert a new entry in f_globals which points
|
||||
# to the builtins __dict__ and then we guard any builtin on this dict.
|
||||
# To avoid any collision with the pre-existing keys, we use the
|
||||
# install_global to give us a unique dict key.
|
||||
|
||||
f_builtins = global_scope["__builtins__"]
|
||||
if not isinstance(f_builtins, dict):
|
||||
f_builtins = f_builtins.__dict__
|
||||
return f_builtins
|
||||
|
||||
|
||||
class OutputGraph(OutputGraphGuardsState):
|
||||
"""
|
||||
Wrapper class to hold outputs of InstructionTranslator. Mainly the
|
||||
@ -566,22 +587,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
self.compiler_trace_stack.close()
|
||||
|
||||
def install_builtins_dict_in_fglobals(self):
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
# implementation detail -
|
||||
# https://docs.python.org/3/library/builtins.html.
|
||||
|
||||
# This makes guarding on any builtin messy because the guard check_fn
|
||||
# has to check if the __builtins__ is a module or dict, and then access
|
||||
# by either using getattr or getitem respectively.
|
||||
|
||||
# To solve this problem, we insert a new entry in f_globals which points
|
||||
# to the builtins __dict__ and then we guard any builtin on this dict.
|
||||
# To avoid any collision with the pre-existing keys, we use the
|
||||
# install_global to give us a unique dict key.
|
||||
|
||||
f_builtins = self.global_scope["__builtins__"]
|
||||
if not isinstance(f_builtins, dict):
|
||||
f_builtins = f_builtins.__dict__
|
||||
f_builtins = get_builtins_dict(self.global_scope)
|
||||
return self.install_global("__builtins_dict__", f_builtins)
|
||||
|
||||
def add_backward_state_hook(self, hook: VariableTracker, prefix="hook"):
|
||||
@ -680,6 +686,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
functorch_layers=self.functorch_layers,
|
||||
current_device=self.current_device,
|
||||
global_state_guard=self.global_state_guard,
|
||||
name_of_builtins_dict_key_in_fglobals=self.name_of_builtins_dict_key_in_fglobals,
|
||||
export=self.export,
|
||||
export_constraints=self.export_constraints,
|
||||
_guards=self.guards,
|
||||
|
@ -375,6 +375,8 @@ class CompilePackage:
|
||||
"""
|
||||
from torch._C._dynamo.eval_frame import _load_precompile_entry
|
||||
|
||||
from .output_graph import get_builtins_dict
|
||||
|
||||
self.uninstall()
|
||||
|
||||
for code, entry in self._codes.items():
|
||||
@ -401,12 +403,25 @@ class CompilePackage:
|
||||
for code, entry in self._codes.items():
|
||||
for guarded_code in entry.guarded_codes:
|
||||
guards_state = pickle.loads(guarded_code.guards_state)
|
||||
runtime_global_scope = sys.modules[entry.python_module].__dict__
|
||||
# The installed builtins dict might be absent from the runtime
|
||||
# while loading guards. Populate it if it's missing.
|
||||
if (
|
||||
builtin_dict_name
|
||||
:= guards_state.output_graph.name_of_builtins_dict_key_in_fglobals
|
||||
):
|
||||
builtins_dict = get_builtins_dict(runtime_global_scope)
|
||||
if builtin_dict_name in runtime_global_scope:
|
||||
assert runtime_global_scope[builtin_dict_name] is builtins_dict
|
||||
else:
|
||||
runtime_global_scope[builtin_dict_name] = builtins_dict
|
||||
assert isinstance(guards_state, torch._dynamo.guards.GuardsState)
|
||||
check_fn_manager = torch._dynamo.guards.CheckFunctionManager(
|
||||
code,
|
||||
guards_state.output_graph,
|
||||
guards_serialization_mode="load",
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
runtime_global_scope=runtime_global_scope,
|
||||
)
|
||||
_load_precompile_entry(
|
||||
code,
|
||||
|
Reference in New Issue
Block a user