Files
pytorch/torch/_C/_dynamo/eval_frame.pyi
James Wu b2fc9cfea1 [precompile] Add CompilePackage to serialize dynamo states. (#155118)
Adding a per torch.compile() object CompilePackage which tracks dynamo artifact. CompilePackage is considered a low level component and should not be directly exposed to end users. It has the following interface:

1. `CompilePackage.__init__()` which optionally takes previously serialized dynamo states.
     a. when `dynamo` argument is None, it will contruct a brand new CompilePackage object.
     b. when `dynamo` argument is not None, it will load a pre-compiled dynamo state.
2. `package.save()` which dumps the dynamo states into _DynamoCacheEntry.
3. `package.install(backends)` which will handle all the side-effectful global scope updates with compiled functions and resume functions.

This diff focus on making the low level mechanism for precompile. It will be left to upper level interface to use these API to build more user-facing frontend.

Differential Revision: [D75956538](https://our.internmc.facebook.com/intern/diff/D75956538/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155118
Approved by: https://github.com/jamesjwu

Co-authored-by: James Wu <jjwu@meta.com>
2025-06-13 13:54:10 +00:00

64 lines
2.0 KiB
Python

import enum
import types
from typing import overload
from torch._dynamo.types import DynamoCallback, DynamoGuardHook, GuardFn
def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def set_skip_guard_eval_unsafe(value: bool) -> bool: ...
def get_eval_frame_callback() -> DynamoCallback: ...
def reset_code(code: types.CodeType) -> None: ...
def unsupported(obj1: object, obj2: object) -> object: ...
def set_code_exec_strategy(
code: types.CodeType, strategy: _FrameExecStrategy
) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
def raise_sigtrap() -> None: ...
class _CacheEntry:
def check_fn(self, *args: object, **kwargs: object) -> bool: ...
code: types.CodeType
next: _CacheEntry | None
class _ExtraState:
def invalidate(self, cache_entry: _CacheEntry, guard_manager: object) -> None: ...
class _FrameAction(enum.IntEnum):
DEFAULT = 0
SKIP = 1
RUN_ONLY = 2
class _FrameExecStrategy:
cur_action: _FrameAction
recursive_action: _FrameAction
@overload
def __init__(self) -> None: ...
@overload
def __init__(
self, cur_action: _FrameAction, recursive_action: _FrameAction
) -> None: ...
# This is an object that encapsulates the Python FrameType, and exposes
# properties Dynamo cares about for a frame.
class _PyInterpreterFrame:
f_code: types.CodeType
f_locals: dict[str, object]
f_globals: dict[str, object]
f_builtins: dict[str, object]
f_lasti: int
f_lineo: int
f_back: types.FrameType
# A tuple containing cell objects captured by this frame.
closure: tuple[types.CellType]
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
py_opcode_caches: list[int]
def code_framelocals_names(code: types.CodeType) -> tuple[str]: ...
def _load_precompile_entry(
code: types.CodeType, guard_manager: GuardFn, dynamo_code: types.CodeType
) -> None: ...
def _reset_precompile_entries(code: types.CodeType) -> None: ...