mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo][refactor][config-cleanp] Use guard_manager consistently instead of check_fn (#138896)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138896 Approved by: https://github.com/williamwen42, https://github.com/jansel ghstack dependencies: #138512
This commit is contained in:
committed by
PyTorch MergeBot
parent
49ed365b22
commit
dba6887dc6
@ -87,7 +87,7 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
|
||||
target_with_varkwargs.__code__: varkwargs_code2.__code__,
|
||||
}
|
||||
|
||||
empty_guard_manager = torch._dynamo.guards.GuardManager()
|
||||
empty_guard_manager = torch._dynamo.guards.GuardManagerWrapper()
|
||||
|
||||
def callback1(frame, cache_entry, frame_state):
|
||||
if frame.f_code in code_map1:
|
||||
|
@ -15,10 +15,10 @@ log = logging.getLogger(__name__)
|
||||
[Note on cache size limit]
|
||||
|
||||
Background - TorchDynamo cache is a linked list. Each cache entry is a
|
||||
(check_fn, out_code, next pointer). These are stored on the f_code's co_extra
|
||||
(guard_manager, out_code, next pointer). These are stored on the f_code's co_extra
|
||||
scratch space. When a frame is invoked, we walk this linked list and run
|
||||
check_fn in each cache_entry to decide if the frame needs recompilation. If none
|
||||
of the check_fn's returns True, we recompile and add a new entry. To ensure we
|
||||
guard_manager in each cache_entry to decide if the frame needs recompilation. If none
|
||||
of the guard_manager's returns True, we recompile and add a new entry. To ensure we
|
||||
don't end up recompiling infinitely, we put limits on the cache size.
|
||||
|
||||
There are two limits
|
||||
@ -121,7 +121,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
|
||||
for (
|
||||
local_name,
|
||||
weakref_from_cache_entry,
|
||||
) in cache_entry.check_fn.id_matched_objs.items():
|
||||
) in cache_entry.guard_manager.id_matched_objs.items():
|
||||
if weakref_from_cache_entry() is not None:
|
||||
weakref_from_frame = _get_weakref_from_f_locals(frame, local_name)
|
||||
if weakref_from_frame is not weakref_from_cache_entry:
|
||||
@ -176,7 +176,7 @@ def exceeds_cache_size_limit(
|
||||
if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit):
|
||||
return True, "cache_size_limit"
|
||||
# NOTE this check is needed in the case that the frame's cache doesn't grow
|
||||
# and we keep recompiling. This can happen if the guard check_fn becomes invalidated,
|
||||
# and we keep recompiling. This can happen if the guard guard_manager becomes invalidated,
|
||||
# e.g. due to guarded objects being freed. This technically makes the
|
||||
# will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
|
||||
# check in case we have a better fix in the future.
|
||||
|
@ -842,7 +842,7 @@ def _compile(
|
||||
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
|
||||
annotation_str = "Torch-Compiled Region: " + compile_id_str
|
||||
guarded_code = GuardedCode(
|
||||
out_code, check_fn.check_fn, compile_id, annotation_str
|
||||
out_code, check_fn.guard_manager, compile_id, annotation_str # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
|
||||
|
@ -145,7 +145,7 @@ recompiles_verbose_log = torch._logging.getArtifactLogger(
|
||||
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
|
||||
|
||||
|
||||
class GuardManager:
|
||||
class GuardManagerWrapper:
|
||||
"""
|
||||
A helper class that contains the root guard manager. An instance of this
|
||||
class is stored in the Dynamo cache entry, so that the cache entry can
|
||||
@ -526,7 +526,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
lookup_weakrefs: Callable[[object], ReferenceType[object]],
|
||||
local_scope: Dict[str, object],
|
||||
global_scope: Dict[str, object],
|
||||
guard_manager: GuardManager,
|
||||
guard_manager: GuardManagerWrapper,
|
||||
check_fn_manager: CheckFunctionManager,
|
||||
):
|
||||
self.id_ref = id_ref
|
||||
@ -570,7 +570,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.tensor_check_names: List[str] = []
|
||||
self.tensor_check_examples: List[torch.Tensor] = []
|
||||
self.tensor_check_guards: List[Guard] = []
|
||||
self.tensor_check_guard_managers: List[GuardManager] = []
|
||||
self.tensor_check_guard_managers: List[GuardManagerWrapper] = []
|
||||
|
||||
self.check_fn_manager: CheckFunctionManager = check_fn_manager
|
||||
|
||||
@ -583,7 +583,7 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self.key_order_guarded_dict_ids.add(id(self.get(source_name)))
|
||||
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
# info is stored alongside optimized_code and check_fn and is used to
|
||||
# info is stored alongside optimized_code and guard_manager and is used to
|
||||
# limit the number of cache entries with same ID_MATCH'd object.
|
||||
self.id_matched_objs: Dict[str, ReferenceType[object]] = {}
|
||||
|
||||
@ -591,7 +591,6 @@ class GuardBuilder(GuardBuilderBase):
|
||||
self._cached_guard_managers: Dict[
|
||||
str, torch._C._dynamo.guards.GuardManager
|
||||
] = {}
|
||||
|
||||
self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set()
|
||||
|
||||
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
|
||||
@ -2111,7 +2110,7 @@ class CheckFunctionManager:
|
||||
):
|
||||
guards = output_graph.guards if output_graph else None
|
||||
self._weakrefs: Dict[int, ReferenceType[object]] = {}
|
||||
self.guard_manager = GuardManager()
|
||||
self.guard_manager = GuardManagerWrapper()
|
||||
self.output_graph = output_graph
|
||||
w_builder = None
|
||||
|
||||
@ -2171,17 +2170,17 @@ class CheckFunctionManager:
|
||||
|
||||
guard.create(builder)
|
||||
|
||||
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
|
||||
self.compile_check_fn(builder, guards, guard_fail_fn)
|
||||
|
||||
# Keep track of weak references of objects with ID_MATCH guard. This
|
||||
# info is stored alongside optimized_code and check_fn and is used to
|
||||
# info is stored alongside optimized_code and guard_manager and is used to
|
||||
# limit the number of cache entries with same ID_MATCH'd object.
|
||||
# TODO(anijain2305) - Currently this information is stored as an attr on
|
||||
# the check_fn itself to avoid changing CacehEntry datastructure in
|
||||
# eval_frame.c. In future, we should probably replace check_fn with a
|
||||
# the guard_manager itself to avoid changing CacheEntry data structure in
|
||||
# eval_frame.c. In future, we should probably replace guard_manager with a
|
||||
# queryable data structure such that this information is already present
|
||||
# in some form.
|
||||
self.check_fn.id_matched_objs = builder.id_matched_objs
|
||||
self.guard_manager.id_matched_objs = builder.id_matched_objs
|
||||
|
||||
# TODO: don't do the string rep, do something more structured here
|
||||
torch._logging.trace_structured(
|
||||
@ -2189,7 +2188,6 @@ class CheckFunctionManager:
|
||||
)
|
||||
guards_log.debug("%s", self.guard_manager)
|
||||
self.guard_manager.id_matched_objs = builder.id_matched_objs
|
||||
self.check_fn = self.guard_manager
|
||||
|
||||
# Check that the guard returns True. False means that we will always
|
||||
# recompile.
|
||||
@ -2351,45 +2349,39 @@ class CheckFunctionManager:
|
||||
}
|
||||
|
||||
globals_for_guard_fn = {"G": builder.scope["G"]}
|
||||
# Guard manager construction is complete
|
||||
# TODO (anijain2305) - When enable_cpp_guard_manager is ON by
|
||||
# default, change the guard_fn name to be guard_manager everywhere
|
||||
# to avoid confusion.
|
||||
guard_fn = self.guard_manager
|
||||
# Ensure we did not miss to insert a guard in cpp guard manager.
|
||||
# Guard manager construction is complete. Ensure we did not miss to
|
||||
# insert a guard in cpp guard manager.
|
||||
assert len(code_parts) == 0
|
||||
|
||||
guard_fn.closure_vars = closure_vars
|
||||
# TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
|
||||
guard_fn.args = largs
|
||||
guard_fn.populate_code_parts_for_debugging()
|
||||
guard_fn.verbose_code_parts = verbose_code_parts
|
||||
self.guard_manager.closure_vars = closure_vars
|
||||
self.guard_manager.args = largs
|
||||
self.guard_manager.populate_code_parts_for_debugging()
|
||||
self.guard_manager.verbose_code_parts = verbose_code_parts
|
||||
# Grab only G, but preserve "G" because guards access it as "G"
|
||||
guard_fn.global_scope = globals_for_guard_fn
|
||||
guard_fn.guard_fail_fn = guard_fail_fn
|
||||
self.guard_manager.global_scope = globals_for_guard_fn
|
||||
self.guard_manager.guard_fail_fn = guard_fail_fn
|
||||
# will be populated by a non-owning reference to CacheEntry/ExtraState
|
||||
# when the CacheEntry is constructed
|
||||
guard_fn.cache_entry = None
|
||||
guard_fn.extra_state = None
|
||||
guard_fn.no_tensor_aliasing_sources = tensor_check_names
|
||||
return guard_fn
|
||||
self.guard_manager.cache_entry = None
|
||||
self.guard_manager.extra_state = None
|
||||
self.guard_manager.no_tensor_aliasing_sources = tensor_check_names
|
||||
|
||||
def invalidate(self):
|
||||
# Some tests reveal that CheckFunctionManager has no attribute
|
||||
# check_fn, but this case should not be of any concern.
|
||||
# guard_manager, but this case should not be of any concern.
|
||||
# This case doesn't seem easy to repro.
|
||||
if (
|
||||
hasattr(self, "check_fn")
|
||||
and self.check_fn is not DeletedGuardFn
|
||||
and (cache_entry := self.check_fn.cache_entry) is not None
|
||||
and (extra_state := self.check_fn.extra_state) is not None
|
||||
hasattr(self, "guard_manager")
|
||||
and self.guard_manager is not DeletedGuardFn
|
||||
and (cache_entry := self.guard_manager.cache_entry) is not None
|
||||
and (extra_state := self.guard_manager.extra_state) is not None
|
||||
):
|
||||
assert isinstance(cache_entry, CacheEntry)
|
||||
assert isinstance(extra_state, ExtraState)
|
||||
extra_state.invalidate(cache_entry)
|
||||
self.check_fn.cache_entry = None
|
||||
self.check_fn.extra_state = None
|
||||
self.check_fn = DeletedGuardFn
|
||||
self.guard_manager.cache_entry = None
|
||||
self.guard_manager.extra_state = None
|
||||
self.guard_manager = DeletedGuardFn # type: ignore[assignment]
|
||||
|
||||
def id_ref(self, obj):
|
||||
"""add a weakref, return the id"""
|
||||
@ -2499,23 +2491,22 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
|
||||
|
||||
|
||||
def get_guard_fail_reason_helper(
|
||||
guard_fn: GuardFn,
|
||||
guard_manager: GuardFn,
|
||||
f_locals: Dict[str, object],
|
||||
compile_id: CompileId,
|
||||
) -> str:
|
||||
"""
|
||||
Return the reason why `guard_fn` failed.
|
||||
Return the reason why `guard_manager` failed.
|
||||
Updates `guard_failures` with the generated reason.
|
||||
Only the first failed check of guard_fn is reported.
|
||||
Only the first failed check of guard_manager is reported.
|
||||
"""
|
||||
scope = {"L": f_locals, "G": guard_fn.global_scope["G"]}
|
||||
scope.update(guard_fn.closure_vars)
|
||||
scope = {"L": f_locals, "G": guard_manager.global_scope["G"]}
|
||||
scope.update(guard_manager.closure_vars)
|
||||
reasons: List[str] = []
|
||||
|
||||
no_tensor_aliasing_check_failed = False
|
||||
|
||||
verbose_code_parts: List[str] = []
|
||||
guard_manager = guard_fn
|
||||
guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined]
|
||||
# For test_export_with_map_cond, the check_verbose fail even without the
|
||||
# C++ guard manager. We need to fix the issue to remove the comment.
|
||||
@ -2537,10 +2528,12 @@ def get_guard_fail_reason_helper(
|
||||
verbose_code_parts = []
|
||||
|
||||
if no_tensor_aliasing_check_failed:
|
||||
reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope)
|
||||
reasons = recompilation_reason_for_no_tensor_aliasing_guard(
|
||||
guard_manager, scope
|
||||
)
|
||||
else:
|
||||
for part in verbose_code_parts:
|
||||
global_scope = dict(guard_fn.global_scope)
|
||||
global_scope = dict(guard_manager.global_scope)
|
||||
global_scope["__compile_source__"] = part
|
||||
with report_compile_source_on_error():
|
||||
try:
|
||||
@ -2565,17 +2558,17 @@ def get_guard_fail_reason_helper(
|
||||
|
||||
|
||||
def get_guard_fail_reason(
|
||||
guard_fn: GuardFn,
|
||||
guard_manager: GuardFn,
|
||||
code: types.CodeType,
|
||||
f_locals: Dict[str, object],
|
||||
compile_id: CompileId,
|
||||
) -> str:
|
||||
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id)
|
||||
reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id)
|
||||
guard_failures[orig_code_map[code]].append(reason_str)
|
||||
|
||||
try:
|
||||
if guard_fn.guard_fail_fn is not None:
|
||||
guard_fn.guard_fail_fn(
|
||||
if guard_manager.guard_fail_fn is not None:
|
||||
guard_manager.guard_fail_fn(
|
||||
GuardFail(reason_str or "unknown reason", orig_code_map[code])
|
||||
)
|
||||
except Exception as e:
|
||||
@ -2597,7 +2590,7 @@ def get_and_maybe_log_recompilation_reason(
|
||||
reasons = []
|
||||
while cache_entry is not None:
|
||||
reason = get_guard_fail_reason(
|
||||
cache_entry.check_fn,
|
||||
cache_entry.guard_manager,
|
||||
cache_entry.code,
|
||||
frame.f_locals,
|
||||
cache_entry.compile_id,
|
||||
@ -2647,7 +2640,7 @@ def get_and_maybe_log_recompilation_reason(
|
||||
|
||||
|
||||
def guard_error_hook(
|
||||
guard_fn: GuardFn,
|
||||
guard_manager: GuardFn,
|
||||
code: types.CodeType,
|
||||
f_locals: Dict[str, object],
|
||||
index: int,
|
||||
@ -2656,15 +2649,15 @@ def guard_error_hook(
|
||||
print(
|
||||
f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
|
||||
)
|
||||
print("lambda " + ", ".join(guard_fn.args) + ":")
|
||||
print(" ", " and\n ".join(guard_fn.code_parts))
|
||||
print("lambda " + ", ".join(guard_manager.args) + ":")
|
||||
print(" ", " and\n ".join(guard_manager.code_parts))
|
||||
|
||||
print(guard_fn)
|
||||
print(guard_manager)
|
||||
|
||||
local_scope = {"L": f_locals, **guard_fn.closure_vars}
|
||||
for guard in guard_fn.code_parts:
|
||||
local_scope = {"L": f_locals, **guard_manager.closure_vars}
|
||||
for guard in guard_manager.code_parts:
|
||||
try:
|
||||
eval(guard, guard_fn.global_scope, local_scope)
|
||||
eval(guard, guard_manager.global_scope, local_scope)
|
||||
except: # noqa: B001,E722
|
||||
print(f"Malformed guard:\n{guard}")
|
||||
|
||||
|
@ -191,7 +191,7 @@ def debug_insert_nops(
|
||||
torch_function_mode_stack=[],
|
||||
)
|
||||
|
||||
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
|
||||
return GuardedCode(code, CheckFunctionManager(graph).guard_manager, CompileId(0, 0)) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class CompileCounter:
|
||||
|
@ -3,7 +3,7 @@ import sys
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
|
||||
|
||||
# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
|
||||
# CacheEntry has a `guard_manager` field for the guard, and a `code` field for the code object.
|
||||
from torch._C._dynamo.eval_frame import (
|
||||
_CacheEntry as CacheEntry,
|
||||
_ExtraState as ExtraState,
|
||||
@ -46,7 +46,7 @@ class GuardFn(Protocol):
|
||||
@dataclasses.dataclass
|
||||
class GuardedCode:
|
||||
code: types.CodeType
|
||||
check_fn: GuardFn
|
||||
guard_manager: GuardFn
|
||||
compile_id: CompileId
|
||||
trace_annotation: str = "Unknown"
|
||||
|
||||
@ -67,7 +67,7 @@ DynamoCallback = Union[DynamoCallbackFn, None, bool]
|
||||
class DynamoGuardHook(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
guard_fn: GuardFn,
|
||||
guard_manager: GuardFn,
|
||||
code: types.CodeType,
|
||||
f_locals: Dict[str, object],
|
||||
index: int,
|
||||
|
@ -6,7 +6,7 @@
|
||||
|
||||
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
|
||||
: backend{backend} {
|
||||
this->check_fn = guarded_code.attr("check_fn");
|
||||
this->guard_manager = guarded_code.attr("guard_manager");
|
||||
this->code = guarded_code.attr("code");
|
||||
this->compile_id = guarded_code.attr("compile_id");
|
||||
py::object trace_annotation = guarded_code.attr("trace_annotation");
|
||||
@ -16,8 +16,8 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
|
||||
} else {
|
||||
this->trace_annotation = "Unknown";
|
||||
}
|
||||
this->root_mgr =
|
||||
torch::dynamo::convert_to_root_guard_manager(this->check_fn.attr("root"));
|
||||
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
|
||||
this->guard_manager.attr("root"));
|
||||
}
|
||||
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
|
||||
@ -25,9 +25,9 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor")
|
||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||
CacheEntry::~CacheEntry() {
|
||||
// prevent check_fn from use-after-free when invalidating
|
||||
this->check_fn.attr("cache_entry") = py::none();
|
||||
this->check_fn.attr("extra_state") = py::none();
|
||||
// prevent guard_manager from use-after-free when invalidating
|
||||
this->guard_manager.attr("cache_entry") = py::none();
|
||||
this->guard_manager.attr("extra_state") = py::none();
|
||||
}
|
||||
C10_DIAGNOSTIC_POP()
|
||||
C10_DIAGNOSTIC_POP()
|
||||
|
@ -18,11 +18,12 @@ of the cache is as follows:
|
||||
|
||||
-> ExtraState
|
||||
-> CacheEntry (list)
|
||||
-> check_fn
|
||||
-> guard_manager (a wrapper that contains the actual guard manager at its
|
||||
attr named root)
|
||||
-> code
|
||||
-> FrameState
|
||||
|
||||
CacheEntry is a linked list node containing the check_fn for guards
|
||||
CacheEntry is a linked list node containing the guard_manager for guards
|
||||
and the optimized code.
|
||||
|
||||
The FrameState is a PyDict that enables sharing between different frames. This
|
||||
@ -41,8 +42,8 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
|
||||
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor")
|
||||
typedef struct VISIBILITY_HIDDEN CacheEntry {
|
||||
// check the guards: lambda: <locals of user function>: bool
|
||||
py::object check_fn;
|
||||
// modified user bytecode (protected by check_fn's guards)
|
||||
py::object guard_manager;
|
||||
// modified user bytecode (protected by guard_manager's guards)
|
||||
py::object code;
|
||||
// CompileId corresponding to this compilation
|
||||
py::object compile_id;
|
||||
|
@ -132,7 +132,7 @@ void lookup(
|
||||
if (guard_error_hook) {
|
||||
py::handle guard_error_hook_handle(guard_error_hook);
|
||||
guard_error_hook_handle(
|
||||
cache_entry.check_fn,
|
||||
cache_entry.guard_manager,
|
||||
cache_entry.code,
|
||||
locals,
|
||||
index,
|
||||
@ -168,12 +168,12 @@ CacheEntry* create_cache_entry(
|
||||
auto new_iter = extra_state->cache_entry_list.begin();
|
||||
new_iter->_owner = extra_state;
|
||||
new_iter->_owner_loc = new_iter;
|
||||
// Set check_fn references to extra_state and CacheEntry
|
||||
// Set guard_manager references to extra_state and CacheEntry
|
||||
// Warning: lifetime is controlled by C++!
|
||||
py::handle check_fn = py::handle(guarded_code).attr("check_fn");
|
||||
check_fn.attr("cache_entry") =
|
||||
py::handle guard_manager = py::handle(guarded_code).attr("guard_manager");
|
||||
guard_manager.attr("cache_entry") =
|
||||
py::cast(*new_iter, py::return_value_policy::reference);
|
||||
check_fn.attr("extra_state") =
|
||||
guard_manager.attr("extra_state") =
|
||||
py::cast(extra_state, py::return_value_policy::reference);
|
||||
return &*new_iter;
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ void initDynamoBindings(PyObject* torch) {
|
||||
auto m = py::handle(eval_frame).cast<py::module>();
|
||||
|
||||
py::class_<CacheEntry>(m, "_CacheEntry")
|
||||
.def_readonly("check_fn", &CacheEntry::check_fn)
|
||||
.def_readonly("guard_manager", &CacheEntry::guard_manager)
|
||||
.def_readonly("code", &CacheEntry::code)
|
||||
.def_readonly("compile_id", &CacheEntry::compile_id)
|
||||
.def_readonly("trace_annotation", &CacheEntry::trace_annotation)
|
||||
|
Reference in New Issue
Block a user