Files
pytorch/torch/csrc/dynamo/eval_frame.c
Edward Z. Yang 80ce1a934e Fix flaky Dynamo export tests (#96488)
Planning to do a full writeup later. The short story is, sometimes the following chain of events happens:

1. We turn on Dynamo's custom frame handler
2. GC triggers (and all of the finalizers run under Dynamo)
3. GC hits a GeneratorExit frame
4. You end up in the custom frame handler with throw_flag == TRUE and PyErr_Occurred() != NULL

If this happens and we blindly call into other Python functions (like the Python callback), the executed Python code will immediately raise an exception (because there's already an ambient exception set.) This is very, very confusing. The fix is to defer to the regular handler when throw_flag is TRUE.

I triggered this locally with

```
PYTHONUNBUFFERED=1 pytest test/dynamo/test_dynamic_shapes.py   -k 'Unspec and export and not dupes and not reorder' -v -x -s
```

But I also have some tests which trigger the problem synthetically.

Fixes https://github.com/pytorch/pytorch/issues/93781

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96488
Approved by: https://github.com/albanD
2023-03-10 21:51:54 +00:00

978 lines
31 KiB
C

#define PY_SSIZE_T_CLEAN
#include <torch/csrc/utils/python_compat.h>
#include <opcode.h>
#include <stdbool.h>
// see https://bugs.python.org/issue35886
#if PY_VERSION_HEX >= 0x03080000
#define Py_BUILD_CORE
#include <internal/pycore_pystate.h>
// These headers were added in 3.11
#if IS_PYTHON_3_11_PLUS
#include <internal/pycore_frame.h>
#define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt
#include <internal/pycore_opcode.h>
#undef NEED_OPCODE_TABLES
#endif
#undef Py_BUILD_CORE
#endif // PY_VERSION_HEX >= 0x03080000
// All the eval APIs change in 3.11 so we need to decide which one to use on the fly
// https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction
#if IS_PYTHON_3_11_PLUS
#define THP_EVAL_API_FRAME_OBJECT _PyInterpreterFrame
// The next two functions are taken from
// https://github.com/python/cpython/blob/a7715ccfba5b86ab09f86ec56ac3755c93b46b48/Objects/frameobject.c#L1182
// These are not exported by the CPython binary and thus we have
// to get our own implementation of them.
// As a simple way to reduce the impact of ABI changes on the CPython side, this check forces
// us to manually re-check that the function didn't change on the next major version
#if PY_VERSION_HEX >= 0x030C0000 // 3.12
#error "Please ensure that the functions below still match the CPython implementation for 3.12"
#endif
static int
_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg)
{
// This only works when opcode is a non-quickened form:
assert(_PyOpcode_Deopt[opcode] == opcode);
int check_oparg = 0;
for (_Py_CODEUNIT *instruction = _PyCode_CODE(frame->f_code);
instruction < frame->prev_instr; instruction++)
{
int check_opcode = _PyOpcode_Deopt[_Py_OPCODE(*instruction)];
check_oparg |= _Py_OPARG(*instruction);
if (check_opcode == opcode && check_oparg == oparg) {
return 1;
}
if (check_opcode == EXTENDED_ARG) {
check_oparg <<= 8;
}
else {
check_oparg = 0;
}
instruction += _PyOpcode_Caches[check_opcode];
}
return 0;
}
int
THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame) {
/* Merge fast locals into f->f_locals */
PyObject *locals;
PyObject **fast;
PyCodeObject *co;
locals = frame->f_locals;
if (locals == NULL) {
locals = frame->f_locals = PyDict_New();
if (locals == NULL)
return -1;
}
co = frame->f_code;
fast = _PyFrame_GetLocalsArray(frame);
// COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt
// here:
int lasti = _PyInterpreterFrame_LASTI(frame);
if (lasti < 0 && _Py_OPCODE(_PyCode_CODE(co)[0]) == COPY_FREE_VARS) {
/* Free vars have not been initialized -- Do that */
PyCodeObject *co = frame->f_code;
PyObject *closure = frame->f_func->func_closure;
int offset = co->co_nlocals + co->co_nplaincellvars;
for (int i = 0; i < co->co_nfreevars; ++i) {
PyObject *o = PyTuple_GET_ITEM(closure, i);
Py_INCREF(o);
frame->localsplus[offset + i] = o;
}
// COPY_FREE_VARS doesn't have inline CACHEs, either:
frame->prev_instr = _PyCode_CODE(frame->f_code);
}
for (int i = 0; i < co->co_nlocalsplus; i++) {
_PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i);
/* If the namespace is unoptimized, then one of the
following cases applies:
1. It does not contain free variables, because it
uses import * or is a top-level namespace.
2. It is a class namespace.
We don't want to accidentally copy free variables
into the locals dict used by the class.
*/
if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) {
continue;
}
PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i);
PyObject *value = fast[i];
if (frame->stacktop) {
if (kind & CO_FAST_FREE) {
// The cell was set by COPY_FREE_VARS.
assert(value != NULL && PyCell_Check(value));
value = PyCell_GET(value);
}
else if (kind & CO_FAST_CELL) {
// Note that no *_DEREF ops can happen before MAKE_CELL
// executes. So there's no need to duplicate the work
// that MAKE_CELL would otherwise do later, if it hasn't
// run yet.
if (value != NULL) {
if (PyCell_Check(value) &&
_PyFrame_OpAlreadyRan(frame, MAKE_CELL, i)) {
// (likely) MAKE_CELL must have executed already.
value = PyCell_GET(value);
}
// (likely) Otherwise it it is an arg (kind & CO_FAST_LOCAL),
// with the initial value set when the frame was created...
// (unlikely) ...or it was set to some initial value by
// an earlier call to PyFrame_LocalsToFast().
}
}
}
else {
assert(value == NULL);
}
if (value == NULL) {
if (PyObject_DelItem(locals, name) != 0) {
if (PyErr_ExceptionMatches(PyExc_KeyError)) {
PyErr_Clear();
}
else {
return -1;
}
}
}
else {
if (PyObject_SetItem(locals, name, value) != 0) {
return -1;
}
}
}
return 0;
}
// We need to be able to return the _PyInterpreterFrame to python so create
// a python binding for it
typedef struct THPPyInterpreterFrame {
PyObject_HEAD
_PyInterpreterFrame* frame; // Borrowed reference
} THPPyInterpreterFrame;
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame);
#define DECLARE_PYOBJ_ATTR(name) \
static PyObject* THPPyInterpreterFrame_##name(THPPyInterpreterFrame* self, PyObject* _noargs) { \
PyObject* res = (PyObject*)self->frame->name; \
Py_XINCREF(res); \
return res; \
}
DECLARE_PYOBJ_ATTR(f_func)
DECLARE_PYOBJ_ATTR(f_globals)
DECLARE_PYOBJ_ATTR(f_builtins)
DECLARE_PYOBJ_ATTR(f_locals)
DECLARE_PYOBJ_ATTR(f_code)
DECLARE_PYOBJ_ATTR(frame_obj)
#undef DECLARE_PYOBJ_ATTR
static THPPyInterpreterFrame* THPPyInterpreterFrame_previous(THPPyInterpreterFrame* self, PyObject* _noargs) {
THPPyInterpreterFrame* res = THPPyInterpreterFrame_New(self->frame->previous);
return res;
}
// This is not a true attribute of the class but we do access it in python and it is hard to implement
// on the python side, so do it here:
static PyObject* THPPyInterpreterFrame_f_lasti(THPPyInterpreterFrame* self, PyObject* _noargs) {
return PyLong_FromLong(_PyInterpreterFrame_LASTI(self->frame));
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables,modernize-avoid-c-arrays)
static struct PyGetSetDef THPDevice_properties[] = {
{"f_func", (getter)THPPyInterpreterFrame_f_func, NULL, NULL, NULL},
{"f_globals", (getter)THPPyInterpreterFrame_f_globals, NULL, NULL, NULL},
{"f_builtins", (getter)THPPyInterpreterFrame_f_builtins, NULL, NULL, NULL},
{"f_locals", (getter)THPPyInterpreterFrame_f_locals, NULL, NULL, NULL},
{"f_code", (getter)THPPyInterpreterFrame_f_code, NULL, NULL, NULL},
{"frame_obj", (getter)THPPyInterpreterFrame_frame_obj, NULL, NULL, NULL},
{"previous", (getter)THPPyInterpreterFrame_previous, NULL, NULL, NULL},
{"f_lasti", (getter)THPPyInterpreterFrame_f_lasti, NULL, NULL, NULL},
{NULL}};
PyTypeObject THPPyInterpreterFrameType = {
PyVarObject_HEAD_INIT(NULL, 0) "torch._C.dynamo.eval_frame._PyInterpreterFrame", /* tp_name */
sizeof(THPPyInterpreterFrame), /* tp_basicsize */
0, /* tp_itemsize */
NULL, /* tp_dealloc */
0, /* tp_vectorcall_offset */
NULL, /* tp_getattr */
NULL, /* tp_setattr */
NULL, /* tp_reserved */
NULL, /* tp_repr */
NULL, /* tp_as_number */
NULL, /* tp_as_sequence */
NULL, /* tp_as_mapping */
NULL, /* tp_hash */
NULL, /* tp_call */
NULL, /* tp_str */
NULL, /* tp_getattro */
NULL, /* tp_setattro */
NULL, /* tp_as_buffer */
Py_TPFLAGS_DEFAULT, /* tp_flags */
NULL, /* tp_doc */
NULL, /* tp_traverse */
NULL, /* tp_clear */
NULL, /* tp_richcompare */
0, /* tp_weaklistoffset */
NULL, /* tp_iter */
NULL, /* tp_iternext */
NULL, /* tp_methods */
NULL, /* tp_members */
THPDevice_properties, /* tp_getset */
NULL, /* tp_base */
NULL, /* tp_dict */
NULL, /* tp_descr_get */
NULL, /* tp_descr_set */
0, /* tp_dictoffset */
NULL, /* tp_init */
NULL, /* tp_alloc */
NULL, /* tp_new */
};
THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
PyTypeObject* type = (PyTypeObject*)&THPPyInterpreterFrameType;
THPPyInterpreterFrame* self = (THPPyInterpreterFrame*)type->tp_alloc(type, 0);
if (!self)
return NULL;
self->frame = frame;
return self;
}
#else
#define THP_EVAL_API_FRAME_OBJECT PyFrameObject
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
#endif
#ifdef _WIN32
#define unlikely(x) (x)
#else
#define unlikely(x) __builtin_expect((x), 0)
#endif
#define NULL_CHECK(val) \
if (unlikely((val) == NULL)) { \
fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
PyErr_Print(); \
abort(); \
} else { \
}
#define CHECK(cond) \
if (unlikely(!(cond))) { \
fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
abort(); \
} else { \
}
#ifdef TORCHDYNAMO_DEBUG
#define DEBUG_CHECK(cond) CHECK(cond)
#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
#define DEBUG_TRACE0(msg) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
#else
#define DEBUG_CHECK(cond)
#define DEBUG_NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...)
#define DEBUG_TRACE0(msg)
#endif
// Flag to just run a frame normally
#define SKIP_CODE ((void*)0x1)
static PyObject* noargs = NULL; /* cached empty tuple */
static PyObject* dotzerokey = NULL; /* ".0" */
static PyObject* guard_fail_hook = NULL;
static PyObject* guard_error_hook = NULL;
static PyObject* profiler_start_hook = NULL;
static PyObject* profiler_end_hook = NULL;
static PyObject* guard_profiler_name_str = NULL; /* cached py str */
size_t extra_index = -1;
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
inline static PyObject* eval_frame_callback_get(void) {
void* result = PyThread_tss_get(&eval_frame_callback_key);
if (unlikely(result == NULL)) {
Py_RETURN_NONE;
} else {
return (PyObject*)result;
}
}
inline static void eval_frame_callback_set(PyObject* obj) {
PyThread_tss_set(&eval_frame_callback_key, obj);
}
static void ignored(void* obj) {}
static PyObject* _custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag);
static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback);
#if PY_VERSION_HEX >= 0x03090000
static PyObject* custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag) {
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#else
static PyObject* custom_eval_frame_shim(THP_EVAL_API_FRAME_OBJECT* frame, int throw_flag) {
PyThreadState* tstate = PyThreadState_GET();
return _custom_eval_frame_shim(tstate, frame, throw_flag);
}
#endif
inline static PyObject* eval_frame_default(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag) {
#if PY_VERSION_HEX >= 0x03090000
if (tstate == NULL) {
tstate = PyThreadState_GET();
}
return _PyEval_EvalFrameDefault(tstate, frame, throw_flag);
#else
return _PyEval_EvalFrameDefault(frame, throw_flag);
#endif
}
inline static void enable_eval_frame_shim(PyThreadState* tstate) {
#if PY_VERSION_HEX >= 0x03090000
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
&custom_eval_frame_shim) {
_PyInterpreterState_SetEvalFrameFunc(
tstate->interp, &custom_eval_frame_shim);
}
#else
if (tstate->interp->eval_frame != &custom_eval_frame_shim) {
// First call
tstate->interp->eval_frame = &custom_eval_frame_shim;
}
#endif
}
inline static void enable_eval_frame_default(PyThreadState* tstate) {
#if PY_VERSION_HEX >= 0x03090000
if (_PyInterpreterState_GetEvalFrameFunc(tstate->interp) !=
&_PyEval_EvalFrameDefault) {
_PyInterpreterState_SetEvalFrameFunc(
tstate->interp, &_PyEval_EvalFrameDefault);
}
#else
if (tstate->interp->eval_frame != &_PyEval_EvalFrameDefault) {
// First call
tstate->interp->eval_frame = &_PyEval_EvalFrameDefault;
}
#endif
}
static inline PyObject* call_callback(
PyObject* callable,
THP_EVAL_API_FRAME_OBJECT* _frame,
long cache_len) {
#if IS_PYTHON_3_11_PLUS
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
#else
PyFrameObject* frame = _frame;
#endif
PyObject* args = Py_BuildValue("(Ol)", frame, cache_len);
if (args == NULL) {
return NULL;
}
PyObject* result = PyObject_CallObject(callable, args);
Py_DECREF(args);
return result;
}
typedef struct cache_entry {
// check the guards: lambda: <locals of user function>: bool
PyObject* check_fn;
// modified user bytecode (protected by check_fn's guards)
PyCodeObject* code;
// on a cache miss, linked list of next thing to try
struct cache_entry* next;
} CacheEntry;
static CacheEntry* create_cache_entry(
CacheEntry* next,
PyObject* guarded_code) {
CacheEntry* e = (CacheEntry*)malloc(sizeof(CacheEntry));
DEBUG_NULL_CHECK(e);
e->check_fn = PyObject_GetAttrString(guarded_code, "check_fn");
NULL_CHECK(e->check_fn);
e->code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code");
NULL_CHECK(e->code);
e->next = next;
return e;
}
static void destroy_cache_entry(CacheEntry* e) {
if (e == NULL || e == SKIP_CODE) {
return;
}
Py_XDECREF(e->check_fn);
Py_XDECREF(e->code);
destroy_cache_entry(e->next);
free(e);
}
inline static CacheEntry* get_extra(PyCodeObject* code) {
CacheEntry* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
return extra;
}
inline static void set_extra(PyCodeObject* code, CacheEntry* extra) {
// TODO(jansel): would it be faster to bypass this?
_PyCode_SetExtra((PyObject*)code, extra_index, extra);
}
#ifdef TORCHDYNAMO_DEBUG
inline static const char* name(THP_EVAL_API_FRAME_OBJECT* frame) {
DEBUG_CHECK(PyUnicode_Check(frame->f_code->co_name));
return PyUnicode_AsUTF8(frame->f_code->co_name);
}
#endif
static PyObject* call_guard_fail_hook(
PyObject* hook,
CacheEntry* e,
PyObject* f_locals) {
// call debugging logic when a guard fails
PyObject* args = PyTuple_Pack(
4,
e->check_fn,
e->code,
f_locals,
(e->next == NULL ? Py_True : Py_False));
if (args == NULL) return NULL;
PyObject* result = PyObject_CallObject(hook, args);
Py_DECREF(args);
return result;
}
static PyObject* call_profiler_start_hook(PyObject* name_str) {
if (profiler_start_hook == NULL) return NULL;
if (name_str == NULL) return NULL;
PyObject* args = PyTuple_Pack(1, name_str);
if (args == NULL) return NULL;
PyObject* result = PyObject_CallObject(profiler_start_hook, args);
Py_DECREF(args);
return result;
}
static void call_profiler_end_hook(PyObject* record) {
// 'record' obj is the return value of calling _start_hook()
if (profiler_end_hook == NULL) return;
if (record == NULL) return;
PyObject* args = PyTuple_Pack(1, record);
if (args == NULL) return;
PyObject* result = PyObject_CallObject(profiler_end_hook, args);
Py_XDECREF(result);
Py_DECREF(args);
}
// Return value: borrowed reference
// Is either Py_None or a PyCodeObject
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev) {
if (e == NULL) {
// NB: intentionally not using Py_RETURN_NONE, to return borrowed ref
return Py_None;
}
PyObject *f_locals = frame->f_locals;
PyObject* dotzero = PyDict_GetItem(f_locals, dotzerokey);
PyObject* valid = NULL;
if (unlikely(dotzero != NULL)) {
// .0 is a special variable name used for implicit args
PyObject* args = PyTuple_Pack(1, dotzero);
if (args == NULL) return NULL;
valid = PyObject_Call(e->check_fn, args, f_locals);
Py_DECREF(args);
} else {
valid = PyObject_Call(e->check_fn, noargs, f_locals);
}
if (unlikely(valid == NULL)) {
if (guard_error_hook != NULL) {
PyObject *type, *value, *traceback;
PyErr_Fetch(&type, &value, &traceback);
PyObject* r = call_guard_fail_hook(guard_error_hook, e, f_locals);
if (r == NULL) {
return NULL;
}
Py_DECREF(r);
PyErr_Restore(type, value, traceback);
}
return NULL;
}
Py_DECREF(valid);
if (valid == Py_True) {
// Keep the head as the most recently used cache entry.
// If the hit cache entry is not the head of the linked list,
// move it to the head
if (prev != NULL) {
CacheEntry* extra = get_extra(frame->f_code);
prev->next = e->next;
e->next = extra;
set_extra(frame->f_code, e);
}
return (PyObject*)e->code;
}
if (unlikely(guard_fail_hook != NULL)) {
PyObject* r = call_guard_fail_hook(guard_fail_hook, e, f_locals);
if (r == NULL) {
return NULL;
}
Py_DECREF(r);
}
return lookup(e->next, frame, e);
}
static long cache_size(CacheEntry* e) {
if (e == NULL) {
return 0;
}
return 1 + cache_size(e->next);
}
inline static PyObject* eval_custom_code(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
PyCodeObject* code,
int throw_flag) {
Py_ssize_t ncells = 0;
Py_ssize_t nfrees = 0;
Py_ssize_t nlocals_new = code->co_nlocals;
Py_ssize_t nlocals_old = frame->f_code->co_nlocals;
ncells = PyCode_GetNCellvars(code);
nfrees = PyCode_GetNFreevars(code);
DEBUG_NULL_CHECK(tstate);
DEBUG_NULL_CHECK(frame);
DEBUG_NULL_CHECK(code);
DEBUG_CHECK(ncells == PyTuple_GET_SIZE(frame->f_code->co_cellvars));
DEBUG_CHECK(nfrees == PyTuple_GET_SIZE(frame->f_code->co_freevars));
DEBUG_CHECK(nlocals_new >= nlocals_old);
PyFrameObject* shadow_obj = PyFrame_New(tstate, code, frame->f_globals, NULL);
#if IS_PYTHON_3_11_PLUS
THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj->f_frame;
#else
THP_EVAL_API_FRAME_OBJECT* shadow = shadow_obj;
#endif
if (shadow == NULL) {
return NULL;
}
#if IS_PYTHON_3_11_PLUS
PyObject** fastlocals_old = frame->localsplus;
PyObject** fastlocals_new = shadow->localsplus;
#else
PyObject** fastlocals_old = frame->f_localsplus;
PyObject** fastlocals_new = shadow->f_localsplus;
#endif
for (Py_ssize_t i = 0; i < nlocals_old; i++) {
Py_XINCREF(fastlocals_old[i]);
fastlocals_new[i] = fastlocals_old[i];
}
for (Py_ssize_t i = 0; i < ncells + nfrees; i++) {
Py_XINCREF(fastlocals_old[nlocals_old + i]);
fastlocals_new[nlocals_new + i] = fastlocals_old[nlocals_old + i];
}
PyObject* result = eval_frame_default(tstate, shadow, throw_flag);
Py_DECREF(shadow_obj);
return result;
}
static PyObject* _custom_eval_frame_shim(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag) {
// Shims logic into one of three states. Can probably be refactored into a
// single func, later:
// - None: disables TorchDynamo
// - False: run-only mode (reuse existing compiles)
// - Python callable(): enables TorchDynamo
PyObject* callback = eval_frame_callback_get();
if (callback == Py_None) {
return eval_frame_default(tstate, frame, throw_flag);
}
return _custom_eval_frame(tstate, frame, throw_flag, callback);
}
static PyObject* _custom_eval_frame(
PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame,
int throw_flag,
PyObject* callback) {
DEBUG_TRACE(
"begin %s %s %i %i %i %i",
name(frame),
PyUnicode_AsUTF8(frame->f_code->co_filename),
frame->f_lineno,
frame->f_lasti,
frame->f_iblock,
frame->f_executing);
if (throw_flag) {
// When unwinding generators, eval frame is called with throw_flag ==
// true. Frame evaluation is supposed to continue unwinding by propagating
// the exception. Dynamo doesn't really know how to do this, nor does it
// really want to do this, because there's unlikely any code to capture
// (you're going to immediately quit out of the frame, perhaps running
// some unwinding logic along the way). So we just run the default
// handler in this case.
//
// NB: A previous version of this patch returned NULL. This is wrong,
// because returning NULL is *different* from unwinding an exception.
// In particular, you will not execute things like context manager
// __exit__ if you just return NULL.
//
// NB: It's /conceivable/ that you might want to actually still call the
// Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
// do any stack unwinding code. But this is not really useful because
// (1) Dynamo doesn't actually know how to do stack unwinding, so it would
// immediately skip the frame, and (2) even if it did, this would only
// be profitable if there was tensor code in the unwinding code. Seems
// unlikely.
DEBUG_TRACE("throw %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
CacheEntry* extra = get_extra(frame->f_code);
if (extra == SKIP_CODE || (callback == Py_False && extra == NULL)) {
DEBUG_TRACE("skip %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
// TODO(jansel): investigate directly using the "fast" representation
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// even though we should pass a PyFrameObject.
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
DEBUG_TRACE("error %s", name(frame));
return NULL;
}
// A callback of Py_False indicates "run only" mode, the cache is checked, but
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", name(frame));
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
PyObject* maybe_cached_code = lookup(extra, frame, NULL);
call_profiler_end_hook(hook_record);
Py_XDECREF(hook_record);
if (maybe_cached_code == NULL) {
// guard eval failed, keep propagating
return NULL;
} else if (maybe_cached_code == Py_None) {
DEBUG_TRACE("cache miss %s", name(frame));
return eval_frame_default(tstate, frame, throw_flag);
}
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", name(frame));
return eval_custom_code(tstate, frame, cached_code, throw_flag);
}
DEBUG_CHECK(PyDict_CheckExact(frame->f_locals));
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
// We don't run the current custom_eval_frame behavior for guards.
// So we temporarily set the callback to Py_None to drive the correct behavior
// in the shim.
eval_frame_callback_set(Py_None);
PyObject* hook_record = call_profiler_start_hook(guard_profiler_name_str);
PyObject* maybe_cached_code = lookup(extra, frame, NULL);
call_profiler_end_hook(hook_record);
Py_XDECREF(hook_record);
if (maybe_cached_code == NULL) {
// Python error
return NULL;
} else if (maybe_cached_code != Py_None) {
PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code;
// used cached version
DEBUG_TRACE("cache hit %s", name(frame));
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, cached_code, throw_flag);
}
// cache miss
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// that gets re-interpreted as a PyObject (which it is NOT!)
PyObject* result =
call_callback(callback, frame, cache_size(extra));
if (result == NULL) {
// internal exception, returning here will leak the exception into user code
// this is useful for debugging -- but we dont want it to happen outside of
// testing
// NB: we intentionally DO NOT re-enable custom behavior to prevent
// cascading failure from internal exceptions. The upshot is if
// Dynamo barfs, that's it for Dynamo, even if you catch the exception
// inside the torch.compile block we won't try to Dynamo anything else.
return NULL;
} else if (result != Py_None) {
DEBUG_TRACE("create cache %s", name(frame));
extra = create_cache_entry(extra, result);
Py_DECREF(result);
set_extra(frame->f_code, extra);
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, extra->code, throw_flag);
} else {
DEBUG_TRACE("create skip %s", name(frame));
Py_DECREF(result);
destroy_cache_entry(extra);
set_extra(frame->f_code, SKIP_CODE);
// Re-enable custom behavior
eval_frame_callback_set(callback);
return eval_frame_default(tstate, frame, throw_flag);
}
}
static int active_dynamo_threads = 0;
static PyObject* increment_working_threads(PyThreadState* tstate) {
active_dynamo_threads = active_dynamo_threads + 1;
if (active_dynamo_threads > 0) {
enable_eval_frame_shim(tstate);
}
Py_RETURN_NONE;
}
static PyObject* decrement_working_threads(PyThreadState* tstate) {
if (active_dynamo_threads > 0) {
active_dynamo_threads = active_dynamo_threads - 1;
if (active_dynamo_threads == 0) {
enable_eval_frame_default(tstate);
}
}
Py_RETURN_NONE;
}
static PyObject* set_eval_frame(PyObject* new_callback, PyThreadState* tstate) {
// Change the eval frame callback and return the old one
// - None: disables TorchDynamo
// - False: run-only mode (reuse existing compiles)
// - Python callable(): enables TorchDynamo
PyObject* old_callback = eval_frame_callback_get();
// owned by caller
Py_INCREF(old_callback);
if (old_callback != Py_None && new_callback == Py_None) {
decrement_working_threads(tstate);
} else if (old_callback == Py_None && new_callback != Py_None) {
increment_working_threads(tstate);
}
Py_INCREF(new_callback);
Py_DECREF(old_callback);
// Set thread local callback. This will drive behavior of our shim, if/when it
// is installed.
eval_frame_callback_set(new_callback);
return old_callback;
}
static PyObject* set_eval_frame_py(PyObject* dummy, PyObject* args) {
PyObject* callback = NULL;
if (!PyArg_ParseTuple(args, "O:callback", &callback)) {
DEBUG_TRACE0("arg error");
return NULL;
}
if (callback != Py_None && callback != Py_False &&
!PyCallable_Check(callback)) {
DEBUG_TRACE0("arg error");
PyErr_SetString(PyExc_TypeError, "expected a callable");
return NULL;
}
DEBUG_TRACE(
"python enabled=%d and is run_only=%d",
callback != Py_None,
callback == Py_False);
return set_eval_frame(callback, PyThreadState_GET());
}
static PyObject* reset_code(PyObject* dummy, PyObject* args) {
PyObject* code = NULL;
if (!PyArg_ParseTuple(args, "O:code", &code)) {
DEBUG_TRACE0("arg error");
return NULL;
}
if (!PyCode_Check(code)) {
DEBUG_TRACE0("arg error");
PyErr_SetString(PyExc_TypeError, "expected a code object");
return NULL;
}
destroy_cache_entry(get_extra((PyCodeObject*)code));
set_extra((PyCodeObject*)code, NULL);
Py_RETURN_NONE;
}
static PyObject* unsupported(PyObject* dummy, PyObject* args) {
// a dummy C function used in testing
PyObject* obj1 = NULL;
PyObject* obj2 = NULL;
if (!PyArg_ParseTuple(args, "OO", &obj1, &obj2)) {
return NULL;
}
Py_INCREF(obj2);
return obj2;
}
static PyObject* skip_code(PyObject* dummy, PyObject* args) {
PyObject* obj = NULL;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
if (!PyCode_Check(obj)) {
PyErr_SetString(PyExc_TypeError, "expected a code object");
return NULL;
}
set_extra((PyCodeObject*)obj, SKIP_CODE);
Py_RETURN_NONE;
}
static PyObject* set_guard_fail_hook(PyObject* dummy, PyObject* args) {
PyObject* obj = NULL;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
Py_XDECREF(guard_fail_hook);
if (obj == Py_None) {
guard_fail_hook = NULL;
} else {
guard_fail_hook = obj;
Py_INCREF(guard_fail_hook);
}
Py_RETURN_NONE;
}
static PyObject* set_guard_error_hook(PyObject* dummy, PyObject* args) {
PyObject* obj = NULL;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return NULL;
}
Py_XDECREF(guard_error_hook);
if (obj == Py_None) {
guard_error_hook = NULL;
} else {
guard_error_hook = obj;
Py_INCREF(guard_error_hook);
}
Py_RETURN_NONE;
}
static PyObject* clear_profiler_hooks(PyObject* dummy, PyObject* args) {
Py_XDECREF(profiler_start_hook);
profiler_start_hook = NULL;
Py_XDECREF(profiler_end_hook);
profiler_end_hook = NULL;
Py_XDECREF(guard_profiler_name_str);
guard_profiler_name_str = NULL;
Py_RETURN_NONE;
}
static PyObject* set_profiler_hooks(PyObject* dummy, PyObject* args) {
PyObject* start = NULL;
PyObject* end = NULL;
if (!PyArg_ParseTuple(args, "OO", &start, &end)) {
return NULL;
}
Py_XDECREF(profiler_start_hook);
Py_XDECREF(profiler_end_hook);
if (start == Py_None || end == Py_None) {
clear_profiler_hooks(NULL, NULL);
} else {
profiler_start_hook = start;
profiler_end_hook = end;
Py_INCREF(profiler_start_hook);
Py_INCREF(profiler_end_hook);
}
Py_XDECREF(guard_profiler_name_str);
guard_profiler_name_str = Py_BuildValue("s", "TorchDynamo Cache Lookup");
Py_RETURN_NONE;
}
static PyMethodDef _methods[] = {
{"set_eval_frame", set_eval_frame_py, METH_VARARGS, NULL},
{"reset_code", reset_code, METH_VARARGS, NULL},
{"unsupported", unsupported, METH_VARARGS, NULL},
{"skip_code", skip_code, METH_VARARGS, NULL},
{"set_guard_fail_hook", set_guard_fail_hook, METH_VARARGS, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_VARARGS, NULL},
{"set_profiler_hooks", set_profiler_hooks, METH_VARARGS, NULL},
{"clear_profiler_hooks", clear_profiler_hooks, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}};
static struct PyModuleDef _module = {
PyModuleDef_HEAD_INIT,
"torch._C._dynamo.eval_frame",
"Module containing hooks to override eval_frame",
-1,
_methods};
PyObject* torch_c_dynamo_eval_frame_init(void) {
extra_index = _PyEval_RequestCodeExtraIndex(ignored);
int result = PyThread_tss_create(&eval_frame_callback_key);
CHECK(result == 0);
Py_INCREF(Py_None);
eval_frame_callback_set(Py_None);
noargs = PyTuple_New(0);
dotzerokey = PyUnicode_InternFromString(".0");
PyObject* module = PyModule_Create(&_module);
#if IS_PYTHON_3_11_PLUS
if (PyType_Ready(&THPPyInterpreterFrameType) < 0) {
return NULL;
}
Py_INCREF(&THPPyInterpreterFrameType);
if (PyModule_AddObject(module, "_PyInterpreterFrame", (PyObject*)&THPPyInterpreterFrameType) != 0) {
return NULL;
}
#endif
return module;
}