diff --git a/build_variables.bzl b/build_variables.bzl index d0f426de6118..c1d0b5dca25b 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -828,6 +828,7 @@ libtorch_python_core_sources = [ "torch/csrc/dynamo/cpython_defs.c", "torch/csrc/dynamo/eval_frame.c", "torch/csrc/dynamo/extra_state.cpp", + "torch/csrc/dynamo/framelocals_mapping.cpp", "torch/csrc/dynamo/guards.cpp", "torch/csrc/dynamo/init.cpp", "torch/csrc/functorch/init.cpp", diff --git a/test/dynamo_expected_failures/TestLinalgCPU.test_lobpcg_torchscript_cpu_float64 b/test/dynamo_skips/TestLinalgCPU.test_lobpcg_torchscript_cpu_float64 similarity index 100% rename from test/dynamo_expected_failures/TestLinalgCPU.test_lobpcg_torchscript_cpu_float64 rename to test/dynamo_skips/TestLinalgCPU.test_lobpcg_torchscript_cpu_float64 diff --git a/test/dynamo_expected_failures/TestScript.test_function_overloading_isinstance b/test/dynamo_skips/TestScript.test_function_overloading_isinstance similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_function_overloading_isinstance rename to test/dynamo_skips/TestScript.test_function_overloading_isinstance diff --git a/test/dynamo_expected_failures/TestScript.test_function_overloads b/test/dynamo_skips/TestScript.test_function_overloads similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_function_overloads rename to test/dynamo_skips/TestScript.test_function_overloads diff --git a/test/dynamo_expected_failures/TestScript.test_ignored_as_value b/test/dynamo_skips/TestScript.test_ignored_as_value similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_ignored_as_value rename to test/dynamo_skips/TestScript.test_ignored_as_value diff --git a/test/dynamo_expected_failures/TestScript.test_namedtuple_python b/test/dynamo_skips/TestScript.test_namedtuple_python similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_namedtuple_python rename to test/dynamo_skips/TestScript.test_namedtuple_python diff --git a/test/dynamo_expected_failures/TestScript.test_no_self_arg_ignore_function b/test/dynamo_skips/TestScript.test_no_self_arg_ignore_function similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_no_self_arg_ignore_function rename to test/dynamo_skips/TestScript.test_no_self_arg_ignore_function diff --git a/test/dynamo_expected_failures/TestScript.test_python_call_non_tensor_wrong b/test/dynamo_skips/TestScript.test_python_call_non_tensor_wrong similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_python_call_non_tensor_wrong rename to test/dynamo_skips/TestScript.test_python_call_non_tensor_wrong diff --git a/test/dynamo_expected_failures/TestScript.test_python_op_builtins b/test/dynamo_skips/TestScript.test_python_op_builtins similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_python_op_builtins rename to test/dynamo_skips/TestScript.test_python_op_builtins diff --git a/test/dynamo_expected_failures/TestScript.test_type_annotation_module b/test/dynamo_skips/TestScript.test_type_annotation_module similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_type_annotation_module rename to test/dynamo_skips/TestScript.test_type_annotation_module diff --git a/test/dynamo_expected_failures/TestScript.test_unused_decorator b/test/dynamo_skips/TestScript.test_unused_decorator similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_unused_decorator rename to test/dynamo_skips/TestScript.test_unused_decorator diff --git a/test/dynamo_expected_failures/TestScript.test_wrong_return_type b/test/dynamo_skips/TestScript.test_wrong_return_type similarity index 100% rename from test/dynamo_expected_failures/TestScript.test_wrong_return_type rename to test/dynamo_skips/TestScript.test_wrong_return_type diff --git a/torch/csrc/dynamo/cpython_defs.c b/torch/csrc/dynamo/cpython_defs.c index fe6ee6432b6c..ad0376e3bdac 100644 --- a/torch/csrc/dynamo/cpython_defs.c +++ b/torch/csrc/dynamo/cpython_defs.c @@ -1,31 +1,12 @@ #include - -#ifdef _WIN32 -#define unlikely(x) (x) -#else -#define unlikely(x) __builtin_expect((x), 0) -#endif - -#define CHECK(cond) \ - if (unlikely(!(cond))) { \ - fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \ - abort(); \ - } else { \ - } +#include +#include #if IS_PYTHON_3_11_PLUS -// Problem in CPython includes when mixing core and non-core build -// The fix was not backported to 3.12 so this is needed here -// https://github.com/python/cpython/issues/105268 -#if IS_PYTHON_3_12_PLUS -#undef _PyGC_FINALIZED -#endif - #define Py_BUILD_CORE -#include - #define NEED_OPCODE_TABLES // To get _PyOpcode_Deopt, _PyOpcode_Caches + #if IS_PYTHON_3_13_PLUS #include // To get PyUnstable_Code_GetFirstFree #define NEED_OPCODE_METADATA @@ -34,10 +15,8 @@ #else #include #endif + #undef NEED_OPCODE_TABLES - -#include - #undef Py_BUILD_CORE // As a simple way to reduce the impact of ABI changes on the CPython side, this check forces @@ -74,189 +53,10 @@ THP_PyFrame_OpAlreadyRan(_PyInterpreterFrame *frame, int opcode, int oparg) #if IS_PYTHON_3_12_PLUS -// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136 -// Initialize frame free variables if needed -// free_vars_copied argument added in order to let caller know that the COPY_FREE_VARS -// codepath occurred. -static void -frame_init_get_vars(_PyInterpreterFrame *frame, int *free_vars_copied) -{ - // COPY_FREE_VARS has no quickened forms, so no need to use _PyOpcode_Deopt - // here: - PyCodeObject *co = F_CODE(frame); - int lasti = _PyInterpreterFrame_LASTI(frame); - if (!(lasti < 0 && _PyCode_CODE(co)->op.code == COPY_FREE_VARS - && PyFunction_Check(frame->f_funcobj))) - { - /* Free vars are initialized */ - return; - } - - /* Free vars have not been initialized -- Do that */ - PyObject *closure = ((PyFunctionObject *)frame->f_funcobj)->func_closure; - #if IS_PYTHON_3_13_PLUS - int offset = PyUnstable_Code_GetFirstFree(co); - #else - int offset = PyCode_GetFirstFree(co); - #endif - for (int i = 0; i < co->co_nfreevars; ++i) { - PyObject *o = PyTuple_GET_ITEM(closure, i); - frame->localsplus[offset + i] = Py_NewRef(o); - } - // COPY_FREE_VARS doesn't have inline CACHEs, either: - PREV_INSTR(frame) = _PyCode_CODE(F_CODE(frame)); - - *free_vars_copied = 1; -} - -// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162 -static int -frame_get_var(_PyInterpreterFrame *frame, PyCodeObject *co, int i, - PyObject **pvalue) -{ - _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); - - /* If the namespace is unoptimized, then one of the - following cases applies: - 1. It does not contain free variables, because it - uses import * or is a top-level namespace. - 2. It is a class namespace. - We don't want to accidentally copy free variables - into the locals dict used by the class. - */ - if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) { - return 0; - } - - PyObject *value = frame->localsplus[i]; - if (frame->stacktop) { - if (kind & CO_FAST_FREE) { - // The cell was set by COPY_FREE_VARS. - CHECK(value != NULL && PyCell_Check(value)); - value = PyCell_GET(value); - } - else if (kind & CO_FAST_CELL) { - // Note that no *_DEREF ops can happen before MAKE_CELL - // executes. So there's no need to duplicate the work - // that MAKE_CELL would otherwise do later, if it hasn't - // run yet. - if (value != NULL) { - if (PyCell_Check(value) && - THP_PyFrame_OpAlreadyRan(frame, MAKE_CELL, i)) { - // (likely) MAKE_CELL must have executed already. - value = PyCell_GET(value); - } - // (likely) Otherwise it it is an arg (kind & CO_FAST_LOCAL), - // with the initial value set when the frame was created... - // (unlikely) ...or it was set to some initial value by - // an earlier call to PyFrame_LocalsToFast(). - } - } - } - else { - CHECK(value == NULL); - } - *pvalue = value; - return 1; -} - -// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213 -static PyObject * -THP_PyFrame_GetLocals(_PyInterpreterFrame *frame, int include_hidden, int *free_vars_copied) -{ - /* Merge fast locals into f->f_locals */ - PyObject *locals = frame->f_locals; - if (locals == NULL) { - locals = frame->f_locals = PyDict_New(); - if (locals == NULL) { - return NULL; - } - } - PyObject *hidden = NULL; - - /* If include_hidden, "hidden" fast locals (from inlined comprehensions in - module/class scopes) will be included in the returned dict, but not in - frame->f_locals; the returned dict will be a modified copy. Non-hidden - locals will still be updated in frame->f_locals. */ - if (include_hidden) { - hidden = PyDict_New(); - if (hidden == NULL) { - return NULL; - } - } - - frame_init_get_vars(frame, free_vars_copied); - - PyCodeObject *co = F_CODE(frame); - for (int i = 0; i < co->co_nlocalsplus; i++) { - PyObject *value; // borrowed reference - if (!frame_get_var(frame, co, i, &value)) { - continue; - } - - PyObject *name = PyTuple_GET_ITEM(co->co_localsplusnames, i); - _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); - if (kind & CO_FAST_HIDDEN) { - if (include_hidden && value != NULL) { - if (PyObject_SetItem(hidden, name, value) != 0) { - goto error; - } - } - continue; - } - if (value == NULL) { - if (PyObject_DelItem(locals, name) != 0) { - if (PyErr_ExceptionMatches(PyExc_KeyError)) { - PyErr_Clear(); - } - else { - goto error; - } - } - } - else { - if (PyObject_SetItem(locals, name, value) != 0) { - goto error; - } - } - } - - if (include_hidden && PyDict_Size(hidden)) { - PyObject *innerlocals = PyDict_New(); - if (innerlocals == NULL) { - goto error; - } - if (PyDict_Merge(innerlocals, locals, 1) != 0) { - Py_DECREF(innerlocals); - goto error; - } - if (PyDict_Merge(innerlocals, hidden, 1) != 0) { - Py_DECREF(innerlocals); - goto error; - } - locals = innerlocals; - } - else { - Py_INCREF(locals); - } - Py_CLEAR(hidden); - - return locals; - - error: - Py_XDECREF(hidden); - return NULL; -} - -// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1301 int THP_PyFrame_FastToLocalsWithError(_PyInterpreterFrame *frame, int *free_vars_copied) { - PyObject *locals = THP_PyFrame_GetLocals(frame, 0, free_vars_copied); - if (locals == NULL) { - return -1; - } - Py_DECREF(locals); + // functionality moved to framelocals_mapping.cpp return 0; } diff --git a/torch/csrc/dynamo/cpython_defs.h b/torch/csrc/dynamo/cpython_defs.h index f0bd61b9167f..cea7907be49a 100644 --- a/torch/csrc/dynamo/cpython_defs.h +++ b/torch/csrc/dynamo/cpython_defs.h @@ -6,19 +6,9 @@ // should go in cpython_defs.c. Copying is required when, e.g., // we need to call internal CPython functions that are not exposed. -#if IS_PYTHON_3_13_PLUS -#define F_CODE(x) ((PyCodeObject*)(x)->f_executable) -#define PREV_INSTR(x) (x)->instr_ptr -#else -#define F_CODE(x) ((PyCodeObject*)(x)->f_code) -#define PREV_INSTR(x) (x)->prev_instr -#endif - #if IS_PYTHON_3_11_PLUS -#define Py_BUILD_CORE -#include -#undef Py_BUILD_CORE +typedef struct _PyInterpreterFrame _PyInterpreterFrame; int THP_PyFrame_FastToLocalsWithError( _PyInterpreterFrame* frame, diff --git a/torch/csrc/dynamo/cpython_includes.h b/torch/csrc/dynamo/cpython_includes.h new file mode 100644 index 000000000000..6b99c1d5aec8 --- /dev/null +++ b/torch/csrc/dynamo/cpython_includes.h @@ -0,0 +1,49 @@ +#pragma once + +#include + +// Problem in CPython includes when mixing core and non-core build +// The fix was not backported to 3.12 so this is needed here +// https://github.com/python/cpython/issues/105268 +#if IS_PYTHON_3_12_PLUS +#undef _PyGC_FINALIZED +#endif + +// see https://bugs.python.org/issue35886 +#if PY_VERSION_HEX >= 0x03080000 +#define Py_BUILD_CORE + +#ifndef __cplusplus +// C-only headers +#include + +#endif // __cplusplus + +#if IS_PYTHON_3_11_PLUS +#include +#endif + +#undef Py_BUILD_CORE +#endif // PY_VERSION_HEX >= 0x03080000 + +#ifdef __cplusplus +extern "C" { +#endif + +#if IS_PYTHON_3_13_PLUS +#define F_CODE(x) ((PyCodeObject*)(x)->f_executable) +#define PREV_INSTR(x) (x)->instr_ptr +#else +#define F_CODE(x) ((PyCodeObject*)(x)->f_code) +#define PREV_INSTR(x) (x)->prev_instr +#endif + +#if IS_PYTHON_3_12_PLUS +#define FUNC(x) ((x)->f_funcobj) +#else +#define FUNC(x) ((x)->f_func) +#endif + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/torch/csrc/dynamo/debug_macros.h b/torch/csrc/dynamo/debug_macros.h index 2a05938cee6d..90ddcb457ad9 100644 --- a/torch/csrc/dynamo/debug_macros.h +++ b/torch/csrc/dynamo/debug_macros.h @@ -1,6 +1,14 @@ #pragma once +#ifdef __cplusplus +#include +#else #include +#endif + +#ifdef __cplusplus +extern "C" { +#endif #ifdef _WIN32 #define unlikely(x) (x) @@ -44,3 +52,7 @@ #define DEBUG_TRACE0(msg) #endif + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/torch/csrc/dynamo/eval_frame.c b/torch/csrc/dynamo/eval_frame.c index abf2f4c88c85..35032f077708 100644 --- a/torch/csrc/dynamo/eval_frame.c +++ b/torch/csrc/dynamo/eval_frame.c @@ -2,14 +2,14 @@ #include #include #include +#include #include #include +#include #include #include #include - - PyObject* guard_error_hook = NULL; const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; @@ -33,26 +33,6 @@ inline static void eval_frame_callback_set(PyObject* obj) { // 3.14 Not supported at all. See cpython_defs.c for hints #if !(IS_PYTHON_3_14_PLUS) -// Problem in CPython includes when mixing core and non-core build -// The fix was not backported to 3.12 so this is needed here -// https://github.com/python/cpython/issues/105268 -#if IS_PYTHON_3_12_PLUS -#undef _PyGC_FINALIZED -#endif - -// see https://bugs.python.org/issue35886 -#if PY_VERSION_HEX >= 0x03080000 -#define Py_BUILD_CORE -#include - -// These headers were added in 3.11 -#if IS_PYTHON_3_11_PLUS -#include -#endif - -#undef Py_BUILD_CORE -#endif // PY_VERSION_HEX >= 0x03080000 - // All the eval APIs change in 3.11 so we need to decide which one to use on the fly // https://docs.python.org/3/c-api/init.html#c._PyFrameEvalFunction #if IS_PYTHON_3_11_PLUS @@ -64,6 +44,7 @@ inline static void eval_frame_callback_set(PyObject* obj) { typedef struct THPPyInterpreterFrame { PyObject_HEAD _PyInterpreterFrame* frame; // Borrowed reference + PyObject* locals; } THPPyInterpreterFrame; THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame); @@ -80,14 +61,22 @@ DECLARE_PYOBJ_ATTR(f_funcobj) #else DECLARE_PYOBJ_ATTR(f_func) #endif + DECLARE_PYOBJ_ATTR(f_globals) DECLARE_PYOBJ_ATTR(f_builtins) -DECLARE_PYOBJ_ATTR(f_locals) + +static PyObject* THPPyInterpreterFrame_f_locals(THPPyInterpreterFrame* self, PyObject* _noargs) { + DEBUG_NULL_CHECK(self->locals); + Py_XINCREF(self->locals); + return self->locals; +} + #if IS_PYTHON_3_13_PLUS DECLARE_PYOBJ_ATTR(f_executable) #else DECLARE_PYOBJ_ATTR(f_code) #endif + DECLARE_PYOBJ_ATTR(frame_obj) #undef DECLARE_PYOBJ_ATTR @@ -158,6 +147,7 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) { if (!self) return NULL; self->frame = frame; + self->locals = NULL; return self; } @@ -261,6 +251,7 @@ inline static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) { static inline PyObject* call_callback( PyObject* callable, THP_EVAL_API_FRAME_OBJECT* _frame, + PyObject* locals, CacheEntry* cache_entry, FrameState* frame_state) { @@ -271,6 +262,7 @@ static inline PyObject* call_callback( if (frame == NULL) { return NULL; } + frame->locals = locals; #else PyObject* frame = Py_NewRef(_frame); #endif @@ -531,6 +523,7 @@ static PyObject* _custom_eval_frame_shim( // we are responsible for clearing it - via clear_old_frame_if_python_312_plus. // The should_clear_frame flag is used to indicate whether the frame should be // cleared by _custom_eval_frame's caller. +// Generally should_clear_frame should be set if and only we don't eval_frame_default. static PyObject* _custom_eval_frame( PyThreadState* tstate, THP_EVAL_API_FRAME_OBJECT* frame, @@ -589,13 +582,19 @@ static PyObject* _custom_eval_frame( extra = init_and_set_extra_state(F_CODE(frame)); } - // TODO(jansel): investigate directly using the "fast" representation + int free_vars_copied = 0; + #if IS_PYTHON_3_12_PLUS + PyObject *locals = get_framelocals_mapping(frame); + #else if (THP_PyFrame_FastToLocalsWithError(frame, &free_vars_copied) < 0) { DEBUG_TRACE("error %s", get_frame_name(frame)); *should_clear_frame = 1; return NULL; } + PyObject *locals = frame->f_locals; + Py_INCREF(locals); + #endif PyObject* backend = get_backend(callback); @@ -604,9 +603,11 @@ static PyObject* _custom_eval_frame( if (callback == Py_False) { DEBUG_TRACE("In run only mode %s", get_frame_name(frame)); _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str); - PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend); + PyObject* maybe_cached_code = lookup(extra, locals, backend); _pytorch_record_function_exit(rf); + Py_DECREF(locals); + if (maybe_cached_code == NULL) { // guard eval failed, keep propagating *should_clear_frame = 1; @@ -619,9 +620,9 @@ static PyObject* _custom_eval_frame( // used cached version DEBUG_TRACE("cache hit %s", get_frame_name(frame)); *should_clear_frame = 1; - return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied); + return eval_custom_code(tstate, frame, cached_code, throw_flag, 0); } - DEBUG_CHECK(PyDict_CheckExact(frame->f_locals)); + DEBUG_CHECK(PyDict_CheckExact(locals)); DEBUG_CHECK(PyDict_CheckExact(frame->f_globals)); DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins)); @@ -631,11 +632,12 @@ static PyObject* _custom_eval_frame( eval_frame_callback_set(Py_None); _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str); - PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend); + PyObject* maybe_cached_code = lookup(extra, locals, backend); _pytorch_record_function_exit(rf); if (maybe_cached_code == NULL) { // Python error *should_clear_frame = 1; + Py_DECREF(locals); return NULL; } else if (maybe_cached_code != Py_None) { PyCodeObject* cached_code = (PyCodeObject*)maybe_cached_code; @@ -644,13 +646,15 @@ static PyObject* _custom_eval_frame( // Re-enable custom behavior eval_frame_callback_set(callback); *should_clear_frame = 1; + Py_DECREF(locals); return eval_custom_code(tstate, frame, cached_code, throw_flag, free_vars_copied); } // cache miss CacheEntry* cache_entry = extract_cache_entry(extra); FrameState* frame_state = extract_frame_state(extra); PyObject* result = - call_callback(callback, frame, cache_entry, frame_state); + call_callback(callback, frame, locals, cache_entry, frame_state); + Py_DECREF(locals); if (result == NULL) { // internal exception, returning here will leak the exception into user code // this is useful for debugging -- but we dont want it to happen outside of diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 7c9b4be0009b..01d29eab5197 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include diff --git a/torch/csrc/dynamo/framelocals_mapping.cpp b/torch/csrc/dynamo/framelocals_mapping.cpp new file mode 100644 index 000000000000..3a11c88af980 --- /dev/null +++ b/torch/csrc/dynamo/framelocals_mapping.cpp @@ -0,0 +1,71 @@ +#include + +#if IS_PYTHON_3_12_PLUS +#include +#include +#include +#include + +#include + +// Our own version of PyFrame_GetLocals. +// Also combines functionality from frame_init_get_vars and frame_get_var. +// PyFrame_GetLocals: +// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1213 +// frame_init_get_vars: +// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1136 +// frame_get_var: +// https://github.com/python/cpython/blob/0325a8a8cdba6c091bcbbb3c995f3bf1d1217012/Objects/frameobject.c#L1162 +// PyFrame_GetLocals returns the frame locals dict. +// frame_init_get_vars initializes free variables from the closure. +// frame_get_var fetches the variable value from the frame given the index +// NOTE: hidden variables are not included. +// Returns a new reference. +PyObject* get_framelocals_mapping(_PyInterpreterFrame* frame) { + if (!frame->stacktop) { + return py::dict().release().ptr(); + } + + PyCodeObject* co = F_CODE(frame); + py::dict mapping; + + auto update_mapping = [&](int i, PyObject* value) { + _PyLocals_Kind kind = _PyLocals_GetKind(co->co_localspluskinds, i); + + if (kind & CO_FAST_FREE && !(co->co_flags & CO_OPTIMIZED)) { + return; + } + if (kind & CO_FAST_HIDDEN) { + return; + } + + if (kind & CO_FAST_FREE) { + CHECK(value != nullptr && PyCell_Check(value)); + value = PyCell_GET(value); + } + + if (value != nullptr) { + py::str name = + py::cast(PyTuple_GET_ITEM(co->co_localsplusnames, i)); + mapping[name] = py::cast(value); + } + }; + + int offset = co->co_nlocalsplus - co->co_nfreevars; + for (int i = 0; i < offset; i++) { + update_mapping(i, frame->localsplus[i]); + } + // Get references to closure variables + PyObject* closure = ((PyFunctionObject*)FUNC(frame))->func_closure; + for (int i = 0; i < co->co_nfreevars; ++i) { + update_mapping(offset + i, PyTuple_GET_ITEM(closure, i)); + } + + // NOTE no need to move the instruction pointer to after COPY_FREE_VARS + // since we don't actually copy free vars from the closure to the frame + // localsplus. + + return mapping.release().ptr(); +} + +#endif diff --git a/torch/csrc/dynamo/framelocals_mapping.h b/torch/csrc/dynamo/framelocals_mapping.h new file mode 100644 index 000000000000..22f29e5657d9 --- /dev/null +++ b/torch/csrc/dynamo/framelocals_mapping.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +#if IS_PYTHON_3_12_PLUS +typedef struct _PyInterpreterFrame _PyInterpreterFrame; +PyObject* get_framelocals_mapping(_PyInterpreterFrame* frame); +#endif + +#ifdef __cplusplus +} // extern "C" +#endif