From b2fc9cfea16c8eb52c1ce79b2032793dd1a545fb Mon Sep 17 00:00:00 2001 From: James Wu Date: Thu, 12 Jun 2025 09:46:38 -0700 Subject: [PATCH] [precompile] Add CompilePackage to serialize dynamo states. (#155118) Adding a per torch.compile() object CompilePackage which tracks dynamo artifact. CompilePackage is considered a low level component and should not be directly exposed to end users. It has the following interface: 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object. b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. 2. `package.save()` which dumps the dynamo states into _DynamoCacheEntry. 3. `package.install(backends)` which will handle all the side-effectful global scope updates with compiled functions and resume functions. This diff focus on making the low level mechanism for precompile. It will be left to upper level interface to use these API to build more user-facing frontend. Differential Revision: [D75956538](https://our.internmc.facebook.com/intern/diff/D75956538/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155118 Approved by: https://github.com/jamesjwu Co-authored-by: James Wu --- test/dynamo/test_guard_serialization.py | 1 + test/dynamo/test_package.py | 196 ++++++++++++++ torch/_C/_dynamo/eval_frame.pyi | 6 +- torch/_dynamo/convert_frame.py | 41 ++- torch/_dynamo/eval_frame.py | 18 +- torch/_dynamo/guards.py | 8 +- torch/_dynamo/output_graph.py | 5 + torch/_dynamo/package.py | 331 ++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 23 +- torch/_dynamo/testing.py | 1 + torch/_dynamo/utils.py | 2 + 11 files changed, 618 insertions(+), 14 deletions(-) create mode 100644 test/dynamo/test_package.py create mode 100644 torch/_dynamo/package.py diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index f421597b5eaa..6fa40064beb1 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -322,6 +322,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): speculation_log=SpeculationLog(), exn_vt_stack=ExceptionStack(), distributed_state=None, + package=None, ) with compile_context(CompileContext(CompileId(0, 0))), tracing( tracer.output.tracing_context diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py new file mode 100644 index 000000000000..b299a3752db5 --- /dev/null +++ b/test/dynamo/test_package.py @@ -0,0 +1,196 @@ +# Owner(s): ["module: dynamo"] + +import os +import pickle + +import torch +import torch._dynamo.testing +import torch._inductor.config +import torch._inductor.test_case +import torch.onnx.operators +import torch.utils.cpp_extension +from torch._dynamo.package import CompilePackage +from torch._inductor.runtime.runtime_utils import cache_dir + + +class StorageForTesting: + def __init__(self, path: str): + self.path = path + self.backends = {} + + def _write_pickle(self, data, *path: str): + with open(os.path.join(self.path, *path) + ".pickle", "wb") as f: + pickle.dump(data, f) + + def write_dynamo(self, dynamo): + self._write_pickle(dynamo, "dynamo") + + def write_backend(self, backend_id): + os.makedirs(os.path.join(self.path, backend_id), exist_ok=True) + self._write_pickle(self.backends[backend_id], backend_id, "fx_graph") + + def _read_pickle(self, *path): + with open(os.path.join(self.path, *path) + ".pickle", "rb") as f: + return pickle.load(f) + + def read_backend(self, backend_id): + return self._read_pickle(backend_id, "fx_graph") + + def read_dynamo(self): + return self._read_pickle("dynamo") + + def add_backend(self, backend_id, backend): + self.backends[backend_id] = backend + + def save_package(self, dynamo_cache_entry): + self.write_dynamo(dynamo_cache_entry) + for backend_id in dynamo_cache_entry.backend_ids: + self.write_backend(backend_id) + + def load_package(self): + dynamo = self.read_dynamo() + self.backends = {} + for backend_id in dynamo.backend_ids: + self.backends[backend_id] = self.read_backend(backend_id) + return dynamo + + +class TestPackage(torch._inductor.test_case.TestCase): + def storage(self): + path = os.path.join(cache_dir(), f"package_{self.id()}") + os.makedirs(path, exist_ok=True) + return StorageForTesting(path) + + def test_basic_fn(self): + storage = self.storage() + + def fn(x): + return x + 1 + + args = (torch.randn(3, 2),) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn) + expected = compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + storage.add_backend(backend_id, backend) + storage.save_package(package.save()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + + package = CompilePackage(fn, storage.load_package()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(storage.backends) + self.assertEqual(expected, compiled_fn(*args)) + + def test_graph_break_bomb(self): + storage = self.storage() + + def fn(x, l, r): + if l > r: + return x.sum() + mid = (l + r) // 2 + if x.sum() == mid: + return x.sum() + elif x.sum() < mid: + return fn(x, l, mid) + else: + return fn(x, mid + 1, r) + + def guard_filter_fn(guards): + return [ + guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") + for guard in guards + ] + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize( + backend="eager", package=package, guard_filter_fn=guard_filter_fn + )(fn) + N = 10 + args_list = [(torch.tensor(x), 0, N - 1) for x in range(N)] + for args in args_list: + compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + storage.add_backend(backend_id, backend) + storage.save_package(package.save()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + for args in args_list: + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + package = CompilePackage(fn, storage.load_package()) + compiled_fn = torch._dynamo.optimize( + backend="eager", package=package, guard_filter_fn=guard_filter_fn + )(fn) + package.install(storage.backends) + for args in args_list: + self.assertEqual(compiled_fn(*args), args[0].sum()) + + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(torch.tensor(N), 0, N - 1) + + def test_dynamic_shape(self): + storage = self.storage() + + def fn(x): + return x + x.shape[0] + + args = (torch.randn(3, 2),) + args1 = (torch.randn(5, 2),) + args2 = (torch.randn(7, 2),) + expected1 = fn(*args1) + + torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn) + compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + storage.add_backend(backend_id, backend) + storage.save_package(package.save()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args1) + + package = CompilePackage(fn, storage.load_package()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(storage.backends) + + self.assertEqual(expected1, compiled_fn(*args1)) + + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args2) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index da0b32637759..c89de9a1ff9f 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -2,7 +2,7 @@ import enum import types from typing import overload -from torch._dynamo.types import DynamoCallback, DynamoGuardHook +from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -57,3 +57,7 @@ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... py_opcode_caches: list[int] def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... +def _load_precompile_entry( + code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index df0d04557707..c0ef12ae3a06 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -161,6 +161,7 @@ except ModuleNotFoundError: if typing.TYPE_CHECKING: from .backends.registry import CompilerFn + from .package import CompilePackage from .repro.after_dynamo import WrapBackendDebug from .types import BytecodeHook, CacheEntry, DynamoFrameType from .variables.builder import FrameStateSizeEntry @@ -478,6 +479,7 @@ class ConvertFrameAssert: one_graph: bool = True, export: bool = False, export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, ) -> None: # assert export_constraints is None reset_graph_break_dup_checker() @@ -485,6 +487,7 @@ class ConvertFrameAssert: self._one_graph = one_graph self._export = export self._export_constraints = export_constraints + self._package = package @property def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: @@ -640,6 +643,7 @@ class ConvertFrameAssert: frame_state=frame_state, compile_id=compile_id, skip=skip + 1, + package=self._package, ) @@ -648,9 +652,12 @@ def convert_frame_assert( one_graph: bool = True, export: bool = False, export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, ) -> ConvertFrameAssert: """Fully convert a frame into an FX graph""" - return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) + return ConvertFrameAssert( + compiler_fn, one_graph, export, export_constraints, package + ) from collections import OrderedDict @@ -693,6 +700,7 @@ def _compile( *, compile_id: CompileId, skip: int = 0, + package: Optional[CompilePackage] = None, ) -> ConvertFrameReturn: from torch.fx.experimental.validator import ( bisect, @@ -717,7 +725,7 @@ def _compile( ) -> None: nonlocal output nonlocal tracer - speculation_log.restart() + speculation_log.restart() # type: ignore[has-type] exn_vt_stack = ExceptionStack() tracer = InstructionTranslator( instructions, @@ -733,9 +741,10 @@ def _compile( export, export_constraints, frame_state=frame_state, - speculation_log=speculation_log, + speculation_log=speculation_log, # type: ignore[has-type] exn_vt_stack=exn_vt_stack, - distributed_state=distributed_state, + distributed_state=distributed_state, # type: ignore[has-type] + package=package, ) try: @@ -743,7 +752,7 @@ def _compile( with tracing(tracer.output.tracing_context), tracer.set_current_tx(): tracer.run() except exc.UnspecializeRestartAnalysis: - speculation_log.clear() + speculation_log.clear() # type: ignore[has-type] raise except ( exc.SpeculationRestartAnalysis, @@ -857,7 +866,7 @@ def _compile( log.debug("No graph captured with one_graph=True") return ConvertFrameReturn() - assert distributed_state is None or distributed_state.all_states is not None, ( + assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type] "compiler collective wasn't run before compilation completed" ) @@ -936,8 +945,13 @@ def _compile( cache_entry, hooks.guard_fail_fn if hooks else None, hooks.guard_filter_fn if hooks else None, + guards_serialization_mode="save" if package else None, ) + if package is not None: + assert check_fn.guards_state is not None + package.add_guarded_code(check_fn.guards_state, out_code) + compile_id_str = str(compile_id) if compile_id is not None else "Unknown" annotation_str = "Torch-Compiled Region: " + compile_id_str guarded_code = GuardedCode( @@ -958,6 +972,9 @@ def _compile( return wrap_guarded_code(guarded_code) metrics_context = get_metrics_context() + code_context = ( + package.code_context(code) if package is not None else contextlib.nullcontext() + ) with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), @@ -971,6 +988,7 @@ def _compile( phase_name="entire_frame_compile", dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ), + code_context, ): restart_reasons: set[str] = set() # This is shared across restarts @@ -1226,9 +1244,12 @@ class ConvertFrame: self, compiler_fn: CompilerFn, hooks: Hooks, + package: Optional[CompilePackage] = None, ) -> None: self._torchdynamo_orig_callable = compiler_fn - self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._inner_convert = convert_frame_assert( + compiler_fn, one_graph=False, package=package + ) self._hooks = hooks @property @@ -1331,9 +1352,11 @@ class ConvertFrame: return ConvertFrameReturn() -def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame: +def convert_frame( + compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None +) -> ConvertFrame: """Try to convert a frame into an FX graph, if error leave frame unmodified""" - return ConvertFrame(compiler_fn, hooks) + return ConvertFrame(compiler_fn, hooks, package=package) # TODO mlazos: add support for same args, or record them diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 771d1071f763..380291f9f5ba 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -531,6 +531,7 @@ class _TorchDynamoContext: export=False, dynamic=None, compiler_config=None, + package=None, ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None @@ -543,6 +544,7 @@ class _TorchDynamoContext: self.compiler_config = compiler_config self.cleanup_fns: list[Callable[[], Any]] = [] self.enter_exit_hooks = [] + self._package = package patch_fn() # Save the backends so that we can reset them during torch._dynamo.reset @@ -792,6 +794,7 @@ class OptimizeContext(_TorchDynamoContext): rebuild_ctx: Optional[ Callable[[], Union[OptimizeContext, _NullDecorator]] ] = None, + package=None, ) -> None: def on_enter(): install_generation_tagging_init() @@ -805,6 +808,7 @@ class OptimizeContext(_TorchDynamoContext): export=export, dynamic=dynamic, compiler_config=compiler_config, + package=package, ) if config.compiled_autograd: @@ -928,6 +932,7 @@ def _optimize_catch_errors( dynamic=None, compiler_config=None, rebuild_ctx=None, + package=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), @@ -937,6 +942,7 @@ def _optimize_catch_errors( dynamic=dynamic, compiler_config=compiler_config, rebuild_ctx=rebuild_ctx, + package=package, ) @@ -1027,6 +1033,7 @@ def _optimize( guard_filter_fn=None, disable=False, dynamic=None, + package=None, ) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1079,6 +1086,7 @@ def _optimize( dynamic=dynamic, hooks=hooks, rebuild_ctx=rebuild_ctx, + package=package, ) backend = get_compiler_fn(backend) @@ -1090,7 +1098,7 @@ def _optimize( # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can # be used by eval_frame.c to insert a guard on the backend. return _optimize_catch_errors( - convert_frame.convert_frame(backend, hooks=hooks), + convert_frame.convert_frame(backend, hooks=hooks, package=package), hooks, backend_ctx_ctor, dynamic=dynamic, @@ -1100,6 +1108,7 @@ def _optimize( else None ), rebuild_ctx=rebuild_ctx, + package=package, ) @@ -1990,6 +1999,7 @@ def _optimize_assert( export=False, export_constraints=None, dynamic=None, + package=None, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -2001,13 +2011,17 @@ def _optimize_assert( return _optimize_catch_errors( convert_frame.convert_frame_assert( - backend, export=export, export_constraints=export_constraints + backend, + export=export, + export_constraints=export_constraints, + package=package, ), hooks, backend_ctx_ctor, export=export, dynamic=dynamic, rebuild_ctx=rebuild_ctx, + package=package, ) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index c744b65e1f74..d5afdaff120e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -497,6 +497,8 @@ def get_verbose_code_parts( def convert_int_to_concrete_values(dim) -> Optional[int]: + if dim is None: + return None if not is_symbolic(dim): return dim else: @@ -1457,7 +1459,11 @@ class GuardBuilder(GuardBuilderBase): def TYPE_MATCH(self, guard: Guard) -> None: # ___check_type_id is same as `id(type(x)) == y` value = self.get(guard.name) - t = type(value) + if isinstance(value, torch._subclasses.FakeTensor) and value.pytype: + t = value.pytype + else: + t = type(value) + if self.serialization_mode == "save": if t.__qualname__ != t.__name__: raise_local_type_error(value) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 8e0899450680..797abbeb4a8f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -368,6 +368,7 @@ class OutputGraph(OutputGraphGuardsState): global_scope: Scope, f_code, torch_function_mode_stack, + package, ): super().__init__( local_scope, @@ -471,6 +472,7 @@ class OutputGraph(OutputGraphGuardsState): self.compiler_fn: Optional[CompilerFn] = compiler_fn self.root_tx = root_tx + self.package = package # Given a source, what are the user stacks of all locations that # accessed it? # @@ -1715,6 +1717,9 @@ class OutputGraph(OutputGraphGuardsState): # replace compiled_fn with the real forward method compiled_fn = lazy_gm.forward + if self.package is not None: + self.package.add_backend_id(name, compiled_fn) + compiled_fn = disable( compiled_fn, reason="do not trace Dynamo-compiled graph" ) diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py new file mode 100644 index 000000000000..41157e07f87d --- /dev/null +++ b/torch/_dynamo/package.py @@ -0,0 +1,331 @@ +""" +This module provides the infrastructure for creating and managing compile package +for torch.compile. We mainly have two abstractions here: + - CompilePackage: Overarching data structure for store and lookup a list of compiled codes. + - CodeCacheEntry: Data structure for a single code being compiled by torch.compile. +The caching behavior is always under user control explicitly so that a stronger guarantee can +be provided about cache hit for a specific compiled model. Users can load the compile package +from a different process or host. +""" + +import contextlib +import dataclasses +import functools +import hashlib +import importlib +import logging +import pickle +import platform +import sys +import types +from collections.abc import Generator +from typing import Any, NewType, Optional + +import torch +import torch._inductor.package + +from .bytecode_transformation import get_code_keys + + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class SerializedCode: + co_argcount: int + co_posonlyargcount: int + co_kwonlyargcount: int + co_nlocals: int + co_stacksize: int + co_flags: int + co_code: bytes + co_consts: tuple[Any, ...] + co_names: tuple[str, ...] + co_varnames: tuple[str, ...] + co_filename: str + co_name: str + co_firstlineno: int + co_cellvars: tuple[str, ...] + co_freevars: tuple[str, ...] + co_linetable: Optional[bytes] = None + co_qualname: Optional[str] = None + co_exceptiontable: Optional[bytes] = None + co_lnotab: Optional[str] = None + + @classmethod + @functools.cache + def from_code_object(cls, code: types.CodeType) -> "SerializedCode": + kwargs = {key: getattr(code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.from_code_object(c) if isinstance(c, types.CodeType) else c + for c in kwargs["co_consts"] + ) + return cls(**kwargs) + + @classmethod + @functools.cache + def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType: + kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.to_code_object(c) if isinstance(c, SerializedCode) else c + for c in kwargs["co_consts"] + ) + return types.CodeType( + *kwargs.values(), + ) + + +@dataclasses.dataclass +class _GuardedCodeCacheEntry: + """ + Contains the serializable information associated with a single compilation in dynamo. + To restore an execution of compiled code, we will need to serialize the following data: + - Dynamo bytecode for mapping Python inputs/outputs. + - Dynamo guards. + """ + + guards_state: bytes + dynamo_code: SerializedCode + + +_BackendId = NewType("_BackendId", str) # __compiled_fn +_FunctionId = NewType("_FunctionId", str) # __resume_at + + +@dataclasses.dataclass +class _DynamoCodeCacheEntry: + """ + Contains the serializable information associated with a single code object + in dynamo. To restore an execution of compiled code, we will need the following + ingredients: + 1. The "original" code object, which serves as the entry point for eager + execution, i.e. the code only executed when there's no cache entry hit. + 2. The python module name this code object belongs to, for idenfifying the + enclosing global scope to inject compiled and resume functions. + 3. A list of function names that pointing to this code object. There could be + multiple function objects pointing to the same code such as recursive functions. + 4. A list of guarded code that eval frame dispatches to. + 5. A list of imported module objects unioned from all compiled branches. + 6. A list of "backends" (compiled fx graph) unioned from all compield branches. + """ + + python_code: SerializedCode + python_module: str + function_names: list[_FunctionId] + guarded_codes: list[_GuardedCodeCacheEntry] + import_sources: dict[str, str] + backend_ids: list[_BackendId] + + +@dataclasses.dataclass +class _DynamoCacheEntry: + codes: list[_DynamoCodeCacheEntry] + python_version: str = platform.python_version() + torch_version: str = torch.__version__ + + @property + def backend_ids(self) -> set[_BackendId]: + return {backend_id for code in self.codes for backend_id in code.backend_ids} + + +class CompilePackage: + """ + CompilePackage is considered a low level component and should not be directly exposed to + end users. It has the following interface: + + 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. + a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object. + b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. + 2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object. + 3. `package.install(backends) which will handle all the side-effectful global scope + updates with compiled functions and resume functions. + """ + + def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + self._innermost_fn = None + self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {} + + self._current_entry: Optional[_DynamoCodeCacheEntry] = None + self._installed_globals: dict[types.ModuleType, list[str]] = {} + + # For debugging/testing purpose only. + self._cached_backends: dict[_BackendId, Any] = {} + + self._initialize(fn, dynamo) + # Always go back to a clean state after initialization. + self.uninstall() + self.validate() + + def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + from .eval_frame import innermost_fn + + self._innermost_fn = innermost_fn(fn) + assert self._innermost_fn is not None + if dynamo is not None: + assert isinstance(dynamo, _DynamoCacheEntry) + if dynamo.python_version != platform.python_version(): + raise RuntimeError( + f"Compile package was created with a different Python version: {dynamo.python_version}" + ) + if dynamo.torch_version != torch.__version__: + raise RuntimeError( + f"Compile package was created with a different PyTorch version: {dynamo.torch_version}" + ) + + main, *codes = dynamo.codes + self._codes = {self._innermost_fn.__code__: main} + for code in codes: + self._codes[SerializedCode.to_code_object(code.python_code)] = code + else: + self._add_function( + self._innermost_fn.__code__, self._innermost_fn.__module__ + ) + + def _add_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[_FunctionId] = None, + ) -> None: + if python_code not in self._codes: + code = _DynamoCodeCacheEntry( + python_code=SerializedCode.from_code_object(python_code), + python_module=python_module, + function_names=[], + guarded_codes=[], + import_sources={}, + backend_ids=[], + ) + self._codes[python_code] = code + else: + code = self._codes[python_code] + assert code.python_module == python_module + + if name is not None: + code.function_names.append(name) + + @property + def cached_backends(self) -> dict[_BackendId, Any]: + return self._cached_backends + + @functools.cached_property + def source_id(self) -> str: + assert self._innermost_fn is not None + sha256_hash = hashlib.sha256() + sha256_hash.update(self._innermost_fn.__qualname__.encode()) + sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + @contextlib.contextmanager + def code_context(self, code: types.CodeType) -> Generator[None, None, None]: + assert self._current_entry is None + + entry = self._codes[code] + self._current_entry = entry + try: + yield + finally: + self._current_entry = None + + def add_guarded_code( + self, + guards_state: bytes, + dynamo_code: types.CodeType, + ) -> None: + assert self._current_entry is not None + guarded_code_entry = _GuardedCodeCacheEntry( + guards_state=guards_state, + dynamo_code=SerializedCode.from_code_object(dynamo_code), + ) + self._current_entry.guarded_codes.append(guarded_code_entry) + + def add_resume_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[str], + ) -> None: + self._add_function( + python_code, python_module, _FunctionId(name) if name else None + ) + + def add_import_source(self, alias: str, module_name: str) -> None: + assert self._current_entry is not None + self._current_entry.import_sources[alias] = module_name + + def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None: + assert self._current_entry is not None + assert backend_id.startswith("__compiled_fn_") # sanity check + backend_id = _BackendId(backend_id) + self._current_entry.backend_ids.append(backend_id) + if backend is not None: + self._cached_backends[backend_id] = backend + + def validate(self) -> None: + assert self._current_entry is None + assert self._innermost_fn is not None + assert next(iter(self._codes)) is self._innermost_fn.__code__ + + def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None: + module.__dict__[name] = value + self._installed_globals.setdefault(module, []).append(name) + + def uninstall(self) -> None: + from torch._C._dynamo.eval_frame import _reset_precompile_entries + + assert self._innermost_fn is not None + for module, names in self._installed_globals.items(): + for name in names: + module.__dict__.pop(name) + + self._installed_globals = {} + + _reset_precompile_entries(self._innermost_fn.__code__) + + def install(self, backends: dict[_BackendId, Any]) -> None: + """ + Sync the package states to the compiled function. This includes the following actions: + 1. Clean up the previously installed states. + 2. Install the compiled functions to global scopes. + 3. Install the precompiled cache entries to ExtraStates on the code object. + """ + from torch._C._dynamo.eval_frame import _load_precompile_entry + + self.uninstall() + + for code, entry in self._codes.items(): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) + ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + backend = backends[backend_id] + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) + + for code, entry in self._codes.items(): + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) + + def save(self) -> _DynamoCacheEntry: + self.validate() + return _DynamoCacheEntry(codes=list(self._codes.values())) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index eae8ee46e09b..9c2054ce6f43 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -44,7 +44,7 @@ import traceback import types import typing import weakref -from typing import Any, Callable, cast, NoReturn, Optional, Union +from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union from unittest.mock import patch import torch @@ -167,6 +167,9 @@ from .variables.user_defined import ( ) +if TYPE_CHECKING: + from .package import CompilePackage + log = logging.getLogger(__name__) graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") @@ -1094,6 +1097,7 @@ class InstructionTranslatorBase( is_leaf_tracer: bool parent: Optional["InstructionTranslatorBase"] debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] + package: Optional["CompilePackage"] def mark_inconsistent_side_effects(self): """ @@ -1554,6 +1558,9 @@ class InstructionTranslatorBase( else: value = _import_module(module_name) alias = f"__import_{module_name.replace('.', '_dot_')}" + + if self.package is not None: + self.package.add_import_source(alias, module_name) f_globals = self.output.global_scope assert alias not in f_globals or f_globals[alias] is value f_globals[alias] = value @@ -3187,6 +3194,7 @@ class InstructionTranslatorBase( distributed_state: Optional[DistributedState], # This determines whether to use the execution recorder. closure: Optional[tuple[types.CellType]] = None, + package: Optional["CompilePackage"] = None, ) -> None: super().__init__() self.speculation_log = speculation_log @@ -3245,6 +3253,8 @@ class InstructionTranslatorBase( self.parent = None self.debug_locals = [] + self.package = package + if sys.version_info >= (3, 10): from .resume_execution import ( CO_ASYNC_GENERATOR, @@ -3305,6 +3315,7 @@ class InstructionTranslator(InstructionTranslatorBase): speculation_log: SpeculationLog, exn_vt_stack: ExceptionStack, distributed_state: Optional[DistributedState], + package: Optional["CompilePackage"], ) -> None: _step_logger()( logging.INFO, @@ -3322,6 +3333,7 @@ class InstructionTranslator(InstructionTranslatorBase): global_scope=f_globals, f_code=f_code, torch_function_mode_stack=torch_function_mode_stack, + package=package, ), instructions=instructions, f_locals=f_locals, @@ -3339,6 +3351,7 @@ class InstructionTranslator(InstructionTranslatorBase): speculation_log=speculation_log, exn_vt_stack=exn_vt_stack, distributed_state=distributed_state, + package=package, ) self._throw_if_in_functorch() @@ -3575,12 +3588,19 @@ class InstructionTranslator(InstructionTranslatorBase): # expose code object for debugging purposes self.output.install_global_unsafe(name, new_code) cg.make_function_with_closure(name, new_code, True, stack_len) + package_name = None else: # This is safe: we pre-generate a unique name self.output.install_global_unsafe( name, types.FunctionType(new_code, self.f_globals, name) ) cg.extend_output(cg.load_function_name(name, True, stack_len)) + package_name = name + + if self.package is not None: + self.package.add_resume_function( + new_code, self.f_globals["__name__"], package_name + ) cg.extend_output([cg.create_load(k) for k in argnames]) cg.extend_output(create_call_function(nargs, False)) @@ -3975,6 +3995,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): speculation_log=parent.speculation_log, exn_vt_stack=parent.exn_vt_stack, distributed_state=parent.distributed_state, + package=parent.package, ) self.funcvar = funcvar self.parent = parent diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index e1b32b289abc..85e44b7c7e48 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -213,6 +213,7 @@ def debug_insert_nops( global_scope=globals(), f_code=frame.f_code, torch_function_mode_stack=[], + package=None, ) return wrap_guarded_code( diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 9300a6048496..1ac1051eab64 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4660,6 +4660,7 @@ def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool: return node is None or "example_value" in node.meta or "val" in node.meta +@torch._disable_dynamo def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: cm: AbstractContextManager[None] = ( torch._C._profiler._RecordFunctionFast("Pregraph bytecode") @@ -4670,6 +4671,7 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: return cm +@torch._disable_dynamo def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None: cm.__exit__(None, None, None)