[dynamo][refactor][config-cleanp] Use guard_manager consistently instead of check_fn (#138896)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138896
Approved by: https://github.com/williamwen42, https://github.com/jansel
ghstack dependencies: #138512
This commit is contained in:
Animesh Jain
2024-10-25 12:14:35 -07:00
committed by PyTorch MergeBot
parent 49ed365b22
commit dba6887dc6
10 changed files with 78 additions and 84 deletions

View File

@ -87,7 +87,7 @@ class FrameInitTests(torch._dynamo.test_case.TestCase):
target_with_varkwargs.__code__: varkwargs_code2.__code__,
}
empty_guard_manager = torch._dynamo.guards.GuardManager()
empty_guard_manager = torch._dynamo.guards.GuardManagerWrapper()
def callback1(frame, cache_entry, frame_state):
if frame.f_code in code_map1:

View File

@ -15,10 +15,10 @@ log = logging.getLogger(__name__)
[Note on cache size limit]
Background - TorchDynamo cache is a linked list. Each cache entry is a
(check_fn, out_code, next pointer). These are stored on the f_code's co_extra
(guard_manager, out_code, next pointer). These are stored on the f_code's co_extra
scratch space. When a frame is invoked, we walk this linked list and run
check_fn in each cache_entry to decide if the frame needs recompilation. If none
of the check_fn's returns True, we recompile and add a new entry. To ensure we
guard_manager in each cache_entry to decide if the frame needs recompilation. If none
of the guard_manager's returns True, we recompile and add a new entry. To ensure we
don't end up recompiling infinitely, we put limits on the cache size.
There are two limits
@ -121,7 +121,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool:
for (
local_name,
weakref_from_cache_entry,
) in cache_entry.check_fn.id_matched_objs.items():
) in cache_entry.guard_manager.id_matched_objs.items():
if weakref_from_cache_entry() is not None:
weakref_from_frame = _get_weakref_from_f_locals(frame, local_name)
if weakref_from_frame is not weakref_from_cache_entry:
@ -176,7 +176,7 @@ def exceeds_cache_size_limit(
if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit):
return True, "cache_size_limit"
# NOTE this check is needed in the case that the frame's cache doesn't grow
# and we keep recompiling. This can happen if the guard check_fn becomes invalidated,
# and we keep recompiling. This can happen if the guard guard_manager becomes invalidated,
# e.g. due to guarded objects being freed. This technically makes the
# will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the
# check in case we have a better fix in the future.

View File

