[dynamo] Add a DynamoFrameType type above Python frame object (#140330)

This patch introduces a `DynamoFrameType` to serve as a layer between
Dynamo and different versions of Python frame object. In
`DynamoFrameType`, we only register attributes Dynamo cares about (e.g.,
`f_code`, `f_locals`, etc.

This will be helpful when it comes to adding new attributes to this
`DynamoFrameType`, or dealing with Python version changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140330
Approved by: https://github.com/jansel, https://github.com/williamwen42
This commit is contained in:
Ryan Guo
2024-11-14 16:53:37 -05:00
committed by PyTorch MergeBot
parent c05eff278a
commit 85dd7b84cf
7 changed files with 59 additions and 69 deletions

View File

@ -1,13 +1,9 @@
# mypy: allow-untyped-defs
import types
from typing import NewType
from typing import Dict, NewType
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
# We implement our own FrameType-like type for Python >= 3.11. So it's not actually an alias of FrameType, but still
# exposes the same interface.
_PyInterpreterFrame = NewType("_PyInterpreterFrame", types.FrameType)
# For typechecking
SkipCodeRecursiveFlag = NewType("SkipCodeRecursiveFlag", object)
CacheLimitHitFlag = NewType("CacheLimitHitFlag", object)
@ -31,6 +27,17 @@ class _CacheEntry:
class _ExtraState:
def invalidate(self, cache_entry: _CacheEntry): ...
# This is an object that encapsulates the Python FrameType, and exposes
# properties Dynamo cares about for a frame.
class _PyInterpreterFrame:
f_code: types.CodeType
f_locals: Dict[str, object]
f_globals: Dict[str, object]
f_builtins: Dict[str, object]
f_lasti: int
f_lineo: int
f_back: types.FrameType
def _debug_get_cache_entry_list(code: types.CodeType) -> list[_CacheEntry]: ...
py_opcode_caches: list[int]

View File

@ -1,6 +1,5 @@
# mypy: allow-untyped-defs
import logging
import types
import weakref
from dataclasses import dataclass
from typing import Tuple
@ -8,6 +7,7 @@ from typing import Tuple
from torch._guards import CompileId
from . import config
from .types import DynamoFrameType
log = logging.getLogger(__name__)
@ -100,7 +100,7 @@ class CacheSizeRelevantForFrame:
return self.num_cache_entries_with_same_id_matched_objs >= limit
def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
def _get_weakref_from_f_locals(frame: DynamoFrameType, local_name: str):
obj = frame.f_locals.get(local_name, None)
weak_id = None
try:
@ -110,7 +110,7 @@ def _get_weakref_from_f_locals(frame: types.FrameType, local_name: str):
return weak_id
def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
def _has_same_id_matched_objs(frame: DynamoFrameType, cache_entry) -> bool:
"""
Checks if the ID_MATCH'd objects saved on cache_entry are same as the ones
in frame.f_locals.
@ -132,7 +132,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
def compute_cache_size(
frame: types.FrameType, cache_entry
frame: DynamoFrameType, cache_entry
) -> CacheSizeRelevantForFrame:
# Walk the linked list to calculate the cache size
num_cache_entries = 0

View File

@ -21,7 +21,7 @@ import typing
import warnings
import weakref
from pathlib import Path
from types import CodeType, FrameType, FunctionType, ModuleType
from types import CodeType, FunctionType, ModuleType
from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
from typing_extensions import ParamSpec
from weakref import ReferenceType
@ -138,7 +138,7 @@ except ModuleNotFoundError:
if typing.TYPE_CHECKING:
from .backends.registry import CompilerFn
from .repro.after_dynamo import WrapBackendDebug
from .types import BytecodeHook, CacheEntry
from .types import BytecodeHook, CacheEntry, DynamoFrameType
from .variables.builder import FrameStateSizeEntry
@ -257,7 +257,7 @@ def preserve_global_state(fn: Callable[_P, _T]) -> Callable[_P, _T]:
@TorchPatcher.suppress_torch_distributed_warnings
def has_tensor_in_frame(frame: FrameType) -> bool:
def has_tensor_in_frame(frame: DynamoFrameType) -> bool:
"""Check if the frame has torch.* related bits"""
# Check if the function was decorated using torch._dynamo.optimize
if frame.f_code in always_optimize_code_objects:
@ -338,7 +338,7 @@ def has_tensor_in_frame(frame: FrameType) -> bool:
def exception_handler(
e: Exception,
code: CodeType,
frame: Optional[FrameType] = None,
frame: Optional[DynamoFrameType] = None,
export: bool = False,
) -> None:
record_filename = None
@ -450,7 +450,7 @@ class ConvertFrameAssert:
def __call__(
self,
frame: FrameType,
frame: DynamoFrameType,
cache_entry: Optional[CacheEntry],
hooks: Hooks,
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
@ -609,7 +609,7 @@ def _compile(
hooks: Hooks,
cache_entry: Optional[CacheEntry],
cache_size: CacheSizeRelevantForFrame,
frame: Optional[FrameType] = None,
frame: Optional[DynamoFrameType] = None,
frame_state: Optional[Dict[str, Union[int, FrameStateSizeEntry]]] = None,
*,
compile_id: CompileId,
@ -1165,7 +1165,7 @@ class ConvertFrame:
def __call__(
self,
frame: FrameType,
frame: DynamoFrameType,
cache_entry: Optional[CacheEntry],
hooks: Hooks,
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
@ -1310,7 +1310,7 @@ def first_real_inst_idx(code: CodeType) -> int:
class ConvertFrameProtocol(typing.Protocol):
def __call__(
self,
frame: FrameType,
frame: DynamoFrameType,
cache_entry: Optional[CacheEntry],
hooks: Hooks,
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
@ -1328,7 +1328,7 @@ class CatchErrorsWrapper:
def __call__(
self,
frame: FrameType,
frame: DynamoFrameType,
cache_entry: Optional[CacheEntry],
frame_state: Dict[str, Union[int, FrameStateSizeEntry]],
) -> Optional[GuardedCode]:

View File

@ -108,7 +108,14 @@ from .source import (
UnspecializedParamBufferSource,
WeakRefCallSource,
)
from .types import CacheEntry, ExtraState, GuardedCode, GuardFail, GuardFn # noqa: F401
from .types import ( # noqa: F401
CacheEntry,
DynamoFrameType,
ExtraState,
GuardedCode,
GuardFail,
GuardFn,
)
from .utils import (
common_constant_types,
dict_keys_repr,
@ -2600,7 +2607,7 @@ def get_guard_fail_reason(
def get_and_maybe_log_recompilation_reason(
cache_entry, frame: types.FrameType
cache_entry, frame: DynamoFrameType
) -> List[str]:
"""
Return the list of guard failure reasons using cache_entry.

View File

@ -35,6 +35,7 @@ from .bytecode_transformation import (
transform_code_object,
)
from .guards import CheckFunctionManager, CompileId, GuardedCode
from .types import DynamoFrameType
from .utils import same
@ -164,7 +165,7 @@ def debug_dump(name: str, code: types.CodeType, extra: str = "") -> None:
def debug_insert_nops(
frame: types.FrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
frame: DynamoFrameType, cache_size: int, hooks: Any, _: Any, *, skip: int = 0
) -> Optional[GuardedCode]:
"""used to debug jump updates"""

View File

@ -1,5 +1,4 @@
import dataclasses
import sys
import types
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
@ -7,16 +6,11 @@ from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Un
from torch._C._dynamo.eval_frame import (
_CacheEntry as CacheEntry,
_ExtraState as ExtraState,
_PyInterpreterFrame as DynamoFrameType,
)
from torch._guards import CompileId
if sys.version_info >= (3, 11):
from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
else:
from types import FrameType as DynamoFrameType
# We use a dict to store additional data per frame.
FrameState = Dict[Any, Any]

View File

@ -38,17 +38,20 @@ static void eval_frame_callback_set(PyObject* obj) {
// https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
#if IS_PYTHON_3_11_PLUS
#define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
#else
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
#endif // IS_PYTHON_3_11_PLUS
// We need to be able to return the _PyInterpreterFrame to python so create
// a python binding for it
typedef struct THPPyInterpreterFrame {
PyObject_HEAD
_PyInterpreterFrame* frame; // Borrowed reference
THP_EVAL_API_FRAME_OBJECT* frame; // Borrowed reference
PyObject* locals;
} THPPyInterpreterFrame;
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
THPPyInterpreterFrame* THPPyInterpreterFrame_New(THP_EVAL_API_FRAME_OBJECT* frame);
#define DECLARE_PYOBJ_ATTR(name) \
static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
@ -57,12 +60,6 @@ static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObj
return res; \
}
#if IS_PYTHON_3_12_PLUS
DECLARE_PYOBJ_ATTR(f_funcobj)
#else
DECLARE_PYOBJ_ATTR(f_func)
#endif
DECLARE_PYOBJ_ATTR(f_globals)
DECLARE_PYOBJ_ATTR(f_builtins)
@ -78,22 +75,20 @@ DECLARE_PYOBJ_ATTR(f_executable)
DECLARE_PYOBJ_ATTR(f_code)
#endif
DECLARE_PYOBJ_ATTR(frame_obj)
#undef DECLARE_PYOBJ_ATTR
static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
return res;
}
// This is not a true attribute of the class but we do access it in python and it is hard to implement
// on the python side, so do it here:
static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
#if IS_PYTHON_3_11_PLUS
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
#else
return PyLong_FromLong(self->frame->f_lasti);
#endif // IS_PYTHON_3_11_PLUS
}
static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyObject* _noargs) {
#if IS_PYTHON_3_11_PLUS
if (!self->frame->frame_obj) {
return PyLong_FromLong(F_CODE(self->frame)->co_firstlineno);
}
@ -102,22 +97,24 @@ static PyObject* THPPyInterpreterFrame_f_lineno(THPPyInterpreterFrame* self, PyO
Py_RETURN_NONE;
}
return PyLong_FromLong(lineno);
#else
return PyLong_FromLong(self->frame->f_lineno);
#endif // IS_PYTHON_3_11_PLUS
}
static PyObject* THPPyInterpreterFrame_f_back(THPPyInterpreterFrame* self, PyObject* _noargs) {
#if IS_PYTHON_3_11_PLUS
if (!self->frame->frame_obj) {
Py_RETURN_NONE;
}
return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
#else
return Py_XNewRef(self->frame->f_back);
#endif // IS_PYTHON_3_11_PLUS
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
#if IS_PYTHON_3_12_PLUS
{"f_func", (getter)THPPyInterpreterFrame_f_funcobj, NULL, NULL, NULL},
#else
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
#endif
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
@ -126,8 +123,6 @@ static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
#else
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
#endif
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
{"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
{"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
@ -142,7 +137,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
};
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
THPPyInterpreterFrame* THPPyInterpreterFrame_New(THP_EVAL_API_FRAME_OBJECT* frame) {
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
if (!self)
@ -152,13 +147,6 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
return self;
}
#else
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
#endif
static PyObject* dynamo__custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
@ -246,6 +234,8 @@ static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
}
// Remember to update the type signature for DynamoCallbackFn.__call__ in
// torch/_dynamo/types.py if this function's signature changes.
static PyObject* dynamo_call_callback(
PyObject* callable,
THP_EVAL_API_FRAME_OBJECT* _frame,
@ -253,18 +243,11 @@ static PyObject* dynamo_call_callback(
CacheEntry* cache_entry,
FrameState* frame_state) {
// remember to update the type signature for DynamoCallbackFn.__call__ in torch/_dynamo/types.py
// if this function changes
#if IS_PYTHON_3_11_PLUS
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
if (frame == NULL) {
return NULL;
}
frame->locals = locals;
#else
PyObject* frame = Py_NewRef(_frame);
#endif
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
PyObject* res = PyObject_CallFunction(
callable,
@ -716,7 +699,7 @@ static PyObject* dynamo__custom_eval_frame(
}
}
#else // IS_PYTHON_3_14_PLUS
#else // !(IS_PYTHON_3_14_PLUS)
// Fake definitions for everything we removed
@ -738,7 +721,7 @@ static PyTypeObject THPPyInterpreterFrameType = {
.tp_getset = THPPyInterpreterFrame_properties,
};
#endif // CPython 3.14
#endif // !(IS_PYTHON_3_14_PLUS)
static PyObject* increment_working_threads(PyThreadState* tstate) {
active_dynamo_threads = active_dynamo_threads + 1;
@ -909,7 +892,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED);
#endif
#if IS_PYTHON_3_11_PLUS
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
return NULL;
}
@ -917,7 +899,6 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
return NULL;
}
#endif
skip_code_recursive_flag = PyObject_New(PyObject, &PyBaseObject_Type);
if (skip_code_recursive_flag == NULL) {