mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Add a DynamoFrameType
type above Python frame object (#140330)
This patch introduces a `DynamoFrameType` to serve as a layer between Dynamo and different versions of Python frame object. In `DynamoFrameType`, we only register attributes Dynamo cares about (e.g., `f_code`, `f_locals`, etc. This will be helpful when it comes to adding new attributes to this `DynamoFrameType`, or dealing with Python version changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/140330 Approved by: https://github.com/jansel, https://github.com/williamwen42
This commit is contained in:
committed by
PyTorch MergeBot
parent
c05eff278a
commit
85dd7b84cf
@ -1,13 +1,9 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import types
|
||||
from typing import NewType
|
||||
from typing import Dict, NewType
|
||||
|
||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||
|
||||
# We implement our own FrameType-like type for Python >= 3.11. So it's not actually an alias of FrameType, but still
|
||||
# exposes the same interface.
|
||||
_PyInterpreterFrame = NewType("_PyInterpreterFrame", types.FrameType)
|
||||
|
||||
# For typechecking
|
||||
SkipCodeRecursiveFlag = NewType("SkipCodeRecursiveFlag", object)
|
||||
CacheLimitHitFlag = NewType("CacheLimitHitFlag", object)
|
||||
@ -31,6 +27,17 @@ class _CacheEntry:
|
||||
class _ExtraState:
|
||||
def invalidate(self, cache_entry: _CacheEntry): ...
|
||||
|
||||
# This is an object that encapsulates the Python FrameType, and exposes
|
||||
# properties Dynamo cares about for a frame.
|
||||
class _PyInterpreterFrame:
|
||||
f_code: types.CodeType
|
||||
f_locals: Dict[str, object]
|
||||
f_globals: Dict[str, object]
|
||||
f_builtins: Dict[str, object]
|
||||
f_lasti: int
|
||||
f_lineo: int
|
||||
f_back: types.FrameType
|
||||
|
||||
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
|
||||
|
||||
py_opcode_caches: list[int]
|
||||
|
@ -1,6 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
import types
|
||||
import weakref
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple
|
||||
@ -8,6 +7,7 @@ from typing import Tuple
|
||||
from torch._guards import CompileId
|
||||
|
||||
from . import config
|
||||
from .types import DynamoFrameType
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -100,7 +100,7 @@ class CacheSizeRelevantForFrame:
|
||||
return self.num_cache_entries_with_same_id_matched_objs >= limit
|
||||
|
||||
|
||||
def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
|
||||
def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str):
|
||||
obj = frame.f_locals.get(local_name, None)
|
||||
weak_id = None
|
||||
try:
|
||||
@ -110,7 +110,7 @@ def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
|
||||
return weak_id
|
||||
|
||||
|
||||
def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
|
||||
def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool:
|
||||
"""
|
||||
Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
|
||||
in frame.f_locals.
|
||||
@ -132,7 +132,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
|
||||
|
||||
|
||||
def compute_cache_size(
|
||||
frame: types.FrameType, cache_entry
|
||||
frame: DynamoFrameType, cache_entry
|
||||
) -> CacheSizeRelevantForFrame:
|
||||
# Walk the linked list to calculate the cache size
|
||||
num_cache_entries = 0
|
||||
|
@ -21,7 +21,7 @@ import typing
|
||||
import warnings
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from types import CodeType, FrameType, FunctionType, ModuleType
|
||||
from types import CodeType, FunctionType, ModuleType
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from weakref import ReferenceType
|
||||
@ -138,7 +138,7 @@ except ModuleNotFoundError:
|
||||
if typing.TYPE_CHECKING:
|
||||
from .backends.registry import CompilerFn
|
||||
from .repro.after_dynamo import WrapBackendDebug
|
||||
from .types import BytecodeHook, CacheEntry
|
||||
from .types import BytecodeHook, CacheEntry, DynamoFrameType
|
||||
from .variables.builder import FrameStateSizeEntry
|
||||
|
||||
|
||||
@ -257,7 +257,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||
|
||||
|
||||
@TorchPatcher.suppress_torch_distributed_warnings
|
||||
def has_tensor_in_frame(frame: FrameType) -> bool:
|
||||
def has_tensor_in_frame(frame: DynamoFrameType) -> bool:
|
||||
"""Check if the frame has torch.* related bits"""
|
||||
# Check if the function was decorated using torch._dynamo.optimize
|
||||
if frame.f_code in always_optimize_code_objects:
|
||||
@ -338,7 +338,7 @@ def has_tensor_in_frame(frame: FrameType) -> bool:
|
||||
def exception_handler(
|
||||
e: Exception,
|
||||
code: CodeType,
|
||||
frame: Optional[FrameType] = None,
|
||||
frame: Optional[DynamoFrameType] = None,
|
||||
export: bool = False,
|
||||
) -> None:
|
||||
record_filename = None
|
||||
@ -450,7 +450,7 @@ class ConvertFrameAssert:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
frame: FrameType,
|
||||
frame: DynamoFrameType,
|
||||
cache_entry: Optional[CacheEntry],
|
||||
hooks: Hooks,
|
||||
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
|
||||
@ -609,7 +609,7 @@ def _compile(
|
||||
hooks: Hooks,
|
||||
cache_entry: Optional[CacheEntry],
|
||||
cache_size: CacheSizeRelevantForFrame,
|
||||
frame: Optional[FrameType] = None,
|
||||
frame: Optional[DynamoFrameType] = None,
|
||||
frame_state: Optional[Dict[str, Union[int, FrameStateSizeEntry]]] = None,
|
||||
*,
|
||||
compile_id: CompileId,
|
||||
@ -1165,7 +1165,7 @@ class ConvertFrame:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
frame: FrameType,
|
||||
frame: DynamoFrameType,
|
||||
cache_entry: Optional[CacheEntry],
|
||||
hooks: Hooks,
|
||||
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
|
||||
@ -1310,7 +1310,7 @@ def first_real_inst_idx(code: CodeType) -> int:
|
||||
class ConvertFrameProtocol(typing.Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
frame: FrameType,
|
||||
frame: DynamoFrameType,
|
||||
cache_entry: Optional[CacheEntry],
|
||||
hooks: Hooks,
|
||||
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
|
||||
@ -1328,7 +1328,7 @@ class CatchErrorsWrapper:
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
frame: FrameType,
|
||||
frame: DynamoFrameType,
|
||||
cache_entry: Optional[CacheEntry],
|
||||
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
|
||||
) -> Optional[GuardedCode]:
|
||||
|
@ -108,7 +108,14 @@ from .source import (
|
||||
UnspecializedParamBufferSource,
|
||||
WeakRefCallSource,
|
||||
)
|
||||
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
|
||||
from .types import ( # noqa: F401
|
||||
CacheEntry,
|
||||
DynamoFrameType,
|
||||
ExtraState,
|
||||
GuardedCode,
|
||||
GuardFail,
|
||||
GuardFn,
|
||||
)
|
||||
from .utils import (
|
||||
common_constant_types,
|
||||
dict_keys_repr,
|
||||
@ -2600,7 +2607,7 @@ def get_guard_fail_reason(
|
||||
|
||||
|
||||
def get_and_maybe_log_recompilation_reason(
|
||||
cache_entry, frame: types.FrameType
|
||||
cache_entry, frame: DynamoFrameType
|
||||
) -> List[str]:
|
||||
"""
|
||||
Return the list of guard failure reasons using cache_entry.
|
||||
|
@ -35,6 +35,7 @@ from .bytecode_transformation import (
|
||||
transform_code_object,
|
||||
)
|
||||
from .guards import CheckFunctionManager, CompileId, GuardedCode
|
||||
from .types import DynamoFrameType
|
||||
from .utils import same
|
||||
|
||||
|
||||
@ -164,7 +165,7 @@ def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None:
|
||||
|
||||
|
||||
def debug_insert_nops(
|
||||
frame: types.FrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
|
||||
frame: DynamoFrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
|
||||
) -> Optional[GuardedCode]:
|
||||
"""used to debug jump updates"""
|
||||
|
||||
|
@ -1,5 +1,4 @@
|
||||
import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
|
||||
|
||||
@ -7,16 +6,11 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Un
|
||||
from torch._C._dynamo.eval_frame import (
|
||||
_CacheEntry as CacheEntry,
|
||||
_ExtraState as ExtraState,
|
||||
_PyInterpreterFrame as DynamoFrameType,
|
||||
)
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
|
||||
else:
|
||||
from types import FrameType as DynamoFrameType
|
||||
|
||||
|
||||
# We use a dict to store additional data per frame.
|
||||
FrameState = Dict[Any, Any]
|
||||
|
||||
|
@ -38,17 +38,20 @@ static void eval_frame_callback_set(PyObject* obj) {
|
||||
// https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
#define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
|
||||
#else
|
||||
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
|
||||
#endif // IS_PYTHON_3_11_PLUS
|
||||
|
||||
// We need to be able to return the _PyInterpreterFrame to python so create
|
||||
// a python binding for it
|
||||
|
||||
typedef struct THPPyInterpreterFrame {
|
||||
PyObject_HEAD
|
||||
_PyInterpreterFrame* frame; // Borrowed reference
|
||||
THP_EVAL_API_FRAME_OBJECT* frame; // Borrowed reference
|
||||
PyObject* locals;
|
||||
} THPPyInterpreterFrame;
|
||||
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(THP_EVAL_API_FRAME_OBJECT* frame);
|
||||
|
||||
#define DECLARE_PYOBJ_ATTR(name) \
|
||||
static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
|
||||
@ -57,12 +60,6 @@ static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObj
|
||||
return res; \
|
||||
}
|
||||
|
||||
#if IS_PYTHON_3_12_PLUS
|
||||
DECLARE_PYOBJ_ATTR(f_funcobj)
|
||||
#else
|
||||
DECLARE_PYOBJ_ATTR(f_func)
|
||||
#endif
|
||||
|
||||
DECLARE_PYOBJ_ATTR(f_globals)
|
||||
DECLARE_PYOBJ_ATTR(f_builtins)
|
||||
|
||||
@ -78,22 +75,20 @@ DECLARE_PYOBJ_ATTR(f_executable)
|
||||
DECLARE_PYOBJ_ATTR(f_code)
|
||||
#endif
|
||||
|
||||
DECLARE_PYOBJ_ATTR(frame_obj)
|
||||
|
||||
#undef DECLARE_PYOBJ_ATTR
|
||||
|
||||
static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
|
||||
THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
|
||||
return res;
|
||||
}
|
||||
|
||||
// This is not a true attribute of the class but we do access it in python and it is hard to implement
|
||||
// on the python side, so do it here:
|
||||
static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
|
||||
#else
|
||||
return PyLong_FromLong(self->frame->f_lasti);
|
||||
#endif // IS_PYTHON_3_11_PLUS
|
||||
}
|
||||
|
||||
static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) {
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
if (!self->frame->frame_obj) {
|
||||
return PyLong_FromLong(F_CODE(self->frame)->co_firstlineno);
|
||||
}
|
||||
@ -102,22 +97,24 @@ static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyO
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
return PyLong_FromLong(lineno);
|
||||
#else
|
||||
return PyLong_FromLong(self->frame->f_lineno);
|
||||
#endif // IS_PYTHON_3_11_PLUS
|
||||
}
|
||||
|
||||
static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) {
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
if (!self->frame->frame_obj) {
|
||||
Py_RETURN_NONE;
|
||||
}
|
||||
return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
|
||||
#else
|
||||
return Py_XNewRef(self->frame->f_back);
|
||||
#endif // IS_PYTHON_3_11_PLUS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
||||
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
|
||||
#if IS_PYTHON_3_12_PLUS
|
||||
{"f_func", (getter)THPPyInterpreterFrame_f_funcobj, NULL, NULL, NULL},
|
||||
#else
|
||||
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
|
||||
#endif
|
||||
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
|
||||
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
|
||||
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
|
||||
@ -126,8 +123,6 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
|
||||
#else
|
||||
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
|
||||
#endif
|
||||
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
|
||||
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
|
||||
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
|
||||
{"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
|
||||
{"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
|
||||
@ -142,7 +137,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
|
||||
};
|
||||
|
||||
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
|
||||
THPPyInterpreterFrame* THPPyInterpreterFrame_New(THP_EVAL_API_FRAME_OBJECT* frame) {
|
||||
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
|
||||
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
|
||||
if (!self)
|
||||
@ -152,13 +147,6 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
|
||||
return self;
|
||||
}
|
||||
|
||||
|
||||
#else
|
||||
|
||||
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
|
||||
|
||||
#endif
|
||||
|
||||
static PyObject* dynamo__custom_eval_frame_shim(
|
||||
PyThreadState* tstate,
|
||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||
@ -246,6 +234,8 @@ static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
|
||||
return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
|
||||
}
|
||||
|
||||
// Remember to update the type signature for DynamoCallbackFn.__call__ in
|
||||
// torch/_dynamo/types.py if this function's signature changes.
|
||||
static PyObject* dynamo_call_callback(
|
||||
PyObject* callable,
|
||||
THP_EVAL_API_FRAME_OBJECT* _frame,
|
||||
@ -253,18 +243,11 @@ static PyObject* dynamo_call_callback(
|
||||
CacheEntry* cache_entry,
|
||||
FrameState* frame_state) {
|
||||
|
||||
// remember to update the type signature for DynamoCallbackFn.__call__ in torch/_dynamo/types.py
|
||||
// if this function changes
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
|
||||
if (frame == NULL) {
|
||||
return NULL;
|
||||
}
|
||||
frame->locals = locals;
|
||||
#else
|
||||
PyObject* frame = Py_NewRef(_frame);
|
||||
#endif
|
||||
|
||||
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
|
||||
PyObject* res = PyObject_CallFunction(
|
||||
callable,
|
||||
@ -716,7 +699,7 @@ static PyObject* dynamo__custom_eval_frame(
|
||||
}
|
||||
}
|
||||
|
||||
#else // IS_PYTHON_3_14_PLUS
|
||||
#else // !(IS_PYTHON_3_14_PLUS)
|
||||
|
||||
// Fake definitions for everything we removed
|
||||
|
||||
@ -738,7 +721,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
|
||||
.tp_getset = THPPyInterpreterFrame_properties,
|
||||
};
|
||||
|
||||
#endif // CPython 3.14
|
||||
#endif // !(IS_PYTHON_3_14_PLUS)
|
||||
|
||||
static PyObject* increment_working_threads(PyThreadState* tstate) {
|
||||
active_dynamo_threads = active_dynamo_threads + 1;
|
||||
@ -909,7 +892,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
|
||||
PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED);
|
||||
#endif
|
||||
|
||||
#if IS_PYTHON_3_11_PLUS
|
||||
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
|
||||
return NULL;
|
||||
}
|
||||
@ -917,7 +899,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
|
||||
if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
|
||||
return NULL;
|
||||
}
|
||||
#endif
|
||||
|
||||
skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type);
|
||||
if (skip_code_recursive_flag == NULL) {
|
||||
|
Reference in New Issue
Block a user