mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
---- - We now record on CacheEntry what the compile id that populated it was, so now we can say why a specific frame was rejected - Add structured log for recompiles under name artifact "recompile_reasons". As it stands, it's not terribly structured, but this was the easiest thing I could do to start - Slightly reformat multi-reason printing; since we only report one guard failure seems better to have it as a single line Example output: ``` V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] Recompiling function f in /data/users/ezyang/a/pytorch/b.py:3 V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] triggered by the following guard failure(s): V0703 10:34:13.273000 140345997743104 torch/_dynamo/guards.py:2590] [0/1] [__recompiles] - 0/0: tensor 'L['x']' size mismatch at index 0. expected 4, actual 5 ``` Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/130043 Approved by: https://github.com/anijain2305
56 lines
1.6 KiB
C++
56 lines
1.6 KiB
C++
#include <torch/csrc/dynamo/cache_entry.h>
|
|
#include <torch/csrc/dynamo/guards.h>
|
|
|
|
#include <torch/csrc/dynamo/debug_macros.h>
|
|
#include <torch/csrc/dynamo/extra_state.h>
|
|
|
|
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) {
|
|
this->check_fn = guarded_code.attr("check_fn");
|
|
this->code = guarded_code.attr("code");
|
|
this->compile_id = guarded_code.attr("compile_id");
|
|
this->backend = backend;
|
|
// TODO - clean this up when enable_cpp_guard_manager is True by default
|
|
if (py::hasattr(this->check_fn, "root")) {
|
|
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
|
|
this->check_fn.attr("root"));
|
|
}
|
|
}
|
|
|
|
CacheEntry::~CacheEntry() {
|
|
// prevent check_fn from use-after-free when invalidating
|
|
this->check_fn.attr("cache_entry") = py::none();
|
|
this->check_fn.attr("extra_state") = py::none();
|
|
}
|
|
|
|
py::object CacheEntry::next() {
|
|
NULL_CHECK(this->_owner);
|
|
auto it = this->_owner_loc;
|
|
++it;
|
|
if (it == this->_owner->cache_entry_list.end()) {
|
|
return py::none();
|
|
}
|
|
return py::cast(*it, py::return_value_policy::reference);
|
|
}
|
|
|
|
PyCodeObject* CacheEntry_get_code(CacheEntry* e) {
|
|
return (PyCodeObject*)e->code.ptr();
|
|
}
|
|
|
|
PyObject* CacheEntry_to_obj(CacheEntry* e) {
|
|
if (!e) {
|
|
return py::none().release().ptr();
|
|
}
|
|
return py::cast(e, py::return_value_policy::reference).release().ptr();
|
|
}
|
|
|
|
PyObject* get_backend(PyObject* callback) {
|
|
py::handle handle = py::handle(callback);
|
|
while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
|
|
handle = handle.attr("_torchdynamo_orig_callable");
|
|
}
|
|
if (py::hasattr(handle, "compiler_fn")) {
|
|
handle = handle.attr("compiler_fn");
|
|
}
|
|
return handle.ptr();
|
|
}
|