@ -842,7 +842,7 @@ def _compile(
compile_id_str = str(compile_id) if compile_id is not None else "Unknown"
annotation_str = "Torch-Compiled Region: " + compile_id_str
guarded_code = GuardedCode(
out_code, check_fn.check_fn, compile_id, annotation_str
out_code, check_fn.guard_manager, compile_id, annotation_str # type: ignore[arg-type]
)
if not output.is_empty_graph() and hooks.guard_export_fn is not None:

View File

@ -145,7 +145,7 @@ recompiles_verbose_log = torch._logging.getArtifactLogger(
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
class GuardManager:
class GuardManagerWrapper:
"""
A helper class that contains the root guard manager. An instance of this
class is stored in the Dynamo cache entry, so that the cache entry can
@ -526,7 +526,7 @@ class GuardBuilder(GuardBuilderBase):
lookup_weakrefs: Callable[[object], ReferenceType[object]],
local_scope: Dict[str, object],
global_scope: Dict[str, object],
guard_manager: GuardManager,
guard_manager: GuardManagerWrapper,
check_fn_manager: CheckFunctionManager,
):
self.id_ref = id_ref
@ -570,7 +570,7 @@ class GuardBuilder(GuardBuilderBase):
self.tensor_check_names: List[str] = []
self.tensor_check_examples: List[torch.Tensor] = []
self.tensor_check_guards: List[Guard] = []
self.tensor_check_guard_managers: List[GuardManager] = []
self.tensor_check_guard_managers: List[GuardManagerWrapper] = []
self.check_fn_manager: CheckFunctionManager = check_fn_manager
@ -583,7 +583,7 @@ class GuardBuilder(GuardBuilderBase):
self.key_order_guarded_dict_ids.add(id(self.get(source_name)))
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and check_fn and is used to
# info is stored alongside optimized_code and guard_manager and is used to
# limit the number of cache entries with same ID_MATCH'd object.
self.id_matched_objs: Dict[str, ReferenceType[object]] = {}
@ -591,7 +591,6 @@ class GuardBuilder(GuardBuilderBase):
self._cached_guard_managers: Dict[
str, torch._C._dynamo.guards.GuardManager
] = {}
self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set()
def guard_on_dict_keys_and_ignore_order(self, example_value, guard):
@ -2111,7 +2110,7 @@ class CheckFunctionManager:
):
guards = output_graph.guards if output_graph else None
self._weakrefs: Dict[int, ReferenceType[object]] = {}
self.guard_manager = GuardManager()
self.guard_manager = GuardManagerWrapper()
self.output_graph = output_graph
w_builder = None
@ -2171,17 +2170,17 @@ class CheckFunctionManager:
guard.create(builder)
self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn)
self.compile_check_fn(builder, guards, guard_fail_fn)
# Keep track of weak references of objects with ID_MATCH guard. This
# info is stored alongside optimized_code and check_fn and is used to
# info is stored alongside optimized_code and guard_manager and is used to
# limit the number of cache entries with same ID_MATCH'd object.
# TODO(anijain2305) - Currently this information is stored as an attr on
# the check_fn itself to avoid changing CacehEntry datastructure in
# eval_frame.c. In future, we should probably replace check_fn with a
# the guard_manager itself to avoid changing CacheEntry data structure in
# eval_frame.c. In future, we should probably replace guard_manager with a
# queryable data structure such that this information is already present
# in some form.
self.check_fn.id_matched_objs = builder.id_matched_objs
self.guard_manager.id_matched_objs = builder.id_matched_objs
# TODO: don't do the string rep, do something more structured here
torch._logging.trace_structured(
@ -2189,7 +2188,6 @@ class CheckFunctionManager:
)
guards_log.debug("%s", self.guard_manager)
self.guard_manager.id_matched_objs = builder.id_matched_objs
self.check_fn = self.guard_manager
# Check that the guard returns True. False means that we will always
# recompile.
@ -2351,45 +2349,39 @@ class CheckFunctionManager:
}
globals_for_guard_fn = {"G": builder.scope["G"]}
# Guard manager construction is complete
# TODO (anijain2305) - When enable_cpp_guard_manager is ON by
# default, change the guard_fn name to be guard_manager everywhere
# to avoid confusion.
guard_fn = self.guard_manager
# Ensure we did not miss to insert a guard in cpp guard manager.
# Guard manager construction is complete. Ensure we did not miss to
# insert a guard in cpp guard manager.
assert len(code_parts) == 0
guard_fn.closure_vars = closure_vars
# TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both
guard_fn.args = largs
guard_fn.populate_code_parts_for_debugging()
guard_fn.verbose_code_parts = verbose_code_parts
self.guard_manager.closure_vars = closure_vars
self.guard_manager.args = largs
self.guard_manager.populate_code_parts_for_debugging()
self.guard_manager.verbose_code_parts = verbose_code_parts
# Grab only G, but preserve "G" because guards access it as "G"
guard_fn.global_scope = globals_for_guard_fn
guard_fn.guard_fail_fn = guard_fail_fn
self.guard_manager.global_scope = globals_for_guard_fn
self.guard_manager.guard_fail_fn = guard_fail_fn
# will be populated by a non-owning reference to CacheEntry/ExtraState
# when the CacheEntry is constructed
guard_fn.cache_entry = None
guard_fn.extra_state = None
guard_fn.no_tensor_aliasing_sources = tensor_check_names
return guard_fn
self.guard_manager.cache_entry = None
self.guard_manager.extra_state = None
self.guard_manager.no_tensor_aliasing_sources = tensor_check_names
def invalidate(self):
# Some tests reveal that CheckFunctionManager has no attribute
# check_fn, but this case should not be of any concern.
# guard_manager, but this case should not be of any concern.
# This case doesn't seem easy to repro.
if (
hasattr(self, "check_fn")
and self.check_fn is not DeletedGuardFn
and (cache_entry := self.check_fn.cache_entry) is not None
and (extra_state := self.check_fn.extra_state) is not None
hasattr(self, "guard_manager")
and self.guard_manager is not DeletedGuardFn
and (cache_entry := self.guard_manager.cache_entry) is not None
and (extra_state := self.guard_manager.extra_state) is not None
):
assert isinstance(cache_entry, CacheEntry)
assert isinstance(extra_state, ExtraState)
extra_state.invalidate(cache_entry)
self.check_fn.cache_entry = None
self.check_fn.extra_state = None
self.check_fn = DeletedGuardFn
self.guard_manager.cache_entry = None
self.guard_manager.extra_state = None
self.guard_manager = DeletedGuardFn # type: ignore[assignment]
def id_ref(self, obj):
"""add a weakref, return the id"""
@ -2499,23 +2491,22 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
def get_guard_fail_reason_helper(
guard_fn: GuardFn,
guard_manager: GuardFn,
f_locals: Dict[str, object],
compile_id: CompileId,
) -> str:
"""
Return the reason why `guard_fn` failed.
Return the reason why `guard_manager` failed.
Updates `guard_failures` with the generated reason.
Only the first failed check of guard_fn is reported.
Only the first failed check of guard_manager is reported.
"""
scope = {"L": f_locals, "G": guard_fn.global_scope["G"]}
scope.update(guard_fn.closure_vars)
scope = {"L": f_locals, "G": guard_manager.global_scope["G"]}
scope.update(guard_manager.closure_vars)
reasons: List[str] = []
no_tensor_aliasing_check_failed = False
verbose_code_parts: List[str] = []
guard_manager = guard_fn
guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined]
# For test_export_with_map_cond, the check_verbose fail even without the
# C++ guard manager. We need to fix the issue to remove the comment.
@ -2537,10 +2528,12 @@ def get_guard_fail_reason_helper(
verbose_code_parts = []
if no_tensor_aliasing_check_failed:
reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope)
reasons = recompilation_reason_for_no_tensor_aliasing_guard(
guard_manager, scope
)
else:
for part in verbose_code_parts:
global_scope = dict(guard_fn.global_scope)
global_scope = dict(guard_manager.global_scope)
global_scope["__compile_source__"] = part
with report_compile_source_on_error():
try:
@ -2565,17 +2558,17 @@ def get_guard_fail_reason_helper(
def get_guard_fail_reason(
guard_fn: GuardFn,
guard_manager: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
compile_id: CompileId,
) -> str:
reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id)
reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id)
guard_failures[orig_code_map[code]].append(reason_str)
try:
if guard_fn.guard_fail_fn is not None:
guard_fn.guard_fail_fn(
if guard_manager.guard_fail_fn is not None:
guard_manager.guard_fail_fn(
GuardFail(reason_str or "unknown reason", orig_code_map[code])
)
except Exception as e:
@ -2597,7 +2590,7 @@ def get_and_maybe_log_recompilation_reason(
reasons = []
while cache_entry is not None:
reason = get_guard_fail_reason(
cache_entry.check_fn,
cache_entry.guard_manager,
cache_entry.code,
frame.f_locals,
cache_entry.compile_id,
@ -2647,7 +2640,7 @@ def get_and_maybe_log_recompilation_reason(
def guard_error_hook(
guard_fn: GuardFn,
guard_manager: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
index: int,
@ -2656,15 +2649,15 @@ def guard_error_hook(
print(
f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}"
)
print("lambda " + ", ".join(guard_fn.args) + ":")
print(" ", " and\n ".join(guard_fn.code_parts))
print("lambda " + ", ".join(guard_manager.args) + ":")
print(" ", " and\n ".join(guard_manager.code_parts))
print(guard_fn)
print(guard_manager)
local_scope = {"L": f_locals, **guard_fn.closure_vars}
for guard in guard_fn.code_parts:
local_scope = {"L": f_locals, **guard_manager.closure_vars}
for guard in guard_manager.code_parts:
try:
eval(guard, guard_fn.global_scope, local_scope)
eval(guard, guard_manager.global_scope, local_scope)
except: # noqa: B001,E722
print(f"Malformed guard:\n{guard}")

View File

@ -191,7 +191,7 @@ def debug_insert_nops(
torch_function_mode_stack=[],
)
return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0))
return GuardedCode(code, CheckFunctionManager(graph).guard_manager, CompileId(0, 0)) # type: ignore[arg-type]
class CompileCounter:

View File

@ -3,7 +3,7 @@ import sys
import types
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
# CacheEntry has a `guard_manager` field for the guard, and a `code` field for the code object.
from torch._C._dynamo.eval_frame import (
_CacheEntry as CacheEntry,
_ExtraState as ExtraState,
@ -46,7 +46,7 @@ class GuardFn(Protocol):
@dataclasses.dataclass
class GuardedCode:
code: types.CodeType
check_fn: GuardFn
guard_manager: GuardFn
compile_id: CompileId
trace_annotation: str = "Unknown"
@ -67,7 +67,7 @@ DynamoCallback = Union[DynamoCallbackFn, None, bool]
class DynamoGuardHook(Protocol):
def __call__(
self,
guard_fn: GuardFn,
guard_manager: GuardFn,
code: types.CodeType,
f_locals: Dict[str, object],
index: int,

View File

@ -6,7 +6,7 @@
CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
: backend{backend} {
this->check_fn = guarded_code.attr("check_fn");
this->guard_manager = guarded_code.attr("guard_manager");
this->code = guarded_code.attr("code");
this->compile_id = guarded_code.attr("compile_id");
py::object trace_annotation = guarded_code.attr("trace_annotation");
@ -16,8 +16,8 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend)
} else {
this->trace_annotation = "Unknown";
}
this->root_mgr =
torch::dynamo::convert_to_root_guard_manager(this->check_fn.attr("root"));
this->root_mgr = torch::dynamo::convert_to_root_guard_manager(
this->guard_manager.attr("root"));
}
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
@ -25,9 +25,9 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor")
// NOLINTNEXTLINE(bugprone-exception-escape)
CacheEntry::~CacheEntry() {
// prevent check_fn from use-after-free when invalidating
this->check_fn.attr("cache_entry") = py::none();
this->check_fn.attr("extra_state") = py::none();
// prevent guard_manager from use-after-free when invalidating
this->guard_manager.attr("cache_entry") = py::none();
this->guard_manager.attr("extra_state") = py::none();
}
C10_DIAGNOSTIC_POP()
C10_DIAGNOSTIC_POP()

View File

@ -18,11 +18,12 @@ of the cache is as follows:
-> ExtraState
-> CacheEntry (list)
-> check_fn
-> guard_manager (a wrapper that contains the actual guard manager at its
attr named root)
-> code
-> FrameState
CacheEntry is a linked list node containing the check_fn for guards
CacheEntry is a linked list node containing the guard_manager for guards
and the optimized code.
The FrameState is a PyDict that enables sharing between different frames. This
@ -41,8 +42,8 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED(
C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor")
typedef struct VISIBILITY_HIDDEN CacheEntry {
// check the guards: lambda: <locals of user function>: bool
py::object check_fn;
// modified user bytecode (protected by check_fn's guards)
py::object guard_manager;
// modified user bytecode (protected by guard_manager's guards)
py::object code;
// CompileId corresponding to this compilation
py::object compile_id;

View File

@ -132,7 +132,7 @@ void lookup(
if (guard_error_hook) {
py::handle guard_error_hook_handle(guard_error_hook);
guard_error_hook_handle(
cache_entry.check_fn,
cache_entry.guard_manager,
cache_entry.code,
locals,
index,
@ -168,12 +168,12 @@ CacheEntry* create_cache_entry(
auto new_iter = extra_state->cache_entry_list.begin();
new_iter->_owner = extra_state;
new_iter->_owner_loc = new_iter;
// Set check_fn references to extra_state and CacheEntry
// Set guard_manager references to extra_state and CacheEntry
// Warning: lifetime is controlled by C++!
py::handle check_fn = py::handle(guarded_code).attr("check_fn");
check_fn.attr("cache_entry") =
py::handle guard_manager = py::handle(guarded_code).attr("guard_manager");
guard_manager.attr("cache_entry") =
py::cast(*new_iter, py::return_value_policy::reference);
check_fn.attr("extra_state") =
guard_manager.attr("extra_state") =
py::cast(extra_state, py::return_value_policy::reference);
return &*new_iter;
}

View File

@ -67,7 +67,7 @@ void initDynamoBindings(PyObject* torch) {
auto m = py::handle(eval_frame).cast<py::module>();
py::class_<CacheEntry>(m, "_CacheEntry")
.def_readonly("check_fn", &CacheEntry::check_fn)
.def_readonly("guard_manager", &CacheEntry::guard_manager)
.def_readonly("code", &CacheEntry::code)
.def_readonly("compile_id", &CacheEntry::compile_id)
.def_readonly("trace_annotation", &CacheEntry::trace_annotation)