diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index f421597b5eaa..6fa40064beb1 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -322,6 +322,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): speculation_log=SpeculationLog(), exn_vt_stack=ExceptionStack(), distributed_state=None, + package=None, ) with compile_context(CompileContext(CompileId(0, 0))), tracing( tracer.output.tracing_context diff --git a/test/dynamo/test_package.py b/test/dynamo/test_package.py new file mode 100644 index 000000000000..b299a3752db5 --- /dev/null +++ b/test/dynamo/test_package.py @@ -0,0 +1,196 @@ +# Owner(s): ["module: dynamo"] + +import os +import pickle + +import torch +import torch._dynamo.testing +import torch._inductor.config +import torch._inductor.test_case +import torch.onnx.operators +import torch.utils.cpp_extension +from torch._dynamo.package import CompilePackage +from torch._inductor.runtime.runtime_utils import cache_dir + + +class StorageForTesting: + def __init__(self, path: str): + self.path = path + self.backends = {} + + def _write_pickle(self, data, *path: str): + with open(os.path.join(self.path, *path) + ".pickle", "wb") as f: + pickle.dump(data, f) + + def write_dynamo(self, dynamo): + self._write_pickle(dynamo, "dynamo") + + def write_backend(self, backend_id): + os.makedirs(os.path.join(self.path, backend_id), exist_ok=True) + self._write_pickle(self.backends[backend_id], backend_id, "fx_graph") + + def _read_pickle(self, *path): + with open(os.path.join(self.path, *path) + ".pickle", "rb") as f: + return pickle.load(f) + + def read_backend(self, backend_id): + return self._read_pickle(backend_id, "fx_graph") + + def read_dynamo(self): + return self._read_pickle("dynamo") + + def add_backend(self, backend_id, backend): + self.backends[backend_id] = backend + + def save_package(self, dynamo_cache_entry): + self.write_dynamo(dynamo_cache_entry) + for backend_id in dynamo_cache_entry.backend_ids: + self.write_backend(backend_id) + + def load_package(self): + dynamo = self.read_dynamo() + self.backends = {} + for backend_id in dynamo.backend_ids: + self.backends[backend_id] = self.read_backend(backend_id) + return dynamo + + +class TestPackage(torch._inductor.test_case.TestCase): + def storage(self): + path = os.path.join(cache_dir(), f"package_{self.id()}") + os.makedirs(path, exist_ok=True) + return StorageForTesting(path) + + def test_basic_fn(self): + storage = self.storage() + + def fn(x): + return x + 1 + + args = (torch.randn(3, 2),) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn) + expected = compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + storage.add_backend(backend_id, backend) + storage.save_package(package.save()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + + package = CompilePackage(fn, storage.load_package()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(storage.backends) + self.assertEqual(expected, compiled_fn(*args)) + + def test_graph_break_bomb(self): + storage = self.storage() + + def fn(x, l, r): + if l > r: + return x.sum() + mid = (l + r) // 2 + if x.sum() == mid: + return x.sum() + elif x.sum() < mid: + return fn(x, l, mid) + else: + return fn(x, mid + 1, r) + + def guard_filter_fn(guards): + return [ + guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH") + for guard in guards + ] + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize( + backend="eager", package=package, guard_filter_fn=guard_filter_fn + )(fn) + N = 10 + args_list = [(torch.tensor(x), 0, N - 1) for x in range(N)] + for args in args_list: + compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + storage.add_backend(backend_id, backend) + storage.save_package(package.save()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + for args in args_list: + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args) + package = CompilePackage(fn, storage.load_package()) + compiled_fn = torch._dynamo.optimize( + backend="eager", package=package, guard_filter_fn=guard_filter_fn + )(fn) + package.install(storage.backends) + for args in args_list: + self.assertEqual(compiled_fn(*args), args[0].sum()) + + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(torch.tensor(N), 0, N - 1) + + def test_dynamic_shape(self): + storage = self.storage() + + def fn(x): + return x + x.shape[0] + + args = (torch.randn(3, 2),) + args1 = (torch.randn(5, 2),) + args2 = (torch.randn(7, 2),) + expected1 = fn(*args1) + + torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5) + + # Saving + package = CompilePackage(fn) + compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn) + compiled_fn(*args) + for backend_id, backend in package.cached_backends.items(): + storage.add_backend(backend_id, backend) + storage.save_package(package.save()) + + # Loading + torch._dynamo.reset() + with torch.compiler.set_stance("fail_on_recompile"): + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args1) + + package = CompilePackage(fn, storage.load_package()) + compiled_fn = torch._dynamo.optimize(package=package)(fn) + package.install(storage.backends) + + self.assertEqual(expected1, compiled_fn(*args1)) + + with self.assertRaisesRegex( + RuntimeError, + "Detected recompile when torch.compile stance is 'fail_on_recompile'", + ): + compiled_fn(*args2) + + +if __name__ == "__main__": + from torch._dynamo.test_case import run_tests + + run_tests() diff --git a/torch/_C/_dynamo/eval_frame.pyi b/torch/_C/_dynamo/eval_frame.pyi index da0b32637759..c89de9a1ff9f 100644 --- a/torch/_C/_dynamo/eval_frame.pyi +++ b/torch/_C/_dynamo/eval_frame.pyi @@ -2,7 +2,7 @@ import enum import types from typing import overload -from torch._dynamo.types import DynamoCallback, DynamoGuardHook +from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ... def set_skip_guard_eval_unsafe(value: bool) -> bool: ... @@ -57,3 +57,7 @@ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ... py_opcode_caches: list[int] def code_framelocals_names(code: types.CodeType) -> tuple[str]: ... +def _load_precompile_entry( + code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType +) -> None: ... +def _reset_precompile_entries(code: types.CodeType) -> None: ... diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index df0d04557707..c0ef12ae3a06 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -161,6 +161,7 @@ except ModuleNotFoundError: if typing.TYPE_CHECKING: from .backends.registry import CompilerFn + from .package import CompilePackage from .repro.after_dynamo import WrapBackendDebug from .types import BytecodeHook, CacheEntry, DynamoFrameType from .variables.builder import FrameStateSizeEntry @@ -478,6 +479,7 @@ class ConvertFrameAssert: one_graph: bool = True, export: bool = False, export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, ) -> None: # assert export_constraints is None reset_graph_break_dup_checker() @@ -485,6 +487,7 @@ class ConvertFrameAssert: self._one_graph = one_graph self._export = export self._export_constraints = export_constraints + self._package = package @property def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]: @@ -640,6 +643,7 @@ class ConvertFrameAssert: frame_state=frame_state, compile_id=compile_id, skip=skip + 1, + package=self._package, ) @@ -648,9 +652,12 @@ def convert_frame_assert( one_graph: bool = True, export: bool = False, export_constraints: Optional[typing.Never] = None, + package: Optional[CompilePackage] = None, ) -> ConvertFrameAssert: """Fully convert a frame into an FX graph""" - return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints) + return ConvertFrameAssert( + compiler_fn, one_graph, export, export_constraints, package + ) from collections import OrderedDict @@ -693,6 +700,7 @@ def _compile( *, compile_id: CompileId, skip: int = 0, + package: Optional[CompilePackage] = None, ) -> ConvertFrameReturn: from torch.fx.experimental.validator import ( bisect, @@ -717,7 +725,7 @@ def _compile( ) -> None: nonlocal output nonlocal tracer - speculation_log.restart() + speculation_log.restart() # type: ignore[has-type] exn_vt_stack = ExceptionStack() tracer = InstructionTranslator( instructions, @@ -733,9 +741,10 @@ def _compile( export, export_constraints, frame_state=frame_state, - speculation_log=speculation_log, + speculation_log=speculation_log, # type: ignore[has-type] exn_vt_stack=exn_vt_stack, - distributed_state=distributed_state, + distributed_state=distributed_state, # type: ignore[has-type] + package=package, ) try: @@ -743,7 +752,7 @@ def _compile( with tracing(tracer.output.tracing_context), tracer.set_current_tx(): tracer.run() except exc.UnspecializeRestartAnalysis: - speculation_log.clear() + speculation_log.clear() # type: ignore[has-type] raise except ( exc.SpeculationRestartAnalysis, @@ -857,7 +866,7 @@ def _compile( log.debug("No graph captured with one_graph=True") return ConvertFrameReturn() - assert distributed_state is None or distributed_state.all_states is not None, ( + assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type] "compiler collective wasn't run before compilation completed" ) @@ -936,8 +945,13 @@ def _compile( cache_entry, hooks.guard_fail_fn if hooks else None, hooks.guard_filter_fn if hooks else None, + guards_serialization_mode="save" if package else None, ) + if package is not None: + assert check_fn.guards_state is not None + package.add_guarded_code(check_fn.guards_state, out_code) + compile_id_str = str(compile_id) if compile_id is not None else "Unknown" annotation_str = "Torch-Compiled Region: " + compile_id_str guarded_code = GuardedCode( @@ -958,6 +972,9 @@ def _compile( return wrap_guarded_code(guarded_code) metrics_context = get_metrics_context() + code_context = ( + package.code_context(code) if package is not None else contextlib.nullcontext() + ) with ( _use_lazy_graph_module(config.use_lazy_graph_module), compile_context(CompileContext(compile_id)), @@ -971,6 +988,7 @@ def _compile( phase_name="entire_frame_compile", dynamo_compile_column_us="dynamo_cumulative_compile_time_us", ), + code_context, ): restart_reasons: set[str] = set() # This is shared across restarts @@ -1226,9 +1244,12 @@ class ConvertFrame: self, compiler_fn: CompilerFn, hooks: Hooks, + package: Optional[CompilePackage] = None, ) -> None: self._torchdynamo_orig_callable = compiler_fn - self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False) + self._inner_convert = convert_frame_assert( + compiler_fn, one_graph=False, package=package + ) self._hooks = hooks @property @@ -1331,9 +1352,11 @@ class ConvertFrame: return ConvertFrameReturn() -def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame: +def convert_frame( + compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None +) -> ConvertFrame: """Try to convert a frame into an FX graph, if error leave frame unmodified""" - return ConvertFrame(compiler_fn, hooks) + return ConvertFrame(compiler_fn, hooks, package=package) # TODO mlazos: add support for same args, or record them diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 771d1071f763..380291f9f5ba 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -531,6 +531,7 @@ class _TorchDynamoContext: export=False, dynamic=None, compiler_config=None, + package=None, ) -> None: super().__init__() assert callable(callback) or callback is False or callback is None @@ -543,6 +544,7 @@ class _TorchDynamoContext: self.compiler_config = compiler_config self.cleanup_fns: list[Callable[[], Any]] = [] self.enter_exit_hooks = [] + self._package = package patch_fn() # Save the backends so that we can reset them during torch._dynamo.reset @@ -792,6 +794,7 @@ class OptimizeContext(_TorchDynamoContext): rebuild_ctx: Optional[ Callable[[], Union[OptimizeContext, _NullDecorator]] ] = None, + package=None, ) -> None: def on_enter(): install_generation_tagging_init() @@ -805,6 +808,7 @@ class OptimizeContext(_TorchDynamoContext): export=export, dynamic=dynamic, compiler_config=compiler_config, + package=package, ) if config.compiled_autograd: @@ -928,6 +932,7 @@ def _optimize_catch_errors( dynamic=None, compiler_config=None, rebuild_ctx=None, + package=None, ): return OptimizeContext( convert_frame.catch_errors_wrapper(compile_fn, hooks), @@ -937,6 +942,7 @@ def _optimize_catch_errors( dynamic=dynamic, compiler_config=compiler_config, rebuild_ctx=rebuild_ctx, + package=package, ) @@ -1027,6 +1033,7 @@ def _optimize( guard_filter_fn=None, disable=False, dynamic=None, + package=None, ) -> Union[OptimizeContext, _NullDecorator]: """ The main entrypoint of TorchDynamo. Do graph capture and call @@ -1079,6 +1086,7 @@ def _optimize( dynamic=dynamic, hooks=hooks, rebuild_ctx=rebuild_ctx, + package=package, ) backend = get_compiler_fn(backend) @@ -1090,7 +1098,7 @@ def _optimize( # _optimize_catch_errors in the field _torchdynamo_orig_callable. This can # be used by eval_frame.c to insert a guard on the backend. return _optimize_catch_errors( - convert_frame.convert_frame(backend, hooks=hooks), + convert_frame.convert_frame(backend, hooks=hooks, package=package), hooks, backend_ctx_ctor, dynamic=dynamic, @@ -1100,6 +1108,7 @@ def _optimize( else None ), rebuild_ctx=rebuild_ctx, + package=package, ) @@ -1990,6 +1999,7 @@ def _optimize_assert( export=False, export_constraints=None, dynamic=None, + package=None, ): """ The same as `torch._dynamo.optimize(backend, nopython=True)` @@ -2001,13 +2011,17 @@ def _optimize_assert( return _optimize_catch_errors( convert_frame.convert_frame_assert( - backend, export=export, export_constraints=export_constraints + backend, + export=export, + export_constraints=export_constraints, + package=package, ), hooks, backend_ctx_ctor, export=export, dynamic=dynamic, rebuild_ctx=rebuild_ctx, + package=package, ) diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index c744b65e1f74..d5afdaff120e 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -497,6 +497,8 @@ def get_verbose_code_parts( def convert_int_to_concrete_values(dim) -> Optional[int]: + if dim is None: + return None if not is_symbolic(dim): return dim else: @@ -1457,7 +1459,11 @@ class GuardBuilder(GuardBuilderBase): def TYPE_MATCH(self, guard: Guard) -> None: # ___check_type_id is same as `id(type(x)) == y` value = self.get(guard.name) - t = type(value) + if isinstance(value, torch._subclasses.FakeTensor) and value.pytype: + t = value.pytype + else: + t = type(value) + if self.serialization_mode == "save": if t.__qualname__ != t.__name__: raise_local_type_error(value) diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 8e0899450680..797abbeb4a8f 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -368,6 +368,7 @@ class OutputGraph(OutputGraphGuardsState): global_scope: Scope, f_code, torch_function_mode_stack, + package, ): super().__init__( local_scope, @@ -471,6 +472,7 @@ class OutputGraph(OutputGraphGuardsState): self.compiler_fn: Optional[CompilerFn] = compiler_fn self.root_tx = root_tx + self.package = package # Given a source, what are the user stacks of all locations that # accessed it? # @@ -1715,6 +1717,9 @@ class OutputGraph(OutputGraphGuardsState): # replace compiled_fn with the real forward method compiled_fn = lazy_gm.forward + if self.package is not None: + self.package.add_backend_id(name, compiled_fn) + compiled_fn = disable( compiled_fn, reason="do not trace Dynamo-compiled graph" ) diff --git a/torch/_dynamo/package.py b/torch/_dynamo/package.py new file mode 100644 index 000000000000..41157e07f87d --- /dev/null +++ b/torch/_dynamo/package.py @@ -0,0 +1,331 @@ +""" +This module provides the infrastructure for creating and managing compile package +for torch.compile. We mainly have two abstractions here: + - CompilePackage: Overarching data structure for store and lookup a list of compiled codes. + - CodeCacheEntry: Data structure for a single code being compiled by torch.compile. +The caching behavior is always under user control explicitly so that a stronger guarantee can +be provided about cache hit for a specific compiled model. Users can load the compile package +from a different process or host. +""" + +import contextlib +import dataclasses +import functools +import hashlib +import importlib +import logging +import pickle +import platform +import sys +import types +from collections.abc import Generator +from typing import Any, NewType, Optional + +import torch +import torch._inductor.package + +from .bytecode_transformation import get_code_keys + + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass(frozen=True) +class SerializedCode: + co_argcount: int + co_posonlyargcount: int + co_kwonlyargcount: int + co_nlocals: int + co_stacksize: int + co_flags: int + co_code: bytes + co_consts: tuple[Any, ...] + co_names: tuple[str, ...] + co_varnames: tuple[str, ...] + co_filename: str + co_name: str + co_firstlineno: int + co_cellvars: tuple[str, ...] + co_freevars: tuple[str, ...] + co_linetable: Optional[bytes] = None + co_qualname: Optional[str] = None + co_exceptiontable: Optional[bytes] = None + co_lnotab: Optional[str] = None + + @classmethod + @functools.cache + def from_code_object(cls, code: types.CodeType) -> "SerializedCode": + kwargs = {key: getattr(code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.from_code_object(c) if isinstance(c, types.CodeType) else c + for c in kwargs["co_consts"] + ) + return cls(**kwargs) + + @classmethod + @functools.cache + def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType: + kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()} + kwargs["co_consts"] = tuple( + cls.to_code_object(c) if isinstance(c, SerializedCode) else c + for c in kwargs["co_consts"] + ) + return types.CodeType( + *kwargs.values(), + ) + + +@dataclasses.dataclass +class _GuardedCodeCacheEntry: + """ + Contains the serializable information associated with a single compilation in dynamo. + To restore an execution of compiled code, we will need to serialize the following data: + - Dynamo bytecode for mapping Python inputs/outputs. + - Dynamo guards. + """ + + guards_state: bytes + dynamo_code: SerializedCode + + +_BackendId = NewType("_BackendId", str) # __compiled_fn +_FunctionId = NewType("_FunctionId", str) # __resume_at + + +@dataclasses.dataclass +class _DynamoCodeCacheEntry: + """ + Contains the serializable information associated with a single code object + in dynamo. To restore an execution of compiled code, we will need the following + ingredients: + 1. The "original" code object, which serves as the entry point for eager + execution, i.e. the code only executed when there's no cache entry hit. + 2. The python module name this code object belongs to, for idenfifying the + enclosing global scope to inject compiled and resume functions. + 3. A list of function names that pointing to this code object. There could be + multiple function objects pointing to the same code such as recursive functions. + 4. A list of guarded code that eval frame dispatches to. + 5. A list of imported module objects unioned from all compiled branches. + 6. A list of "backends" (compiled fx graph) unioned from all compield branches. + """ + + python_code: SerializedCode + python_module: str + function_names: list[_FunctionId] + guarded_codes: list[_GuardedCodeCacheEntry] + import_sources: dict[str, str] + backend_ids: list[_BackendId] + + +@dataclasses.dataclass +class _DynamoCacheEntry: + codes: list[_DynamoCodeCacheEntry] + python_version: str = platform.python_version() + torch_version: str = torch.__version__ + + @property + def backend_ids(self) -> set[_BackendId]: + return {backend_id for code in self.codes for backend_id in code.backend_ids} + + +class CompilePackage: + """ + CompilePackage is considered a low level component and should not be directly exposed to + end users. It has the following interface: + + 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. + a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object. + b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. + 2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object. + 3. `package.install(backends) which will handle all the side-effectful global scope + updates with compiled functions and resume functions. + """ + + def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + self._innermost_fn = None + self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {} + + self._current_entry: Optional[_DynamoCodeCacheEntry] = None + self._installed_globals: dict[types.ModuleType, list[str]] = {} + + # For debugging/testing purpose only. + self._cached_backends: dict[_BackendId, Any] = {} + + self._initialize(fn, dynamo) + # Always go back to a clean state after initialization. + self.uninstall() + self.validate() + + def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None: + from .eval_frame import innermost_fn + + self._innermost_fn = innermost_fn(fn) + assert self._innermost_fn is not None + if dynamo is not None: + assert isinstance(dynamo, _DynamoCacheEntry) + if dynamo.python_version != platform.python_version(): + raise RuntimeError( + f"Compile package was created with a different Python version: {dynamo.python_version}" + ) + if dynamo.torch_version != torch.__version__: + raise RuntimeError( + f"Compile package was created with a different PyTorch version: {dynamo.torch_version}" + ) + + main, *codes = dynamo.codes + self._codes = {self._innermost_fn.__code__: main} + for code in codes: + self._codes[SerializedCode.to_code_object(code.python_code)] = code + else: + self._add_function( + self._innermost_fn.__code__, self._innermost_fn.__module__ + ) + + def _add_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[_FunctionId] = None, + ) -> None: + if python_code not in self._codes: + code = _DynamoCodeCacheEntry( + python_code=SerializedCode.from_code_object(python_code), + python_module=python_module, + function_names=[], + guarded_codes=[], + import_sources={}, + backend_ids=[], + ) + self._codes[python_code] = code + else: + code = self._codes[python_code] + assert code.python_module == python_module + + if name is not None: + code.function_names.append(name) + + @property + def cached_backends(self) -> dict[_BackendId, Any]: + return self._cached_backends + + @functools.cached_property + def source_id(self) -> str: + assert self._innermost_fn is not None + sha256_hash = hashlib.sha256() + sha256_hash.update(self._innermost_fn.__qualname__.encode()) + sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode()) + return sha256_hash.hexdigest() + + @contextlib.contextmanager + def code_context(self, code: types.CodeType) -> Generator[None, None, None]: + assert self._current_entry is None + + entry = self._codes[code] + self._current_entry = entry + try: + yield + finally: + self._current_entry = None + + def add_guarded_code( + self, + guards_state: bytes, + dynamo_code: types.CodeType, + ) -> None: + assert self._current_entry is not None + guarded_code_entry = _GuardedCodeCacheEntry( + guards_state=guards_state, + dynamo_code=SerializedCode.from_code_object(dynamo_code), + ) + self._current_entry.guarded_codes.append(guarded_code_entry) + + def add_resume_function( + self, + python_code: types.CodeType, + python_module: str, + name: Optional[str], + ) -> None: + self._add_function( + python_code, python_module, _FunctionId(name) if name else None + ) + + def add_import_source(self, alias: str, module_name: str) -> None: + assert self._current_entry is not None + self._current_entry.import_sources[alias] = module_name + + def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None: + assert self._current_entry is not None + assert backend_id.startswith("__compiled_fn_") # sanity check + backend_id = _BackendId(backend_id) + self._current_entry.backend_ids.append(backend_id) + if backend is not None: + self._cached_backends[backend_id] = backend + + def validate(self) -> None: + assert self._current_entry is None + assert self._innermost_fn is not None + assert next(iter(self._codes)) is self._innermost_fn.__code__ + + def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None: + module.__dict__[name] = value + self._installed_globals.setdefault(module, []).append(name) + + def uninstall(self) -> None: + from torch._C._dynamo.eval_frame import _reset_precompile_entries + + assert self._innermost_fn is not None + for module, names in self._installed_globals.items(): + for name in names: + module.__dict__.pop(name) + + self._installed_globals = {} + + _reset_precompile_entries(self._innermost_fn.__code__) + + def install(self, backends: dict[_BackendId, Any]) -> None: + """ + Sync the package states to the compiled function. This includes the following actions: + 1. Clean up the previously installed states. + 2. Install the compiled functions to global scopes. + 3. Install the precompiled cache entries to ExtraStates on the code object. + """ + from torch._C._dynamo.eval_frame import _load_precompile_entry + + self.uninstall() + + for code, entry in self._codes.items(): + module = sys.modules[entry.python_module] + for alias, module_name in entry.import_sources.items(): + self._install_global( + module, alias, importlib.import_module(module_name) + ) + for function_name in entry.function_names: + fn = types.FunctionType(code, module.__dict__, function_name) + self._install_global(module, function_name, fn) + for backend_id in entry.backend_ids: + backend = backends[backend_id] + self._install_global( + module, + backend_id, + torch._dynamo.disable(backend), + ) + + for code, entry in self._codes.items(): + for guarded_code in entry.guarded_codes: + guards_state = pickle.loads(guarded_code.guards_state) + assert isinstance(guards_state, torch._dynamo.guards.GuardsState) + check_fn_manager = torch._dynamo.guards.CheckFunctionManager( + code, + guards_state.output_graph, + guards_serialization_mode="load", + shape_code_parts=guards_state.shape_code_parts, + ) + _load_precompile_entry( + code, + check_fn_manager.guard_manager, + SerializedCode.to_code_object(guarded_code.dynamo_code), + ) + + def save(self) -> _DynamoCacheEntry: + self.validate() + return _DynamoCacheEntry(codes=list(self._codes.values())) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index eae8ee46e09b..9c2054ce6f43 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -44,7 +44,7 @@ import traceback import types import typing import weakref -from typing import Any, Callable, cast, NoReturn, Optional, Union +from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union from unittest.mock import patch import torch @@ -167,6 +167,9 @@ from .variables.user_defined import ( ) +if TYPE_CHECKING: + from .package import CompilePackage + log = logging.getLogger(__name__) graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call") @@ -1094,6 +1097,7 @@ class InstructionTranslatorBase( is_leaf_tracer: bool parent: Optional["InstructionTranslatorBase"] debug_locals: list[tuple[VariableTracker, list[VariableTracker]]] + package: Optional["CompilePackage"] def mark_inconsistent_side_effects(self): """ @@ -1554,6 +1558,9 @@ class InstructionTranslatorBase( else: value = _import_module(module_name) alias = f"__import_{module_name.replace('.', '_dot_')}" + + if self.package is not None: + self.package.add_import_source(alias, module_name) f_globals = self.output.global_scope assert alias not in f_globals or f_globals[alias] is value f_globals[alias] = value @@ -3187,6 +3194,7 @@ class InstructionTranslatorBase( distributed_state: Optional[DistributedState], # This determines whether to use the execution recorder. closure: Optional[tuple[types.CellType]] = None, + package: Optional["CompilePackage"] = None, ) -> None: super().__init__() self.speculation_log = speculation_log @@ -3245,6 +3253,8 @@ class InstructionTranslatorBase( self.parent = None self.debug_locals = [] + self.package = package + if sys.version_info >= (3, 10): from .resume_execution import ( CO_ASYNC_GENERATOR, @@ -3305,6 +3315,7 @@ class InstructionTranslator(InstructionTranslatorBase): speculation_log: SpeculationLog, exn_vt_stack: ExceptionStack, distributed_state: Optional[DistributedState], + package: Optional["CompilePackage"], ) -> None: _step_logger()( logging.INFO, @@ -3322,6 +3333,7 @@ class InstructionTranslator(InstructionTranslatorBase): global_scope=f_globals, f_code=f_code, torch_function_mode_stack=torch_function_mode_stack, + package=package, ), instructions=instructions, f_locals=f_locals, @@ -3339,6 +3351,7 @@ class InstructionTranslator(InstructionTranslatorBase): speculation_log=speculation_log, exn_vt_stack=exn_vt_stack, distributed_state=distributed_state, + package=package, ) self._throw_if_in_functorch() @@ -3575,12 +3588,19 @@ class InstructionTranslator(InstructionTranslatorBase): # expose code object for debugging purposes self.output.install_global_unsafe(name, new_code) cg.make_function_with_closure(name, new_code, True, stack_len) + package_name = None else: # This is safe: we pre-generate a unique name self.output.install_global_unsafe( name, types.FunctionType(new_code, self.f_globals, name) ) cg.extend_output(cg.load_function_name(name, True, stack_len)) + package_name = name + + if self.package is not None: + self.package.add_resume_function( + new_code, self.f_globals["__name__"], package_name + ) cg.extend_output([cg.create_load(k) for k in argnames]) cg.extend_output(create_call_function(nargs, False)) @@ -3975,6 +3995,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase): speculation_log=parent.speculation_log, exn_vt_stack=parent.exn_vt_stack, distributed_state=parent.distributed_state, + package=parent.package, ) self.funcvar = funcvar self.parent = parent diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index e1b32b289abc..85e44b7c7e48 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -213,6 +213,7 @@ def debug_insert_nops( global_scope=globals(), f_code=frame.f_code, torch_function_mode_stack=[], + package=None, ) return wrap_guarded_code( diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 9300a6048496..1ac1051eab64 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -4660,6 +4660,7 @@ def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool: return node is None or "example_value" in node.meta or "val" in node.meta +@torch._disable_dynamo def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: cm: AbstractContextManager[None] = ( torch._C._profiler._RecordFunctionFast("Pregraph bytecode") @@ -4670,6 +4671,7 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: return cm +@torch._disable_dynamo def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None: cm.__exit__(None, None, None)