mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138896 Approved by: https://github.com/williamwen42, https://github.com/jansel ghstack dependencies: #138512
67 lines
2.1 KiB
C++
67 lines
2.1 KiB
C++
#include <torch/csrc/dynamo/cache_entry.h>
|
|
#include <torch/csrc/dynamo/guards.h>
|
|
|
|
#include <torch/csrc/dynamo/debug_macros.h>
|
|
#include <torch/csrc/dynamo/extra_state.h>
|
|
|
|
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
|
|
: backend{backend} {
|
|
this->guard_manager = guarded_code.attr("guard_manager");
|
|
this->code = guarded_code.attr("code");
|
|
this->compile_id = guarded_code.attr("compile_id");
|
|
py::object trace_annotation = guarded_code.attr("trace_annotation");
|
|
const char* trace_annotation_str = PyUnicode_AsUTF8(trace_annotation.ptr());
|
|
if (trace_annotation) {
|
|
this->trace_annotation = std::string(trace_annotation_str);
|
|
} else {
|
|
this->trace_annotation = "Unknown";
|
|
}
|
|
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
|
|
this->guard_manager.attr("root"));
|
|
}
|
|
|
|
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
|
|
"-Wdeprecated-copy-with-user-provided-dtor")
|
|
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor")
|
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
|
CacheEntry::~CacheEntry() {
|
|
// prevent guard_manager from use-after-free when invalidating
|
|
this->guard_manager.attr("cache_entry") = py::none();
|
|
this->guard_manager.attr("extra_state") = py::none();
|
|
}
|
|
C10_DIAGNOSTIC_POP()
|
|
C10_DIAGNOSTIC_POP()
|
|
|
|
py::object CacheEntry::next() {
|
|
NULL_CHECK(this->_owner);
|
|
auto it = this->_owner_loc;
|
|
++it;
|
|
if (it == this->_owner->cache_entry_list.end()) {
|
|
return py::none();
|
|
}
|
|
return py::cast(*it, py::return_value_policy::reference);
|
|
}
|
|
|
|
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
|
|
return (PyCodeObject*)e->code.ptr();
|
|
}
|
|
|
|
const char* CacheEntry_get_trace_annotation(CacheEntry* e) {
|
|
return e->trace_annotation.c_str();
|
|
}
|
|
|
|
PyObject* CacheEntry_to_obj(CacheEntry* e) {
|
|
if (!e) {
|
|
return py::none().release().ptr();
|
|
}
|
|
return py::cast(e, py::return_value_policy::reference).release().ptr();
|
|
}
|
|
|
|
PyObject* get_backend(PyObject* callback) {
|
|
py::handle handle = py::handle(callback);
|
|
while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
|
|
handle = handle.attr("_torchdynamo_orig_callable");
|
|
}
|
|
return handle.ptr();
|
|
}
|