[dynamo] refactor CacheEntry and ExtraState to eval_frame.c to C++ (#118438)

Part of implementing CacheEntry invalidation to fix https://github.com/pytorch/pytorch/issues/112090.

Changes:
- Move CacheEntry and ExtraState to C++
- Use pybind to control reference counting
- Use std::list instead of manually implementing a linked list

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118438
Approved by: https://github.com/jansel
This commit is contained in:
William Wen
2024-02-05 23:39:22 -08:00
committed by PyTorch MergeBot
parent 73f0fdea5b
commit ae4e866bba
13 changed files with 507 additions and 457 deletions

View File

@ -807,9 +807,11 @@ libtorch_python_core_sources = [
"torch/csrc/autograd/python_variable.cpp", "torch/csrc/autograd/python_variable.cpp",
"torch/csrc/autograd/python_variable_indexing.cpp", "torch/csrc/autograd/python_variable_indexing.cpp",
"torch/csrc/dynamo/python_compiled_autograd.cpp", "torch/csrc/dynamo/python_compiled_autograd.cpp",
"torch/csrc/dynamo/cache_entry.cpp",
"torch/csrc/dynamo/cpp_shim.cpp", "torch/csrc/dynamo/cpp_shim.cpp",
"torch/csrc/dynamo/cpython_defs.c", "torch/csrc/dynamo/cpython_defs.c",
"torch/csrc/dynamo/eval_frame.c", "torch/csrc/dynamo/eval_frame.c",
"torch/csrc/dynamo/extra_state.cpp",
"torch/csrc/dynamo/guards.cpp", "torch/csrc/dynamo/guards.cpp",
"torch/csrc/dynamo/init.cpp", "torch/csrc/dynamo/init.cpp",
"torch/csrc/functorch/init.cpp", "torch/csrc/functorch/init.cpp",

View File

@ -9394,6 +9394,57 @@ fn
self.assertIn(0, result) self.assertIn(0, result)
self.assertTrue(same(result[0], torch.tensor(3))) self.assertTrue(same(result[0], torch.tensor(3)))
def test_dynamo_reset_clears_cache(self):
"""Test that dynamo bytecode cache is freed
when dynamo reset is called
"""
def fn(x):
return torch.sin(x)
opt_fn = torch.compile(backend="eager")(fn)
opt_fn(torch.randn(3, 3))
c1 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c1), 1)
torch._dynamo.reset()
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c2), 0)
def test_dynamo_cache_move_to_front(self):
class Mod(torch.nn.Module):
def __init__(self):
super(Mod, self).__init__()
self.fc = torch.nn.Linear(3, 3)
def forward(self, out):
return self.fc(out)
def fn(x, mod):
return mod(x)
opt_fn = torch.compile(fn, backend="eager")
m1 = Mod()
m2 = Mod()
m3 = Mod()
inp = torch.randn(3, 3)
# NOTE: assumes that each cache entry is guarded
# on unique Mod instance
opt_fn(inp, m1)
opt_fn(inp, m2)
opt_fn(inp, m3)
c1 = _debug_get_cache_entry_list(fn.__code__)
self.assertEqual(len(c1), 3)
# move cache entry to front
opt_fn(inp, m2)
c2 = _debug_get_cache_entry_list(fn.__code__)
self.assertIs(c1[1], c2[0])
class TestTracer(JitTestCase): class TestTracer(JitTestCase):
def test_jit_save(self): def test_jit_save(self):

View File

@ -1,5 +1,5 @@
import types import types
from typing import NewType, Optional from typing import List, NewType, Optional
from torch._dynamo.types import DynamoCallback, DynamoGuardHook from torch._dynamo.types import DynamoCallback, DynamoGuardHook
@ -11,7 +11,6 @@ def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
def reset_code(code: types.CodeType) -> None: ... def reset_code(code: types.CodeType) -> None: ...
def unsupported(obj1: object, obj2: object) -> object: ... def unsupported(obj1: object, obj2: object) -> object: ...
def skip_code(code: types.CodeType) -> None: ... def skip_code(code: types.CodeType) -> None: ...
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ... def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
class _CacheEntry: class _CacheEntry:
@ -19,4 +18,4 @@ class _CacheEntry:
code: types.CodeType code: types.CodeType
next: Optional[_CacheEntry] next: Optional[_CacheEntry]
def _debug_get_cache_entry_list(code: types.CodeType) -> Optional[_CacheEntry]: ... def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...

View File

