[precompile] Add CompilePackage to serialize dynamo states. (#155118)

Adding a per torch.compile() object CompilePackage which tracks dynamo artifact. CompilePackage is considered a low level component and should not be directly exposed to end users. It has the following interface:

1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
     a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object.
     b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
2. `package.save()` which dumps the dynamo states into _DynamoCacheEntry.
3. `package.install(backends)` which will handle all the side-effectful global scope updates with compiled functions and resume functions.

This diff focus on making the low level mechanism for precompile. It will be left to upper level interface to use these API to build more user-facing frontend.

Differential Revision: [D75956538](https://our.internmc.facebook.com/intern/diff/D75956538/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155118
Approved by: https://github.com/jamesjwu

Co-authored-by: James Wu <jjwu@meta.com>
This commit is contained in:
James Wu
2025-06-12 09:46:38 -07:00
committed by PyTorch MergeBot
parent 670dab6c63
commit b2fc9cfea1
11 changed files with 618 additions and 14 deletions

View File

@ -322,6 +322,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
speculation_log=SpeculationLog(),
exn_vt_stack=ExceptionStack(),
distributed_state=None,
package=None,
)
with compile_context(CompileContext(CompileId(0, 0))), tracing(
tracer.output.tracing_context

196
test/dynamo/test_package.py Normal file
View File

@ -0,0 +1,196 @@
# Owner(s): ["module: dynamo"]
import os
import pickle
import torch
import torch._dynamo.testing
import torch._inductor.config
import torch._inductor.test_case
import torch.onnx.operators
import torch.utils.cpp_extension
from torch._dynamo.package import CompilePackage
from torch._inductor.runtime.runtime_utils import cache_dir
class StorageForTesting:
def __init__(self, path: str):
self.path = path
self.backends = {}
def _write_pickle(self, data, *path: str):
with open(os.path.join(self.path, *path) + ".pickle", "wb") as f:
pickle.dump(data, f)
def write_dynamo(self, dynamo):
self._write_pickle(dynamo, "dynamo")
def write_backend(self, backend_id):
os.makedirs(os.path.join(self.path, backend_id), exist_ok=True)
self._write_pickle(self.backends[backend_id], backend_id, "fx_graph")
def _read_pickle(self, *path):
with open(os.path.join(self.path, *path) + ".pickle", "rb") as f:
return pickle.load(f)
def read_backend(self, backend_id):
return self._read_pickle(backend_id, "fx_graph")
def read_dynamo(self):
return self._read_pickle("dynamo")
def add_backend(self, backend_id, backend):
self.backends[backend_id] = backend
def save_package(self, dynamo_cache_entry):
self.write_dynamo(dynamo_cache_entry)
for backend_id in dynamo_cache_entry.backend_ids:
self.write_backend(backend_id)
def load_package(self):
dynamo = self.read_dynamo()
self.backends = {}
for backend_id in dynamo.backend_ids:
self.backends[backend_id] = self.read_backend(backend_id)
return dynamo
class TestPackage(torch._inductor.test_case.TestCase):
def storage(self):
path = os.path.join(cache_dir(), f"package_{self.id()}")
os.makedirs(path, exist_ok=True)
return StorageForTesting(path)
def test_basic_fn(self):
storage = self.storage()
def fn(x):
return x + 1
args = (torch.randn(3, 2),)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
expected = compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
storage.add_backend(backend_id, backend)
storage.save_package(package.save())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package = CompilePackage(fn, storage.load_package())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(storage.backends)
self.assertEqual(expected, compiled_fn(*args))
def test_graph_break_bomb(self):
storage = self.storage()
def fn(x, l, r):
if l > r:
return x.sum()
mid = (l + r) // 2
if x.sum() == mid:
return x.sum()
elif x.sum() < mid:
return fn(x, l, mid)
else:
return fn(x, mid + 1, r)
def guard_filter_fn(guards):
return [
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
for guard in guards
]
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
N = 10
args_list = [(torch.tensor(x), 0, N - 1) for x in range(N)]
for args in args_list:
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
storage.add_backend(backend_id, backend)
storage.save_package(package.save())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
for args in args_list:
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args)
package = CompilePackage(fn, storage.load_package())
compiled_fn = torch._dynamo.optimize(
backend="eager", package=package, guard_filter_fn=guard_filter_fn
)(fn)
package.install(storage.backends)
for args in args_list:
self.assertEqual(compiled_fn(*args), args[0].sum())
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(torch.tensor(N), 0, N - 1)
def test_dynamic_shape(self):
storage = self.storage()
def fn(x):
return x + x.shape[0]
args = (torch.randn(3, 2),)
args1 = (torch.randn(5, 2),)
args2 = (torch.randn(7, 2),)
expected1 = fn(*args1)
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
# Saving
package = CompilePackage(fn)
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
compiled_fn(*args)
for backend_id, backend in package.cached_backends.items():
storage.add_backend(backend_id, backend)
storage.save_package(package.save())
# Loading
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args1)
package = CompilePackage(fn, storage.load_package())
compiled_fn = torch._dynamo.optimize(package=package)(fn)
package.install(storage.backends)
self.assertEqual(expected1, compiled_fn(*args1))
with self.assertRaisesRegex(
RuntimeError,
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
):
compiled_fn(*args2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()

View File

@ -2,7 +2,7 @@ import enum
import types
from typing import overload
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
@ -57,3 +57,7 @@ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
py_opcode_caches: list[int]
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
def _load_precompile_entry(
code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
) -> None: ...
def _reset_precompile_entries(code: types.CodeType) -> None: ...

View File

@ -161,6 +161,7 @@ except ModuleNotFoundError:
if typing.TYPE_CHECKING:
from .backends.registry import CompilerFn
from .package import CompilePackage
from .repro.after_dynamo import WrapBackendDebug
from .types import BytecodeHook, CacheEntry, DynamoFrameType
from .variables.builder import FrameStateSizeEntry
@ -478,6 +479,7 @@ class ConvertFrameAssert:
one_graph: bool = True,
export: bool = False,
export_constraints: Optional[typing.Never] = None,
package: Optional[CompilePackage] = None,
) -> None:
# assert export_constraints is None
reset_graph_break_dup_checker()
@ -485,6 +487,7 @@ class ConvertFrameAssert:
self._one_graph = one_graph
self._export = export
self._export_constraints = export_constraints
self._package = package
@property
def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]:
@ -640,6 +643,7 @@ class ConvertFrameAssert:
frame_state=frame_state,
compile_id=compile_id,
skip=skip + 1,
package=self._package,
)
@ -648,9 +652,12 @@ def convert_frame_assert(
one_graph: bool = True,
export: bool = False,
export_constraints: Optional[typing.Never] = None,
package: Optional[CompilePackage] = None,
) -> ConvertFrameAssert:
"""Fully convert a frame into an FX graph"""
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
return ConvertFrameAssert(
compiler_fn, one_graph, export, export_constraints, package
)
from collections import OrderedDict
@ -693,6 +700,7 @@ def _compile(
*,
compile_id: CompileId,
skip: int = 0,
package: Optional[CompilePackage] = None,
) -> ConvertFrameReturn:
from torch.fx.experimental.validator import (
bisect,
@ -717,7 +725,7 @@ def _compile(
) -> None:
nonlocal output
nonlocal tracer
speculation_log.restart()
speculation_log.restart() # type: ignore[has-type]
exn_vt_stack = ExceptionStack()
tracer = InstructionTranslator(
instructions,
@ -733,9 +741,10 @@ def _compile(
export,
export_constraints,
frame_state=frame_state,
speculation_log=speculation_log,
speculation_log=speculation_log, # type: ignore[has-type]
exn_vt_stack=exn_vt_stack,
distributed_state=distributed_state,
distributed_state=distributed_state, # type: ignore[has-type]
package=package,
)
try:
@ -743,7 +752,7 @@ def _compile(
with tracing(tracer.output.tracing_context), tracer.set_current_tx():
tracer.run()
except exc.UnspecializeRestartAnalysis:
speculation_log.clear()
speculation_log.clear() # type: ignore[has-type]
raise
except (
exc.SpeculationRestartAnalysis,
@ -857,7 +866,7 @@ def _compile(
log.debug("No graph captured with one_graph=True")
return ConvertFrameReturn()
assert distributed_state is None or distributed_state.all_states is not None, (
assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type]
"compiler collective wasn't run before compilation completed"
)
@ -936,8 +945,13 @@ def _compile(
cache_entry,
hooks.guard_fail_fn if hooks else None,
hooks.guard_filter_fn if hooks else None,
guards_serialization_mode="save" if package else None,
)
if package is not None:
assert check_fn.guards_state is not None
package.add_guarded_code(check_fn.guards_state, out_code)
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
annotation_str = "Torch-Compiled Region: " + compile_id_str
guarded_code = GuardedCode(
@ -958,6 +972,9 @@ def _compile(
return wrap_guarded_code(guarded_code)
metrics_context = get_metrics_context()
code_context = (
package.code_context(code) if package is not None else contextlib.nullcontext()
)
with (
_use_lazy_graph_module(config.use_lazy_graph_module),
compile_context(CompileContext(compile_id)),
@ -971,6 +988,7 @@ def _compile(
phase_name="entire_frame_compile",
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
),
code_context,
):
restart_reasons: set[str] = set()
# This is shared across restarts
@ -1226,9 +1244,12 @@ class ConvertFrame:
self,
compiler_fn: CompilerFn,
hooks: Hooks,
package: Optional[CompilePackage] = None,
) -> None:
self._torchdynamo_orig_callable = compiler_fn
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
self._inner_convert = convert_frame_assert(
compiler_fn, one_graph=False, package=package
)
self._hooks = hooks
@property
@ -1331,9 +1352,11 @@ class ConvertFrame:
return ConvertFrameReturn()
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame:
def convert_frame(
compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None
) -> ConvertFrame:
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
return ConvertFrame(compiler_fn, hooks)
return ConvertFrame(compiler_fn, hooks, package=package)
# TODO mlazos: add support for same args, or record them

View File

@ -531,6 +531,7 @@ class _TorchDynamoContext:
export=False,
dynamic=None,
compiler_config=None,
package=None,
) -> None:
super().__init__()
assert callable(callback) or callback is False or callback is None
@ -543,6 +544,7 @@ class _TorchDynamoContext:
self.compiler_config = compiler_config
self.cleanup_fns: list[Callable[[], Any]] = []
self.enter_exit_hooks = []
self._package = package
patch_fn()
# Save the backends so that we can reset them during torch._dynamo.reset
@ -792,6 +794,7 @@ class OptimizeContext(_TorchDynamoContext):
rebuild_ctx: Optional[
Callable[[], Union[OptimizeContext, _NullDecorator]]
] = None,
package=None,
) -> None:
def on_enter():
install_generation_tagging_init()
@ -805,6 +808,7 @@ class OptimizeContext(_TorchDynamoContext):
export=export,
dynamic=dynamic,
compiler_config=compiler_config,
package=package,
)
if config.compiled_autograd:
@ -928,6 +932,7 @@ def _optimize_catch_errors(
dynamic=None,
compiler_config=None,
rebuild_ctx=None,
package=None,
):
return OptimizeContext(
convert_frame.catch_errors_wrapper(compile_fn, hooks),
@ -937,6 +942,7 @@ def _optimize_catch_errors(
dynamic=dynamic,
compiler_config=compiler_config,
rebuild_ctx=rebuild_ctx,
package=package,
)
@ -1027,6 +1033,7 @@ def _optimize(
guard_filter_fn=None,
disable=False,
dynamic=None,
package=None,
) -> Union[OptimizeContext, _NullDecorator]:
"""
The main entrypoint of TorchDynamo. Do graph capture and call
@ -1079,6 +1086,7 @@ def _optimize(
dynamic=dynamic,
hooks=hooks,
rebuild_ctx=rebuild_ctx,
package=package,
)
backend = get_compiler_fn(backend)
@ -1090,7 +1098,7 @@ def _optimize(
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
# be used by eval_frame.c to insert a guard on the backend.
return _optimize_catch_errors(
convert_frame.convert_frame(backend, hooks=hooks),
convert_frame.convert_frame(backend, hooks=hooks, package=package),
hooks,
backend_ctx_ctor,
dynamic=dynamic,
@ -1100,6 +1108,7 @@ def _optimize(
else None
),
rebuild_ctx=rebuild_ctx,
package=package,
)
@ -1990,6 +1999,7 @@ def _optimize_assert(
export=False,
export_constraints=None,
dynamic=None,
package=None,
):
"""
The same as `torch._dynamo.optimize(backend, nopython=True)`
@ -2001,13 +2011,17 @@ def _optimize_assert(
return _optimize_catch_errors(
convert_frame.convert_frame_assert(
backend, export=export, export_constraints=export_constraints
backend,
export=export,
export_constraints=export_constraints,
package=package,
),
hooks,
backend_ctx_ctor,
export=export,
dynamic=dynamic,
rebuild_ctx=rebuild_ctx,
package=package,
)

View File

@ -497,6 +497,8 @@ def get_verbose_code_parts(
def convert_int_to_concrete_values(dim) -> Optional[int]:
if dim is None:
return None
if not is_symbolic(dim):
return dim
else:
@ -1457,7 +1459,11 @@ class GuardBuilder(GuardBuilderBase):
def TYPE_MATCH(self, guard: Guard) -> None:
# ___check_type_id is same as `id(type(x)) == y`
value = self.get(guard.name)
t = type(value)
if isinstance(value, torch._subclasses.FakeTensor) and value.pytype:
t = value.pytype
else:
t = type(value)
if self.serialization_mode == "save":
if t.__qualname__ != t.__name__:
raise_local_type_error(value)

View File

@ -368,6 +368,7 @@ class OutputGraph(OutputGraphGuardsState):
global_scope: Scope,
f_code,
torch_function_mode_stack,
package,
):
super().__init__(
local_scope,
@ -471,6 +472,7 @@ class OutputGraph(OutputGraphGuardsState):
self.compiler_fn: Optional[CompilerFn] = compiler_fn
self.root_tx = root_tx
self.package = package
# Given a source, what are the user stacks of all locations that
# accessed it?
#
@ -1715,6 +1717,9 @@ class OutputGraph(OutputGraphGuardsState):
# replace compiled_fn with the real forward method
compiled_fn = lazy_gm.forward
if self.package is not None:
self.package.add_backend_id(name, compiled_fn)
compiled_fn = disable(
compiled_fn, reason="do not trace Dynamo-compiled graph"
)

331
torch/_dynamo/package.py Normal file
View File

@ -0,0 +1,331 @@
"""
This module provides the infrastructure for creating and managing compile package
for torch.compile. We mainly have two abstractions here:
- CompilePackage: Overarching data structure for store and lookup a list of compiled codes.
- CodeCacheEntry: Data structure for a single code being compiled by torch.compile.
The caching behavior is always under user control explicitly so that a stronger guarantee can
be provided about cache hit for a specific compiled model. Users can load the compile package
from a different process or host.
"""
import contextlib
import dataclasses
import functools
import hashlib
import importlib
import logging
import pickle
import platform
import sys
import types
from collections.abc import Generator
from typing import Any, NewType, Optional
import torch
import torch._inductor.package
from .bytecode_transformation import get_code_keys
logger = logging.getLogger(__name__)
@dataclasses.dataclass(frozen=True)
class SerializedCode:
co_argcount: int
co_posonlyargcount: int
co_kwonlyargcount: int
co_nlocals: int
co_stacksize: int
co_flags: int
co_code: bytes
co_consts: tuple[Any, ...]
co_names: tuple[str, ...]
co_varnames: tuple[str, ...]
co_filename: str
co_name: str
co_firstlineno: int
co_cellvars: tuple[str, ...]
co_freevars: tuple[str, ...]
co_linetable: Optional[bytes] = None
co_qualname: Optional[str] = None
co_exceptiontable: Optional[bytes] = None
co_lnotab: Optional[str] = None
@classmethod
@functools.cache
def from_code_object(cls, code: types.CodeType) -> "SerializedCode":
kwargs = {key: getattr(code, key) for key in get_code_keys()}
kwargs["co_consts"] = tuple(
cls.from_code_object(c) if isinstance(c, types.CodeType) else c
for c in kwargs["co_consts"]
)
return cls(**kwargs)
@classmethod
@functools.cache
def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType:
kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()}
kwargs["co_consts"] = tuple(
cls.to_code_object(c) if isinstance(c, SerializedCode) else c
for c in kwargs["co_consts"]
)
return types.CodeType(
*kwargs.values(),
)
@dataclasses.dataclass
class _GuardedCodeCacheEntry:
"""
Contains the serializable information associated with a single compilation in dynamo.
To restore an execution of compiled code, we will need to serialize the following data:
- Dynamo bytecode for mapping Python inputs/outputs.
- Dynamo guards.
"""
guards_state: bytes
dynamo_code: SerializedCode
_BackendId = NewType("_BackendId", str) # __compiled_fn
_FunctionId = NewType("_FunctionId", str) # __resume_at
@dataclasses.dataclass
class _DynamoCodeCacheEntry:
"""
Contains the serializable information associated with a single code object
in dynamo. To restore an execution of compiled code, we will need the following
ingredients:
1. The "original" code object, which serves as the entry point for eager
execution, i.e. the code only executed when there's no cache entry hit.
2. The python module name this code object belongs to, for idenfifying the
enclosing global scope to inject compiled and resume functions.
3. A list of function names that pointing to this code object. There could be
multiple function objects pointing to the same code such as recursive functions.
4. A list of guarded code that eval frame dispatches to.
5. A list of imported module objects unioned from all compiled branches.
6. A list of "backends" (compiled fx graph) unioned from all compield branches.
"""
python_code: SerializedCode
python_module: str
function_names: list[_FunctionId]
guarded_codes: list[_GuardedCodeCacheEntry]
import_sources: dict[str, str]
backend_ids: list[_BackendId]
@dataclasses.dataclass
class _DynamoCacheEntry:
codes: list[_DynamoCodeCacheEntry]
python_version: str = platform.python_version()
torch_version: str = torch.__version__
@property
def backend_ids(self) -> set[_BackendId]:
return {backend_id for code in self.codes for backend_id in code.backend_ids}
class CompilePackage:
"""
CompilePackage is considered a low level component and should not be directly exposed to
end users. It has the following interface:
1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object.
b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object.
3. `package.install(backends) which will handle all the side-effectful global scope
updates with compiled functions and resume functions.
"""
def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
self._innermost_fn = None
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
self._current_entry: Optional[_DynamoCodeCacheEntry] = None
self._installed_globals: dict[types.ModuleType, list[str]] = {}
# For debugging/testing purpose only.
self._cached_backends: dict[_BackendId, Any] = {}
self._initialize(fn, dynamo)
# Always go back to a clean state after initialization.
self.uninstall()
self.validate()
def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
from .eval_frame import innermost_fn
self._innermost_fn = innermost_fn(fn)
assert self._innermost_fn is not None
if dynamo is not None:
assert isinstance(dynamo, _DynamoCacheEntry)
if dynamo.python_version != platform.python_version():
raise RuntimeError(
f"Compile package was created with a different Python version: {dynamo.python_version}"
)
if dynamo.torch_version != torch.__version__:
raise RuntimeError(
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
)
main, *codes = dynamo.codes
self._codes = {self._innermost_fn.__code__: main}
for code in codes:
self._codes[SerializedCode.to_code_object(code.python_code)] = code
else:
self._add_function(
self._innermost_fn.__code__, self._innermost_fn.__module__
)
def _add_function(
self,
python_code: types.CodeType,
python_module: str,
name: Optional[_FunctionId] = None,
) -> None:
if python_code not in self._codes:
code = _DynamoCodeCacheEntry(
python_code=SerializedCode.from_code_object(python_code),
python_module=python_module,
function_names=[],
guarded_codes=[],
import_sources={},
backend_ids=[],
)
self._codes[python_code] = code
else:
code = self._codes[python_code]
assert code.python_module == python_module
if name is not None:
code.function_names.append(name)
@property
def cached_backends(self) -> dict[_BackendId, Any]:
return self._cached_backends
@functools.cached_property
def source_id(self) -> str:
assert self._innermost_fn is not None
sha256_hash = hashlib.sha256()
sha256_hash.update(self._innermost_fn.__qualname__.encode())
sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode())
return sha256_hash.hexdigest()
@contextlib.contextmanager
def code_context(self, code: types.CodeType) -> Generator[None, None, None]:
assert self._current_entry is None
entry = self._codes[code]
self._current_entry = entry
try:
yield
finally:
self._current_entry = None
def add_guarded_code(
self,
guards_state: bytes,
dynamo_code: types.CodeType,
) -> None:
assert self._current_entry is not None
guarded_code_entry = _GuardedCodeCacheEntry(
guards_state=guards_state,
dynamo_code=SerializedCode.from_code_object(dynamo_code),
)
self._current_entry.guarded_codes.append(guarded_code_entry)
def add_resume_function(
self,
python_code: types.CodeType,
python_module: str,
name: Optional[str],
) -> None:
self._add_function(
python_code, python_module, _FunctionId(name) if name else None
)
def add_import_source(self, alias: str, module_name: str) -> None:
assert self._current_entry is not None
self._current_entry.import_sources[alias] = module_name
def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None:
assert self._current_entry is not None
assert backend_id.startswith("__compiled_fn_") # sanity check
backend_id = _BackendId(backend_id)
self._current_entry.backend_ids.append(backend_id)
if backend is not None:
self._cached_backends[backend_id] = backend
def validate(self) -> None:
assert self._current_entry is None
assert self._innermost_fn is not None
assert next(iter(self._codes)) is self._innermost_fn.__code__
def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None:
module.__dict__[name] = value
self._installed_globals.setdefault(module, []).append(name)
def uninstall(self) -> None:
from torch._C._dynamo.eval_frame import _reset_precompile_entries
assert self._innermost_fn is not None
for module, names in self._installed_globals.items():
for name in names:
module.__dict__.pop(name)
self._installed_globals = {}
_reset_precompile_entries(self._innermost_fn.__code__)
def install(self, backends: dict[_BackendId, Any]) -> None:
"""
Sync the package states to the compiled function. This includes the following actions:
1. Clean up the previously installed states.
2. Install the compiled functions to global scopes.
3. Install the precompiled cache entries to ExtraStates on the code object.
"""
from torch._C._dynamo.eval_frame import _load_precompile_entry
self.uninstall()
for code, entry in self._codes.items():
module = sys.modules[entry.python_module]
for alias, module_name in entry.import_sources.items():
self._install_global(
module, alias, importlib.import_module(module_name)
)
for function_name in entry.function_names:
fn = types.FunctionType(code, module.__dict__, function_name)
self._install_global(module, function_name, fn)
for backend_id in entry.backend_ids:
backend = backends[backend_id]
self._install_global(
module,
backend_id,
torch._dynamo.disable(backend),
)
for code, entry in self._codes.items():
for guarded_code in entry.guarded_codes:
guards_state = pickle.loads(guarded_code.guards_state)
assert isinstance(guards_state, torch._dynamo.guards.GuardsState)
check_fn_manager = torch._dynamo.guards.CheckFunctionManager(
code,
guards_state.output_graph,
guards_serialization_mode="load",
shape_code_parts=guards_state.shape_code_parts,
)
_load_precompile_entry(
code,
check_fn_manager.guard_manager,
SerializedCode.to_code_object(guarded_code.dynamo_code),
)
def save(self) -> _DynamoCacheEntry:
self.validate()
return _DynamoCacheEntry(codes=list(self._codes.values()))

View File

@ -44,7 +44,7 @@ import traceback
import types
import typing
import weakref
from typing import Any, Callable, cast, NoReturn, Optional, Union
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from unittest.mock import patch
import torch
@ -167,6 +167,9 @@ from .variables.user_defined import (
)
if TYPE_CHECKING:
from .package import CompilePackage
log = logging.getLogger(__name__)
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
@ -1094,6 +1097,7 @@ class InstructionTranslatorBase(
is_leaf_tracer: bool
parent: Optional["InstructionTranslatorBase"]
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
package: Optional["CompilePackage"]
def mark_inconsistent_side_effects(self):
"""
@ -1554,6 +1558,9 @@ class InstructionTranslatorBase(
else:
value = _import_module(module_name)
alias = f"__import_{module_name.replace('.', '_dot_')}"
if self.package is not None:
self.package.add_import_source(alias, module_name)
f_globals = self.output.global_scope
assert alias not in f_globals or f_globals[alias] is value
f_globals[alias] = value
@ -3187,6 +3194,7 @@ class InstructionTranslatorBase(
distributed_state: Optional[DistributedState],
# This determines whether to use the execution recorder.
closure: Optional[tuple[types.CellType]] = None,
package: Optional["CompilePackage"] = None,
) -> None:
super().__init__()
self.speculation_log = speculation_log
@ -3245,6 +3253,8 @@ class InstructionTranslatorBase(
self.parent = None
self.debug_locals = []
self.package = package
if sys.version_info >= (3, 10):
from .resume_execution import (
CO_ASYNC_GENERATOR,
@ -3305,6 +3315,7 @@ class InstructionTranslator(InstructionTranslatorBase):
speculation_log: SpeculationLog,
exn_vt_stack: ExceptionStack,
distributed_state: Optional[DistributedState],
package: Optional["CompilePackage"],
) -> None:
_step_logger()(
logging.INFO,
@ -3322,6 +3333,7 @@ class InstructionTranslator(InstructionTranslatorBase):
global_scope=f_globals,
f_code=f_code,
torch_function_mode_stack=torch_function_mode_stack,
package=package,
),
instructions=instructions,
f_locals=f_locals,
@ -3339,6 +3351,7 @@ class InstructionTranslator(InstructionTranslatorBase):
speculation_log=speculation_log,
exn_vt_stack=exn_vt_stack,
distributed_state=distributed_state,
package=package,
)
self._throw_if_in_functorch()
@ -3575,12 +3588,19 @@ class InstructionTranslator(InstructionTranslatorBase):
# expose code object for debugging purposes
self.output.install_global_unsafe(name, new_code)
cg.make_function_with_closure(name, new_code, True, stack_len)
package_name = None
else:
# This is safe: we pre-generate a unique name
self.output.install_global_unsafe(
name, types.FunctionType(new_code, self.f_globals, name)
)
cg.extend_output(cg.load_function_name(name, True, stack_len))
package_name = name
if self.package is not None:
self.package.add_resume_function(
new_code, self.f_globals["__name__"], package_name
)
cg.extend_output([cg.create_load(k) for k in argnames])
cg.extend_output(create_call_function(nargs, False))
@ -3975,6 +3995,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
speculation_log=parent.speculation_log,
exn_vt_stack=parent.exn_vt_stack,
distributed_state=parent.distributed_state,
package=parent.package,
)
self.funcvar = funcvar
self.parent = parent

View File

@ -213,6 +213,7 @@ def debug_insert_nops(
global_scope=globals(),
f_code=frame.f_code,
torch_function_mode_stack=[],
package=None,
)
return wrap_guarded_code(

View File

@ -4660,6 +4660,7 @@ def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
return node is None or "example_value" in node.meta or "val" in node.meta
@torch._disable_dynamo
def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
cm: AbstractContextManager[None] = (
torch._C._profiler._RecordFunctionFast("Pregraph bytecode")
@ -4670,6 +4671,7 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
return cm
@torch._disable_dynamo
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
cm.__exit__(None, None, None)