mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164400 Approved by: https://github.com/Camyll, https://github.com/atalman
783 lines
25 KiB
C
783 lines
25 KiB
C
#define PY_SSIZE_T_CLEAN
|
|
#include <opcode.h>
|
|
#include <signal.h>
|
|
#include <torch/csrc/dynamo/cache_entry.h>
|
|
#include <torch/csrc/dynamo/cpp_shim.h>
|
|
#include <torch/csrc/dynamo/cpython_defs.h>
|
|
#include <torch/csrc/dynamo/cpython_includes.h>
|
|
#include <torch/csrc/dynamo/debug_macros.h>
|
|
#include <torch/csrc/dynamo/eval_frame.h>
|
|
#include <torch/csrc/dynamo/eval_frame_cpp.h>
|
|
#include <torch/csrc/utils/python_compat.h>
|
|
|
|
PyObject* guard_error_hook = NULL;
|
|
PyObject* guard_complete_hook = NULL;
|
|
|
|
typedef struct {
|
|
int active_dynamo_threads;
|
|
} ModuleState;
|
|
|
|
// static int active_dynamo_threads = 0;
|
|
|
|
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
|
|
|
|
static PyObject* eval_frame_callback_get(void) {
|
|
void* result = PyThread_tss_get(&eval_frame_callback_key);
|
|
if (unlikely(result == NULL)) {
|
|
return (PyObject*)Py_None;
|
|
} else {
|
|
return (PyObject*)result;
|
|
}
|
|
}
|
|
|
|
void eval_frame_callback_set(PyObject* obj) {
|
|
PyThread_tss_set(&eval_frame_callback_key, obj);
|
|
}
|
|
|
|
// 3.15 Not supported at all. See cpython_defs.c for hints
|
|
// 3.14 currently not fully supported on Windows
|
|
#if !(IS_PYTHON_3_15_PLUS || (IS_PYTHON_3_14_PLUS && defined(_WIN32)))
|
|
|
|
#define DECLARE_PYOBJ_ATTR(name) \
|
|
static PyObject* THPPyInterpreterFrame_##name( \
|
|
THPPyInterpreterFrame* self, PyObject* _noargs) { \
|
|
PyObject* res = (PyObject*)self->frame->name; \
|
|
Py_XINCREF(res); \
|
|
return res; \
|
|
}
|
|
|
|
DECLARE_PYOBJ_ATTR(f_globals)
|
|
DECLARE_PYOBJ_ATTR(f_builtins)
|
|
|
|
static PyObject* THPPyInterpreterFrame_f_locals(
|
|
THPPyInterpreterFrame* self,
|
|
PyObject* _noargs) {
|
|
DEBUG_NULL_CHECK(self->locals);
|
|
Py_XINCREF(self->locals);
|
|
return self->locals;
|
|
}
|
|
|
|
#if IS_PYTHON_3_14_PLUS
|
|
static PyObject* THPPyInterpreterFrame_f_executable(
|
|
THPPyInterpreterFrame* self,
|
|
PyObject* _noargs) {
|
|
return PyStackRef_AsPyObjectNew(self->frame->f_executable);
|
|
}
|
|
#elif IS_PYTHON_3_13_PLUS
|
|
DECLARE_PYOBJ_ATTR(f_executable)
|
|
#else
|
|
DECLARE_PYOBJ_ATTR(f_code)
|
|
#endif
|
|
|
|
#undef DECLARE_PYOBJ_ATTR
|
|
|
|
// This is not a true attribute of the class but we do access it in python and
|
|
// it is hard to implement on the python side, so do it here:
|
|
static PyObject* THPPyInterpreterFrame_f_lasti(
|
|
THPPyInterpreterFrame* self,
|
|
PyObject* _noargs) {
|
|
#if IS_PYTHON_3_11_PLUS
|
|
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
|
|
#else
|
|
return PyLong_FromLong(self->frame->f_lasti);
|
|
#endif // IS_PYTHON_3_11_PLUS
|
|
}
|
|
|
|
static PyObject* THPPyInterpreterFrame_f_lineno(
|
|
THPPyInterpreterFrame* self,
|
|
PyObject* _noargs) {
|
|
#if IS_PYTHON_3_11_PLUS
|
|
if (!self->frame->frame_obj) {
|
|
return PyLong_FromLong(F_CODE(self->frame)->co_firstlineno);
|
|
}
|
|
int lineno = PyFrame_GetLineNumber(self->frame->frame_obj);
|
|
if (lineno < 0) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
return PyLong_FromLong(lineno);
|
|
#else
|
|
return PyLong_FromLong(self->frame->f_lineno);
|
|
#endif // IS_PYTHON_3_11_PLUS
|
|
}
|
|
|
|
static PyObject* THPPyInterpreterFrame_f_back(
|
|
THPPyInterpreterFrame* self,
|
|
PyObject* _noargs) {
|
|
#if IS_PYTHON_3_11_PLUS
|
|
if (!self->frame->frame_obj) {
|
|
Py_RETURN_NONE;
|
|
}
|
|
return (PyObject*)PyFrame_GetBack(self->frame->frame_obj);
|
|
#else
|
|
return Py_XNewRef(self->frame->f_back);
|
|
#endif // IS_PYTHON_3_11_PLUS
|
|
}
|
|
|
|
static PyObject* THPPyInterpreterFrame_closure(
|
|
THPPyInterpreterFrame* self,
|
|
PyObject* _noargs) {
|
|
#if IS_PYTHON_3_11_PLUS
|
|
PyObject* closure = FUNC(self->frame)->func_closure;
|
|
return closure == NULL ? PyTuple_New(0) : Py_XNewRef(closure);
|
|
#else
|
|
PyCodeObject* code = self->frame->f_code;
|
|
// Why this check? See
|
|
// https://github.com/python/cpython/blob/5f24da9d75bb0150781b17ee4706e93e6bb364ea/Objects/frameobject.c#L1058-L1065
|
|
if (code->co_flags & CO_OPTIMIZED) {
|
|
int size = PyTuple_GET_SIZE(code->co_freevars);
|
|
PyObject* freevars = PyTuple_New(size);
|
|
int ncells = PyTuple_GET_SIZE(code->co_cellvars);
|
|
PyObject** freevarArr =
|
|
self->frame->f_localsplus + code->co_nlocals + ncells;
|
|
for (int i = 0; i < size; i++) {
|
|
PyTuple_SET_ITEM(freevars, i, Py_XNewRef(freevarArr[i]));
|
|
}
|
|
return freevars;
|
|
}
|
|
return PyTuple_New(0);
|
|
#endif // IS_PYTHON_3_11_PLUS
|
|
}
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
|
|
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {
|
|
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
|
|
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
|
|
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
|
|
#if IS_PYTHON_3_13_PLUS
|
|
{"f_code", (getter)THPPyInterpreterFrame_f_executable, NULL, NULL, NULL},
|
|
#else
|
|
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
|
|
#endif
|
|
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
|
|
{"f_lineno", (getter)THPPyInterpreterFrame_f_lineno, NULL, NULL, NULL},
|
|
{"f_back", (getter)THPPyInterpreterFrame_f_back, NULL, NULL, NULL},
|
|
{"closure", (getter)THPPyInterpreterFrame_closure, NULL, NULL, NULL},
|
|
{NULL}};
|
|
|
|
static PyTypeObject THPPyInterpreterFrameType = {
|
|
PyVarObject_HEAD_INIT(NULL, 0)
|
|
.tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame",
|
|
.tp_basicsize = sizeof(THPPyInterpreterFrame),
|
|
.tp_flags = Py_TPFLAGS_DEFAULT,
|
|
.tp_getset = THPPyInterpreterFrame_properties,
|
|
};
|
|
|
|
THPPyInterpreterFrame* THPPyInterpreterFrame_New(
|
|
THP_EVAL_API_FRAME_OBJECT* frame) {
|
|
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
|
|
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
|
|
if (!self)
|
|
return NULL;
|
|
self->frame = frame;
|
|
self->locals = NULL;
|
|
return self;
|
|
}
|
|
|
|
static PyObject* dynamo__custom_eval_frame_shim(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag);
|
|
static PyObject* (*previous_eval_frame)(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag) = NULL;
|
|
|
|
static PyObject* dynamo_custom_eval_frame_shim(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag) {
|
|
return dynamo__custom_eval_frame_shim(tstate, frame, throw_flag);
|
|
}
|
|
|
|
PyObject* dynamo_eval_frame_default(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag) {
|
|
if (tstate == NULL) {
|
|
tstate = PyThreadState_GET();
|
|
}
|
|
if (previous_eval_frame) {
|
|
return previous_eval_frame(tstate, frame, throw_flag);
|
|
} else {
|
|
return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
|
|
}
|
|
}
|
|
|
|
static void enable_eval_frame_shim(PyThreadState* tstate) {
|
|
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
|
|
&dynamo_custom_eval_frame_shim) {
|
|
DEBUG_CHECK(previous_eval_frame == NULL);
|
|
previous_eval_frame = _PyInterpreterState_GetEvalFrameFunc(tstate->interp);
|
|
_PyInterpreterState_SetEvalFrameFunc(
|
|
tstate->interp, &dynamo_custom_eval_frame_shim);
|
|
}
|
|
}
|
|
|
|
static void enable_eval_frame_default(PyThreadState* tstate) {
|
|
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
|
|
previous_eval_frame) {
|
|
DEBUG_CHECK(previous_eval_frame != NULL);
|
|
_PyInterpreterState_SetEvalFrameFunc(tstate->interp, previous_eval_frame);
|
|
previous_eval_frame = NULL;
|
|
}
|
|
}
|
|
|
|
const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
|
|
// Returns the C string name of the current frame.
|
|
DEBUG_CHECK(PyUnicode_Check(F_CODE(frame)->co_name));
|
|
return PyUnicode_AsUTF8(F_CODE(frame)->co_name);
|
|
}
|
|
|
|
#if IS_PYTHON_3_14_PLUS
|
|
static void dup_obj(_PyStackRef* dst, _PyStackRef src) {
|
|
*dst = PyStackRef_DUP(src);
|
|
}
|
|
#else
|
|
static void dup_obj(PyObject** dst, PyObject* src) {
|
|
Py_XINCREF(src);
|
|
*dst = src;
|
|
}
|
|
#endif
|
|
|
|
static PyObject* dynamo_eval_custom_code_impl(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
PyCodeObject* code,
|
|
int throw_flag) {
|
|
DEBUG_NULL_CHECK(tstate);
|
|
DEBUG_NULL_CHECK(frame);
|
|
DEBUG_NULL_CHECK(code);
|
|
|
|
#if IS_PYTHON_3_11_PLUS
|
|
|
|
// Generate Python function object and _PyInterpreterFrame in a way similar to
|
|
// https://github.com/python/cpython/blob/e715da6db1d1d70cd779dc48e1ba8110c51cc1bf/Python/ceval.c#L1130
|
|
PyFunctionObject* old_func = FUNC(frame);
|
|
#if IS_PYTHON_3_12_PLUS
|
|
size_t size = code->co_framesize;
|
|
#else
|
|
size_t size = code->co_nlocalsplus + code->co_stacksize + FRAME_SPECIALS_SIZE;
|
|
#endif
|
|
|
|
PyFunctionObject* func = _PyFunction_CopyWithNewCode(old_func, code);
|
|
if (func == NULL) {
|
|
return NULL;
|
|
}
|
|
|
|
THP_EVAL_API_FRAME_OBJECT* shadow =
|
|
THP_PyThreadState_BumpFramePointerSlow(tstate, size);
|
|
if (shadow == NULL) {
|
|
Py_DECREF(func);
|
|
return NULL;
|
|
}
|
|
|
|
Py_INCREF(func);
|
|
// consumes reference to func
|
|
#if IS_PYTHON_3_14_PLUS
|
|
_PyStackRef func_stackref = PyStackRef_FromPyObjectSteal((PyObject*)func);
|
|
_PyFrame_Initialize(
|
|
tstate, shadow, func_stackref, NULL, code, 0, frame->previous);
|
|
#elif IS_PYTHON_3_12_PLUS
|
|
_PyFrame_Initialize(shadow, func, NULL, code, 0);
|
|
#else
|
|
_PyFrame_InitializeSpecials(shadow, func, NULL, code->co_nlocalsplus);
|
|
#endif
|
|
|
|
#if IS_PYTHON_3_14_PLUS
|
|
_PyStackRef* fastlocals_old = frame->localsplus;
|
|
_PyStackRef* fastlocals_new = shadow->localsplus;
|
|
#else
|
|
PyObject** fastlocals_old = frame->localsplus;
|
|
PyObject** fastlocals_new = shadow->localsplus;
|
|
#endif
|
|
Py_ssize_t n_old = F_CODE(frame)->co_nlocalsplus;
|
|
Py_ssize_t n_new = code->co_nlocalsplus;
|
|
|
|
// localsplus are XINCREF'd by default eval frame, so all values must be
|
|
// valid.
|
|
#if !(IS_PYTHON_3_12_PLUS)
|
|
// _PyFrame_Initialize in 3.12 already does this
|
|
for (int i = 0; i < code->co_nlocalsplus; i++) {
|
|
fastlocals_new[i] = NULL;
|
|
}
|
|
#endif
|
|
|
|
#else
|
|
|
|
THP_EVAL_API_FRAME_OBJECT* shadow =
|
|
PyFrame_New(tstate, code, frame->f_globals, NULL);
|
|
if (shadow == NULL) {
|
|
return NULL;
|
|
}
|
|
|
|
PyObject** fastlocals_old = frame->f_localsplus;
|
|
PyObject** fastlocals_new = shadow->f_localsplus;
|
|
Py_ssize_t n_old = F_CODE(frame)->co_nlocals +
|
|
PyCode_GetNFreevars(F_CODE(frame)) + PyCode_GetNCellvars(F_CODE(frame));
|
|
Py_ssize_t n_new =
|
|
code->co_nlocals + PyCode_GetNFreevars(code) + PyCode_GetNCellvars(code);
|
|
|
|
#endif
|
|
|
|
// ============== Initialize new frame from old frame ============
|
|
// Python internal for executing a function:
|
|
// 1. CPython interpreter first creates an empty frame according to the code
|
|
// object
|
|
// 2. CPython interpreter initializes the frame by filling arguments/free
|
|
// variables into frame and initializing cell variables
|
|
// 3. CPython interpreter executes the code object
|
|
//
|
|
// Dynamo hooks the 3th step: before executing the code object, Dynamo
|
|
// transforms the code object into a new code object. Then, the old frame is
|
|
// not suitable for executing the new code. Therefore, Dynamo needs to
|
|
// manually create and initialize a new frame to execute the new code. The
|
|
// main task is to copy data in old frame to new frame, concerning a storage
|
|
// space named `localsplus`.
|
|
//
|
|
// localsplus storage is an array with the following layout:
|
|
// | args | new_locals | cell_variables | free_variables |
|
|
// | <--- from left to right, index from 0 to n - 1 ---> |
|
|
// code.co_varnames == args + new_locals, code.co_nlocals ==
|
|
// len(code.co_varnames) code.co_freevars == free_variables In Python 3.10 and
|
|
// lower, `n == code.co_nlocals + len(code.co_cellvars) +
|
|
// len(code.co_freevars)` (Python expression) In Python 3.11 and higher, `n <=
|
|
// code.co_nlocals + len(code.co_cellvars) + len(code.co_freevars)` (Python
|
|
// expression). There is an extra field in Python C-API: `n ==
|
|
// code->co_nlocalsplus` (C expression) to retrieve the length of array. The
|
|
// complexity happens if an argument becomes a cell variable:
|
|
// In Python 3.10 and lower, `code.co_cellvars == cell_variables`, and the
|
|
// corresponding slot in args becomes `NULL`. In Python 3.11 and higher,
|
|
// `code.co_cellvars > cell_variables`, that cell variable is still stored in
|
|
// args, with a flag set in corresponding item's `co_localspluskinds` .
|
|
//
|
|
// ideally, we need to look up new localsplus from old localsplus by name:
|
|
// for i, name, value in enumerate(localsplusnames_old):
|
|
// if value != NULL: (NULL happens for new local variables and arguments
|
|
// that becomes cell variables)
|
|
// name_to_idx[name] = i
|
|
// for i, name in enumerate(localsplusnames_new):
|
|
// if name in name_to_idx:
|
|
// fastlocals_new[i] = fastlocals_old[name_to_idx[name]]
|
|
//
|
|
// The above process of building a `name_to_idx` mapping is expensive.
|
|
// Dynamo makes the following assumptions:
|
|
// 1. new code has the same arguments as the old code (both the number and
|
|
// the order)
|
|
// 2. new code has the same cell variables as the old code (both the number
|
|
// and the order)
|
|
// 3. new code has the same free variables as the old code (both the number
|
|
// and the order) The only flexibility lies in new local variables: new code
|
|
// can introduce their own variables.
|
|
// With these assumptions, Dynamo can copy data directly by index. Dynamo just
|
|
// needs to take care of copying cell variables correctly. To avoid runtime
|
|
// cost, the assumptions are checked when we first generate the code object in
|
|
// pytorch/torch/_dynamo/convert_frame.py .
|
|
|
|
// copy args
|
|
// according to https://docs.python.org/3/library/inspect.html , `co_argcount`
|
|
// is the number of arguments (not including keyword only arguments, * or **
|
|
// args). so we need to add `co_kwonlyargcount` and `co_flags` to get the
|
|
// total number of arguments.
|
|
// !!(F_CODE(frame)->co_flags & CO_VARARGS) is 1 if the function has *args, 0
|
|
// otherwise
|
|
// !!(F_CODE(frame)->co_flags & CO_VARKEYWORDS) is 1 if the function has
|
|
// **kwargs, 0 otherwise they convert bit flags to 0 or 1, and avoid
|
|
// branching. This is performance critical code, so we really care about
|
|
// performance.
|
|
Py_ssize_t total_argcount_old = F_CODE(frame)->co_argcount +
|
|
F_CODE(frame)->co_kwonlyargcount +
|
|
!!(F_CODE(frame)->co_flags & CO_VARARGS) +
|
|
!!(F_CODE(frame)->co_flags & CO_VARKEYWORDS);
|
|
|
|
for (Py_ssize_t i = 0; i < total_argcount_old; i++) {
|
|
dup_obj(&fastlocals_new[i], fastlocals_old[i]);
|
|
}
|
|
|
|
// copy free vars
|
|
Py_ssize_t nfrees_old = PyCode_GetNFreevars(F_CODE(frame));
|
|
|
|
for (Py_ssize_t i = 0; i < nfrees_old; i++) {
|
|
dup_obj(&fastlocals_new[n_new - 1 - i], fastlocals_old[n_old - 1 - i]);
|
|
}
|
|
|
|
// copy cell vars, from high index to low index, until it meets a variable
|
|
// that is not cell variable.
|
|
for (Py_ssize_t i = n_old - nfrees_old - 1, j = n_new - nfrees_old - 1;
|
|
i >= total_argcount_old;
|
|
i--, j--) {
|
|
// conditional test to tell if a variable is not a cell variable
|
|
// this is straightforward in Python 3.11 and higher, as there are bit flags
|
|
// in `co_localspluskinds` to tell if a variable is a cell variable. in
|
|
// Python 3.10 and lower, essentially we are checking if a variable is a new
|
|
// local variable (because of the layout mentioned above, the first variable
|
|
// that is not cell variable is the first new local variable). the
|
|
// corresponding slot in `flocalsplus` is NULL for new local variables.
|
|
#if IS_PYTHON_3_11_PLUS
|
|
if (!(_PyLocals_GetKind(F_CODE(frame)->co_localspluskinds, i) &
|
|
CO_FAST_CELL)) {
|
|
break;
|
|
}
|
|
#else
|
|
if (fastlocals_old[i] == NULL) {
|
|
break;
|
|
}
|
|
#endif
|
|
|
|
dup_obj(&fastlocals_new[j], fastlocals_old[i]);
|
|
}
|
|
|
|
// NOTE: if you want to evaluate frame instead of shadow in 3.12+,
|
|
// you need to clear_old_frame_if_python_312_plus the shadow frame BEFORE
|
|
// calling eval_frame_default (i.e. here) and comment out the
|
|
// clear_old_frame_if_python_312_plus call on the original frame.
|
|
|
|
PyObject* result = dynamo_eval_frame_default(tstate, shadow, throw_flag);
|
|
|
|
#if IS_PYTHON_3_12_PLUS
|
|
|
|
// frame is cleared by caller
|
|
Py_DECREF(func);
|
|
|
|
#elif IS_PYTHON_3_11_PLUS
|
|
|
|
// In 3.11, shadow has is_entry set to true, so _PyEvalFrameClearAndPop is not
|
|
// called, so we manually clear and pop the shadow frame.
|
|
THP_PyFrame_Clear(shadow);
|
|
THP_PyThreadState_PopFrame(tstate, shadow);
|
|
Py_DECREF(func);
|
|
|
|
#else
|
|
|
|
Py_DECREF(shadow);
|
|
|
|
#endif
|
|
|
|
return result;
|
|
}
|
|
|
|
// This wrapper function adds a profiler event
|
|
PyObject* dynamo_eval_custom_code(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
PyCodeObject* code,
|
|
const char* trace_annotation,
|
|
int throw_flag) {
|
|
_PytorchRecordFunctionState* rf =
|
|
_pytorch_record_function_enter(trace_annotation);
|
|
PyObject* result =
|
|
dynamo_eval_custom_code_impl(tstate, frame, code, throw_flag);
|
|
_pytorch_record_function_exit(rf);
|
|
return result;
|
|
}
|
|
|
|
static PyObject* dynamo__custom_eval_frame_shim(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag) {
|
|
// Shims logic into one of three states. Can probably be refactored into a
|
|
// single func, later:
|
|
// - None: disables TorchDynamo
|
|
// - False: run-only mode (reuse existing compiles)
|
|
// - Python callable(): enables TorchDynamo
|
|
PyObject* callback = eval_frame_callback_get();
|
|
|
|
if (callback == Py_None) {
|
|
return dynamo_eval_frame_default(tstate, frame, throw_flag);
|
|
}
|
|
|
|
return dynamo__custom_eval_frame(tstate, frame, throw_flag, callback);
|
|
}
|
|
|
|
#else // !(IS_PYTHON_3_15_PLUS)
|
|
|
|
// Fake definitions for everything we removed
|
|
|
|
static void enable_eval_frame_shim(PyThreadState* tstate) {}
|
|
static void enable_eval_frame_default(PyThreadState* tstate) {}
|
|
PyObject* dynamo_eval_custom_code(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
PyCodeObject* code,
|
|
const char* trace_annotation,
|
|
int throw_flag) {
|
|
return NULL;
|
|
}
|
|
THPPyInterpreterFrame* THPPyInterpreterFrame_New(
|
|
THP_EVAL_API_FRAME_OBJECT* frame) {
|
|
return NULL;
|
|
}
|
|
PyObject* dynamo_eval_frame_default(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag) {
|
|
return NULL;
|
|
}
|
|
|
|
static struct PyGetSetDef THPPyInterpreterFrame_properties[] = {{NULL}};
|
|
|
|
static PyTypeObject THPPyInterpreterFrameType = {
|
|
PyVarObject_HEAD_INIT(NULL, 0)
|
|
.tp_name = "torch._C._dynamo.eval_frame._PyInterpreterFrame",
|
|
.tp_basicsize = sizeof(THPPyInterpreterFrame),
|
|
.tp_flags = Py_TPFLAGS_DEFAULT,
|
|
.tp_getset = THPPyInterpreterFrame_properties,
|
|
};
|
|
|
|
#endif // !(IS_PYTHON_3_15_PLUS)
|
|
|
|
void clear_old_frame_if_python_312_plus(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame) {
|
|
#if IS_PYTHON_3_12_PLUS
|
|
|
|
THP_PyFrame_Clear(frame);
|
|
THP_PyThreadState_PopFrame(tstate, frame);
|
|
|
|
#endif
|
|
}
|
|
|
|
static PyObject* increment_working_threads(
|
|
PyThreadState* tstate,
|
|
PyObject* module) {
|
|
ModuleState* state = PyModule_GetState(module);
|
|
|
|
if (state != NULL) {
|
|
state->active_dynamo_threads = state->active_dynamo_threads + 1;
|
|
if (state->active_dynamo_threads > 0) {
|
|
enable_eval_frame_shim(tstate);
|
|
}
|
|
}
|
|
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* decrement_working_threads(
|
|
PyThreadState* tstate,
|
|
PyObject* module) {
|
|
ModuleState* state = PyModule_GetState(module);
|
|
|
|
if (state != NULL) {
|
|
if (state->active_dynamo_threads > 0) {
|
|
state->active_dynamo_threads = state->active_dynamo_threads - 1;
|
|
if (state->active_dynamo_threads == 0) {
|
|
enable_eval_frame_default(tstate);
|
|
}
|
|
}
|
|
}
|
|
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* set_eval_frame(PyObject* new_callback, PyObject* module) {
|
|
// Change the eval frame callback and return the old one
|
|
// - None: disables TorchDynamo
|
|
// - False: run-only mode (reuse existing compiles)
|
|
// - Python callable(): enables TorchDynamo
|
|
PyObject* old_callback = eval_frame_callback_get();
|
|
|
|
// Common case: if Dynamo is actually off, we might see a lot of
|
|
// traffic setting the callback to None when it was already
|
|
// None. Skip messing with threading, thread-local storage, and
|
|
// reference counts.
|
|
if (old_callback != new_callback) {
|
|
if (new_callback == Py_None) {
|
|
decrement_working_threads(PyThreadState_GET(), module);
|
|
} else {
|
|
increment_working_threads(PyThreadState_GET(), module);
|
|
}
|
|
|
|
Py_INCREF(new_callback);
|
|
|
|
// Set thread local callback. This will drive behavior of our shim, if/when
|
|
// it is installed.
|
|
eval_frame_callback_set(new_callback);
|
|
|
|
// Transfer owned reference from eval_frame_callback_get() to caller
|
|
// without Py_DECREF/Py_INCREF.
|
|
} else {
|
|
// We retain a reference to old_callback because it's still the
|
|
// eval_frame_callback, so we need to give the caller their
|
|
// own reference.
|
|
Py_INCREF(old_callback);
|
|
}
|
|
|
|
return old_callback;
|
|
}
|
|
|
|
static PyObject* set_eval_frame_py(PyObject* module, PyObject* callback) {
|
|
if (callback != Py_None && callback != Py_False &&
|
|
!PyCallable_Check(callback)) {
|
|
DEBUG_TRACE0("arg error");
|
|
PyErr_SetString(PyExc_TypeError, "expected a callable");
|
|
return NULL;
|
|
}
|
|
DEBUG_TRACE(
|
|
"python enabled=%d and is run_only=%d",
|
|
callback != Py_None,
|
|
callback == Py_False);
|
|
return set_eval_frame(callback, module);
|
|
}
|
|
|
|
static PyObject* set_skip_guard_eval_unsafe(
|
|
PyObject* dummy,
|
|
PyObject* skip_guard_unsafe_flag) {
|
|
if (skip_guard_unsafe_flag != Py_False && skip_guard_unsafe_flag != Py_True) {
|
|
DEBUG_TRACE0("arg error");
|
|
PyErr_SetString(PyExc_TypeError, "expected True/False");
|
|
return NULL;
|
|
}
|
|
bool old_skip_guard_eval_unsafe = is_skip_guard_eval_unsafe;
|
|
is_skip_guard_eval_unsafe = skip_guard_unsafe_flag == Py_True;
|
|
if (old_skip_guard_eval_unsafe) {
|
|
Py_RETURN_TRUE;
|
|
}
|
|
Py_RETURN_FALSE;
|
|
}
|
|
|
|
static PyObject* get_eval_frame_callback_py(PyObject* dummy, PyObject* args) {
|
|
// New reference
|
|
PyObject* callback = eval_frame_callback_get();
|
|
Py_INCREF(callback);
|
|
return callback;
|
|
}
|
|
|
|
static PyObject* reset_code(PyObject* dummy, PyObject* code) {
|
|
if (!PyCode_Check(code)) {
|
|
DEBUG_TRACE0("arg error");
|
|
PyErr_SetString(PyExc_TypeError, "expected a code object");
|
|
return NULL;
|
|
}
|
|
|
|
// set_extra_state destroys the existing object on extra scratch space.
|
|
set_extra_state((PyCodeObject*)code, NULL);
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* unsupported(PyObject* dummy, PyObject* args) {
|
|
// a dummy C function used in testing
|
|
PyObject* obj1 = NULL;
|
|
PyObject* obj2 = NULL;
|
|
if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) {
|
|
return NULL;
|
|
}
|
|
Py_INCREF(obj2);
|
|
return obj2;
|
|
}
|
|
|
|
static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* obj) {
|
|
if (obj == Py_None) {
|
|
obj = NULL;
|
|
}
|
|
Py_XSETREF(guard_error_hook, Py_XNewRef(obj));
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static PyObject* set_guard_complete_hook(PyObject* dummy, PyObject* obj) {
|
|
PyObject* old_hook = guard_complete_hook;
|
|
|
|
if (obj == Py_None) {
|
|
obj = NULL;
|
|
}
|
|
|
|
guard_complete_hook = Py_XNewRef(obj);
|
|
|
|
if (old_hook == NULL) {
|
|
Py_RETURN_NONE;
|
|
} else {
|
|
return old_hook;
|
|
}
|
|
}
|
|
|
|
// Debugging function for GNU C only.
|
|
// Used to set gdb breakpoints in hot CPython sites from Python.
|
|
// Code example:
|
|
//
|
|
// def foo(x):
|
|
// x = x + 1
|
|
// torch._dynamo.eval_frame.raise_sigtrap()
|
|
// # (gdb) b bytecodes.c:1234 (whatever line CALL is handled)
|
|
// x = torch.sin(x) # gdb breakpoint hit when sin is called
|
|
//
|
|
// In this example, we want to breakpoint on CALL in bytecodes.c only when
|
|
// running foo. Otherwise, we would need to breakpoint before running the
|
|
// program, and that breakpoint would be hit every time Python makes a function
|
|
// call, leading to a spammy debugging experience.
|
|
static PyObject* raise_sigtrap(PyObject* dummy, PyObject* obj) {
|
|
#ifdef __GNUC__
|
|
raise(SIGTRAP);
|
|
#endif
|
|
Py_RETURN_NONE;
|
|
}
|
|
|
|
static int clear_state(PyObject* module) {
|
|
ModuleState* state = PyModule_GetState(module);
|
|
if (state) {
|
|
state->active_dynamo_threads = 0;
|
|
return 0;
|
|
}
|
|
return -1;
|
|
}
|
|
|
|
bool is_skip_guard_eval_unsafe = false;
|
|
|
|
static PyMethodDef _methods[] = {
|
|
{"set_eval_frame", set_eval_frame_py, METH_O, NULL},
|
|
{"set_skip_guard_eval_unsafe", set_skip_guard_eval_unsafe, METH_O, NULL},
|
|
{"get_eval_frame_callback", get_eval_frame_callback_py, METH_NOARGS, NULL},
|
|
{"reset_code", reset_code, METH_O, NULL},
|
|
{"unsupported", unsupported, METH_VARARGS, NULL},
|
|
{"set_code_exec_strategy", set_code_exec_strategy, METH_VARARGS, NULL},
|
|
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
|
|
{"set_guard_complete_hook", set_guard_complete_hook, METH_O, NULL},
|
|
{"raise_sigtrap", raise_sigtrap, METH_NOARGS, NULL},
|
|
{NULL, NULL, 0, NULL}};
|
|
|
|
static struct PyModuleDef _module = {
|
|
PyModuleDef_HEAD_INIT,
|
|
.m_name = "torch._C._dynamo.eval_frame",
|
|
.m_doc = "Module containing hooks to override eval_frame",
|
|
.m_size = sizeof(ModuleState),
|
|
.m_methods = _methods,
|
|
.m_clear = clear_state};
|
|
|
|
#if IS_PYTHON_3_12_PLUS
|
|
#define _PyEval_RequestCodeExtraIndex PyUnstable_Eval_RequestCodeExtraIndex
|
|
#endif
|
|
|
|
PyObject* torch_c_dynamo_eval_frame_init(void) {
|
|
extra_index = _PyEval_RequestCodeExtraIndex(destroy_extra_state);
|
|
if (extra_index < 0) {
|
|
PyErr_SetString(
|
|
PyExc_RuntimeError, "dynamo: unable to register extra index");
|
|
return NULL;
|
|
}
|
|
|
|
int result = PyThread_tss_create(&eval_frame_callback_key);
|
|
CHECK(result == 0);
|
|
|
|
Py_INCREF(Py_None);
|
|
eval_frame_callback_set(Py_None);
|
|
|
|
PyObject* module = PyModule_Create(&_module);
|
|
if (module == NULL) {
|
|
return NULL;
|
|
}
|
|
|
|
#ifdef Py_GIL_DISABLED
|
|
PyUnstable_Module_SetGIL(module, Py_MOD_GIL_NOT_USED);
|
|
#endif
|
|
|
|
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
|
|
return NULL;
|
|
}
|
|
Py_INCREF(&THPPyInterpreterFrameType);
|
|
if (PyModule_AddObject(
|
|
module,
|
|
"_PyInterpreterFrame",
|
|
(PyObject*)&THPPyInterpreterFrameType) != 0) {
|
|
return NULL;
|
|
}
|
|
|
|
return module;
|
|
}
|