@ -186,12 +186,7 @@ def _debug_get_cache_entry_list(
""" """
if callable(code): if callable(code):
code = code.__code__ code = code.__code__
cache_head = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code) return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
cache_list = []
while cache_head is not None:
cache_list.append(cache_head)
cache_head = cache_head.next
return cache_list
class OptimizedModule(torch.nn.Module): class OptimizedModule(torch.nn.Module):

View File

@ -0,0 +1,30 @@
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/extra_state.h>
CacheEntry::CacheEntry(const py::handle& guarded_code) {
this->check_fn = guarded_code.attr("check_fn");
this->code = guarded_code.attr("code");
}
py::object CacheEntry::next() {
NULL_CHECK(this->_owner);
auto it = this->_owner_loc;
++it;
if (it == this->_owner->cache_entry_list.end()) {
return py::none();
}
return py::cast(*it, py::return_value_policy::reference);
}
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
return (PyCodeObject*)e->code.ptr();
}
PyObject* CacheEntry_to_obj(CacheEntry* e) {
if (!e) {
return py::none().release().ptr();
}
return py::cast(e, py::return_value_policy::reference).release().ptr();
}

View File

@ -0,0 +1,68 @@
#pragma once
#include <Python.h>
#ifdef __cplusplus
#include <torch/csrc/dynamo/utils.h>
#include <torch/csrc/utils/pybind.h>
#include <list>
namespace py = pybind11;
extern "C" {
#endif
/*
Our cache resides on the extra scratch space of the code object. The structure
of the cache is as follows:
-> ExtraState
-> CacheEntry (list)
-> check_fn
-> code
-> FrameState
CacheEntry is a linked list node containing the check_fn for guards
and the optimized code.
The FrameState is a PyDict that enables sharing between different frames. This
is used to detect dynamism in automatic dynamic shapes.
These two are encapsulated into a ExtraState.
*/
typedef struct CacheEntry CacheEntry;
typedef struct ExtraState ExtraState;
#ifdef __cplusplus
typedef struct VISIBILITY_HIDDEN CacheEntry {
// check the guards: lambda: <locals of user function>: bool
py::object check_fn;
// modified user bytecode (protected by check_fn's guards)
py::object code;
// Reference to owning ExtraState
ExtraState* _owner{nullptr};
// Reference to this CacheEntry's location in owner's linked list
std::list<CacheEntry>::iterator _owner_loc;
CacheEntry(const py::handle& guarded_code);
// Warning: returns a reference whose lifetime is controlled by C++
py::object next();
} CacheEntry;
#endif
// Returns borrowed reference
PyCodeObject* CacheEntry_get_code(CacheEntry* e);
// Returns a borrowed reference to CacheEntry as a PyObject
// Warning: lifetime is controlled by C++
PyObject* CacheEntry_to_obj(CacheEntry* e);
#ifdef __cplusplus
} // extern "C"
#endif

View File

@ -0,0 +1,46 @@
#pragma once
#include <stdio.h>
#ifdef _WIN32
#define unlikely(x) (x)
#else
#define unlikely(x) __builtin_expect((x), 0)
#endif
#define NULL_CHECK(val) \
if (unlikely((val) == NULL)) { \
fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
PyErr_Print(); \
abort(); \
} else { \
}
// CHECK might be previously declared
#undef CHECK
#define CHECK(cond) \
if (unlikely(!(cond))) { \
fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
abort(); \
} else { \
}
// Uncomment next line to print debug message
// #define TORCHDYNAMO_DEBUG 1
#ifdef TORCHDYNAMO_DEBUG
#define DEBUG_CHECK(cond) CHECK(cond)
#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
#define DEBUG_TRACE0(msg) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
#else
#define DEBUG_CHECK(cond)
#define DEBUG_NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...)
#define DEBUG_TRACE0(msg)
#endif

View File

