mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is a remix of https://github.com/pytorch/pytorch/pull/155558 Instead of mediating guard collective via a config option, in this one it's done via a `set_stance` like API. The motivation is that checking for the config value on entry on torch.compile is apparently quite expensive, according to functorch_maml_omniglot. So this makes it a bit cheaper. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/156562 Approved by: https://github.com/Microve
338 lines
12 KiB
C++
338 lines
12 KiB
C++
#include <torch/csrc/dynamo/cache_entry.h>
|
|
#include <torch/csrc/dynamo/cpp_shim.h>
|
|
#include <torch/csrc/dynamo/cpython_includes.h>
|
|
#include <torch/csrc/dynamo/debug_macros.h>
|
|
#include <torch/csrc/dynamo/eval_frame.h>
|
|
#include <torch/csrc/dynamo/eval_frame_cpp.h>
|
|
#include <torch/csrc/dynamo/framelocals_mapping.h>
|
|
#include <torch/csrc/utils/python_compat.h>
|
|
|
|
extern "C" {
|
|
extern PyObject* guard_complete_hook;
|
|
}
|
|
|
|
static constexpr const char* cache_lookup_profiler_str =
|
|
"TorchDynamo Cache Lookup";
|
|
|
|
// Remember to update the type signature for DynamoCallbackFn.__call__ in
|
|
// torch/_dynamo/types.py if this function's signature changes.
|
|
static py::object dynamo_call_callback(
|
|
py::handle callback,
|
|
THP_EVAL_API_FRAME_OBJECT* _frame,
|
|
FrameLocalsMapping* locals,
|
|
CacheEntry* cache_entry,
|
|
FrameState* frame_state) {
|
|
THPPyInterpreterFrame* frame = THPPyInterpreterFrame_New(_frame);
|
|
if (frame == nullptr) {
|
|
throw std::runtime_error(
|
|
"Dynamo failed to initialize CPython interpreter frame wrapper");
|
|
}
|
|
frame->locals = (PyObject*)framelocals_mapping_to_dict(locals);
|
|
|
|
py::object cache_entry_obj = py::none();
|
|
if (cache_entry) {
|
|
cache_entry_obj = py::cast(cache_entry, py::return_value_policy::reference);
|
|
}
|
|
|
|
py::object result = callback(
|
|
py::handle((PyObject*)frame), cache_entry_obj, py::handle(frame_state));
|
|
Py_DECREF(frame);
|
|
return result;
|
|
}
|
|
|
|
static py::handle _callback_from_action(
|
|
py::handle callback,
|
|
FrameAction action) {
|
|
if (action == SKIP) {
|
|
return Py_None;
|
|
} else if (action == RUN_ONLY) {
|
|
return Py_False;
|
|
}
|
|
return callback;
|
|
}
|
|
|
|
// frame and callback are borrowed references.
|
|
// Returns new reference.
|
|
PyObject* dynamo__custom_eval_frame(
|
|
PyThreadState* tstate,
|
|
THP_EVAL_API_FRAME_OBJECT* frame,
|
|
int throw_flag,
|
|
PyObject* callback_py) {
|
|
#if IS_PYTHON_3_11_PLUS
|
|
DEBUG_TRACE(
|
|
"begin %s %s %i %i",
|
|
get_frame_name(frame),
|
|
PyUnicode_AsUTF8(F_CODE(frame)->co_filename),
|
|
F_CODE(frame)->co_firstlineno,
|
|
_PyInterpreterFrame_LASTI(frame));
|
|
#else
|
|
DEBUG_TRACE(
|
|
"begin %s %s %i %i %i",
|
|
get_frame_name(frame),
|
|
PyUnicode_AsUTF8(F_CODE(frame)->co_filename),
|
|
frame->f_lineno,
|
|
frame->f_lasti,
|
|
frame->f_iblock);
|
|
#endif
|
|
|
|
if (throw_flag) {
|
|
// When unwinding generators, eval frame is called with throw_flag ==
|
|
// true. Frame evaluation is supposed to continue unwinding by propagating
|
|
// the exception. Dynamo doesn't really know how to do this, nor does it
|
|
// really want to do this, because there's unlikely any code to capture
|
|
// (you're going to immediately quit out of the frame, perhaps running
|
|
// some unwinding logic along the way). So we just run the default
|
|
// handler in this case.
|
|
//
|
|
// NB: A previous version of this patch returned NULL. This is wrong,
|
|
// because returning NULL is *different* from unwinding an exception.
|
|
// In particular, you will not execute things like context manager
|
|
// __exit__ if you just return NULL.
|
|
//
|
|
// NB: It's /conceivable/ that you might want to actually still call the
|
|
// Dynamo callback when throw_flag == TRUE, to give Dynamo a chance to
|
|
// do any stack unwinding code. But this is not really useful because
|
|
// (1) Dynamo doesn't actually know how to do stack unwinding, so it would
|
|
// immediately skip the frame, and (2) even if it did, this would only
|
|
// be profitable if there was tensor code in the unwinding code. Seems
|
|
// unlikely.
|
|
DEBUG_TRACE("throw %s", get_frame_name(frame));
|
|
return dynamo_eval_frame_default(tstate, frame, throw_flag);
|
|
}
|
|
|
|
py::handle callback(callback_py);
|
|
|
|
// callback to run on recursively invoked frames
|
|
py::handle recursive_callback = callback; // borrowed
|
|
PyCodeObject* cached_code = nullptr; // borrowed
|
|
const char* trace_annotation = "";
|
|
PyObject* eval_result = nullptr; // strong reference
|
|
|
|
// exit functions
|
|
auto eval_default = [&]() {
|
|
eval_frame_callback_set(recursive_callback.ptr());
|
|
eval_result = dynamo_eval_frame_default(tstate, frame, throw_flag);
|
|
if (!callback.is(recursive_callback)) {
|
|
// NB: Only set the callback if it's different than the recursive
|
|
// callback! Setting the callback is dangerous in the case that `frame`
|
|
// also sets the eval frame callback. This happens in some functions in
|
|
// eval_frame.py. These functions should be skipped with DEFAULT recursive
|
|
// action, so we won't accidentally overwrite the callback.
|
|
eval_frame_callback_set(callback.ptr());
|
|
}
|
|
};
|
|
|
|
// NOTE: In 3.12+, the frame evaluation function (callee) is responsible for
|
|
// clearing/popping the frame, meaning that unless we default evaluate the
|
|
// original frame, we are responsible for clearing it - via
|
|
// clear_old_frame_if_python_312_plus.
|
|
auto eval_custom = [&]() {
|
|
eval_frame_callback_set(recursive_callback.ptr());
|
|
DEBUG_NULL_CHECK(cached_code);
|
|
eval_result = dynamo_eval_custom_code(
|
|
tstate, frame, cached_code, trace_annotation, throw_flag);
|
|
if (!callback.is(recursive_callback)) {
|
|
eval_frame_callback_set(callback.ptr());
|
|
}
|
|
clear_old_frame_if_python_312_plus(tstate, frame);
|
|
};
|
|
|
|
auto fail = [&]() { clear_old_frame_if_python_312_plus(tstate, frame); };
|
|
|
|
ExtraState* extra = get_extra_state(F_CODE(frame));
|
|
|
|
if (callback.is(py::bool_(false)) && extra == nullptr) {
|
|
DEBUG_TRACE("skip (run only with empty cache) %s", get_frame_name(frame));
|
|
eval_default();
|
|
return eval_result;
|
|
}
|
|
|
|
// create cache
|
|
if (extra == nullptr) {
|
|
extra = init_and_set_extra_state(F_CODE(frame));
|
|
}
|
|
|
|
// Get recursive action
|
|
FrameExecStrategy strategy = extra_state_get_exec_strategy(extra);
|
|
recursive_callback =
|
|
_callback_from_action(recursive_callback, strategy.recursive_action);
|
|
|
|
// Skip this frame
|
|
if (strategy.cur_action == SKIP) {
|
|
DEBUG_TRACE("skip %s", get_frame_name(frame));
|
|
eval_default();
|
|
return eval_result;
|
|
}
|
|
|
|
// default and run-only mode require guard eval
|
|
std::unique_ptr<FrameLocalsMapping> locals =
|
|
std::make_unique<FrameLocalsMapping>(frame);
|
|
PyObject* backend = get_backend(callback.ptr()); // borrowed
|
|
|
|
// We don't run the current custom_eval_frame behavior for guards.
|
|
// So we temporarily set the callback to Py_None to drive the correct behavior
|
|
// in the shim.
|
|
eval_frame_callback_set(Py_None);
|
|
|
|
DEBUG_CHECK(PyDict_CheckExact(frame->f_globals));
|
|
DEBUG_CHECK(PyDict_CheckExact(frame->f_builtins));
|
|
|
|
_PytorchRecordFunctionState* rf =
|
|
_pytorch_record_function_enter(cache_lookup_profiler_str);
|
|
PyObject* maybe_cached_code = nullptr;
|
|
lookup(
|
|
extra,
|
|
locals.get(),
|
|
backend,
|
|
&maybe_cached_code,
|
|
&trace_annotation,
|
|
is_skip_guard_eval_unsafe);
|
|
_pytorch_record_function_exit(rf);
|
|
|
|
// A callback of Py_False indicates "run only" mode, the cache is checked,
|
|
// but we never compile.
|
|
bool run_only =
|
|
strategy.cur_action == RUN_ONLY || callback.is(py::bool_(false));
|
|
if (run_only) {
|
|
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
|
|
}
|
|
|
|
if (maybe_cached_code == nullptr) {
|
|
// guard eval failed, keep propagating
|
|
fail();
|
|
return eval_result;
|
|
}
|
|
|
|
// NB: We only do guard collectives when there are any compiled code entries
|
|
// at all; these reduces overtriggering and we don't need to do guard
|
|
// collectives the very first time we've seen a frame
|
|
// TODO: We could also check if we had just created extra for the first
|
|
// time? Not too sure the best condition for extra->cache_entry_list
|
|
if (guard_complete_hook != nullptr && !extra->cache_entry_list.empty()) {
|
|
py::handle guard_complete_hook_handle(guard_complete_hook);
|
|
// False means force compilation (someone cache missed)
|
|
py::object res = guard_complete_hook_handle(maybe_cached_code != Py_None);
|
|
if (!py::cast<bool>(res)) {
|
|
maybe_cached_code = Py_None; // NB: non-owning
|
|
}
|
|
}
|
|
|
|
if (maybe_cached_code != Py_None) {
|
|
cached_code = (PyCodeObject*)maybe_cached_code;
|
|
// used cached version
|
|
DEBUG_TRACE("cache hit %s", get_frame_name(frame));
|
|
eval_custom();
|
|
return eval_result;
|
|
}
|
|
|
|
// cache miss
|
|
DEBUG_TRACE("cache miss %s", get_frame_name(frame));
|
|
if (is_skip_guard_eval_unsafe) {
|
|
PyErr_SetString(
|
|
PyExc_RuntimeError,
|
|
"Recompilation triggered with skip_guard_eval_unsafe stance. "
|
|
"This usually means that you have not warmed up your model "
|
|
"with enough inputs such that you can guarantee no more recompilations.");
|
|
fail();
|
|
return eval_result;
|
|
}
|
|
|
|
if (run_only) {
|
|
eval_default();
|
|
return eval_result;
|
|
}
|
|
|
|
// call callback
|
|
CacheEntry* cache_entry = extract_cache_entry(extra);
|
|
FrameState* frame_state = extract_frame_state(extra);
|
|
py::object callback_result;
|
|
FrameExecStrategy new_strategy;
|
|
bool apply_to_code = false;
|
|
PyObject* guarded_code = nullptr;
|
|
try {
|
|
callback_result = dynamo_call_callback(
|
|
callback, frame, locals.get(), cache_entry, frame_state);
|
|
new_strategy =
|
|
callback_result.attr("frame_exec_strategy").cast<FrameExecStrategy>();
|
|
apply_to_code = callback_result.attr("apply_to_code").cast<bool>();
|
|
guarded_code = callback_result.attr("guarded_code").ptr();
|
|
} catch (py::error_already_set& e) {
|
|
// internal exception, returning here will leak the exception into user
|
|
// code this is useful for debugging -- but we dont want it to happen
|
|
// outside of testing NB: we intentionally DO NOT re-enable custom
|
|
// behavior to prevent cascading failure from internal exceptions. The
|
|
// upshot is if Dynamo barfs, that's it for Dynamo, even if you catch the
|
|
// exception inside the torch.compile block we won't try to Dynamo
|
|
// anything else.
|
|
fail();
|
|
e.restore();
|
|
return eval_result;
|
|
}
|
|
|
|
// recursive frame action
|
|
if (strategy.recursive_action == DEFAULT) {
|
|
// old recursive action overrides new recursive action
|
|
recursive_callback = _callback_from_action(
|
|
recursive_callback, new_strategy.recursive_action);
|
|
}
|
|
|
|
// possibly apply frame strategy to future frames with same code object
|
|
if (apply_to_code) {
|
|
if (new_strategy.cur_action != DEFAULT) {
|
|
DEBUG_TRACE("create action: %d\n", new_strategy.cur_action);
|
|
}
|
|
if (new_strategy.recursive_action != DEFAULT) {
|
|
DEBUG_TRACE(
|
|
"create recursive action: %d\n", new_strategy.recursive_action);
|
|
}
|
|
extra_state_set_exec_strategy(extra, new_strategy);
|
|
}
|
|
|
|
if (guarded_code != Py_None) {
|
|
DEBUG_TRACE("create cache %s", get_frame_name(frame));
|
|
|
|
// NB: We could use extract_cache_entry to get the cache_entry, but
|
|
// extract_cache_entry returns a borrowed reference. Modifying a borrowed
|
|
// reference seems wrong. Therefore, we directly access the
|
|
// extra->cache_entry. extra won't be NULL here.
|
|
CacheEntry* new_cache_entry =
|
|
create_cache_entry(extra, guarded_code, backend);
|
|
|
|
// Update the existing cache_entry on the extra object. This extra object
|
|
// is sitting on the extra scratch space, we are just changing the
|
|
// cache_entry ptr. As a result, extra now becomes the owner of CacheEntry
|
|
// object. This will be cleaned up when set_extra_state is called.
|
|
// Re-enable custom behavior
|
|
cached_code = CacheEntry_get_code(new_cache_entry),
|
|
trace_annotation = CacheEntry_get_trace_annotation(new_cache_entry);
|
|
eval_custom();
|
|
} else {
|
|
eval_default();
|
|
}
|
|
return eval_result;
|
|
}
|
|
|
|
PyObject* set_code_exec_strategy(PyObject* dummy, PyObject* args) {
|
|
PyObject* code_obj = nullptr;
|
|
PyObject* strategy_obj = nullptr;
|
|
if (!PyArg_ParseTuple(args, "OO", &code_obj, &strategy_obj)) {
|
|
return nullptr;
|
|
}
|
|
if (!PyCode_Check(code_obj)) {
|
|
PyErr_SetString(PyExc_TypeError, "expected a code object");
|
|
return nullptr;
|
|
}
|
|
|
|
PyCodeObject* code = (PyCodeObject*)code_obj;
|
|
ExtraState* extra = get_extra_state(code);
|
|
if (extra == nullptr) {
|
|
extra = init_and_set_extra_state(code);
|
|
}
|
|
|
|
FrameExecStrategy strategy =
|
|
py::handle(strategy_obj).cast<FrameExecStrategy>();
|
|
|
|
extra_state_set_exec_strategy(extra, strategy);
|
|
Py_RETURN_NONE;
|
|
}
|