[dynamo][guards] Move backend match to eval_frame (#121954)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121954
Approved by: https://github.com/jansel
This commit is contained in:
Animesh Jain
2024-03-16 19:41:19 -07:00
committed by PyTorch MergeBot
parent fc504d719f
commit c568b84794
11 changed files with 76 additions and 127 deletions

View File

@ -223,13 +223,6 @@ y = TensorVariable()
'obj_weakref': None
'guarded_class': None
}
global '' BACKEND_MATCH
{
'guard_types': None,
'code': None,
'obj_weakref': None
'guarded_class': None
}
shape_env '' SHAPE_ENV
{
'guard_types': None,

View File

@ -4827,10 +4827,6 @@ def fn():
opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
self.assertEqual(exp_out, opt_out)
self.assertEqual(cnt.frame_count, exp_frame_count)
self.assertEqual(
len(torch._dynamo.eval_frame.cached_backends),
exp_n_cached_backend,
)
def test_backend_match_guard(self):
x = torch.randn([3, 4])
@ -4912,12 +4908,6 @@ def fn():
for thread in threads:
thread.join()
# Threads are sharing the backend cache. We see two cnt backends and one None backend
self.assertEqual(
len(torch._dynamo.eval_frame.cached_backends),
3,
)
self.assertEqual(len(thread_success), len(threads))
def test_dynamo_min_operator_with_shape(self):

View File

@ -1685,6 +1685,9 @@ class _TorchCompileInductorWrapper:
self.apply_mode(mode)
self.apply_options(options)
# Stash the compiler_fn to be used for backend match guard.
from torch._inductor.compile_fx import compile_fx
self.compiler_fn = compile_fx
if self.config.get("triton.cudagraphs", False):
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
# FIXME: CUDA Graph does not work well with CUPTI teardown.

View File

@ -16,7 +16,6 @@ import logging
import os
import sys
import textwrap
import threading
import traceback
import types
import warnings
@ -77,82 +76,19 @@ class Unset(Enum):
token = 0
unset = Unset.token
guarded_backend_cache = threading.local()
cached_backends: Dict[int, CompilerFn] = {}
def check_current_backend(backend_obj_id: int):
"""
Called from guards to check if we need to recompile due to a backend change
"""
# TODO(jansel): we should move guarded_backend_cache to C++
try:
if guarded_backend_cache.skip_backend_check_for_run_only_mode:
return True
except AttributeError:
# Go slightly faster next time
guarded_backend_cache.skip_backend_check_for_run_only_mode = False
try:
current_backend = guarded_backend_cache.current_backend
except AttributeError:
current_backend = None
return (
# Avoid the dict lookup in case of exact same object
id(current_backend) == backend_obj_id
or current_backend == cached_backends.get(backend_obj_id, None)
)
unset = Unset.token
def _reset_guarded_backend_cache():
global cached_backends
guarded_backend_cache.skip_backend_check_for_run_only_mode = False
guarded_backend_cache.current_backend = None
for backend in cached_backends.values():
if hasattr(backend, "reset"):
backend.reset()
cached_backends.clear()
def backend_cache_manager(callback: DynamoCallback):
# callback is False for RunOnlyContext. RunOnlyContext is used
# as a way to re-use the previous compiled cache.
# We therefore skip the check and re-use whatever code that's already cached.
# Note: the cache that's actually used depends on the caching policy.
if callback is False:
def change():
try:
prev_skip = guarded_backend_cache.skip_backend_check_for_run_only_mode
except AttributeError:
prev_skip = False
guarded_backend_cache.skip_backend_check_for_run_only_mode = True
def revert():
guarded_backend_cache.skip_backend_check_for_run_only_mode = prev_skip
return revert
else:
backend = innermost_fn(callback)
def change():
cached_backends.setdefault(id(backend), backend)
try:
prev_backend = guarded_backend_cache.current_backend
except AttributeError:
prev_backend = None
guarded_backend_cache.current_backend = backend
def revert():
guarded_backend_cache.current_backend = prev_backend
return revert
return change
DONT_WRAP_FILES = {
# For tracing into fx modules
inspect.getsourcefile(GraphModule),
@ -306,9 +242,13 @@ class _TorchDynamoContext:
self.export = export
self.compiler_config = compiler_config
self.cleanup_fns: List[Callable[[], Any]] = []
self.enter_exit_hooks = [backend_cache_manager(self.callback)]
self.enter_exit_hooks = []
patch_fn()
# Save the backends so that we can reset them during torch._dynamo.reset
backend = innermost_fn(callback)
cached_backends.setdefault(id(backend), backend)
if dynamic is not None:
self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
@ -672,6 +612,9 @@ def optimize(
dynamic=dynamic,
hooks=hooks,
)
# The backend function is stashed in the callable returned by
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
# be used by eval_frame.c to insert a guard on the backend.
return _optimize_catch_errors(
convert_frame.convert_frame(backend, hooks=hooks),
hooks,

View File

@ -647,15 +647,6 @@ class GuardBuilder(GuardBuilderBase):
guard, [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"]
)
def BACKEND_MATCH(self, guard: Guard):
"""Guard on backend matching based on id of current_backend"""
assert guard.source is GuardSource.GLOBAL
backend_id = (
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
)
code = [f"___check_current_backend({backend_id})"]
self._produce_guard_code(guard, code)
def SHAPE_ENV(self, guard: Guard):
# Let's handle ShapeEnv guards. To do this, we will resolve
# shape variables to sources from tracked_fakes. This must happen after
@ -1203,7 +1194,6 @@ class CheckFunctionManager:
"___check_tensors": check_tensors_fn,
"___check_tensors_verbose": check_tensors_verbose_fn,
"___check_global_state": global_state.check,
"___check_current_backend": torch._dynamo.eval_frame.check_current_backend,
"tensor_check_names": tensor_check_names,
**SYMPY_INTERP,
**CLOSURE_VARS,

View File

@ -471,8 +471,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
)
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
def synthetic_graph_input(self, fn, args):
"""
call fn(*args) before the graph runs and turn the result into a fake input.

View File

@ -4,9 +4,10 @@
#include <torch/csrc/dynamo/debug_macros.h>
#include <torch/csrc/dynamo/extra_state.h>
CacheEntry::CacheEntry(const py::handle& guarded_code) {
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) {
this->check_fn = guarded_code.attr("check_fn");
this->code = guarded_code.attr("code");
this->backend = backend;
// TODO - clean this up when enable_cpp_guard_manager is True by default
if (py::hasattr(this->check_fn, "root")) {
this->root_mgr = convert_to_root_guard_manager(this->check_fn.attr("root"));
@ -39,3 +40,14 @@ PyObject* CacheEntry_to_obj(CacheEntry* e) {
}
return py::cast(e, py::return_value_policy::reference).release().ptr();
}
PyObject* get_backend(PyObject* callback) {
py::handle handle = py::handle(callback);
while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
handle = handle.attr("_torchdynamo_orig_callable");
}
if (py::hasattr(handle, "compiler_fn")) {
handle = handle.attr("compiler_fn");
}
return handle.ptr();
}

View File

@ -45,12 +45,14 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
py::object code;
// root guard manager if exists
void* root_mgr{nullptr};
// backend used to create this cache entry
PyObject* backend{nullptr};
// Reference to owning ExtraState
ExtraState* _owner{nullptr};
// Reference to this CacheEntry's location in owner's linked list
std::list<CacheEntry>::iterator _owner_loc;
CacheEntry(const py::handle& guarded_code);
CacheEntry(const py::handle& guarded_code, PyObject* backend);
~CacheEntry();
// Warning: returns a reference whose lifetime is controlled by C++

View File

@ -530,12 +530,14 @@ static PyObject* _custom_eval_frame(
return NULL;
}
PyObject* backend = get_backend(callback);
// A callback of Py_False indicates "run only" mode, the cache is checked, but
// we never compile.
if (callback == Py_False) {
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
@ -560,7 +562,7 @@ static PyObject* _custom_eval_frame(
eval_frame_callback_set(Py_None);
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
_pytorch_record_function_exit(rf);
if (maybe_cached_code == NULL) {
// Python error
@ -594,7 +596,7 @@ static PyObject* _custom_eval_frame(
// extract_cache_entry returns a borrowed reference. Modifying a borrowed
// reference seems wrong. Therefore, we directly access the
// extra->cache_entry. extra wont be NULL here.
CacheEntry* new_cache_entry = create_cache_entry(extra, result);
CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
Py_DECREF(result);
// Update the existing cache_entry on the extra object. This extra object is
// sitting on the extra scratch space, we are just changing the cache_entry

View File

@ -82,34 +82,40 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) {
return extra_state;
}
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
PyObject* lookup(
ExtraState* extra_state,
PyObject* f_locals,
PyObject* backend) {
size_t index = 0;
CacheEntry* found = nullptr;
py::handle locals(f_locals);
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
bool valid = false;
try {
// TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is True
// by default
if (cache_entry.root_mgr != nullptr) {
valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
} else {
valid = cache_entry.check_fn(locals).cast<bool>();
// Check backend. Py_False means run only mode.
bool valid = backend == Py_False || cache_entry.backend == backend;
if (valid) {
try {
// TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
// True by default
if (cache_entry.root_mgr != nullptr) {
valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
} else {
valid = cache_entry.check_fn(locals).cast<bool>();
}
} catch (py::error_already_set& e) {
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.check_fn,
cache_entry.code,
locals,
index,
index == extra_state->cache_entry_list.size() - 1);
}
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
return NULL;
}
} catch (py::error_already_set& e) {
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.check_fn,
cache_entry.code,
locals,
index,
index == extra_state->cache_entry_list.size() - 1);
}
// this function is called from C, so we cannot repropagate
// the exception
e.restore();
return NULL;
}
if (valid) {
found = &cache_entry;
@ -126,8 +132,9 @@ PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guarded_code) {
extra_state->cache_entry_list.emplace_front(guarded_code);
PyObject* guarded_code,
PyObject* backend) {
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;

View File

@ -124,7 +124,10 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code);
// - f_locals: Borrowed
// return:
// - Py_None or PyCodeObject: Borrowed reference.
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
PyObject* lookup(
ExtraState* extra_state,
PyObject* f_locals,
PyObject* callback);
// Create a new cache entry at extra_state holding on to guarded_code.
// Ownership contract
@ -133,7 +136,13 @@ PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
// - guarded_code: Borrowed
// return:
// - cache_entry: Borrowed reference
CacheEntry* create_cache_entry(ExtraState* extra_state, PyObject* guraded_code);
CacheEntry* create_cache_entry(
ExtraState* extra_state,
PyObject* guraded_code,
PyObject* callback);
// Extracts the backend fn from the callback.
PyObject* get_backend(PyObject* callback);
#ifdef __cplusplus