mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[precompile] Add CompilePackage to serialize dynamo states. (#155118)
Adding a per torch.compile() object CompilePackage which tracks dynamo artifact. CompilePackage is considered a low level component and should not be directly exposed to end users. It has the following interface: 1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states. a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object. b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state. 2. `package.save()` which dumps the dynamo states into _DynamoCacheEntry. 3. `package.install(backends)` which will handle all the side-effectful global scope updates with compiled functions and resume functions. This diff focus on making the low level mechanism for precompile. It will be left to upper level interface to use these API to build more user-facing frontend. Differential Revision: [D75956538](https://our.internmc.facebook.com/intern/diff/D75956538/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/155118 Approved by: https://github.com/jamesjwu Co-authored-by: James Wu <jjwu@meta.com>
This commit is contained in:
committed by
PyTorch MergeBot
parent
670dab6c63
commit
b2fc9cfea1
@ -322,6 +322,7 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase):
|
||||
speculation_log=SpeculationLog(),
|
||||
exn_vt_stack=ExceptionStack(),
|
||||
distributed_state=None,
|
||||
package=None,
|
||||
)
|
||||
with compile_context(CompileContext(CompileId(0, 0))), tracing(
|
||||
tracer.output.tracing_context
|
||||
|
196
test/dynamo/test_package.py
Normal file
196
test/dynamo/test_package.py
Normal file
@ -0,0 +1,196 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import os
|
||||
import pickle
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch._inductor.config
|
||||
import torch._inductor.test_case
|
||||
import torch.onnx.operators
|
||||
import torch.utils.cpp_extension
|
||||
from torch._dynamo.package import CompilePackage
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
|
||||
|
||||
class StorageForTesting:
|
||||
def __init__(self, path: str):
|
||||
self.path = path
|
||||
self.backends = {}
|
||||
|
||||
def _write_pickle(self, data, *path: str):
|
||||
with open(os.path.join(self.path, *path) + ".pickle", "wb") as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
def write_dynamo(self, dynamo):
|
||||
self._write_pickle(dynamo, "dynamo")
|
||||
|
||||
def write_backend(self, backend_id):
|
||||
os.makedirs(os.path.join(self.path, backend_id), exist_ok=True)
|
||||
self._write_pickle(self.backends[backend_id], backend_id, "fx_graph")
|
||||
|
||||
def _read_pickle(self, *path):
|
||||
with open(os.path.join(self.path, *path) + ".pickle", "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def read_backend(self, backend_id):
|
||||
return self._read_pickle(backend_id, "fx_graph")
|
||||
|
||||
def read_dynamo(self):
|
||||
return self._read_pickle("dynamo")
|
||||
|
||||
def add_backend(self, backend_id, backend):
|
||||
self.backends[backend_id] = backend
|
||||
|
||||
def save_package(self, dynamo_cache_entry):
|
||||
self.write_dynamo(dynamo_cache_entry)
|
||||
for backend_id in dynamo_cache_entry.backend_ids:
|
||||
self.write_backend(backend_id)
|
||||
|
||||
def load_package(self):
|
||||
dynamo = self.read_dynamo()
|
||||
self.backends = {}
|
||||
for backend_id in dynamo.backend_ids:
|
||||
self.backends[backend_id] = self.read_backend(backend_id)
|
||||
return dynamo
|
||||
|
||||
|
||||
class TestPackage(torch._inductor.test_case.TestCase):
|
||||
def storage(self):
|
||||
path = os.path.join(cache_dir(), f"package_{self.id()}")
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return StorageForTesting(path)
|
||||
|
||||
def test_basic_fn(self):
|
||||
storage = self.storage()
|
||||
|
||||
def fn(x):
|
||||
return x + 1
|
||||
|
||||
args = (torch.randn(3, 2),)
|
||||
|
||||
# Saving
|
||||
package = CompilePackage(fn)
|
||||
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
|
||||
expected = compiled_fn(*args)
|
||||
for backend_id, backend in package.cached_backends.items():
|
||||
storage.add_backend(backend_id, backend)
|
||||
storage.save_package(package.save())
|
||||
|
||||
# Loading
|
||||
torch._dynamo.reset()
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
||||
):
|
||||
compiled_fn(*args)
|
||||
|
||||
package = CompilePackage(fn, storage.load_package())
|
||||
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
||||
package.install(storage.backends)
|
||||
self.assertEqual(expected, compiled_fn(*args))
|
||||
|
||||
def test_graph_break_bomb(self):
|
||||
storage = self.storage()
|
||||
|
||||
def fn(x, l, r):
|
||||
if l > r:
|
||||
return x.sum()
|
||||
mid = (l + r) // 2
|
||||
if x.sum() == mid:
|
||||
return x.sum()
|
||||
elif x.sum() < mid:
|
||||
return fn(x, l, mid)
|
||||
else:
|
||||
return fn(x, mid + 1, r)
|
||||
|
||||
def guard_filter_fn(guards):
|
||||
return [
|
||||
guard.guard_type not in ("CLOSURE_MATCH", "FUNCTION_MATCH")
|
||||
for guard in guards
|
||||
]
|
||||
|
||||
# Saving
|
||||
package = CompilePackage(fn)
|
||||
compiled_fn = torch._dynamo.optimize(
|
||||
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
||||
)(fn)
|
||||
N = 10
|
||||
args_list = [(torch.tensor(x), 0, N - 1) for x in range(N)]
|
||||
for args in args_list:
|
||||
compiled_fn(*args)
|
||||
for backend_id, backend in package.cached_backends.items():
|
||||
storage.add_backend(backend_id, backend)
|
||||
storage.save_package(package.save())
|
||||
|
||||
# Loading
|
||||
torch._dynamo.reset()
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
for args in args_list:
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
||||
):
|
||||
compiled_fn(*args)
|
||||
package = CompilePackage(fn, storage.load_package())
|
||||
compiled_fn = torch._dynamo.optimize(
|
||||
backend="eager", package=package, guard_filter_fn=guard_filter_fn
|
||||
)(fn)
|
||||
package.install(storage.backends)
|
||||
for args in args_list:
|
||||
self.assertEqual(compiled_fn(*args), args[0].sum())
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
||||
):
|
||||
compiled_fn(torch.tensor(N), 0, N - 1)
|
||||
|
||||
def test_dynamic_shape(self):
|
||||
storage = self.storage()
|
||||
|
||||
def fn(x):
|
||||
return x + x.shape[0]
|
||||
|
||||
args = (torch.randn(3, 2),)
|
||||
args1 = (torch.randn(5, 2),)
|
||||
args2 = (torch.randn(7, 2),)
|
||||
expected1 = fn(*args1)
|
||||
|
||||
torch._dynamo.mark_dynamic(args[0], 0, min=3, max=5)
|
||||
|
||||
# Saving
|
||||
package = CompilePackage(fn)
|
||||
compiled_fn = torch._dynamo.optimize(backend="eager", package=package)(fn)
|
||||
compiled_fn(*args)
|
||||
for backend_id, backend in package.cached_backends.items():
|
||||
storage.add_backend(backend_id, backend)
|
||||
storage.save_package(package.save())
|
||||
|
||||
# Loading
|
||||
torch._dynamo.reset()
|
||||
with torch.compiler.set_stance("fail_on_recompile"):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
||||
):
|
||||
compiled_fn(*args1)
|
||||
|
||||
package = CompilePackage(fn, storage.load_package())
|
||||
compiled_fn = torch._dynamo.optimize(package=package)(fn)
|
||||
package.install(storage.backends)
|
||||
|
||||
self.assertEqual(expected1, compiled_fn(*args1))
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Detected recompile when torch.compile stance is 'fail_on_recompile'",
|
||||
):
|
||||
compiled_fn(*args2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
@ -2,7 +2,7 @@ import enum
|
||||
import types
|
||||
from typing import overload
|
||||
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn
|
||||
|
||||
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
||||
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
|
||||
@ -57,3 +57,7 @@ def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
|
||||
py_opcode_caches: list[int]
|
||||
|
||||
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
|
||||
def _load_precompile_entry(
|
||||
code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
|
||||
) -> None: ...
|
||||
def _reset_precompile_entries(code: types.CodeType) -> None: ...
|
||||
|
@ -161,6 +161,7 @@ except ModuleNotFoundError:
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .backends.registry import CompilerFn
|
||||
from .package import CompilePackage
|
||||
from .repro.after_dynamo import WrapBackendDebug
|
||||
from .types import BytecodeHook, CacheEntry, DynamoFrameType
|
||||
from .variables.builder import FrameStateSizeEntry
|
||||
@ -478,6 +479,7 @@ class ConvertFrameAssert:
|
||||
one_graph: bool = True,
|
||||
export: bool = False,
|
||||
export_constraints: Optional[typing.Never] = None,
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> None:
|
||||
# assert export_constraints is None
|
||||
reset_graph_break_dup_checker()
|
||||
@ -485,6 +487,7 @@ class ConvertFrameAssert:
|
||||
self._one_graph = one_graph
|
||||
self._export = export
|
||||
self._export_constraints = export_constraints
|
||||
self._package = package
|
||||
|
||||
@property
|
||||
def _clone_with_backend(self) -> Callable[[CompilerFn], ConvertFrameAssert]:
|
||||
@ -640,6 +643,7 @@ class ConvertFrameAssert:
|
||||
frame_state=frame_state,
|
||||
compile_id=compile_id,
|
||||
skip=skip + 1,
|
||||
package=self._package,
|
||||
)
|
||||
|
||||
|
||||
@ -648,9 +652,12 @@ def convert_frame_assert(
|
||||
one_graph: bool = True,
|
||||
export: bool = False,
|
||||
export_constraints: Optional[typing.Never] = None,
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> ConvertFrameAssert:
|
||||
"""Fully convert a frame into an FX graph"""
|
||||
return ConvertFrameAssert(compiler_fn, one_graph, export, export_constraints)
|
||||
return ConvertFrameAssert(
|
||||
compiler_fn, one_graph, export, export_constraints, package
|
||||
)
|
||||
|
||||
|
||||
from collections import OrderedDict
|
||||
@ -693,6 +700,7 @@ def _compile(
|
||||
*,
|
||||
compile_id: CompileId,
|
||||
skip: int = 0,
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> ConvertFrameReturn:
|
||||
from torch.fx.experimental.validator import (
|
||||
bisect,
|
||||
@ -717,7 +725,7 @@ def _compile(
|
||||
) -> None:
|
||||
nonlocal output
|
||||
nonlocal tracer
|
||||
speculation_log.restart()
|
||||
speculation_log.restart() # type: ignore[has-type]
|
||||
exn_vt_stack = ExceptionStack()
|
||||
tracer = InstructionTranslator(
|
||||
instructions,
|
||||
@ -733,9 +741,10 @@ def _compile(
|
||||
export,
|
||||
export_constraints,
|
||||
frame_state=frame_state,
|
||||
speculation_log=speculation_log,
|
||||
speculation_log=speculation_log, # type: ignore[has-type]
|
||||
exn_vt_stack=exn_vt_stack,
|
||||
distributed_state=distributed_state,
|
||||
distributed_state=distributed_state, # type: ignore[has-type]
|
||||
package=package,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -743,7 +752,7 @@ def _compile(
|
||||
with tracing(tracer.output.tracing_context), tracer.set_current_tx():
|
||||
tracer.run()
|
||||
except exc.UnspecializeRestartAnalysis:
|
||||
speculation_log.clear()
|
||||
speculation_log.clear() # type: ignore[has-type]
|
||||
raise
|
||||
except (
|
||||
exc.SpeculationRestartAnalysis,
|
||||
@ -857,7 +866,7 @@ def _compile(
|
||||
log.debug("No graph captured with one_graph=True")
|
||||
return ConvertFrameReturn()
|
||||
|
||||
assert distributed_state is None or distributed_state.all_states is not None, (
|
||||
assert distributed_state is None or distributed_state.all_states is not None, ( # type: ignore[has-type]
|
||||
"compiler collective wasn't run before compilation completed"
|
||||
)
|
||||
|
||||
@ -936,8 +945,13 @@ def _compile(
|
||||
cache_entry,
|
||||
hooks.guard_fail_fn if hooks else None,
|
||||
hooks.guard_filter_fn if hooks else None,
|
||||
guards_serialization_mode="save" if package else None,
|
||||
)
|
||||
|
||||
if package is not None:
|
||||
assert check_fn.guards_state is not None
|
||||
package.add_guarded_code(check_fn.guards_state, out_code)
|
||||
|
||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
||||
guarded_code = GuardedCode(
|
||||
@ -958,6 +972,9 @@ def _compile(
|
||||
return wrap_guarded_code(guarded_code)
|
||||
|
||||
metrics_context = get_metrics_context()
|
||||
code_context = (
|
||||
package.code_context(code) if package is not None else contextlib.nullcontext()
|
||||
)
|
||||
with (
|
||||
_use_lazy_graph_module(config.use_lazy_graph_module),
|
||||
compile_context(CompileContext(compile_id)),
|
||||
@ -971,6 +988,7 @@ def _compile(
|
||||
phase_name="entire_frame_compile",
|
||||
dynamo_compile_column_us="dynamo_cumulative_compile_time_us",
|
||||
),
|
||||
code_context,
|
||||
):
|
||||
restart_reasons: set[str] = set()
|
||||
# This is shared across restarts
|
||||
@ -1226,9 +1244,12 @@ class ConvertFrame:
|
||||
self,
|
||||
compiler_fn: CompilerFn,
|
||||
hooks: Hooks,
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> None:
|
||||
self._torchdynamo_orig_callable = compiler_fn
|
||||
self._inner_convert = convert_frame_assert(compiler_fn, one_graph=False)
|
||||
self._inner_convert = convert_frame_assert(
|
||||
compiler_fn, one_graph=False, package=package
|
||||
)
|
||||
self._hooks = hooks
|
||||
|
||||
@property
|
||||
@ -1331,9 +1352,11 @@ class ConvertFrame:
|
||||
return ConvertFrameReturn()
|
||||
|
||||
|
||||
def convert_frame(compiler_fn: CompilerFn, hooks: Hooks) -> ConvertFrame:
|
||||
def convert_frame(
|
||||
compiler_fn: CompilerFn, hooks: Hooks, package: Optional[CompilePackage] = None
|
||||
) -> ConvertFrame:
|
||||
"""Try to convert a frame into an FX graph, if error leave frame unmodified"""
|
||||
return ConvertFrame(compiler_fn, hooks)
|
||||
return ConvertFrame(compiler_fn, hooks, package=package)
|
||||
|
||||
|
||||
# TODO mlazos: add support for same args, or record them
|
||||
|
@ -531,6 +531,7 @@ class _TorchDynamoContext:
|
||||
export=False,
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
package=None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert callable(callback) or callback is False or callback is None
|
||||
@ -543,6 +544,7 @@ class _TorchDynamoContext:
|
||||
self.compiler_config = compiler_config
|
||||
self.cleanup_fns: list[Callable[[], Any]] = []
|
||||
self.enter_exit_hooks = []
|
||||
self._package = package
|
||||
patch_fn()
|
||||
|
||||
# Save the backends so that we can reset them during torch._dynamo.reset
|
||||
@ -792,6 +794,7 @@ class OptimizeContext(_TorchDynamoContext):
|
||||
rebuild_ctx: Optional[
|
||||
Callable[[], Union[OptimizeContext, _NullDecorator]]
|
||||
] = None,
|
||||
package=None,
|
||||
) -> None:
|
||||
def on_enter():
|
||||
install_generation_tagging_init()
|
||||
@ -805,6 +808,7 @@ class OptimizeContext(_TorchDynamoContext):
|
||||
export=export,
|
||||
dynamic=dynamic,
|
||||
compiler_config=compiler_config,
|
||||
package=package,
|
||||
)
|
||||
|
||||
if config.compiled_autograd:
|
||||
@ -928,6 +932,7 @@ def _optimize_catch_errors(
|
||||
dynamic=None,
|
||||
compiler_config=None,
|
||||
rebuild_ctx=None,
|
||||
package=None,
|
||||
):
|
||||
return OptimizeContext(
|
||||
convert_frame.catch_errors_wrapper(compile_fn, hooks),
|
||||
@ -937,6 +942,7 @@ def _optimize_catch_errors(
|
||||
dynamic=dynamic,
|
||||
compiler_config=compiler_config,
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
package=package,
|
||||
)
|
||||
|
||||
|
||||
@ -1027,6 +1033,7 @@ def _optimize(
|
||||
guard_filter_fn=None,
|
||||
disable=False,
|
||||
dynamic=None,
|
||||
package=None,
|
||||
) -> Union[OptimizeContext, _NullDecorator]:
|
||||
"""
|
||||
The main entrypoint of TorchDynamo. Do graph capture and call
|
||||
@ -1079,6 +1086,7 @@ def _optimize(
|
||||
dynamic=dynamic,
|
||||
hooks=hooks,
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
package=package,
|
||||
)
|
||||
|
||||
backend = get_compiler_fn(backend)
|
||||
@ -1090,7 +1098,7 @@ def _optimize(
|
||||
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
|
||||
# be used by eval_frame.c to insert a guard on the backend.
|
||||
return _optimize_catch_errors(
|
||||
convert_frame.convert_frame(backend, hooks=hooks),
|
||||
convert_frame.convert_frame(backend, hooks=hooks, package=package),
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
dynamic=dynamic,
|
||||
@ -1100,6 +1108,7 @@ def _optimize(
|
||||
else None
|
||||
),
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
package=package,
|
||||
)
|
||||
|
||||
|
||||
@ -1990,6 +1999,7 @@ def _optimize_assert(
|
||||
export=False,
|
||||
export_constraints=None,
|
||||
dynamic=None,
|
||||
package=None,
|
||||
):
|
||||
"""
|
||||
The same as `torch._dynamo.optimize(backend, nopython=True)`
|
||||
@ -2001,13 +2011,17 @@ def _optimize_assert(
|
||||
|
||||
return _optimize_catch_errors(
|
||||
convert_frame.convert_frame_assert(
|
||||
backend, export=export, export_constraints=export_constraints
|
||||
backend,
|
||||
export=export,
|
||||
export_constraints=export_constraints,
|
||||
package=package,
|
||||
),
|
||||
hooks,
|
||||
backend_ctx_ctor,
|
||||
export=export,
|
||||
dynamic=dynamic,
|
||||
rebuild_ctx=rebuild_ctx,
|
||||
package=package,
|
||||
)
|
||||
|
||||
|
||||
|
@ -497,6 +497,8 @@ def get_verbose_code_parts(
|
||||
|
||||
|
||||
def convert_int_to_concrete_values(dim) -> Optional[int]:
|
||||
if dim is None:
|
||||
return None
|
||||
if not is_symbolic(dim):
|
||||
return dim
|
||||
else:
|
||||
@ -1457,7 +1459,11 @@ class GuardBuilder(GuardBuilderBase):
|
||||
def TYPE_MATCH(self, guard: Guard) -> None:
|
||||
# ___check_type_id is same as `id(type(x)) == y`
|
||||
value = self.get(guard.name)
|
||||
t = type(value)
|
||||
if isinstance(value, torch._subclasses.FakeTensor) and value.pytype:
|
||||
t = value.pytype
|
||||
else:
|
||||
t = type(value)
|
||||
|
||||
if self.serialization_mode == "save":
|
||||
if t.__qualname__ != t.__name__:
|
||||
raise_local_type_error(value)
|
||||
|
@ -368,6 +368,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
global_scope: Scope,
|
||||
f_code,
|
||||
torch_function_mode_stack,
|
||||
package,
|
||||
):
|
||||
super().__init__(
|
||||
local_scope,
|
||||
@ -471,6 +472,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
self.compiler_fn: Optional[CompilerFn] = compiler_fn
|
||||
self.root_tx = root_tx
|
||||
|
||||
self.package = package
|
||||
# Given a source, what are the user stacks of all locations that
|
||||
# accessed it?
|
||||
#
|
||||
@ -1715,6 +1717,9 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# replace compiled_fn with the real forward method
|
||||
compiled_fn = lazy_gm.forward
|
||||
|
||||
if self.package is not None:
|
||||
self.package.add_backend_id(name, compiled_fn)
|
||||
|
||||
compiled_fn = disable(
|
||||
compiled_fn, reason="do not trace Dynamo-compiled graph"
|
||||
)
|
||||
|
331
torch/_dynamo/package.py
Normal file
331
torch/_dynamo/package.py
Normal file
@ -0,0 +1,331 @@
|
||||
"""
|
||||
This module provides the infrastructure for creating and managing compile package
|
||||
for torch.compile. We mainly have two abstractions here:
|
||||
- CompilePackage: Overarching data structure for store and lookup a list of compiled codes.
|
||||
- CodeCacheEntry: Data structure for a single code being compiled by torch.compile.
|
||||
The caching behavior is always under user control explicitly so that a stronger guarantee can
|
||||
be provided about cache hit for a specific compiled model. Users can load the compile package
|
||||
from a different process or host.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import dataclasses
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import logging
|
||||
import pickle
|
||||
import platform
|
||||
import sys
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
from typing import Any, NewType, Optional
|
||||
|
||||
import torch
|
||||
import torch._inductor.package
|
||||
|
||||
from .bytecode_transformation import get_code_keys
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SerializedCode:
|
||||
co_argcount: int
|
||||
co_posonlyargcount: int
|
||||
co_kwonlyargcount: int
|
||||
co_nlocals: int
|
||||
co_stacksize: int
|
||||
co_flags: int
|
||||
co_code: bytes
|
||||
co_consts: tuple[Any, ...]
|
||||
co_names: tuple[str, ...]
|
||||
co_varnames: tuple[str, ...]
|
||||
co_filename: str
|
||||
co_name: str
|
||||
co_firstlineno: int
|
||||
co_cellvars: tuple[str, ...]
|
||||
co_freevars: tuple[str, ...]
|
||||
co_linetable: Optional[bytes] = None
|
||||
co_qualname: Optional[str] = None
|
||||
co_exceptiontable: Optional[bytes] = None
|
||||
co_lnotab: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def from_code_object(cls, code: types.CodeType) -> "SerializedCode":
|
||||
kwargs = {key: getattr(code, key) for key in get_code_keys()}
|
||||
kwargs["co_consts"] = tuple(
|
||||
cls.from_code_object(c) if isinstance(c, types.CodeType) else c
|
||||
for c in kwargs["co_consts"]
|
||||
)
|
||||
return cls(**kwargs)
|
||||
|
||||
@classmethod
|
||||
@functools.cache
|
||||
def to_code_object(cls, serialized_code: "SerializedCode") -> types.CodeType:
|
||||
kwargs = {key: getattr(serialized_code, key) for key in get_code_keys()}
|
||||
kwargs["co_consts"] = tuple(
|
||||
cls.to_code_object(c) if isinstance(c, SerializedCode) else c
|
||||
for c in kwargs["co_consts"]
|
||||
)
|
||||
return types.CodeType(
|
||||
*kwargs.values(),
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _GuardedCodeCacheEntry:
|
||||
"""
|
||||
Contains the serializable information associated with a single compilation in dynamo.
|
||||
To restore an execution of compiled code, we will need to serialize the following data:
|
||||
- Dynamo bytecode for mapping Python inputs/outputs.
|
||||
- Dynamo guards.
|
||||
"""
|
||||
|
||||
guards_state: bytes
|
||||
dynamo_code: SerializedCode
|
||||
|
||||
|
||||
_BackendId = NewType("_BackendId", str) # __compiled_fn
|
||||
_FunctionId = NewType("_FunctionId", str) # __resume_at
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamoCodeCacheEntry:
|
||||
"""
|
||||
Contains the serializable information associated with a single code object
|
||||
in dynamo. To restore an execution of compiled code, we will need the following
|
||||
ingredients:
|
||||
1. The "original" code object, which serves as the entry point for eager
|
||||
execution, i.e. the code only executed when there's no cache entry hit.
|
||||
2. The python module name this code object belongs to, for idenfifying the
|
||||
enclosing global scope to inject compiled and resume functions.
|
||||
3. A list of function names that pointing to this code object. There could be
|
||||
multiple function objects pointing to the same code such as recursive functions.
|
||||
4. A list of guarded code that eval frame dispatches to.
|
||||
5. A list of imported module objects unioned from all compiled branches.
|
||||
6. A list of "backends" (compiled fx graph) unioned from all compield branches.
|
||||
"""
|
||||
|
||||
python_code: SerializedCode
|
||||
python_module: str
|
||||
function_names: list[_FunctionId]
|
||||
guarded_codes: list[_GuardedCodeCacheEntry]
|
||||
import_sources: dict[str, str]
|
||||
backend_ids: list[_BackendId]
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class _DynamoCacheEntry:
|
||||
codes: list[_DynamoCodeCacheEntry]
|
||||
python_version: str = platform.python_version()
|
||||
torch_version: str = torch.__version__
|
||||
|
||||
@property
|
||||
def backend_ids(self) -> set[_BackendId]:
|
||||
return {backend_id for code in self.codes for backend_id in code.backend_ids}
|
||||
|
||||
|
||||
class CompilePackage:
|
||||
"""
|
||||
CompilePackage is considered a low level component and should not be directly exposed to
|
||||
end users. It has the following interface:
|
||||
|
||||
1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
|
||||
a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object.
|
||||
b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
|
||||
2. `package.save()` which dumps the dynamo and backend states to a DynamoCacheEntry object.
|
||||
3. `package.install(backends) which will handle all the side-effectful global scope
|
||||
updates with compiled functions and resume functions.
|
||||
"""
|
||||
|
||||
def __init__(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
|
||||
self._innermost_fn = None
|
||||
self._codes: dict[types.CodeType, _DynamoCodeCacheEntry] = {}
|
||||
|
||||
self._current_entry: Optional[_DynamoCodeCacheEntry] = None
|
||||
self._installed_globals: dict[types.ModuleType, list[str]] = {}
|
||||
|
||||
# For debugging/testing purpose only.
|
||||
self._cached_backends: dict[_BackendId, Any] = {}
|
||||
|
||||
self._initialize(fn, dynamo)
|
||||
# Always go back to a clean state after initialization.
|
||||
self.uninstall()
|
||||
self.validate()
|
||||
|
||||
def _initialize(self, fn: Any, dynamo: Optional[_DynamoCacheEntry] = None) -> None:
|
||||
from .eval_frame import innermost_fn
|
||||
|
||||
self._innermost_fn = innermost_fn(fn)
|
||||
assert self._innermost_fn is not None
|
||||
if dynamo is not None:
|
||||
assert isinstance(dynamo, _DynamoCacheEntry)
|
||||
if dynamo.python_version != platform.python_version():
|
||||
raise RuntimeError(
|
||||
f"Compile package was created with a different Python version: {dynamo.python_version}"
|
||||
)
|
||||
if dynamo.torch_version != torch.__version__:
|
||||
raise RuntimeError(
|
||||
f"Compile package was created with a different PyTorch version: {dynamo.torch_version}"
|
||||
)
|
||||
|
||||
main, *codes = dynamo.codes
|
||||
self._codes = {self._innermost_fn.__code__: main}
|
||||
for code in codes:
|
||||
self._codes[SerializedCode.to_code_object(code.python_code)] = code
|
||||
else:
|
||||
self._add_function(
|
||||
self._innermost_fn.__code__, self._innermost_fn.__module__
|
||||
)
|
||||
|
||||
def _add_function(
|
||||
self,
|
||||
python_code: types.CodeType,
|
||||
python_module: str,
|
||||
name: Optional[_FunctionId] = None,
|
||||
) -> None:
|
||||
if python_code not in self._codes:
|
||||
code = _DynamoCodeCacheEntry(
|
||||
python_code=SerializedCode.from_code_object(python_code),
|
||||
python_module=python_module,
|
||||
function_names=[],
|
||||
guarded_codes=[],
|
||||
import_sources={},
|
||||
backend_ids=[],
|
||||
)
|
||||
self._codes[python_code] = code
|
||||
else:
|
||||
code = self._codes[python_code]
|
||||
assert code.python_module == python_module
|
||||
|
||||
if name is not None:
|
||||
code.function_names.append(name)
|
||||
|
||||
@property
|
||||
def cached_backends(self) -> dict[_BackendId, Any]:
|
||||
return self._cached_backends
|
||||
|
||||
@functools.cached_property
|
||||
def source_id(self) -> str:
|
||||
assert self._innermost_fn is not None
|
||||
sha256_hash = hashlib.sha256()
|
||||
sha256_hash.update(self._innermost_fn.__qualname__.encode())
|
||||
sha256_hash.update(str(self._innermost_fn.__code__.co_firstlineno).encode())
|
||||
return sha256_hash.hexdigest()
|
||||
|
||||
@contextlib.contextmanager
|
||||
def code_context(self, code: types.CodeType) -> Generator[None, None, None]:
|
||||
assert self._current_entry is None
|
||||
|
||||
entry = self._codes[code]
|
||||
self._current_entry = entry
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
self._current_entry = None
|
||||
|
||||
def add_guarded_code(
|
||||
self,
|
||||
guards_state: bytes,
|
||||
dynamo_code: types.CodeType,
|
||||
) -> None:
|
||||
assert self._current_entry is not None
|
||||
guarded_code_entry = _GuardedCodeCacheEntry(
|
||||
guards_state=guards_state,
|
||||
dynamo_code=SerializedCode.from_code_object(dynamo_code),
|
||||
)
|
||||
self._current_entry.guarded_codes.append(guarded_code_entry)
|
||||
|
||||
def add_resume_function(
|
||||
self,
|
||||
python_code: types.CodeType,
|
||||
python_module: str,
|
||||
name: Optional[str],
|
||||
) -> None:
|
||||
self._add_function(
|
||||
python_code, python_module, _FunctionId(name) if name else None
|
||||
)
|
||||
|
||||
def add_import_source(self, alias: str, module_name: str) -> None:
|
||||
assert self._current_entry is not None
|
||||
self._current_entry.import_sources[alias] = module_name
|
||||
|
||||
def add_backend_id(self, backend_id: str, backend: Optional[Any] = None) -> None:
|
||||
assert self._current_entry is not None
|
||||
assert backend_id.startswith("__compiled_fn_") # sanity check
|
||||
backend_id = _BackendId(backend_id)
|
||||
self._current_entry.backend_ids.append(backend_id)
|
||||
if backend is not None:
|
||||
self._cached_backends[backend_id] = backend
|
||||
|
||||
def validate(self) -> None:
|
||||
assert self._current_entry is None
|
||||
assert self._innermost_fn is not None
|
||||
assert next(iter(self._codes)) is self._innermost_fn.__code__
|
||||
|
||||
def _install_global(self, module: types.ModuleType, name: str, value: Any) -> None:
|
||||
module.__dict__[name] = value
|
||||
self._installed_globals.setdefault(module, []).append(name)
|
||||
|
||||
def uninstall(self) -> None:
|
||||
from torch._C._dynamo.eval_frame import _reset_precompile_entries
|
||||
|
||||
assert self._innermost_fn is not None
|
||||
for module, names in self._installed_globals.items():
|
||||
for name in names:
|
||||
module.__dict__.pop(name)
|
||||
|
||||
self._installed_globals = {}
|
||||
|
||||
_reset_precompile_entries(self._innermost_fn.__code__)
|
||||
|
||||
def install(self, backends: dict[_BackendId, Any]) -> None:
|
||||
"""
|
||||
Sync the package states to the compiled function. This includes the following actions:
|
||||
1. Clean up the previously installed states.
|
||||
2. Install the compiled functions to global scopes.
|
||||
3. Install the precompiled cache entries to ExtraStates on the code object.
|
||||
"""
|
||||
from torch._C._dynamo.eval_frame import _load_precompile_entry
|
||||
|
||||
self.uninstall()
|
||||
|
||||
for code, entry in self._codes.items():
|
||||
module = sys.modules[entry.python_module]
|
||||
for alias, module_name in entry.import_sources.items():
|
||||
self._install_global(
|
||||
module, alias, importlib.import_module(module_name)
|
||||
)
|
||||
for function_name in entry.function_names:
|
||||
fn = types.FunctionType(code, module.__dict__, function_name)
|
||||
self._install_global(module, function_name, fn)
|
||||
for backend_id in entry.backend_ids:
|
||||
backend = backends[backend_id]
|
||||
self._install_global(
|
||||
module,
|
||||
backend_id,
|
||||
torch._dynamo.disable(backend),
|
||||
)
|
||||
|
||||
for code, entry in self._codes.items():
|
||||
for guarded_code in entry.guarded_codes:
|
||||
guards_state = pickle.loads(guarded_code.guards_state)
|
||||
assert isinstance(guards_state, torch._dynamo.guards.GuardsState)
|
||||
check_fn_manager = torch._dynamo.guards.CheckFunctionManager(
|
||||
code,
|
||||
guards_state.output_graph,
|
||||
guards_serialization_mode="load",
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
)
|
||||
_load_precompile_entry(
|
||||
code,
|
||||
check_fn_manager.guard_manager,
|
||||
SerializedCode.to_code_object(guarded_code.dynamo_code),
|
||||
)
|
||||
|
||||
def save(self) -> _DynamoCacheEntry:
|
||||
self.validate()
|
||||
return _DynamoCacheEntry(codes=list(self._codes.values()))
|
@ -44,7 +44,7 @@ import traceback
|
||||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, Union
|
||||
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -167,6 +167,9 @@ from .variables.user_defined import (
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .package import CompilePackage
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
||||
@ -1094,6 +1097,7 @@ class InstructionTranslatorBase(
|
||||
is_leaf_tracer: bool
|
||||
parent: Optional["InstructionTranslatorBase"]
|
||||
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
|
||||
package: Optional["CompilePackage"]
|
||||
|
||||
def mark_inconsistent_side_effects(self):
|
||||
"""
|
||||
@ -1554,6 +1558,9 @@ class InstructionTranslatorBase(
|
||||
else:
|
||||
value = _import_module(module_name)
|
||||
alias = f"__import_{module_name.replace('.', '_dot_')}"
|
||||
|
||||
if self.package is not None:
|
||||
self.package.add_import_source(alias, module_name)
|
||||
f_globals = self.output.global_scope
|
||||
assert alias not in f_globals or f_globals[alias] is value
|
||||
f_globals[alias] = value
|
||||
@ -3187,6 +3194,7 @@ class InstructionTranslatorBase(
|
||||
distributed_state: Optional[DistributedState],
|
||||
# This determines whether to use the execution recorder.
|
||||
closure: Optional[tuple[types.CellType]] = None,
|
||||
package: Optional["CompilePackage"] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.speculation_log = speculation_log
|
||||
@ -3245,6 +3253,8 @@ class InstructionTranslatorBase(
|
||||
self.parent = None
|
||||
self.debug_locals = []
|
||||
|
||||
self.package = package
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from .resume_execution import (
|
||||
CO_ASYNC_GENERATOR,
|
||||
@ -3305,6 +3315,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
speculation_log: SpeculationLog,
|
||||
exn_vt_stack: ExceptionStack,
|
||||
distributed_state: Optional[DistributedState],
|
||||
package: Optional["CompilePackage"],
|
||||
) -> None:
|
||||
_step_logger()(
|
||||
logging.INFO,
|
||||
@ -3322,6 +3333,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
global_scope=f_globals,
|
||||
f_code=f_code,
|
||||
torch_function_mode_stack=torch_function_mode_stack,
|
||||
package=package,
|
||||
),
|
||||
instructions=instructions,
|
||||
f_locals=f_locals,
|
||||
@ -3339,6 +3351,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
speculation_log=speculation_log,
|
||||
exn_vt_stack=exn_vt_stack,
|
||||
distributed_state=distributed_state,
|
||||
package=package,
|
||||
)
|
||||
|
||||
self._throw_if_in_functorch()
|
||||
@ -3575,12 +3588,19 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
# expose code object for debugging purposes
|
||||
self.output.install_global_unsafe(name, new_code)
|
||||
cg.make_function_with_closure(name, new_code, True, stack_len)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
self.output.install_global_unsafe(
|
||||
name, types.FunctionType(new_code, self.f_globals, name)
|
||||
)
|
||||
cg.extend_output(cg.load_function_name(name, True, stack_len))
|
||||
package_name = name
|
||||
|
||||
if self.package is not None:
|
||||
self.package.add_resume_function(
|
||||
new_code, self.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
cg.extend_output([cg.create_load(k) for k in argnames])
|
||||
cg.extend_output(create_call_function(nargs, False))
|
||||
@ -3975,6 +3995,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
speculation_log=parent.speculation_log,
|
||||
exn_vt_stack=parent.exn_vt_stack,
|
||||
distributed_state=parent.distributed_state,
|
||||
package=parent.package,
|
||||
)
|
||||
self.funcvar = funcvar
|
||||
self.parent = parent
|
||||
|
@ -213,6 +213,7 @@ def debug_insert_nops(
|
||||
global_scope=globals(),
|
||||
f_code=frame.f_code,
|
||||
torch_function_mode_stack=[],
|
||||
package=None,
|
||||
)
|
||||
|
||||
return wrap_guarded_code(
|
||||
|
@ -4660,6 +4660,7 @@ def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
|
||||
return node is None or "example_value" in node.meta or "val" in node.meta
|
||||
|
||||
|
||||
@torch._disable_dynamo
|
||||
def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
|
||||
cm: AbstractContextManager[None] = (
|
||||
torch._C._profiler._RecordFunctionFast("Pregraph bytecode")
|
||||
@ -4670,6 +4671,7 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
|
||||
return cm
|
||||
|
||||
|
||||
@torch._disable_dynamo
|
||||
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
|
||||
cm.__exit__(None, None, None)
|
||||
|
||||
|
Reference in New Issue
Block a user