mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][guards] Move backend match to eval_frame (#121954)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/121954 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
fc504d719f
commit
c568b84794
@ -223,13 +223,6 @@ y = TensorVariable()
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
global '' BACKEND_MATCH
|
||||
{
|
||||
'guard_types': None,
|
||||
'code': None,
|
||||
'obj_weakref': None
|
||||
'guarded_class': None
|
||||
}
|
||||
shape_env '' SHAPE_ENV
|
||||
{
|
||||
'guard_types': None,
|
||||
|
@ -4827,10 +4827,6 @@ def fn():
|
||||
opt_out = torch._dynamo.optimize(backend=cnt)(foo)(*args)
|
||||
self.assertEqual(exp_out, opt_out)
|
||||
self.assertEqual(cnt.frame_count, exp_frame_count)
|
||||
self.assertEqual(
|
||||
len(torch._dynamo.eval_frame.cached_backends),
|
||||
exp_n_cached_backend,
|
||||
)
|
||||
|
||||
def test_backend_match_guard(self):
|
||||
x = torch.randn([3, 4])
|
||||
@ -4912,12 +4908,6 @@ def fn():
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
# Threads are sharing the backend cache. We see two cnt backends and one None backend
|
||||
self.assertEqual(
|
||||
len(torch._dynamo.eval_frame.cached_backends),
|
||||
3,
|
||||
)
|
||||
|
||||
self.assertEqual(len(thread_success), len(threads))
|
||||
|
||||
def test_dynamo_min_operator_with_shape(self):
|
||||
|
@ -1685,6 +1685,9 @@ class _TorchCompileInductorWrapper:
|
||||
self.apply_mode(mode)
|
||||
self.apply_options(options)
|
||||
|
||||
# Stash the compiler_fn to be used for backend match guard.
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
self.compiler_fn = compile_fx
|
||||
if self.config.get("triton.cudagraphs", False):
|
||||
os.environ["DISABLE_CUPTI_LAZY_REINIT"] = "1"
|
||||
# FIXME: CUDA Graph does not work well with CUPTI teardown.
|
||||
|
@ -16,7 +16,6 @@ import logging
|
||||
import os
|
||||
import sys
|
||||
import textwrap
|
||||
import threading
|
||||
import traceback
|
||||
import types
|
||||
import warnings
|
||||
@ -77,82 +76,19 @@ class Unset(Enum):
|
||||
token = 0
|
||||
|
||||
|
||||
unset = Unset.token
|
||||
|
||||
guarded_backend_cache = threading.local()
|
||||
cached_backends: Dict[int, CompilerFn] = {}
|
||||
|
||||
|
||||
def check_current_backend(backend_obj_id: int):
|
||||
"""
|
||||
Called from guards to check if we need to recompile due to a backend change
|
||||
"""
|
||||
# TODO(jansel): we should move guarded_backend_cache to C++
|
||||
try:
|
||||
if guarded_backend_cache.skip_backend_check_for_run_only_mode:
|
||||
return True
|
||||
except AttributeError:
|
||||
# Go slightly faster next time
|
||||
guarded_backend_cache.skip_backend_check_for_run_only_mode = False
|
||||
try:
|
||||
current_backend = guarded_backend_cache.current_backend
|
||||
except AttributeError:
|
||||
current_backend = None
|
||||
return (
|
||||
# Avoid the dict lookup in case of exact same object
|
||||
id(current_backend) == backend_obj_id
|
||||
or current_backend == cached_backends.get(backend_obj_id, None)
|
||||
)
|
||||
unset = Unset.token
|
||||
|
||||
|
||||
def _reset_guarded_backend_cache():
|
||||
global cached_backends
|
||||
guarded_backend_cache.skip_backend_check_for_run_only_mode = False
|
||||
guarded_backend_cache.current_backend = None
|
||||
for backend in cached_backends.values():
|
||||
if hasattr(backend, "reset"):
|
||||
backend.reset()
|
||||
cached_backends.clear()
|
||||
|
||||
|
||||
def backend_cache_manager(callback: DynamoCallback):
|
||||
# callback is False for RunOnlyContext. RunOnlyContext is used
|
||||
# as a way to re-use the previous compiled cache.
|
||||
# We therefore skip the check and re-use whatever code that's already cached.
|
||||
# Note: the cache that's actually used depends on the caching policy.
|
||||
if callback is False:
|
||||
|
||||
def change():
|
||||
try:
|
||||
prev_skip = guarded_backend_cache.skip_backend_check_for_run_only_mode
|
||||
except AttributeError:
|
||||
prev_skip = False
|
||||
guarded_backend_cache.skip_backend_check_for_run_only_mode = True
|
||||
|
||||
def revert():
|
||||
guarded_backend_cache.skip_backend_check_for_run_only_mode = prev_skip
|
||||
|
||||
return revert
|
||||
|
||||
else:
|
||||
backend = innermost_fn(callback)
|
||||
|
||||
def change():
|
||||
cached_backends.setdefault(id(backend), backend)
|
||||
try:
|
||||
prev_backend = guarded_backend_cache.current_backend
|
||||
except AttributeError:
|
||||
prev_backend = None
|
||||
guarded_backend_cache.current_backend = backend
|
||||
|
||||
def revert():
|
||||
guarded_backend_cache.current_backend = prev_backend
|
||||
|
||||
return revert
|
||||
|
||||
return change
|
||||
|
||||
|
||||
DONT_WRAP_FILES = {
|
||||
# For tracing into fx modules
|
||||
inspect.getsourcefile(GraphModule),
|
||||
@ -306,9 +242,13 @@ class _TorchDynamoContext:
|
||||
self.export = export
|
||||
self.compiler_config = compiler_config
|
||||
self.cleanup_fns: List[Callable[[], Any]] = []
|
||||
self.enter_exit_hooks = [backend_cache_manager(self.callback)]
|
||||
self.enter_exit_hooks = []
|
||||
patch_fn()
|
||||
|
||||
# Save the backends so that we can reset them during torch._dynamo.reset
|
||||
backend = innermost_fn(callback)
|
||||
cached_backends.setdefault(id(backend), backend)
|
||||
|
||||
if dynamic is not None:
|
||||
self.enter_exit_hooks.append(make_set_enable_dynamic(dynamic))
|
||||
|
||||
@ -672,6 +612,9 @@ def optimize(
|
||||
dynamic=dynamic,
|
||||
hooks=hooks,
|
||||
)
|
||||
# The backend function is stashed in the callable returned by
|
||||
# _optimize_catch_errors in the field _torchdynamo_orig_callable. This can
|
||||
# be used by eval_frame.c to insert a guard on the backend.
|
||||
return _optimize_catch_errors(
|
||||
convert_frame.convert_frame(backend, hooks=hooks),
|
||||
hooks,
|
||||
|
@ -647,15 +647,6 @@ class GuardBuilder(GuardBuilderBase):
|
||||
guard, [f"utils_device.CURRENT_DEVICE == {m.CURRENT_DEVICE!r}"]
|
||||
)
|
||||
|
||||
def BACKEND_MATCH(self, guard: Guard):
|
||||
"""Guard on backend matching based on id of current_backend"""
|
||||
assert guard.source is GuardSource.GLOBAL
|
||||
backend_id = (
|
||||
f"{id(torch._dynamo.eval_frame.guarded_backend_cache.current_backend)}"
|
||||
)
|
||||
code = [f"___check_current_backend({backend_id})"]
|
||||
self._produce_guard_code(guard, code)
|
||||
|
||||
def SHAPE_ENV(self, guard: Guard):
|
||||
# Let's handle ShapeEnv guards. To do this, we will resolve
|
||||
# shape variables to sources from tracked_fakes. This must happen after
|
||||
@ -1203,7 +1194,6 @@ class CheckFunctionManager:
|
||||
"___check_tensors": check_tensors_fn,
|
||||
"___check_tensors_verbose": check_tensors_verbose_fn,
|
||||
"___check_global_state": global_state.check,
|
||||
"___check_current_backend": torch._dynamo.eval_frame.check_current_backend,
|
||||
"tensor_check_names": tensor_check_names,
|
||||
**SYMPY_INTERP,
|
||||
**CLOSURE_VARS,
|
||||
|
@ -471,8 +471,6 @@ class OutputGraph(Checkpointable[OutputGraphState]):
|
||||
GlobalStateSource().make_guard(GuardBuilder.TORCH_FUNCTION_STATE)
|
||||
)
|
||||
|
||||
self.guards.add(GlobalStateSource().make_guard(GuardBuilder.BACKEND_MATCH))
|
||||
|
||||
def synthetic_graph_input(self, fn, args):
|
||||
"""
|
||||
call fn(*args) before the graph runs and turn the result into a fake input.
|
||||
|
@ -4,9 +4,10 @@
|
||||
#include <torch/csrc/dynamo/debug_macros.h>
|
||||
#include <torch/csrc/dynamo/extra_state.h>
|
||||
|
||||
CacheEntry::CacheEntry(const py::handle& guarded_code) {
|
||||
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) {
|
||||
this->check_fn = guarded_code.attr("check_fn");
|
||||
this->code = guarded_code.attr("code");
|
||||
this->backend = backend;
|
||||
// TODO - clean this up when enable_cpp_guard_manager is True by default
|
||||
if (py::hasattr(this->check_fn, "root")) {
|
||||
this->root_mgr = convert_to_root_guard_manager(this->check_fn.attr("root"));
|
||||
@ -39,3 +40,14 @@ PyObject* CacheEntry_to_obj(CacheEntry* e) {
|
||||
}
|
||||
return py::cast(e, py::return_value_policy::reference).release().ptr();
|
||||
}
|
||||
|
||||
PyObject* get_backend(PyObject* callback) {
|
||||
py::handle handle = py::handle(callback);
|
||||
while (py::hasattr(handle, "_torchdynamo_orig_callable")) {
|
||||
handle = handle.attr("_torchdynamo_orig_callable");
|
||||
}
|
||||
if (py::hasattr(handle, "compiler_fn")) {
|
||||
handle = handle.attr("compiler_fn");
|
||||
}
|
||||
return handle.ptr();
|
||||
}
|
||||
|
@ -45,12 +45,14 @@ typedef struct VISIBILITY_HIDDEN CacheEntry {
|
||||
py::object code;
|
||||
// root guard manager if exists
|
||||
void* root_mgr{nullptr};
|
||||
// backend used to create this cache entry
|
||||
PyObject* backend{nullptr};
|
||||
// Reference to owning ExtraState
|
||||
ExtraState* _owner{nullptr};
|
||||
// Reference to this CacheEntry's location in owner's linked list
|
||||
std::list<CacheEntry>::iterator _owner_loc;
|
||||
|
||||
CacheEntry(const py::handle& guarded_code);
|
||||
CacheEntry(const py::handle& guarded_code, PyObject* backend);
|
||||
~CacheEntry();
|
||||
|
||||
// Warning: returns a reference whose lifetime is controlled by C++
|
||||
|
@ -530,12 +530,14 @@ static PyObject* _custom_eval_frame(
|
||||
return NULL;
|
||||
}
|
||||
|
||||
PyObject* backend = get_backend(callback);
|
||||
|
||||
// A callback of Py_False indicates "run only" mode, the cache is checked, but
|
||||
// we never compile.
|
||||
if (callback == Py_False) {
|
||||
DEBUG_TRACE("In run only mode %s", get_frame_name(frame));
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||
PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
|
||||
PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
|
||||
_pytorch_record_function_exit(rf);
|
||||
|
||||
if (maybe_cached_code == NULL) {
|
||||
@ -560,7 +562,7 @@ static PyObject* _custom_eval_frame(
|
||||
eval_frame_callback_set(Py_None);
|
||||
|
||||
_PytorchRecordFunctionState* rf = _pytorch_record_function_enter(cache_lookup_profiler_str);
|
||||
PyObject* maybe_cached_code = lookup(extra, frame->f_locals);
|
||||
PyObject* maybe_cached_code = lookup(extra, frame->f_locals, backend);
|
||||
_pytorch_record_function_exit(rf);
|
||||
if (maybe_cached_code == NULL) {
|
||||
// Python error
|
||||
@ -594,7 +596,7 @@ static PyObject* _custom_eval_frame(
|
||||
// extract_cache_entry returns a borrowed reference. Modifying a borrowed
|
||||
// reference seems wrong. Therefore, we directly access the
|
||||
// extra->cache_entry. extra wont be NULL here.
|
||||
CacheEntry* new_cache_entry = create_cache_entry(extra, result);
|
||||
CacheEntry* new_cache_entry = create_cache_entry(extra, result, backend);
|
||||
Py_DECREF(result);
|
||||
// Update the existing cache_entry on the extra object. This extra object is
|
||||
// sitting on the extra scratch space, we are just changing the cache_entry
|
||||
|
@ -82,34 +82,40 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code) {
|
||||
return extra_state;
|
||||
}
|
||||
|
||||
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
|
||||
PyObject* lookup(
|
||||
ExtraState* extra_state,
|
||||
PyObject* f_locals,
|
||||
PyObject* backend) {
|
||||
size_t index = 0;
|
||||
CacheEntry* found = nullptr;
|
||||
py::handle locals(f_locals);
|
||||
for (CacheEntry& cache_entry : extra_state->cache_entry_list) {
|
||||
bool valid = false;
|
||||
try {
|
||||
// TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is True
|
||||
// by default
|
||||
if (cache_entry.root_mgr != nullptr) {
|
||||
valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
|
||||
} else {
|
||||
valid = cache_entry.check_fn(locals).cast<bool>();
|
||||
// Check backend. Py_False means run only mode.
|
||||
bool valid = backend == Py_False || cache_entry.backend == backend;
|
||||
if (valid) {
|
||||
try {
|
||||
// TODO(anijain2305) - Clean this up when enable_cpp_guard_manager is
|
||||
// True by default
|
||||
if (cache_entry.root_mgr != nullptr) {
|
||||
valid = run_root_guard_manager(cache_entry.root_mgr, f_locals);
|
||||
} else {
|
||||
valid = cache_entry.check_fn(locals).cast<bool>();
|
||||
}
|
||||
} catch (py::error_already_set& e) {
|
||||
if (guard_error_hook) {
|
||||
py::handle guard_error_hook_handle(guard_error_hook);
|
||||
guard_error_hook_handle(
|
||||
cache_entry.check_fn,
|
||||
cache_entry.code,
|
||||
locals,
|
||||
index,
|
||||
index == extra_state->cache_entry_list.size() - 1);
|
||||
}
|
||||
// this function is called from C, so we cannot repropagate
|
||||
// the exception
|
||||
e.restore();
|
||||
return NULL;
|
||||
}
|
||||
} catch (py::error_already_set& e) {
|
||||
if (guard_error_hook) {
|
||||
py::handle guard_error_hook_handle(guard_error_hook);
|
||||
guard_error_hook_handle(
|
||||
cache_entry.check_fn,
|
||||
cache_entry.code,
|
||||
locals,
|
||||
index,
|
||||
index == extra_state->cache_entry_list.size() - 1);
|
||||
}
|
||||
// this function is called from C, so we cannot repropagate
|
||||
// the exception
|
||||
e.restore();
|
||||
return NULL;
|
||||
}
|
||||
if (valid) {
|
||||
found = &cache_entry;
|
||||
@ -126,8 +132,9 @@ PyObject* lookup(ExtraState* extra_state, PyObject* f_locals) {
|
||||
|
||||
CacheEntry* create_cache_entry(
|
||||
ExtraState* extra_state,
|
||||
PyObject* guarded_code) {
|
||||
extra_state->cache_entry_list.emplace_front(guarded_code);
|
||||
PyObject* guarded_code,
|
||||
PyObject* backend) {
|
||||
extra_state->cache_entry_list.emplace_front(guarded_code, backend);
|
||||
auto new_iter = extra_state->cache_entry_list.begin();
|
||||
new_iter->_owner = extra_state;
|
||||
new_iter->_owner_loc = new_iter;
|
||||
|
@ -124,7 +124,10 @@ ExtraState* init_and_set_extra_state(PyCodeObject* code);
|
||||
// - f_locals: Borrowed
|
||||
// return:
|
||||
// - Py_None or PyCodeObject: Borrowed reference.
|
||||
PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
|
||||
PyObject* lookup(
|
||||
ExtraState* extra_state,
|
||||
PyObject* f_locals,
|
||||
PyObject* callback);
|
||||
|
||||
// Create a new cache entry at extra_state holding on to guarded_code.
|
||||
// Ownership contract
|
||||
@ -133,7 +136,13 @@ PyObject* lookup(ExtraState* extra_state, PyObject* f_locals);
|
||||
// - guarded_code: Borrowed
|
||||
// return:
|
||||
// - cache_entry: Borrowed reference
|
||||
CacheEntry* create_cache_entry(ExtraState* extra_state, PyObject* guraded_code);
|
||||
CacheEntry* create_cache_entry(
|
||||
ExtraState* extra_state,
|
||||
PyObject* guraded_code,
|
||||
PyObject* callback);
|
||||
|
||||
// Extracts the backend fn from the callback.
|
||||
PyObject* get_backend(PyObject* callback);
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
||||
|
Reference in New Issue
Block a user