@ -1,6 +1,9 @@
#define PY_SSIZE_T_CLEAN #define PY_SSIZE_T_CLEAN
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/cpp_shim.h> #include <torch/csrc/dynamo/cpp_shim.h>
#include <torch/csrc/dynamo/cpython_defs.h> #include <torch/csrc/dynamo/cpython_defs.h>
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/utils/python_compat.h> #include <torch/csrc/utils/python_compat.h>
#include <opcode.h> #include <opcode.h>
#include <stdbool.h> #include <stdbool.h>
@ -132,57 +135,9 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError #define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
#endif #endif
#ifdef _WIN32 PyObject* guard_error_hook = NULL;
#define unlikely(x) (x)
#else
#define unlikely(x) __builtin_expect((x), 0)
#endif
#define NULL_CHECK(val) \
if (unlikely((val) == NULL)) { \
fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
PyErr_Print(); \
abort(); \
} else { \
}
#define CHECK(cond) \
if (unlikely(!(cond))) { \
fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
abort(); \
} else { \
}
// Uncomment next line to print debug message
// #define TORCHDYNAMO_DEBUG 1
#ifdef TORCHDYNAMO_DEBUG
#define DEBUG_CHECK(cond) CHECK(cond)
#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
#define DEBUG_TRACE0(msg) \
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
#else
#define DEBUG_CHECK(cond)
#define DEBUG_NULL_CHECK(val)
#define DEBUG_TRACE(msg, ...)
#define DEBUG_TRACE0(msg)
#endif
// Flag to just run a frame normally
#define SKIP_CODE ((void*)0x1)
static PyObject* guard_error_hook = NULL;
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup"; const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
// Points to the extra scratch space on the code object
static Py_ssize_t extra_index = -1;
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT; static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
inline static PyObject* eval_frame_callback_get(void) { inline static PyObject* eval_frame_callback_get(void) {
@ -284,323 +239,6 @@ inline static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
return PyUnicode_AsUTF8(frame->f_code->co_name); return PyUnicode_AsUTF8(frame->f_code->co_name);
} }
typedef PyObject FrameState;
/*
Our cache resides on the extra scratch space of the code object. The structure
of the cache is as follows:
-> ExtraState
-> CacheEntry
-> check_fn
-> optimized_code
-> next
-> FrameState
CacheEntry is a linked list, with each node containing the check_fn for guards
and the optimized code.
The frame_state is a PyDict that enables sharing between different frames. This
is used to detect dynamism in automatic dynamic shapes.
These two are encapsulated into a ExtraState.
*/
// Linked list of cache entries, where each cache entry stores
// the check_fn and the torch.compile optimized python bytecode.
typedef struct cache_entry {
PyObject_HEAD
// check the guards: lambda: <locals of user function>: bool
PyObject* check_fn;
// modified user bytecode (protected by check_fn's guards)
PyCodeObject* code;
// on a cache miss, linked list of next thing to try
struct cache_entry* next;
} CacheEntry;
static void cache_entry_dealloc(CacheEntry* e);
#define DECLARE_CACHE_ENTRY_ATTR(name) \
static PyObject* CacheEntry_##name(CacheEntry* self, PyObject* _noargs) { \
PyObject* res = (PyObject*)self->name; \
Py_INCREF(res); \
return res; \
}
DECLARE_CACHE_ENTRY_ATTR(check_fn)
DECLARE_CACHE_ENTRY_ATTR(code)
DECLARE_CACHE_ENTRY_ATTR(next)
static struct PyGetSetDef CacheEntry_properties[] = {
{"check_fn", (getter)CacheEntry_check_fn, NULL, NULL, NULL},
{"code", (getter)CacheEntry_code, NULL, NULL, NULL},
{"next", (getter)CacheEntry_next, NULL, NULL, NULL},
{NULL}};
static PyObject* cache_entry_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
CacheEntry *self = (CacheEntry*) type->tp_alloc(type, 0);
if (self != NULL) {
// The corresponding decrefs for Py_None are in cache_entry_init.
Py_INCREF(Py_None);
self->check_fn = Py_None;
Py_INCREF(Py_None);
self->code = (PyCodeObject*)Py_None;
Py_INCREF(Py_None);
self->next = (CacheEntry*)Py_None;
}
return (PyObject*)self;
}
static int cache_entry_init(CacheEntry* self, PyObject* args, PyObject* kwds) {
PyObject* check_fn = NULL;
PyCodeObject* code = NULL;
CacheEntry* next = NULL;
static char *kwlist[] = {"check_fn", "code", "next", NULL};
int ret = PyArg_ParseTupleAndKeywords(
args, kwds, "OOO", kwlist,
&check_fn, &code, &next);
if (!ret) return -1;
if (check_fn) {
PyObject* tmp = self->check_fn;
Py_INCREF(check_fn);
self->check_fn = check_fn;
Py_XDECREF(tmp);
}
if (code) {
PyCodeObject* tmp = self->code;
Py_INCREF(code);
self->code = code;
Py_XDECREF(tmp);
}
if (next) {
CacheEntry* tmp = self->next;
Py_INCREF(next);
self->next = next;
Py_XDECREF(tmp);
}
return 0;
}
static PyTypeObject CacheEntryType = {
PyVarObject_HEAD_INIT(NULL, 0)
.tp_name = "torch._C.dynamo.eval_frame.CacheEntryWrapper",
.tp_basicsize = sizeof(CacheEntry),
.tp_itemsize = 0,
.tp_flags = Py_TPFLAGS_DEFAULT,
.tp_new = cache_entry_new,
.tp_init = (initproc)cache_entry_init,
.tp_dealloc = (destructor)cache_entry_dealloc,
.tp_getset = CacheEntry_properties,
};
// ExtraState encasulates CacheEntry and FrameState. ExtraState is the highest
// level of abstraction of what is stored on the extra code object. Previously,
// we saved different parts on different extra indexes. We prefer this way
// because of cleaner abstraction and faster SetExtra access.
// TODO(anijain2305) - Consider making this a PyObject. Benefits are
// 1) Modular dealloc - destroy_extra_state just becomes Py_DECREF(extra)
// 2) We can directly send the extra object to convert_frame callback. One
// data structure - easier to understand code.
// There might be some perf impact of going through a PyObject on the critical
// path, but it should not be too bad.
typedef struct {
// Cache entry for the code object
CacheEntry* cache_entry;
// Frame state to detect dynamic shape dims
FrameState* frame_state;
} ExtraState;
/* CacheEntry helper functions begins */
static CacheEntry* create_cache_entry(
CacheEntry* next,
PyObject* guarded_code) {
// Ownership contract
// args
// - next: steals
// - guarded_code: Borrowed
// return
// - CacheEntry*: new reference.
PyObject* check_fn = PyObject_GetAttrString(guarded_code, "check_fn"); // new reference
PyCodeObject* code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code"); // new reference
// equivalent to CacheEntry(check_fn, code, next) in Python
PyObject* args = Py_BuildValue("OOO", check_fn, code, next);
CacheEntry* e = (CacheEntry*)PyObject_CallObject((PyObject*)&CacheEntryType, args); // new reference
// CacheEntry e is the now the owner of old cachey entry next. This happens
// when we incref the next pointer in cache_entry_init.
Py_DECREF(next);
Py_DECREF(check_fn);
Py_DECREF(code);
Py_DECREF(args);
return e;
}
static void cache_entry_dealloc(CacheEntry* e) {
Py_XDECREF(e->check_fn);
Py_XDECREF(e->code);
// This will recursively call cache_entry_dealloc for the next items in the
// linked list.
Py_XDECREF(e->next);
Py_TYPE(e)->tp_free((PyObject*)e);
}
/* CacheEntry helper functions ends */
/* Extractions helper functions begins. They help with NULL and SKIP_CODE corner cases */
inline static CacheEntry* extract_cache_entry(ExtraState* extra_state) {
// Helper to extra the cache_entry from the extra state.
// Ownership contract
// args
// - extra_state: Borrowed
// return
// - CacheEntry: Borrowed.
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
}
return extra_state->cache_entry;
}
inline static FrameState* extract_frame_state(ExtraState* extra_state) {
// Returns either the previously stored frame state or an empty dict.
// Ownership contract
// args
// - extra_state: Borrowed
// return
// - extra_state->frame_state: Borrowed.
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
}
return extra_state->frame_state;
}
/* Extractions helper functions ends */
/* Extra state helper functions begins */
inline static ExtraState* get_extra_state(PyCodeObject* code) {
// Ownership contract
// args
// - code: Borrowed
// return
// - extra_state: Borrowed.
ExtraState* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
return extra;
}
inline static void destroy_extra_state(void* obj) {
// This is passed as freefunc to _PyEval_RequestCodeExtraIndex. This acts as a
// deleter for the object on extra scratch space. This function is called
// internally in _PyCode_SetExtra and also during the code deallocation.
// Destroys the extra state by deleting cache_entry, frame state and finally
// freeing the constructed extra state.
// Developer note - You should not call this function directly. This is called
// directly inside set_extra_state. If you are in a situation trying to call
// this function, consider if set_extra_state should be called.
ExtraState* extra = (ExtraState*)obj;
if (extra != NULL && extra != SKIP_CODE) {
// Cpython gc will call cache_entry_dealloc on its own when the ref count
// goes to 0.
Py_XDECREF(extra->cache_entry);
Py_XDECREF(extra->frame_state);
free(extra);
}
}
inline static void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
// Clears the existing object sitting on the extra scratch spance and sets it
// up with the new state. Note that _PyCode_SetExtra calls the
// destroy_extra_state deleter internally, and therefore we don't call it
// explicity here.
// Ownership contract
// args
// - extra_state: Stolen
// return
// - there is no return, but the extra_state is stolen, so it becomes
// set_extra_state responsibility to clean it up. It will be deleted during
// the reset_code/skip, when the set_extra_state is called with
// NULL/SKIP_CODE.
// Invariant - Dont set the extra state for the extra state that is already on
// the code object. Otherwise, we will first free up the old extra state
// (which is also the new extra state) and write something invalid on the
// scratch space.
ExtraState* old_extra_state = get_extra_state(code);
CHECK(old_extra_state == NULL || old_extra_state == SKIP_CODE || old_extra_state != extra_state);
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
}
inline static ExtraState* init_and_set_extra_state(PyCodeObject* code) {
// Creates a new extra state and put it on the extra scrach space of the code
// object.
// Ownership contract
// args
// - code: Borrowed
// return:
// - extra_state: New reference.
// These references are then further passed to set_extra_state which becomes
// the final owner of these references.
// Invariant - Extra state should not have been set before, therefore it should be NULL.
CHECK(get_extra_state(code) == NULL);
ExtraState* extra_state = (ExtraState*)malloc(sizeof(ExtraState));
DEBUG_NULL_CHECK(extra_state);
// We set the last node in the linked list to Py_None. We incref the Py_None
// here, the corresponding decref is in cache_entry_dealloc.
Py_INCREF(Py_None);
extra_state->cache_entry = (CacheEntry*)Py_None;
extra_state->frame_state = PyDict_New();
set_extra_state(code, extra_state);
return extra_state;
}
/* Extra state helper functions ends */
/*
Debugger helper functions.
*/
PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
// get the cache entry out of a code object
PyObject* object = NULL;
if (!PyArg_ParseTuple(args, "O", &object)) {
return NULL;
}
if (!PyCode_Check(object)) {
PyErr_SetString(PyExc_TypeError, "expected a code object!");
return NULL;
}
PyCodeObject* code = (PyCodeObject*)object;
ExtraState* extra = get_extra_state(code);
CacheEntry* current_node = extract_cache_entry(extra);
if (current_node == NULL)
{
Py_RETURN_NONE;
}
Py_INCREF(current_node);
return (PyObject*)current_node;
}
static inline PyObject* call_callback( static inline PyObject* call_callback(
PyObject* callable, PyObject* callable,
THP_EVAL_API_FRAME_OBJECT* _frame, THP_EVAL_API_FRAME_OBJECT* _frame,
@ -618,74 +256,18 @@ static inline PyObject* call_callback(
PyObject* frame = Py_NewRef(_frame); PyObject* frame = Py_NewRef(_frame);
#endif #endif
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
PyObject* res = PyObject_CallFunction( PyObject* res = PyObject_CallFunction(
callable, callable,
"OOO", "OOO",
frame, frame,
cache_entry, cache_entry_pyobj,
frame_state); frame_state);
Py_DECREF(frame); Py_DECREF(frame);
Py_DECREF(cache_entry_pyobj);
return res; return res;
} }
static PyObject* call_guard_fail_hook(
PyObject* hook,
CacheEntry* e,
size_t index,
PyObject* f_locals) {
// call debugging logic when a guard fails
return PyObject_CallFunction(
hook,
"OOOnO",
e->check_fn,
e->code,
f_locals,
(Py_ssize_t)index,
(e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
}
// Return value: borrowed reference
// Is either Py_None or a PyCodeObject
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
if (e == (CacheEntry*)Py_None) {
// NB: intentionally not using Py_RETURN_NONE, to return borrowed ref
return Py_None;
}
PyObject *f_locals = frame->f_locals;
// remember to update the type signature for GuardFn.__call__ in torch/_dynamo/types.py
// if this calling convention changes
PyObject* valid = PyObject_CallOneArg(e->check_fn, f_locals);
if (unlikely(valid == NULL)) {
if (guard_error_hook != NULL) {
PyObject *type = NULL, *value = NULL, *traceback = NULL;
PyErr_Fetch(&type, &value, &traceback);
PyObject* r = call_guard_fail_hook(guard_error_hook, e, index, f_locals);
if (r == NULL) {
return NULL;
}
Py_DECREF(r);
PyErr_Restore(type, value, traceback);
}
return NULL;
}
Py_DECREF(valid);
if (valid == Py_True) {
// Keep the head as the most recently used cache entry.
// If the hit cache entry is not the head of the linked list,
// move it to the head
if (prev != NULL) {
ExtraState* extra = get_extra_state(frame->f_code);
// Override the extra state to reflect the updated cache line.
CacheEntry* old_cache_entry = extra->cache_entry;
prev->next = e->next;
e->next = old_cache_entry;
extra->cache_entry = e;
}
return (PyObject*)e->code;
}
return lookup(e->next, frame, e, index + 1);
}
inline static PyObject* eval_custom_code_impl( inline static PyObject* eval_custom_code_impl(
PyThreadState* tstate, PyThreadState* tstate,
THP_EVAL_API_FRAME_OBJECT* frame, THP_EVAL_API_FRAME_OBJECT* frame,
@ -942,12 +524,7 @@ static PyObject* _custom_eval_frame(
extra = init_and_set_extra_state(frame->f_code); extra = init_and_set_extra_state(frame->f_code);
} }
CacheEntry* cache_entry = extract_cache_entry(extra);
FrameState* frame_state = extract_frame_state(extra);
// TODO(jansel): investigate directly using the "fast" representation // TODO(jansel): investigate directly using the "fast" representation
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
// even though we should pass a PyFrameObject.
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) { if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
DEBUG_TRACE("error %s", get_frame_name(frame)); DEBUG_TRACE("error %s", get_frame_name(frame));
return NULL; return NULL;
@ -958,7 +535,7 @@ static PyObject* _custom_eval_frame(
if (callback == Py_False) { if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", get_frame_name(frame)); DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str); _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0); PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
_pytorch_record_function_exit(rf); _pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) { if (maybe_cached_code == NULL) {
@ -983,7 +560,7 @@ static PyObject* _custom_eval_frame(
eval_frame_callback_set(Py_None); eval_frame_callback_set(Py_None);
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str); _PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0); PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
_pytorch_record_function_exit(rf); _pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) { if (maybe_cached_code == NULL) {
// Python error // Python error
@ -997,8 +574,8 @@ static PyObject* _custom_eval_frame(
return eval_custom_code(tstate, frame, cached_code, throw_flag); return eval_custom_code(tstate, frame, cached_code, throw_flag);
} }
// cache miss // cache miss
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame CacheEntry* cache_entry = extract_cache_entry(extra);
// that gets re-interpreted as a PyObject (which it is NOT!) FrameState* frame_state = extract_frame_state(extra);
PyObject* result = PyObject* result =
call_callback(callback, frame, cache_entry, frame_state); call_callback(callback, frame, cache_entry, frame_state);
if (result == NULL) { if (result == NULL) {
@ -1017,7 +594,7 @@ static PyObject* _custom_eval_frame(
// extract_cache_entry returns a borrowed reference. Modifying a borrowed // extract_cache_entry returns a borrowed reference. Modifying a borrowed
// reference seems wrong. Therefore, we directly access the // reference seems wrong. Therefore, we directly access the
// extra->cache_entry. extra wont be NULL here. // extra->cache_entry. extra wont be NULL here.
extra->cache_entry = create_cache_entry(extra->cache_entry, result); CacheEntry* new_cache_entry = create_cache_entry(extra, result);
Py_DECREF(result); Py_DECREF(result);
// Update the existing cache_entry on the extra object. This extra object is // Update the existing cache_entry on the extra object. This extra object is
// sitting on the extra scratch space, we are just changing the cache_entry // sitting on the extra scratch space, we are just changing the cache_entry
@ -1025,7 +602,7 @@ static PyObject* _custom_eval_frame(
// will be cleaned up when set_extra_state is called. // will be cleaned up when set_extra_state is called.
// Re-enable custom behavior // Re-enable custom behavior
eval_frame_callback_set(callback); eval_frame_callback_set(callback);
return eval_custom_code(tstate, frame, extra->cache_entry->code, throw_flag); return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag);
} else { } else {
DEBUG_TRACE("create skip %s", get_frame_name(frame)); DEBUG_TRACE("create skip %s", get_frame_name(frame));
Py_DECREF(result); Py_DECREF(result);
@ -1144,7 +721,6 @@ static PyMethodDef _methods[] = {
{"unsupported", unsupported, METH_VARARGS, NULL}, {"unsupported", unsupported, METH_VARARGS, NULL},
{"skip_code", skip_code, METH_O, NULL}, {"skip_code", skip_code, METH_O, NULL},
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL}, {"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
{"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
{NULL, NULL, 0, NULL}}; {NULL, NULL, 0, NULL}};
static struct PyModuleDef _module = { static struct PyModuleDef _module = {
@ -1184,15 +760,5 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
} }
#endif #endif
if (PyType_Ready(&CacheEntryType) < 0) {
return NULL;
}
Py_INCREF(&CacheEntryType);
if (PyModule_AddObject(module, "_CacheEntry", (PyObject *) &CacheEntryType) < 0) {
Py_DECREF(&CacheEntryType);
return NULL;
}
return module; return module;
} }

View File

@ -0,0 +1,129 @@
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/debug_macros.h>
Py_ssize_t extra_index = -1;
CacheEntry* ExtraState::get_first_entry() {
if (this->cache_entry_list.empty()) {
return NULL;
}
return &this->cache_entry_list.front();
}
void ExtraState::move_to_front(CacheEntry* cache_entry) {
CHECK(cache_entry->_owner == this);
CHECK(!this->cache_entry_list.empty());
CHECK(cache_entry == &*cache_entry->_owner_loc);
this->cache_entry_list.splice(
this->cache_entry_list.begin(),
this->cache_entry_list,
cache_entry->_owner_loc);
}
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
}
return extra_state->get_first_entry();
}
FrameState* extract_frame_state(ExtraState* extra_state) {
if (extra_state == NULL || extra_state == SKIP_CODE) {
return NULL;
}
return (FrameState*)extra_state->frame_state.ptr();
}
ExtraState* get_extra_state(PyCodeObject* code) {
ExtraState* extra = NULL;
_PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
return extra;
}
void destroy_extra_state(void* obj) {
ExtraState* extra = (ExtraState*)obj;
if (extra != NULL && extra != SKIP_CODE) {
delete extra;
}
}
void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
ExtraState* old_extra_state = get_extra_state(code);
CHECK(
old_extra_state == NULL || old_extra_state == SKIP_CODE ||
old_extra_state != extra_state);
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
}
ExtraState* init_and_set_extra_state(PyCodeObject* code) {
// Invariant - Extra state should not have been set before, therefore it
// should be NULL.
CHECK(get_extra_state(code) == NULL);
ExtraState* extra_state = new ExtraState();
NULL_CHECK(extra_state);
set_extra_state(code, extra_state);
return extra_state;
}
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
py::object valid = py::none();
try {
valid = cache_entry.check_fn(locals);
} catch (py::error_already_set& e) {
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.check_fn,
cache_entry.code,
locals,
index,
index == extra_state->cache_entry_list.size() - 1);
}
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
return NULL;
}
if (valid.cast<bool>()) {
found = &cache_entry;
break;
}
++index;
}
if (found) {
extra_state->move_to_front(found);
return found->code.ptr();
}
return py::none().ptr();
}
CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guarded_code) {
extra_state->cache_entry_list.emplace_front(guarded_code);
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
return &*new_iter;
}
py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
throw py::type_error("expected a code object!");
}
PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
ExtraState* extra = get_extra_state(code);
py::list result;
if (extra) {
for (CacheEntry& e : extra->cache_entry_list) {
result.append(py::cast(e, py::return_value_policy::reference));
}
}
return result;
}

