mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[dynamo] refactor CacheEntry and ExtraState to eval_frame.c to C++ (#118438)
Part of implementing CacheEntry invalidation to fix https://github.com/pytorch/pytorch/issues/112090. Changes: - Move CacheEntry and ExtraState to C++ - Use pybind to control reference counting - Use std::list instead of manually implementing a linked list Pull Request resolved: https://github.com/pytorch/pytorch/pull/118438 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
73f0fdea5b
commit
ae4e866bba
@ -807,9 +807,11 @@ libtorch_python_core_sources = [
|
|||||||
"torch/csrc/autograd/python_variable.cpp",
|
"torch/csrc/autograd/python_variable.cpp",
|
||||||
"torch/csrc/autograd/python_variable_indexing.cpp",
|
"torch/csrc/autograd/python_variable_indexing.cpp",
|
||||||
"torch/csrc/dynamo/python_compiled_autograd.cpp",
|
"torch/csrc/dynamo/python_compiled_autograd.cpp",
|
||||||
|
"torch/csrc/dynamo/cache_entry.cpp",
|
||||||
"torch/csrc/dynamo/cpp_shim.cpp",
|
"torch/csrc/dynamo/cpp_shim.cpp",
|
||||||
"torch/csrc/dynamo/cpython_defs.c",
|
"torch/csrc/dynamo/cpython_defs.c",
|
||||||
"torch/csrc/dynamo/eval_frame.c",
|
"torch/csrc/dynamo/eval_frame.c",
|
||||||
|
"torch/csrc/dynamo/extra_state.cpp",
|
||||||
"torch/csrc/dynamo/guards.cpp",
|
"torch/csrc/dynamo/guards.cpp",
|
||||||
"torch/csrc/dynamo/init.cpp",
|
"torch/csrc/dynamo/init.cpp",
|
||||||
"torch/csrc/functorch/init.cpp",
|
"torch/csrc/functorch/init.cpp",
|
||||||
|
@ -9394,6 +9394,57 @@ fn
|
|||||||
self.assertIn(0, result)
|
self.assertIn(0, result)
|
||||||
self.assertTrue(same(result[0], torch.tensor(3)))
|
self.assertTrue(same(result[0], torch.tensor(3)))
|
||||||
|
|
||||||
|
def test_dynamo_reset_clears_cache(self):
|
||||||
|
"""Test that dynamo bytecode cache is freed
|
||||||
|
when dynamo reset is called
|
||||||
|
"""
|
||||||
|
|
||||||
|
def fn(x):
|
||||||
|
return torch.sin(x)
|
||||||
|
|
||||||
|
opt_fn = torch.compile(backend="eager")(fn)
|
||||||
|
opt_fn(torch.randn(3, 3))
|
||||||
|
|
||||||
|
c1 = _debug_get_cache_entry_list(fn.__code__)
|
||||||
|
self.assertEqual(len(c1), 1)
|
||||||
|
|
||||||
|
torch._dynamo.reset()
|
||||||
|
c2 = _debug_get_cache_entry_list(fn.__code__)
|
||||||
|
self.assertEqual(len(c2), 0)
|
||||||
|
|
||||||
|
def test_dynamo_cache_move_to_front(self):
|
||||||
|
class Mod(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(Mod, self).__init__()
|
||||||
|
self.fc = torch.nn.Linear(3, 3)
|
||||||
|
|
||||||
|
def forward(self, out):
|
||||||
|
return self.fc(out)
|
||||||
|
|
||||||
|
def fn(x, mod):
|
||||||
|
return mod(x)
|
||||||
|
|
||||||
|
opt_fn = torch.compile(fn, backend="eager")
|
||||||
|
|
||||||
|
m1 = Mod()
|
||||||
|
m2 = Mod()
|
||||||
|
m3 = Mod()
|
||||||
|
inp = torch.randn(3, 3)
|
||||||
|
|
||||||
|
# NOTE: assumes that each cache entry is guarded
|
||||||
|
# on unique Mod instance
|
||||||
|
opt_fn(inp, m1)
|
||||||
|
opt_fn(inp, m2)
|
||||||
|
opt_fn(inp, m3)
|
||||||
|
|
||||||
|
c1 = _debug_get_cache_entry_list(fn.__code__)
|
||||||
|
self.assertEqual(len(c1), 3)
|
||||||
|
|
||||||
|
# move cache entry to front
|
||||||
|
opt_fn(inp, m2)
|
||||||
|
c2 = _debug_get_cache_entry_list(fn.__code__)
|
||||||
|
self.assertIs(c1[1], c2[0])
|
||||||
|
|
||||||
|
|
||||||
class TestTracer(JitTestCase):
|
class TestTracer(JitTestCase):
|
||||||
def test_jit_save(self):
|
def test_jit_save(self):
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import types
|
import types
|
||||||
from typing import NewType, Optional
|
from typing import List, NewType, Optional
|
||||||
|
|
||||||
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
from torch._dynamo.types import DynamoCallback, DynamoGuardHook
|
||||||
|
|
||||||
@ -11,7 +11,6 @@ def set_eval_frame(callback: DynamoCallback) -> DynamoCallback: ...
|
|||||||
def reset_code(code: types.CodeType) -> None: ...
|
def reset_code(code: types.CodeType) -> None: ...
|
||||||
def unsupported(obj1: object, obj2: object) -> object: ...
|
def unsupported(obj1: object, obj2: object) -> object: ...
|
||||||
def skip_code(code: types.CodeType) -> None: ...
|
def skip_code(code: types.CodeType) -> None: ...
|
||||||
def set_guard_fail_hook(hook: DynamoGuardHook) -> None: ...
|
|
||||||
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
def set_guard_error_hook(hook: DynamoGuardHook) -> None: ...
|
||||||
|
|
||||||
class _CacheEntry:
|
class _CacheEntry:
|
||||||
@ -19,4 +18,4 @@ class _CacheEntry:
|
|||||||
code: types.CodeType
|
code: types.CodeType
|
||||||
next: Optional[_CacheEntry]
|
next: Optional[_CacheEntry]
|
||||||
|
|
||||||
def _debug_get_cache_entry_list(code: types.CodeType) -> Optional[_CacheEntry]: ...
|
def _debug_get_cache_entry_list(code: types.CodeType) -> List[_CacheEntry]: ...
|
||||||
|
@ -186,12 +186,7 @@ def _debug_get_cache_entry_list(
|
|||||||
"""
|
"""
|
||||||
if callable(code):
|
if callable(code):
|
||||||
code = code.__code__
|
code = code.__code__
|
||||||
cache_head = torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
|
return torch._C._dynamo.eval_frame._debug_get_cache_entry_list(code)
|
||||||
cache_list = []
|
|
||||||
while cache_head is not None:
|
|
||||||
cache_list.append(cache_head)
|
|
||||||
cache_head = cache_head.next
|
|
||||||
return cache_list
|
|
||||||
|
|
||||||
|
|
||||||
class OptimizedModule(torch.nn.Module):
|
class OptimizedModule(torch.nn.Module):
|
||||||
|
30
torch/csrc/dynamo/cache_entry.cpp
Normal file
30
torch/csrc/dynamo/cache_entry.cpp
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
#include <torch/csrc/dynamo/cache_entry.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/dynamo/debug_macros.h>
|
||||||
|
#include <torch/csrc/dynamo/extra_state.h>
|
||||||
|
|
||||||
|
CacheEntry::CacheEntry(const py::handle& guarded_code) {
|
||||||
|
this->check_fn = guarded_code.attr("check_fn");
|
||||||
|
this->code = guarded_code.attr("code");
|
||||||
|
}
|
||||||
|
|
||||||
|
py::object CacheEntry::next() {
|
||||||
|
NULL_CHECK(this->_owner);
|
||||||
|
auto it = this->_owner_loc;
|
||||||
|
++it;
|
||||||
|
if (it == this->_owner->cache_entry_list.end()) {
|
||||||
|
return py::none();
|
||||||
|
}
|
||||||
|
return py::cast(*it, py::return_value_policy::reference);
|
||||||
|
}
|
||||||
|
|
||||||
|
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
|
||||||
|
return (PyCodeObject*)e->code.ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* CacheEntry_to_obj(CacheEntry* e) {
|
||||||
|
if (!e) {
|
||||||
|
return py::none().release().ptr();
|
||||||
|
}
|
||||||
|
return py::cast(e, py::return_value_policy::reference).release().ptr();
|
||||||
|
}
|
68
torch/csrc/dynamo/cache_entry.h
Normal file
68
torch/csrc/dynamo/cache_entry.h
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
#include <torch/csrc/dynamo/utils.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
#include <list>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
/*
|
||||||
|
Our cache resides on the extra scratch space of the code object. The structure
|
||||||
|
of the cache is as follows:
|
||||||
|
|
||||||
|
-> ExtraState
|
||||||
|
-> CacheEntry (list)
|
||||||
|
-> check_fn
|
||||||
|
-> code
|
||||||
|
-> FrameState
|
||||||
|
|
||||||
|
CacheEntry is a linked list node containing the check_fn for guards
|
||||||
|
and the optimized code.
|
||||||
|
|
||||||
|
The FrameState is a PyDict that enables sharing between different frames. This
|
||||||
|
is used to detect dynamism in automatic dynamic shapes.
|
||||||
|
|
||||||
|
These two are encapsulated into a ExtraState.
|
||||||
|
*/
|
||||||
|
|
||||||
|
typedef struct CacheEntry CacheEntry;
|
||||||
|
typedef struct ExtraState ExtraState;
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
typedef struct VISIBILITY_HIDDEN CacheEntry {
|
||||||
|
// check the guards: lambda: <locals of user function>: bool
|
||||||
|
py::object check_fn;
|
||||||
|
// modified user bytecode (protected by check_fn's guards)
|
||||||
|
py::object code;
|
||||||
|
// Reference to owning ExtraState
|
||||||
|
ExtraState* _owner{nullptr};
|
||||||
|
// Reference to this CacheEntry's location in owner's linked list
|
||||||
|
std::list<CacheEntry>::iterator _owner_loc;
|
||||||
|
|
||||||
|
CacheEntry(const py::handle& guarded_code);
|
||||||
|
|
||||||
|
// Warning: returns a reference whose lifetime is controlled by C++
|
||||||
|
py::object next();
|
||||||
|
} CacheEntry;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Returns borrowed reference
|
||||||
|
PyCodeObject* CacheEntry_get_code(CacheEntry* e);
|
||||||
|
|
||||||
|
// Returns a borrowed reference to CacheEntry as a PyObject
|
||||||
|
// Warning: lifetime is controlled by C++
|
||||||
|
PyObject* CacheEntry_to_obj(CacheEntry* e);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
} // extern "C"
|
||||||
|
#endif
|
46
torch/csrc/dynamo/debug_macros.h
Normal file
46
torch/csrc/dynamo/debug_macros.h
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define unlikely(x) (x)
|
||||||
|
#else
|
||||||
|
#define unlikely(x) __builtin_expect((x), 0)
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define NULL_CHECK(val) \
|
||||||
|
if (unlikely((val) == NULL)) { \
|
||||||
|
fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
|
||||||
|
PyErr_Print(); \
|
||||||
|
abort(); \
|
||||||
|
} else { \
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK might be previously declared
|
||||||
|
#undef CHECK
|
||||||
|
#define CHECK(cond) \
|
||||||
|
if (unlikely(!(cond))) { \
|
||||||
|
fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
|
||||||
|
abort(); \
|
||||||
|
} else { \
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uncomment next line to print debug message
|
||||||
|
// #define TORCHDYNAMO_DEBUG 1
|
||||||
|
#ifdef TORCHDYNAMO_DEBUG
|
||||||
|
|
||||||
|
#define DEBUG_CHECK(cond) CHECK(cond)
|
||||||
|
#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
|
||||||
|
#define DEBUG_TRACE(msg, ...) \
|
||||||
|
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
|
||||||
|
#define DEBUG_TRACE0(msg) \
|
||||||
|
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
#define DEBUG_CHECK(cond)
|
||||||
|
#define DEBUG_NULL_CHECK(val)
|
||||||
|
#define DEBUG_TRACE(msg, ...)
|
||||||
|
#define DEBUG_TRACE0(msg)
|
||||||
|
|
||||||
|
#endif
|
@ -1,6 +1,9 @@
|
|||||||
#define PY_SSIZE_T_CLEAN
|
#define PY_SSIZE_T_CLEAN
|
||||||
|
#include <torch/csrc/dynamo/cache_entry.h>
|
||||||
#include <torch/csrc/dynamo/cpp_shim.h>
|
#include <torch/csrc/dynamo/cpp_shim.h>
|
||||||
#include <torch/csrc/dynamo/cpython_defs.h>
|
#include <torch/csrc/dynamo/cpython_defs.h>
|
||||||
|
#include <torch/csrc/dynamo/debug_macros.h>
|
||||||
|
#include <torch/csrc/dynamo/extra_state.h>
|
||||||
#include <torch/csrc/utils/python_compat.h>
|
#include <torch/csrc/utils/python_compat.h>
|
||||||
#include <opcode.h>
|
#include <opcode.h>
|
||||||
#include <stdbool.h>
|
#include <stdbool.h>
|
||||||
@ -132,57 +135,9 @@ THPPyInterpreterFrame* THPPyInterpreterFrame_New(_PyInterpreterFrame* frame) {
|
|||||||
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
|
#define THP_PyFrame_FastToLocalsWithError PyFrame_FastToLocalsWithError
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#ifdef _WIN32
|
PyObject* guard_error_hook = NULL;
|
||||||
#define unlikely(x) (x)
|
|
||||||
#else
|
|
||||||
#define unlikely(x) __builtin_expect((x), 0)
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#define NULL_CHECK(val) \
|
|
||||||
if (unlikely((val) == NULL)) { \
|
|
||||||
fprintf(stderr, "NULL ERROR: %s:%d\n", __FILE__, __LINE__); \
|
|
||||||
PyErr_Print(); \
|
|
||||||
abort(); \
|
|
||||||
} else { \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define CHECK(cond) \
|
|
||||||
if (unlikely(!(cond))) { \
|
|
||||||
fprintf(stderr, "DEBUG CHECK FAILED: %s:%d\n", __FILE__, __LINE__); \
|
|
||||||
abort(); \
|
|
||||||
} else { \
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uncomment next line to print debug message
|
|
||||||
// #define TORCHDYNAMO_DEBUG 1
|
|
||||||
|
|
||||||
#ifdef TORCHDYNAMO_DEBUG
|
|
||||||
|
|
||||||
#define DEBUG_CHECK(cond) CHECK(cond)
|
|
||||||
#define DEBUG_NULL_CHECK(val) NULL_CHECK(val)
|
|
||||||
#define DEBUG_TRACE(msg, ...) \
|
|
||||||
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__, __VA_ARGS__)
|
|
||||||
#define DEBUG_TRACE0(msg) \
|
|
||||||
fprintf(stderr, "TRACE[%s:%d] " msg "\n", __func__, __LINE__)
|
|
||||||
|
|
||||||
#else
|
|
||||||
|
|
||||||
#define DEBUG_CHECK(cond)
|
|
||||||
#define DEBUG_NULL_CHECK(val)
|
|
||||||
#define DEBUG_TRACE(msg, ...)
|
|
||||||
#define DEBUG_TRACE0(msg)
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
// Flag to just run a frame normally
|
|
||||||
#define SKIP_CODE ((void*)0x1)
|
|
||||||
|
|
||||||
static PyObject* guard_error_hook = NULL;
|
|
||||||
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
const char* cache_lookup_profiler_str = "TorchDynamo Cache Lookup";
|
||||||
|
|
||||||
// Points to the extra scratch space on the code object
|
|
||||||
static Py_ssize_t extra_index = -1;
|
|
||||||
|
|
||||||
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
|
static Py_tss_t eval_frame_callback_key = Py_tss_NEEDS_INIT;
|
||||||
|
|
||||||
inline static PyObject* eval_frame_callback_get(void) {
|
inline static PyObject* eval_frame_callback_get(void) {
|
||||||
@ -284,323 +239,6 @@ inline static const char* get_frame_name(THP_EVAL_API_FRAME_OBJECT* frame) {
|
|||||||
return PyUnicode_AsUTF8(frame->f_code->co_name);
|
return PyUnicode_AsUTF8(frame->f_code->co_name);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef PyObject FrameState;
|
|
||||||
/*
|
|
||||||
Our cache resides on the extra scratch space of the code object. The structure
|
|
||||||
of the cache is as follows:
|
|
||||||
|
|
||||||
-> ExtraState
|
|
||||||
-> CacheEntry
|
|
||||||
-> check_fn
|
|
||||||
-> optimized_code
|
|
||||||
-> next
|
|
||||||
-> FrameState
|
|
||||||
|
|
||||||
CacheEntry is a linked list, with each node containing the check_fn for guards
|
|
||||||
and the optimized code.
|
|
||||||
|
|
||||||
The frame_state is a PyDict that enables sharing between different frames. This
|
|
||||||
is used to detect dynamism in automatic dynamic shapes.
|
|
||||||
|
|
||||||
These two are encapsulated into a ExtraState.
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Linked list of cache entries, where each cache entry stores
|
|
||||||
// the check_fn and the torch.compile optimized python bytecode.
|
|
||||||
typedef struct cache_entry {
|
|
||||||
PyObject_HEAD
|
|
||||||
// check the guards: lambda: <locals of user function>: bool
|
|
||||||
PyObject* check_fn;
|
|
||||||
// modified user bytecode (protected by check_fn's guards)
|
|
||||||
PyCodeObject* code;
|
|
||||||
// on a cache miss, linked list of next thing to try
|
|
||||||
struct cache_entry* next;
|
|
||||||
} CacheEntry;
|
|
||||||
|
|
||||||
static void cache_entry_dealloc(CacheEntry* e);
|
|
||||||
|
|
||||||
#define DECLARE_CACHE_ENTRY_ATTR(name) \
|
|
||||||
static PyObject* CacheEntry_##name(CacheEntry* self, PyObject* _noargs) { \
|
|
||||||
PyObject* res = (PyObject*)self->name; \
|
|
||||||
Py_INCREF(res); \
|
|
||||||
return res; \
|
|
||||||
}
|
|
||||||
|
|
||||||
DECLARE_CACHE_ENTRY_ATTR(check_fn)
|
|
||||||
DECLARE_CACHE_ENTRY_ATTR(code)
|
|
||||||
DECLARE_CACHE_ENTRY_ATTR(next)
|
|
||||||
|
|
||||||
static struct PyGetSetDef CacheEntry_properties[] = {
|
|
||||||
{"check_fn", (getter)CacheEntry_check_fn, NULL, NULL, NULL},
|
|
||||||
{"code", (getter)CacheEntry_code, NULL, NULL, NULL},
|
|
||||||
{"next", (getter)CacheEntry_next, NULL, NULL, NULL},
|
|
||||||
{NULL}};
|
|
||||||
|
|
||||||
|
|
||||||
static PyObject* cache_entry_new(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
|
|
||||||
CacheEntry *self = (CacheEntry*) type->tp_alloc(type, 0);
|
|
||||||
if (self != NULL) {
|
|
||||||
// The corresponding decrefs for Py_None are in cache_entry_init.
|
|
||||||
Py_INCREF(Py_None);
|
|
||||||
self->check_fn = Py_None;
|
|
||||||
Py_INCREF(Py_None);
|
|
||||||
self->code = (PyCodeObject*)Py_None;
|
|
||||||
Py_INCREF(Py_None);
|
|
||||||
self->next = (CacheEntry*)Py_None;
|
|
||||||
}
|
|
||||||
return (PyObject*)self;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
static int cache_entry_init(CacheEntry* self, PyObject* args, PyObject* kwds) {
|
|
||||||
PyObject* check_fn = NULL;
|
|
||||||
PyCodeObject* code = NULL;
|
|
||||||
CacheEntry* next = NULL;
|
|
||||||
|
|
||||||
static char *kwlist[] = {"check_fn", "code", "next", NULL};
|
|
||||||
|
|
||||||
int ret = PyArg_ParseTupleAndKeywords(
|
|
||||||
args, kwds, "OOO", kwlist,
|
|
||||||
&check_fn, &code, &next);
|
|
||||||
|
|
||||||
if (!ret) return -1;
|
|
||||||
|
|
||||||
if (check_fn) {
|
|
||||||
PyObject* tmp = self->check_fn;
|
|
||||||
Py_INCREF(check_fn);
|
|
||||||
self->check_fn = check_fn;
|
|
||||||
Py_XDECREF(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (code) {
|
|
||||||
PyCodeObject* tmp = self->code;
|
|
||||||
Py_INCREF(code);
|
|
||||||
self->code = code;
|
|
||||||
Py_XDECREF(tmp);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (next) {
|
|
||||||
CacheEntry* tmp = self->next;
|
|
||||||
Py_INCREF(next);
|
|
||||||
self->next = next;
|
|
||||||
Py_XDECREF(tmp);
|
|
||||||
}
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static PyTypeObject CacheEntryType = {
|
|
||||||
PyVarObject_HEAD_INIT(NULL, 0)
|
|
||||||
.tp_name = "torch._C.dynamo.eval_frame.CacheEntryWrapper",
|
|
||||||
.tp_basicsize = sizeof(CacheEntry),
|
|
||||||
.tp_itemsize = 0,
|
|
||||||
.tp_flags = Py_TPFLAGS_DEFAULT,
|
|
||||||
.tp_new = cache_entry_new,
|
|
||||||
.tp_init = (initproc)cache_entry_init,
|
|
||||||
.tp_dealloc = (destructor)cache_entry_dealloc,
|
|
||||||
.tp_getset = CacheEntry_properties,
|
|
||||||
};
|
|
||||||
|
|
||||||
// ExtraState encasulates CacheEntry and FrameState. ExtraState is the highest
|
|
||||||
// level of abstraction of what is stored on the extra code object. Previously,
|
|
||||||
// we saved different parts on different extra indexes. We prefer this way
|
|
||||||
// because of cleaner abstraction and faster SetExtra access.
|
|
||||||
|
|
||||||
// TODO(anijain2305) - Consider making this a PyObject. Benefits are
|
|
||||||
// 1) Modular dealloc - destroy_extra_state just becomes Py_DECREF(extra)
|
|
||||||
// 2) We can directly send the extra object to convert_frame callback. One
|
|
||||||
// data structure - easier to understand code.
|
|
||||||
// There might be some perf impact of going through a PyObject on the critical
|
|
||||||
// path, but it should not be too bad.
|
|
||||||
typedef struct {
|
|
||||||
// Cache entry for the code object
|
|
||||||
CacheEntry* cache_entry;
|
|
||||||
// Frame state to detect dynamic shape dims
|
|
||||||
FrameState* frame_state;
|
|
||||||
} ExtraState;
|
|
||||||
|
|
||||||
|
|
||||||
/* CacheEntry helper functions begins */
|
|
||||||
|
|
||||||
static CacheEntry* create_cache_entry(
|
|
||||||
CacheEntry* next,
|
|
||||||
PyObject* guarded_code) {
|
|
||||||
// Ownership contract
|
|
||||||
// args
|
|
||||||
// - next: steals
|
|
||||||
// - guarded_code: Borrowed
|
|
||||||
// return
|
|
||||||
// - CacheEntry*: new reference.
|
|
||||||
PyObject* check_fn = PyObject_GetAttrString(guarded_code, "check_fn"); // new reference
|
|
||||||
PyCodeObject* code = (PyCodeObject*)PyObject_GetAttrString(guarded_code, "code"); // new reference
|
|
||||||
|
|
||||||
// equivalent to CacheEntry(check_fn, code, next) in Python
|
|
||||||
PyObject* args = Py_BuildValue("OOO", check_fn, code, next);
|
|
||||||
CacheEntry* e = (CacheEntry*)PyObject_CallObject((PyObject*)&CacheEntryType, args); // new reference
|
|
||||||
// CacheEntry e is the now the owner of old cachey entry next. This happens
|
|
||||||
// when we incref the next pointer in cache_entry_init.
|
|
||||||
Py_DECREF(next);
|
|
||||||
Py_DECREF(check_fn);
|
|
||||||
Py_DECREF(code);
|
|
||||||
Py_DECREF(args);
|
|
||||||
return e;
|
|
||||||
}
|
|
||||||
|
|
||||||
static void cache_entry_dealloc(CacheEntry* e) {
|
|
||||||
Py_XDECREF(e->check_fn);
|
|
||||||
Py_XDECREF(e->code);
|
|
||||||
// This will recursively call cache_entry_dealloc for the next items in the
|
|
||||||
// linked list.
|
|
||||||
Py_XDECREF(e->next);
|
|
||||||
Py_TYPE(e)->tp_free((PyObject*)e);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* CacheEntry helper functions ends */
|
|
||||||
|
|
||||||
/* Extractions helper functions begins. They help with NULL and SKIP_CODE corner cases */
|
|
||||||
|
|
||||||
inline static CacheEntry* extract_cache_entry(ExtraState* extra_state) {
|
|
||||||
// Helper to extra the cache_entry from the extra state.
|
|
||||||
|
|
||||||
// Ownership contract
|
|
||||||
// args
|
|
||||||
// - extra_state: Borrowed
|
|
||||||
// return
|
|
||||||
// - CacheEntry: Borrowed.
|
|
||||||
if (extra_state == NULL || extra_state == SKIP_CODE) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
return extra_state->cache_entry;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
inline static FrameState* extract_frame_state(ExtraState* extra_state) {
|
|
||||||
// Returns either the previously stored frame state or an empty dict.
|
|
||||||
|
|
||||||
// Ownership contract
|
|
||||||
// args
|
|
||||||
// - extra_state: Borrowed
|
|
||||||
// return
|
|
||||||
// - extra_state->frame_state: Borrowed.
|
|
||||||
if (extra_state == NULL || extra_state == SKIP_CODE) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
return extra_state->frame_state;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Extractions helper functions ends */
|
|
||||||
|
|
||||||
/* Extra state helper functions begins */
|
|
||||||
|
|
||||||
inline static ExtraState* get_extra_state(PyCodeObject* code) {
|
|
||||||
// Ownership contract
|
|
||||||
// args
|
|
||||||
// - code: Borrowed
|
|
||||||
// return
|
|
||||||
// - extra_state: Borrowed.
|
|
||||||
ExtraState* extra = NULL;
|
|
||||||
_PyCode_GetExtra((PyObject*)code, extra_index, (void*)&extra);
|
|
||||||
return extra;
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static void destroy_extra_state(void* obj) {
|
|
||||||
// This is passed as freefunc to _PyEval_RequestCodeExtraIndex. This acts as a
|
|
||||||
// deleter for the object on extra scratch space. This function is called
|
|
||||||
// internally in _PyCode_SetExtra and also during the code deallocation.
|
|
||||||
|
|
||||||
// Destroys the extra state by deleting cache_entry, frame state and finally
|
|
||||||
// freeing the constructed extra state.
|
|
||||||
|
|
||||||
// Developer note - You should not call this function directly. This is called
|
|
||||||
// directly inside set_extra_state. If you are in a situation trying to call
|
|
||||||
// this function, consider if set_extra_state should be called.
|
|
||||||
|
|
||||||
ExtraState* extra = (ExtraState*)obj;
|
|
||||||
if (extra != NULL && extra != SKIP_CODE) {
|
|
||||||
// Cpython gc will call cache_entry_dealloc on its own when the ref count
|
|
||||||
// goes to 0.
|
|
||||||
Py_XDECREF(extra->cache_entry);
|
|
||||||
Py_XDECREF(extra->frame_state);
|
|
||||||
free(extra);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
|
|
||||||
// Clears the existing object sitting on the extra scratch spance and sets it
|
|
||||||
// up with the new state. Note that _PyCode_SetExtra calls the
|
|
||||||
// destroy_extra_state deleter internally, and therefore we don't call it
|
|
||||||
// explicity here.
|
|
||||||
|
|
||||||
// Ownership contract
|
|
||||||
// args
|
|
||||||
// - extra_state: Stolen
|
|
||||||
// return
|
|
||||||
// - there is no return, but the extra_state is stolen, so it becomes
|
|
||||||
// set_extra_state responsibility to clean it up. It will be deleted during
|
|
||||||
// the reset_code/skip, when the set_extra_state is called with
|
|
||||||
// NULL/SKIP_CODE.
|
|
||||||
|
|
||||||
// Invariant - Dont set the extra state for the extra state that is already on
|
|
||||||
// the code object. Otherwise, we will first free up the old extra state
|
|
||||||
// (which is also the new extra state) and write something invalid on the
|
|
||||||
// scratch space.
|
|
||||||
ExtraState* old_extra_state = get_extra_state(code);
|
|
||||||
CHECK(old_extra_state == NULL || old_extra_state == SKIP_CODE || old_extra_state != extra_state);
|
|
||||||
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static ExtraState* init_and_set_extra_state(PyCodeObject* code) {
|
|
||||||
// Creates a new extra state and put it on the extra scrach space of the code
|
|
||||||
// object.
|
|
||||||
|
|
||||||
// Ownership contract
|
|
||||||
// args
|
|
||||||
// - code: Borrowed
|
|
||||||
// return:
|
|
||||||
// - extra_state: New reference.
|
|
||||||
// These references are then further passed to set_extra_state which becomes
|
|
||||||
// the final owner of these references.
|
|
||||||
|
|
||||||
// Invariant - Extra state should not have been set before, therefore it should be NULL.
|
|
||||||
CHECK(get_extra_state(code) == NULL);
|
|
||||||
ExtraState* extra_state = (ExtraState*)malloc(sizeof(ExtraState));
|
|
||||||
DEBUG_NULL_CHECK(extra_state);
|
|
||||||
// We set the last node in the linked list to Py_None. We incref the Py_None
|
|
||||||
// here, the corresponding decref is in cache_entry_dealloc.
|
|
||||||
Py_INCREF(Py_None);
|
|
||||||
extra_state->cache_entry = (CacheEntry*)Py_None;
|
|
||||||
extra_state->frame_state = PyDict_New();
|
|
||||||
set_extra_state(code, extra_state);
|
|
||||||
return extra_state;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Extra state helper functions ends */
|
|
||||||
|
|
||||||
/*
|
|
||||||
Debugger helper functions.
|
|
||||||
*/
|
|
||||||
|
|
||||||
PyObject* _debug_get_cache_entry_list(PyObject* self, PyObject* args) {
|
|
||||||
// get the cache entry out of a code object
|
|
||||||
PyObject* object = NULL;
|
|
||||||
if (!PyArg_ParseTuple(args, "O", &object)) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
if (!PyCode_Check(object)) {
|
|
||||||
PyErr_SetString(PyExc_TypeError, "expected a code object!");
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
PyCodeObject* code = (PyCodeObject*)object;
|
|
||||||
|
|
||||||
ExtraState* extra = get_extra_state(code);
|
|
||||||
CacheEntry* current_node = extract_cache_entry(extra);
|
|
||||||
if (current_node == NULL)
|
|
||||||
{
|
|
||||||
Py_RETURN_NONE;
|
|
||||||
}
|
|
||||||
Py_INCREF(current_node);
|
|
||||||
return (PyObject*)current_node;
|
|
||||||
}
|
|
||||||
|
|
||||||
static inline PyObject* call_callback(
|
static inline PyObject* call_callback(
|
||||||
PyObject* callable,
|
PyObject* callable,
|
||||||
THP_EVAL_API_FRAME_OBJECT* _frame,
|
THP_EVAL_API_FRAME_OBJECT* _frame,
|
||||||
@ -618,74 +256,18 @@ static inline PyObject* call_callback(
|
|||||||
PyObject* frame = Py_NewRef(_frame);
|
PyObject* frame = Py_NewRef(_frame);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
PyObject* cache_entry_pyobj = CacheEntry_to_obj(cache_entry);
|
||||||
PyObject* res = PyObject_CallFunction(
|
PyObject* res = PyObject_CallFunction(
|
||||||
callable,
|
callable,
|
||||||
"OOO",
|
"OOO",
|
||||||
frame,
|
frame,
|
||||||
cache_entry,
|
cache_entry_pyobj,
|
||||||
frame_state);
|
frame_state);
|
||||||
Py_DECREF(frame);
|
Py_DECREF(frame);
|
||||||
|
Py_DECREF(cache_entry_pyobj);
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PyObject* call_guard_fail_hook(
|
|
||||||
PyObject* hook,
|
|
||||||
CacheEntry* e,
|
|
||||||
size_t index,
|
|
||||||
PyObject* f_locals) {
|
|
||||||
// call debugging logic when a guard fails
|
|
||||||
return PyObject_CallFunction(
|
|
||||||
hook,
|
|
||||||
"OOOnO",
|
|
||||||
e->check_fn,
|
|
||||||
e->code,
|
|
||||||
f_locals,
|
|
||||||
(Py_ssize_t)index,
|
|
||||||
(e->next == (CacheEntry*)Py_None ? Py_True : Py_False));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Return value: borrowed reference
|
|
||||||
// Is either Py_None or a PyCodeObject
|
|
||||||
static PyObject* lookup(CacheEntry* e, THP_EVAL_API_FRAME_OBJECT *frame, CacheEntry* prev, size_t index) {
|
|
||||||
if (e == (CacheEntry*)Py_None) {
|
|
||||||
// NB: intentionally not using Py_RETURN_NONE, to return borrowed ref
|
|
||||||
return Py_None;
|
|
||||||
}
|
|
||||||
PyObject *f_locals = frame->f_locals;
|
|
||||||
// remember to update the type signature for GuardFn.__call__ in torch/_dynamo/types.py
|
|
||||||
// if this calling convention changes
|
|
||||||
PyObject* valid = PyObject_CallOneArg(e->check_fn, f_locals);
|
|
||||||
if (unlikely(valid == NULL)) {
|
|
||||||
if (guard_error_hook != NULL) {
|
|
||||||
PyObject *type = NULL, *value = NULL, *traceback = NULL;
|
|
||||||
PyErr_Fetch(&type, &value, &traceback);
|
|
||||||
PyObject* r = call_guard_fail_hook(guard_error_hook, e, index, f_locals);
|
|
||||||
if (r == NULL) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
Py_DECREF(r);
|
|
||||||
PyErr_Restore(type, value, traceback);
|
|
||||||
}
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
Py_DECREF(valid);
|
|
||||||
if (valid == Py_True) {
|
|
||||||
// Keep the head as the most recently used cache entry.
|
|
||||||
// If the hit cache entry is not the head of the linked list,
|
|
||||||
// move it to the head
|
|
||||||
if (prev != NULL) {
|
|
||||||
ExtraState* extra = get_extra_state(frame->f_code);
|
|
||||||
// Override the extra state to reflect the updated cache line.
|
|
||||||
CacheEntry* old_cache_entry = extra->cache_entry;
|
|
||||||
prev->next = e->next;
|
|
||||||
e->next = old_cache_entry;
|
|
||||||
extra->cache_entry = e;
|
|
||||||
}
|
|
||||||
return (PyObject*)e->code;
|
|
||||||
}
|
|
||||||
return lookup(e->next, frame, e, index + 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
inline static PyObject* eval_custom_code_impl(
|
inline static PyObject* eval_custom_code_impl(
|
||||||
PyThreadState* tstate,
|
PyThreadState* tstate,
|
||||||
THP_EVAL_API_FRAME_OBJECT* frame,
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
||||||
@ -942,12 +524,7 @@ static PyObject* _custom_eval_frame(
|
|||||||
extra = init_and_set_extra_state(frame->f_code);
|
extra = init_and_set_extra_state(frame->f_code);
|
||||||
}
|
}
|
||||||
|
|
||||||
CacheEntry* cache_entry = extract_cache_entry(extra);
|
|
||||||
FrameState* frame_state = extract_frame_state(extra);
|
|
||||||
|
|
||||||
// TODO(jansel): investigate directly using the "fast" representation
|
// TODO(jansel): investigate directly using the "fast" representation
|
||||||
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
|
|
||||||
// even though we should pass a PyFrameObject.
|
|
||||||
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
|
if (THP_PyFrame_FastToLocalsWithError(frame) < 0) {
|
||||||
DEBUG_TRACE("error %s", get_frame_name(frame));
|
DEBUG_TRACE("error %s", get_frame_name(frame));
|
||||||
return NULL;
|
return NULL;
|
||||||
@ -958,7 +535,7 @@ static PyObject* _custom_eval_frame(
|
|||||||
if (callback == Py_False) {
|
if (callback == Py_False) {
|
||||||
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
|
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
|
||||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||||
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
|
PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
|
||||||
_pytorch_record_function_exit(rf);
|
_pytorch_record_function_exit(rf);
|
||||||
|
|
||||||
if (maybe_cached_code == NULL) {
|
if (maybe_cached_code == NULL) {
|
||||||
@ -983,7 +560,7 @@ static PyObject* _custom_eval_frame(
|
|||||||
eval_frame_callback_set(Py_None);
|
eval_frame_callback_set(Py_None);
|
||||||
|
|
||||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||||
PyObject* maybe_cached_code = lookup(cache_entry, frame, NULL, 0);
|
PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
|
||||||
_pytorch_record_function_exit(rf);
|
_pytorch_record_function_exit(rf);
|
||||||
if (maybe_cached_code == NULL) {
|
if (maybe_cached_code == NULL) {
|
||||||
// Python error
|
// Python error
|
||||||
@ -997,8 +574,8 @@ static PyObject* _custom_eval_frame(
|
|||||||
return eval_custom_code(tstate, frame, cached_code, throw_flag);
|
return eval_custom_code(tstate, frame, cached_code, throw_flag);
|
||||||
}
|
}
|
||||||
// cache miss
|
// cache miss
|
||||||
// TODO(alband): This is WRONG for python3.11+ we pass in a _PyInterpreterFrame
|
CacheEntry* cache_entry = extract_cache_entry(extra);
|
||||||
// that gets re-interpreted as a PyObject (which it is NOT!)
|
FrameState* frame_state = extract_frame_state(extra);
|
||||||
PyObject* result =
|
PyObject* result =
|
||||||
call_callback(callback, frame, cache_entry, frame_state);
|
call_callback(callback, frame, cache_entry, frame_state);
|
||||||
if (result == NULL) {
|
if (result == NULL) {
|
||||||
@ -1017,7 +594,7 @@ static PyObject* _custom_eval_frame(
|
|||||||
// extract_cache_entry returns a borrowed reference. Modifying a borrowed
|
// extract_cache_entry returns a borrowed reference. Modifying a borrowed
|
||||||
// reference seems wrong. Therefore, we directly access the
|
// reference seems wrong. Therefore, we directly access the
|
||||||
// extra->cache_entry. extra wont be NULL here.
|
// extra->cache_entry. extra wont be NULL here.
|
||||||
extra->cache_entry = create_cache_entry(extra->cache_entry, result);
|
CacheEntry* new_cache_entry = create_cache_entry(extra, result);
|
||||||
Py_DECREF(result);
|
Py_DECREF(result);
|
||||||
// Update the existing cache_entry on the extra object. This extra object is
|
// Update the existing cache_entry on the extra object. This extra object is
|
||||||
// sitting on the extra scratch space, we are just changing the cache_entry
|
// sitting on the extra scratch space, we are just changing the cache_entry
|
||||||
@ -1025,7 +602,7 @@ static PyObject* _custom_eval_frame(
|
|||||||
// will be cleaned up when set_extra_state is called.
|
// will be cleaned up when set_extra_state is called.
|
||||||
// Re-enable custom behavior
|
// Re-enable custom behavior
|
||||||
eval_frame_callback_set(callback);
|
eval_frame_callback_set(callback);
|
||||||
return eval_custom_code(tstate, frame, extra->cache_entry->code, throw_flag);
|
return eval_custom_code(tstate, frame, CacheEntry_get_code(new_cache_entry), throw_flag);
|
||||||
} else {
|
} else {
|
||||||
DEBUG_TRACE("create skip %s", get_frame_name(frame));
|
DEBUG_TRACE("create skip %s", get_frame_name(frame));
|
||||||
Py_DECREF(result);
|
Py_DECREF(result);
|
||||||
@ -1144,7 +721,6 @@ static PyMethodDef _methods[] = {
|
|||||||
{"unsupported", unsupported, METH_VARARGS, NULL},
|
{"unsupported", unsupported, METH_VARARGS, NULL},
|
||||||
{"skip_code", skip_code, METH_O, NULL},
|
{"skip_code", skip_code, METH_O, NULL},
|
||||||
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
|
{"set_guard_error_hook", set_guard_error_hook, METH_O, NULL},
|
||||||
{"_debug_get_cache_entry_list", _debug_get_cache_entry_list, METH_VARARGS, NULL},
|
|
||||||
{NULL, NULL, 0, NULL}};
|
{NULL, NULL, 0, NULL}};
|
||||||
|
|
||||||
static struct PyModuleDef _module = {
|
static struct PyModuleDef _module = {
|
||||||
@ -1184,15 +760,5 @@ PyObject* torch_c_dynamo_eval_frame_init(void) {
|
|||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
|
||||||
if (PyType_Ready(&CacheEntryType) < 0) {
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
Py_INCREF(&CacheEntryType);
|
|
||||||
if (PyModule_AddObject(module, "_CacheEntry", (PyObject *) &CacheEntryType) < 0) {
|
|
||||||
Py_DECREF(&CacheEntryType);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
return module;
|
return module;
|
||||||
}
|
}
|
||||||
|
129
torch/csrc/dynamo/extra_state.cpp
Normal file
129
torch/csrc/dynamo/extra_state.cpp
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
#include <torch/csrc/dynamo/extra_state.h>
|
||||||
|
|
||||||
|
#include <torch/csrc/dynamo/cache_entry.h>
|
||||||
|
#include <torch/csrc/dynamo/debug_macros.h>
|
||||||
|
|
||||||
|
Py_ssize_t extra_index = -1;
|
||||||
|
|
||||||
|
CacheEntry* ExtraState::get_first_entry() {
|
||||||
|
if (this->cache_entry_list.empty()) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return &this->cache_entry_list.front();
|
||||||
|
}
|
||||||
|
|
||||||
|
void ExtraState::move_to_front(CacheEntry* cache_entry) {
|
||||||
|
CHECK(cache_entry->_owner == this);
|
||||||
|
CHECK(!this->cache_entry_list.empty());
|
||||||
|
CHECK(cache_entry == &*cache_entry->_owner_loc);
|
||||||
|
this->cache_entry_list.splice(
|
||||||
|
this->cache_entry_list.begin(),
|
||||||
|
this->cache_entry_list,
|
||||||
|
cache_entry->_owner_loc);
|
||||||
|
}
|
||||||
|
|
||||||
|
CacheEntry* extract_cache_entry(ExtraState* extra_state) {
|
||||||
|
if (extra_state == NULL || extra_state == SKIP_CODE) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return extra_state->get_first_entry();
|
||||||
|
}
|
||||||
|
|
||||||
|
FrameState* extract_frame_state(ExtraState* extra_state) {
|
||||||
|
if (extra_state == NULL || extra_state == SKIP_CODE) {
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
return (FrameState*)extra_state->frame_state.ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
ExtraState* get_extra_state(PyCodeObject* code) {
|
||||||
|
ExtraState* extra = NULL;
|
||||||
|
_PyCode_GetExtra((PyObject*)code, extra_index, (void**)&extra);
|
||||||
|
return extra;
|
||||||
|
}
|
||||||
|
|
||||||
|
void destroy_extra_state(void* obj) {
|
||||||
|
ExtraState* extra = (ExtraState*)obj;
|
||||||
|
if (extra != NULL && extra != SKIP_CODE) {
|
||||||
|
delete extra;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void set_extra_state(PyCodeObject* code, ExtraState* extra_state) {
|
||||||
|
ExtraState* old_extra_state = get_extra_state(code);
|
||||||
|
CHECK(
|
||||||
|
old_extra_state == NULL || old_extra_state == SKIP_CODE ||
|
||||||
|
old_extra_state != extra_state);
|
||||||
|
_PyCode_SetExtra((PyObject*)code, extra_index, extra_state);
|
||||||
|
}
|
||||||
|
|
||||||
|
ExtraState* init_and_set_extra_state(PyCodeObject* code) {
|
||||||
|
// Invariant - Extra state should not have been set before, therefore it
|
||||||
|
// should be NULL.
|
||||||
|
CHECK(get_extra_state(code) == NULL);
|
||||||
|
ExtraState* extra_state = new ExtraState();
|
||||||
|
NULL_CHECK(extra_state);
|
||||||
|
set_extra_state(code, extra_state);
|
||||||
|
return extra_state;
|
||||||
|
}
|
||||||
|
|
||||||
|
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
|
||||||
|
size_t index = 0;
|
||||||
|
CacheEntry* found = nullptr;
|
||||||
|
py::handle locals(f_locals);
|
||||||
|
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
|
||||||
|
py::object valid = py::none();
|
||||||
|
try {
|
||||||
|
valid = cache_entry.check_fn(locals);
|
||||||
|
} catch (py::error_already_set& e) {
|
||||||
|
if (guard_error_hook) {
|
||||||
|
py::handle guard_error_hook_handle(guard_error_hook);
|
||||||
|
guard_error_hook_handle(
|
||||||
|
cache_entry.check_fn,
|
||||||
|
cache_entry.code,
|
||||||
|
locals,
|
||||||
|
index,
|
||||||
|
index == extra_state->cache_entry_list.size() - 1);
|
||||||
|
}
|
||||||
|
// this function is called from C, so we cannot repropagate
|
||||||
|
// the exception
|
||||||
|
e.restore();
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
if (valid.cast<bool>()) {
|
||||||
|
found = &cache_entry;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
++index;
|
||||||
|
}
|
||||||
|
if (found) {
|
||||||
|
extra_state->move_to_front(found);
|
||||||
|
return found->code.ptr();
|
||||||
|
}
|
||||||
|
return py::none().ptr();
|
||||||
|
}
|
||||||
|
|
||||||
|
CacheEntry* create_cache_entry(
|
||||||
|
ExtraState* extra_state,
|
||||||
|
PyObject* guarded_code) {
|
||||||
|
extra_state->cache_entry_list.emplace_front(guarded_code);
|
||||||
|
auto new_iter = extra_state->cache_entry_list.begin();
|
||||||
|
new_iter->_owner = extra_state;
|
||||||
|
new_iter->_owner_loc = new_iter;
|
||||||
|
return &*new_iter;
|
||||||
|
}
|
||||||
|
|
||||||
|
py::list _debug_get_cache_entry_list(const py::handle& code_obj) {
|
||||||
|
if (!py::isinstance(code_obj, py::module::import("types").attr("CodeType"))) {
|
||||||
|
throw py::type_error("expected a code object!");
|
||||||
|
}
|
||||||
|
PyCodeObject* code = (PyCodeObject*)code_obj.ptr();
|
||||||
|
ExtraState* extra = get_extra_state(code);
|
||||||
|
py::list result;
|
||||||
|
if (extra) {
|
||||||
|
for (CacheEntry& e : extra->cache_entry_list) {
|
||||||
|
result.append(py::cast(e, py::return_value_policy::reference));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
145
torch/csrc/dynamo/extra_state.h
Normal file
145
torch/csrc/dynamo/extra_state.h
Normal file
@ -0,0 +1,145 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <Python.h>
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
#include <torch/csrc/dynamo/utils.h>
|
||||||
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
#include <list>
|
||||||
|
|
||||||
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
extern "C" {
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Flag to just run a frame normally
|
||||||
|
#define SKIP_CODE ((void*)0x1)
|
||||||
|
|
||||||
|
// Points to the extra scratch space on the code object
|
||||||
|
extern Py_ssize_t extra_index;
|
||||||
|
|
||||||
|
// function to call when cache lookup errors
|
||||||
|
extern PyObject* guard_error_hook;
|
||||||
|
|
||||||
|
typedef PyObject FrameState;
|
||||||
|
typedef struct CacheEntry CacheEntry;
|
||||||
|
|
||||||
|
// ExtraState encasulates CacheEntry and FrameState. ExtraState is the highest
|
||||||
|
// level of abstraction of what is stored on the extra code object. Previously,
|
||||||
|
// we saved different parts on different extra indexes. We prefer this way
|
||||||
|
// because of cleaner abstraction and faster SetExtra access.
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
typedef struct VISIBILITY_HIDDEN ExtraState {
|
||||||
|
// List of cache entries for compiled code objects
|
||||||
|
std::list<CacheEntry> cache_entry_list;
|
||||||
|
// Frame state to detect dynamic shape dims
|
||||||
|
py::dict frame_state;
|
||||||
|
|
||||||
|
CacheEntry* get_first_entry();
|
||||||
|
void move_to_front(CacheEntry* cache_entry);
|
||||||
|
} ExtraState;
|
||||||
|
|
||||||
|
#else
|
||||||
|
|
||||||
|
typedef struct ExtraState ExtraState;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// Helper to extra the cache_entry from the extra state.
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - extra_state: Borrowed
|
||||||
|
// return
|
||||||
|
// - CacheEntry: Borrowed.
|
||||||
|
CacheEntry* extract_cache_entry(ExtraState* extra_state);
|
||||||
|
|
||||||
|
// Returns either the previously stored frame state or an empty dict.
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - extra_state: Borrowed
|
||||||
|
// return
|
||||||
|
// - extra_state->frame_state: Borrowed.
|
||||||
|
FrameState* extract_frame_state(ExtraState* extra_state);
|
||||||
|
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - code: Borrowed
|
||||||
|
// return
|
||||||
|
// - extra_state: Borrowed.
|
||||||
|
ExtraState* get_extra_state(PyCodeObject* code);
|
||||||
|
|
||||||
|
// This is passed as freefunc to _PyEval_RequestCodeExtraIndex. This acts as a
|
||||||
|
// deleter for the object on extra scratch space. This function is called
|
||||||
|
// internally in _PyCode_SetExtra and also during the code deallocation.
|
||||||
|
|
||||||
|
// Destroys the extra state by deleting cache_entry, frame state and finally
|
||||||
|
// freeing the constructed extra state.
|
||||||
|
|
||||||
|
// Developer note - You should not call this function directly. This is called
|
||||||
|
// directly inside set_extra_state. If you are in a situation trying to call
|
||||||
|
// this function, consider if set_extra_state should be called.
|
||||||
|
void destroy_extra_state(void* obj);
|
||||||
|
|
||||||
|
// Clears the existing object sitting on the extra scratch spance and sets it
|
||||||
|
// up with the new state. Note that _PyCode_SetExtra calls the
|
||||||
|
// destroy_extra_state deleter internally, and therefore we don't call it
|
||||||
|
// explicity here.
|
||||||
|
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - extra_state: Stolen
|
||||||
|
// return
|
||||||
|
// - there is no return, but the extra_state is stolen, so it becomes
|
||||||
|
// set_extra_state responsibility to clean it up. It will be deleted during
|
||||||
|
// the reset_code/skip, when the set_extra_state is called with
|
||||||
|
// NULL/SKIP_CODE.
|
||||||
|
|
||||||
|
// Invariant - Dont set the extra state for the extra state that is already on
|
||||||
|
// the code object. Otherwise, we will first free up the old extra state
|
||||||
|
// (which is also the new extra state) and write something invalid on the
|
||||||
|
// scratch space.
|
||||||
|
void set_extra_state(PyCodeObject* code, ExtraState* extra_state);
|
||||||
|
|
||||||
|
// Creates a new extra state and put it on the extra scrach space of the code
|
||||||
|
// object.
|
||||||
|
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - code: Borrowed
|
||||||
|
// return:
|
||||||
|
// - extra_state: New reference.
|
||||||
|
// These references are then further passed to set_extra_state which becomes
|
||||||
|
// the final owner of these references.
|
||||||
|
ExtraState* init_and_set_extra_state(PyCodeObject* code);
|
||||||
|
|
||||||
|
// Lookup the cache held by extra_state.
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - extra_state: Borrowed
|
||||||
|
// - f_locals: Borrowed
|
||||||
|
// return:
|
||||||
|
// - Py_None or PyCodeObject: Borrowed reference.
|
||||||
|
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
|
||||||
|
|
||||||
|
// Create a new cache entry at extra_state holding on to guarded_code.
|
||||||
|
// Ownership contract
|
||||||
|
// args
|
||||||
|
// - extra_state: Borrowed
|
||||||
|
// - guarded_code: Borrowed
|
||||||
|
// return:
|
||||||
|
// - cache_entry: Borrowed reference
|
||||||
|
CacheEntry* create_cache_entry(ExtraState* extra_state, PyObject* guraded_code);
|
||||||
|
|
||||||
|
#ifdef __cplusplus
|
||||||
|
|
||||||
|
} // extern "C"
|
||||||
|
|
||||||
|
// Returns the list of CacheEntry corresponding to code_obj.
|
||||||
|
// Warning: returns references whose lifetimes are controlled by C++
|
||||||
|
py::list _debug_get_cache_entry_list(const py::handle& code_obj);
|
||||||
|
|
||||||
|
#endif
|
@ -1,7 +1,9 @@
|
|||||||
#include <torch/csrc/dynamo/init.h>
|
#include <torch/csrc/dynamo/init.h>
|
||||||
|
|
||||||
#include <torch/csrc/Exceptions.h>
|
#include <torch/csrc/Exceptions.h>
|
||||||
|
#include <torch/csrc/dynamo/cache_entry.h>
|
||||||
#include <torch/csrc/dynamo/eval_frame.h>
|
#include <torch/csrc/dynamo/eval_frame.h>
|
||||||
|
#include <torch/csrc/dynamo/extra_state.h>
|
||||||
#include <torch/csrc/dynamo/guards.h>
|
#include <torch/csrc/dynamo/guards.h>
|
||||||
#include <torch/csrc/dynamo/python_compiled_autograd.h>
|
#include <torch/csrc/dynamo/python_compiled_autograd.h>
|
||||||
|
|
||||||
@ -34,6 +36,15 @@ void initDynamoBindings(PyObject* torch) {
|
|||||||
PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) {
|
PyModule_AddObject(dynamo, "compiled_autograd", compiled_autograd) != 0) {
|
||||||
throw python_error();
|
throw python_error();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
auto m = py::handle(eval_frame).cast<py::module>();
|
||||||
|
|
||||||
|
py::class_<CacheEntry>(m, "_CacheEntry")
|
||||||
|
.def_readonly("check_fn", &CacheEntry::check_fn)
|
||||||
|
.def_readonly("code", &CacheEntry::code)
|
||||||
|
.def_property_readonly("next", &CacheEntry::next);
|
||||||
|
|
||||||
|
m.def("_debug_get_cache_entry_list", &_debug_get_cache_entry_list);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace dynamo
|
} // namespace dynamo
|
||||||
|
@ -2,7 +2,6 @@
|
|||||||
|
|
||||||
// C2039 MSVC
|
// C2039 MSVC
|
||||||
#include <pybind11/complex.h>
|
#include <pybind11/complex.h>
|
||||||
#include <pybind11/pybind11.h>
|
|
||||||
#include <torch/csrc/utils/pybind.h>
|
#include <torch/csrc/utils/pybind.h>
|
||||||
|
|
||||||
#include <Python.h>
|
#include <Python.h>
|
||||||
|
9
torch/csrc/dynamo/utils.h
Normal file
9
torch/csrc/dynamo/utils.h
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
#pragma once
|
||||||
|
|
||||||
|
// The visibility attribute is to avoid a warning about storing a field in the
|
||||||
|
// struct that has a different visibility (from pybind) than the struct.
|
||||||
|
#ifdef _WIN32
|
||||||
|
#define VISIBILITY_HIDDEN
|
||||||
|
#else
|
||||||
|
#define VISIBILITY_HIDDEN __attribute__((visibility("hidden")))
|
||||||
|
#endif
|
Reference in New Issue
Block a user