View File

@ -0,0 +1,145 @@
#pragma once
#include <Python.h>
#ifdef __cplusplus
#include <torch/csrc/dynamo/utils.h>
#include <torch/csrc/utils/pybind.h>
#include <list>
namespace py = pybind11;
extern "C" {
#endif
// Flag to just run a frame normally
#define SKIP_CODE ((void*)0x1)
// Points to the extra scratch space on the code object
extern Py_ssize_t extra_index;
// function to call when cache lookup errors
extern PyObject* guard_error_hook;
typedef PyObject FrameState;
typedef struct CacheEntry CacheEntry;
// ExtraState encasulates CacheEntry and FrameState. ExtraState is the highest
// level of abstraction of what is stored on the extra code object. Previously,
// we saved different parts on different extra indexes. We prefer this way
// because of cleaner abstraction and faster SetExtra access.
#ifdef __cplusplus
typedef struct VISIBILITY_HIDDEN ExtraState {
// List of cache entries for compiled code objects
std::list<CacheEntry> cache_entry_list;
// Frame state to detect dynamic shape dims
py::dict frame_state;
CacheEntry* get_first_entry();
void move_to_front(CacheEntry* cache_entry);
} ExtraState;
#else
typedef struct ExtraState ExtraState;
#endif
// Helper to extra the cache_entry from the extra state.
// Ownership contract
// args
// - extra_state: Borrowed
// return
// - CacheEntry: Borrowed.
CacheEntry* extract_cache_entry(ExtraState* extra_state);
// Returns either the previously stored frame state or an empty dict.
// Ownership contract
// args
// - extra_state: Borrowed
// return
// - extra_state->frame_state: Borrowed.
FrameState* extract_frame_state(ExtraState* extra_state);
// Ownership contract
// args
// - code: Borrowed
// return
// - extra_state: Borrowed.
ExtraState* get_extra_state(PyCodeObject* code);
// This is passed as freefunc to _PyEval_RequestCodeExtraIndex. This acts as a
// deleter for the object on extra scratch space. This function is called
// internally in _PyCode_SetExtra and also during the code deallocation.
// Destroys the extra state by deleting cache_entry, frame state and finally
// freeing the constructed extra state.
// Developer note - You should not call this function directly. This is called
// directly inside set_extra_state. If you are in a situation trying to call
// this function, consider if set_extra_state should be called.
void destroy_extra_state(void* obj);
// Clears the existing object sitting on the extra scratch spance and sets it
// up with the new state. Note that _PyCode_SetExtra calls the
// destroy_extra_state deleter internally, and therefore we don't call it
// explicity here.
// Ownership contract
// args
// - extra_state: Stolen
// return
// - there is no return, but the extra_state is stolen, so it becomes
// set_extra_state responsibility to clean it up. It will be deleted during
// the reset_code/skip, when the set_extra_state is called with
// NULL/SKIP_CODE.
// Invariant - Dont set the extra state for the extra state that is already on
// the code object. Otherwise, we will first free up the old extra state
// (which is also the new extra state) and write something invalid on the
// scratch space.
void set_extra_state(PyCodeObject* code, ExtraState* extra_state);
// Creates a new extra state and put it on the extra scrach space of the code
// object.
// Ownership contract
// args
// - code: Borrowed
// return:
// - extra_state: New reference.
// These references are then further passed to set_extra_state which becomes
// the final owner of these references.
ExtraState* init_and_set_extra_state(PyCodeObject* code);
// Lookup the cache held by extra_state.
// Ownership contract
// args
// - extra_state: Borrowed
// - f_locals: Borrowed
// return:
// - Py_None or PyCodeObject: Borrowed reference.
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
// Create a new cache entry at extra_state holding on to guarded_code.
// Ownership contract
// args
// - extra_state: Borrowed
// - guarded_code: Borrowed
// return:
// - cache_entry: Borrowed reference
CacheEntry* create_cache_entry(ExtraState* extra_state, PyObject* guraded_code);
#ifdef __cplusplus
} // extern "C"
// Returns the list of CacheEntry corresponding to code_obj.
// Warning: returns references whose lifetimes are controlled by C++
py::list _debug_get_cache_entry_list(const py::handle& code_obj);
#endif

View File

@ -1,7 +1,9 @@
#include <torch/csrc/dynamo/init.h> #include <torch/csrc/dynamo/init.h>
#include <torch/csrc/Exceptions.h> #include <torch/csrc/Exceptions.h>
#include <torch/csrc/dynamo/cache_entry.h>
#include <torch/csrc/dynamo/eval_frame.h> #include <torch/csrc/dynamo/eval_frame.h>
#include <torch/csrc/dynamo/extra_state.h>
#include <torch/csrc/dynamo/guards.h> #include <torch/csrc/dynamo/guards.h>
#include <torch/csrc/dynamo/python_compiled_autograd.h> #include <torch/csrc/dynamo/python_compiled_autograd.h>
@ -34,6 +36,15 @@ void initDynamoBindings(PyObject* torch) {
PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) { PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) {
throw python_error(); throw python_error();
} }
auto m = py::handle(eval_frame).cast<py::module>();
py::class_<CacheEntry>(m, "_CacheEntry")
.def_readonly("check_fn", &CacheEntry::check_fn)
.def_readonly("code", &CacheEntry::code)
.def_property_readonly("next", &CacheEntry::next);
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
} }
} // namespace dynamo } // namespace dynamo

View File

@ -2,7 +2,6 @@
// C2039 MSVC // C2039 MSVC
#include <pybind11/complex.h> #include <pybind11/complex.h>
#include <pybind11/pybind11.h>
#include <torch/csrc/utils/pybind.h> #include <torch/csrc/utils/pybind.h>
#include <Python.h> #include <Python.h>

View File

@ -0,0 +1,9 @@
#pragma once
// The visibility attribute is to avoid a warning about storing a field in the
// struct that has a different visibility (from pybind) than the struct.
#ifdef _WIN32
#define VISIBILITY_HIDDEN
#else
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
#endif