diff --git a/test/dynamo/test_frame_init.py b/test/dynamo/test_frame_init.py index 00206d52e393..97aac1870e98 100644 --- a/test/dynamo/test_frame_init.py +++ b/test/dynamo/test_frame_init.py @@ -87,7 +87,7 @@ class FrameInitTests(torch._dynamo.test_case.TestCase): target_with_varkwargs.__code__: varkwargs_code2.__code__, } - empty_guard_manager = torch._dynamo.guards.GuardManager() + empty_guard_manager = torch._dynamo.guards.GuardManagerWrapper() def callback1(frame, cache_entry, frame_state): if frame.f_code in code_map1: diff --git a/torch/_dynamo/cache_size.py b/torch/_dynamo/cache_size.py index 5c675ad05290..1d0c169345d2 100644 --- a/torch/_dynamo/cache_size.py +++ b/torch/_dynamo/cache_size.py @@ -15,10 +15,10 @@ log = logging.getLogger(__name__) [Note on cache size limit] Background - TorchDynamo cache is a linked list. Each cache entry is a -(check_fn, out_code, next pointer). These are stored on the f_code's co_extra +(guard_manager, out_code, next pointer). These are stored on the f_code's co_extra scratch space. When a frame is invoked, we walk this linked list and run -check_fn in each cache_entry to decide if the frame needs recompilation. If none -of the check_fn's returns True, we recompile and add a new entry. To ensure we +guard_manager in each cache_entry to decide if the frame needs recompilation. If none +of the guard_manager's returns True, we recompile and add a new entry. To ensure we don't end up recompiling infinitely, we put limits on the cache size. There are two limits @@ -121,7 +121,7 @@ def _has_same_id_matched_objs(frame: types.FrameType, cache_entry) -> bool: for ( local_name, weakref_from_cache_entry, - ) in cache_entry.check_fn.id_matched_objs.items(): + ) in cache_entry.guard_manager.id_matched_objs.items(): if weakref_from_cache_entry() is not None: weakref_from_frame = _get_weakref_from_f_locals(frame, local_name) if weakref_from_frame is not weakref_from_cache_entry: @@ -176,7 +176,7 @@ def exceeds_cache_size_limit( if cache_size.will_compilation_exceed_specific_limit(config.cache_size_limit): return True, "cache_size_limit" # NOTE this check is needed in the case that the frame's cache doesn't grow - # and we keep recompiling. This can happen if the guard check_fn becomes invalidated, + # and we keep recompiling. This can happen if the guard guard_manager becomes invalidated, # e.g. due to guarded objects being freed. This technically makes the # will_compilation_exceed_accumulated_limit check unnecessary, but we will keep the # check in case we have a better fix in the future. diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 6a313b08c64c..a3aa8eb00e4b 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -842,7 +842,7 @@ def _compile( compile_id_str = str(compile_id) if compile_id is not None else "Unknown" annotation_str = "Torch-Compiled Region: " + compile_id_str guarded_code = GuardedCode( - out_code, check_fn.check_fn, compile_id, annotation_str + out_code, check_fn.guard_manager, compile_id, annotation_str # type: ignore[arg-type] ) if not output.is_empty_graph() and hooks.guard_export_fn is not None: diff --git a/torch/_dynamo/guards.py b/torch/_dynamo/guards.py index 717eb1499c79..51706da78f12 100644 --- a/torch/_dynamo/guards.py +++ b/torch/_dynamo/guards.py @@ -145,7 +145,7 @@ recompiles_verbose_log = torch._logging.getArtifactLogger( verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") -class GuardManager: +class GuardManagerWrapper: """ A helper class that contains the root guard manager. An instance of this class is stored in the Dynamo cache entry, so that the cache entry can @@ -526,7 +526,7 @@ class GuardBuilder(GuardBuilderBase): lookup_weakrefs: Callable[[object], ReferenceType[object]], local_scope: Dict[str, object], global_scope: Dict[str, object], - guard_manager: GuardManager, + guard_manager: GuardManagerWrapper, check_fn_manager: CheckFunctionManager, ): self.id_ref = id_ref @@ -570,7 +570,7 @@ class GuardBuilder(GuardBuilderBase): self.tensor_check_names: List[str] = [] self.tensor_check_examples: List[torch.Tensor] = [] self.tensor_check_guards: List[Guard] = [] - self.tensor_check_guard_managers: List[GuardManager] = [] + self.tensor_check_guard_managers: List[GuardManagerWrapper] = [] self.check_fn_manager: CheckFunctionManager = check_fn_manager @@ -583,7 +583,7 @@ class GuardBuilder(GuardBuilderBase): self.key_order_guarded_dict_ids.add(id(self.get(source_name))) # Keep track of weak references of objects with ID_MATCH guard. This - # info is stored alongside optimized_code and check_fn and is used to + # info is stored alongside optimized_code and guard_manager and is used to # limit the number of cache entries with same ID_MATCH'd object. self.id_matched_objs: Dict[str, ReferenceType[object]] = {} @@ -591,7 +591,6 @@ class GuardBuilder(GuardBuilderBase): self._cached_guard_managers: Dict[ str, torch._C._dynamo.guards.GuardManager ] = {} - self._cached_duplicate_input_guards: Set[Tuple[str, str]] = set() def guard_on_dict_keys_and_ignore_order(self, example_value, guard): @@ -2111,7 +2110,7 @@ class CheckFunctionManager: ): guards = output_graph.guards if output_graph else None self._weakrefs: Dict[int, ReferenceType[object]] = {} - self.guard_manager = GuardManager() + self.guard_manager = GuardManagerWrapper() self.output_graph = output_graph w_builder = None @@ -2171,17 +2170,17 @@ class CheckFunctionManager: guard.create(builder) - self.check_fn = self.compile_check_fn(builder, guards, guard_fail_fn) + self.compile_check_fn(builder, guards, guard_fail_fn) # Keep track of weak references of objects with ID_MATCH guard. This - # info is stored alongside optimized_code and check_fn and is used to + # info is stored alongside optimized_code and guard_manager and is used to # limit the number of cache entries with same ID_MATCH'd object. # TODO(anijain2305) - Currently this information is stored as an attr on - # the check_fn itself to avoid changing CacehEntry datastructure in - # eval_frame.c. In future, we should probably replace check_fn with a + # the guard_manager itself to avoid changing CacheEntry data structure in + # eval_frame.c. In future, we should probably replace guard_manager with a # queryable data structure such that this information is already present # in some form. - self.check_fn.id_matched_objs = builder.id_matched_objs + self.guard_manager.id_matched_objs = builder.id_matched_objs # TODO: don't do the string rep, do something more structured here torch._logging.trace_structured( @@ -2189,7 +2188,6 @@ class CheckFunctionManager: ) guards_log.debug("%s", self.guard_manager) self.guard_manager.id_matched_objs = builder.id_matched_objs - self.check_fn = self.guard_manager # Check that the guard returns True. False means that we will always # recompile. @@ -2351,45 +2349,39 @@ class CheckFunctionManager: } globals_for_guard_fn = {"G": builder.scope["G"]} - # Guard manager construction is complete - # TODO (anijain2305) - When enable_cpp_guard_manager is ON by - # default, change the guard_fn name to be guard_manager everywhere - # to avoid confusion. - guard_fn = self.guard_manager - # Ensure we did not miss to insert a guard in cpp guard manager. + # Guard manager construction is complete. Ensure we did not miss to + # insert a guard in cpp guard manager. assert len(code_parts) == 0 - guard_fn.closure_vars = closure_vars - # TODO(whc) maybe '.code_parts' was only kept around for the guard callback? so we don't need both - guard_fn.args = largs - guard_fn.populate_code_parts_for_debugging() - guard_fn.verbose_code_parts = verbose_code_parts + self.guard_manager.closure_vars = closure_vars + self.guard_manager.args = largs + self.guard_manager.populate_code_parts_for_debugging() + self.guard_manager.verbose_code_parts = verbose_code_parts # Grab only G, but preserve "G" because guards access it as "G" - guard_fn.global_scope = globals_for_guard_fn - guard_fn.guard_fail_fn = guard_fail_fn + self.guard_manager.global_scope = globals_for_guard_fn + self.guard_manager.guard_fail_fn = guard_fail_fn # will be populated by a non-owning reference to CacheEntry/ExtraState # when the CacheEntry is constructed - guard_fn.cache_entry = None - guard_fn.extra_state = None - guard_fn.no_tensor_aliasing_sources = tensor_check_names - return guard_fn + self.guard_manager.cache_entry = None + self.guard_manager.extra_state = None + self.guard_manager.no_tensor_aliasing_sources = tensor_check_names def invalidate(self): # Some tests reveal that CheckFunctionManager has no attribute - # check_fn, but this case should not be of any concern. + # guard_manager, but this case should not be of any concern. # This case doesn't seem easy to repro. if ( - hasattr(self, "check_fn") - and self.check_fn is not DeletedGuardFn - and (cache_entry := self.check_fn.cache_entry) is not None - and (extra_state := self.check_fn.extra_state) is not None + hasattr(self, "guard_manager") + and self.guard_manager is not DeletedGuardFn + and (cache_entry := self.guard_manager.cache_entry) is not None + and (extra_state := self.guard_manager.extra_state) is not None ): assert isinstance(cache_entry, CacheEntry) assert isinstance(extra_state, ExtraState) extra_state.invalidate(cache_entry) - self.check_fn.cache_entry = None - self.check_fn.extra_state = None - self.check_fn = DeletedGuardFn + self.guard_manager.cache_entry = None + self.guard_manager.extra_state = None + self.guard_manager = DeletedGuardFn # type: ignore[assignment] def id_ref(self, obj): """add a weakref, return the id""" @@ -2499,23 +2491,22 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope): def get_guard_fail_reason_helper( - guard_fn: GuardFn, + guard_manager: GuardFn, f_locals: Dict[str, object], compile_id: CompileId, ) -> str: """ - Return the reason why `guard_fn` failed. + Return the reason why `guard_manager` failed. Updates `guard_failures` with the generated reason. - Only the first failed check of guard_fn is reported. + Only the first failed check of guard_manager is reported. """ - scope = {"L": f_locals, "G": guard_fn.global_scope["G"]} - scope.update(guard_fn.closure_vars) + scope = {"L": f_locals, "G": guard_manager.global_scope["G"]} + scope.update(guard_manager.closure_vars) reasons: List[str] = [] no_tensor_aliasing_check_failed = False verbose_code_parts: List[str] = [] - guard_manager = guard_fn guard_debug_info = guard_manager.check_verbose(f_locals) # type: ignore[attr-defined] # For test_export_with_map_cond, the check_verbose fail even without the # C++ guard manager. We need to fix the issue to remove the comment. @@ -2537,10 +2528,12 @@ def get_guard_fail_reason_helper( verbose_code_parts = [] if no_tensor_aliasing_check_failed: - reasons = recompilation_reason_for_no_tensor_aliasing_guard(guard_fn, scope) + reasons = recompilation_reason_for_no_tensor_aliasing_guard( + guard_manager, scope + ) else: for part in verbose_code_parts: - global_scope = dict(guard_fn.global_scope) + global_scope = dict(guard_manager.global_scope) global_scope["__compile_source__"] = part with report_compile_source_on_error(): try: @@ -2565,17 +2558,17 @@ def get_guard_fail_reason_helper( def get_guard_fail_reason( - guard_fn: GuardFn, + guard_manager: GuardFn, code: types.CodeType, f_locals: Dict[str, object], compile_id: CompileId, ) -> str: - reason_str = get_guard_fail_reason_helper(guard_fn, f_locals, compile_id) + reason_str = get_guard_fail_reason_helper(guard_manager, f_locals, compile_id) guard_failures[orig_code_map[code]].append(reason_str) try: - if guard_fn.guard_fail_fn is not None: - guard_fn.guard_fail_fn( + if guard_manager.guard_fail_fn is not None: + guard_manager.guard_fail_fn( GuardFail(reason_str or "unknown reason", orig_code_map[code]) ) except Exception as e: @@ -2597,7 +2590,7 @@ def get_and_maybe_log_recompilation_reason( reasons = [] while cache_entry is not None: reason = get_guard_fail_reason( - cache_entry.check_fn, + cache_entry.guard_manager, cache_entry.code, frame.f_locals, cache_entry.compile_id, @@ -2647,7 +2640,7 @@ def get_and_maybe_log_recompilation_reason( def guard_error_hook( - guard_fn: GuardFn, + guard_manager: GuardFn, code: types.CodeType, f_locals: Dict[str, object], index: int, @@ -2656,15 +2649,15 @@ def guard_error_hook( print( f"ERROR RUNNING GUARDS {code.co_name} {code.co_filename}:{code.co_firstlineno}" ) - print("lambda " + ", ".join(guard_fn.args) + ":") - print(" ", " and\n ".join(guard_fn.code_parts)) + print("lambda " + ", ".join(guard_manager.args) + ":") + print(" ", " and\n ".join(guard_manager.code_parts)) - print(guard_fn) + print(guard_manager) - local_scope = {"L": f_locals, **guard_fn.closure_vars} - for guard in guard_fn.code_parts: + local_scope = {"L": f_locals, **guard_manager.closure_vars} + for guard in guard_manager.code_parts: try: - eval(guard, guard_fn.global_scope, local_scope) + eval(guard, guard_manager.global_scope, local_scope) except: # noqa: B001,E722 print(f"Malformed guard:\n{guard}") diff --git a/torch/_dynamo/testing.py b/torch/_dynamo/testing.py index bbc5f27713ab..9281c7c7e284 100644 --- a/torch/_dynamo/testing.py +++ b/torch/_dynamo/testing.py @@ -191,7 +191,7 @@ def debug_insert_nops( torch_function_mode_stack=[], ) - return GuardedCode(code, CheckFunctionManager(graph).check_fn, CompileId(0, 0)) + return GuardedCode(code, CheckFunctionManager(graph).guard_manager, CompileId(0, 0)) # type: ignore[arg-type] class CompileCounter: diff --git a/torch/_dynamo/types.py b/torch/_dynamo/types.py index 16ef7b5821c2..298741a4e958 100644 --- a/torch/_dynamo/types.py +++ b/torch/_dynamo/types.py @@ -3,7 +3,7 @@ import sys import types from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union -# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object. +# CacheEntry has a `guard_manager` field for the guard, and a `code` field for the code object. from torch._C._dynamo.eval_frame import ( _CacheEntry as CacheEntry, _ExtraState as ExtraState, @@ -46,7 +46,7 @@ class GuardFn(Protocol): @dataclasses.dataclass class GuardedCode: code: types.CodeType - check_fn: GuardFn + guard_manager: GuardFn compile_id: CompileId trace_annotation: str = "Unknown" @@ -67,7 +67,7 @@ DynamoCallback = Union[DynamoCallbackFn, None, bool] class DynamoGuardHook(Protocol): def __call__( self, - guard_fn: GuardFn, + guard_manager: GuardFn, code: types.CodeType, f_locals: Dict[str, object], index: int, diff --git a/torch/csrc/dynamo/cache_entry.cpp b/torch/csrc/dynamo/cache_entry.cpp index 6ea8a441c48f..2dc4bbece04b 100644 --- a/torch/csrc/dynamo/cache_entry.cpp +++ b/torch/csrc/dynamo/cache_entry.cpp @@ -6,7 +6,7 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) : backend{backend} { - this->check_fn = guarded_code.attr("check_fn"); + this->guard_manager = guarded_code.attr("guard_manager"); this->code = guarded_code.attr("code"); this->compile_id = guarded_code.attr("compile_id"); py::object trace_annotation = guarded_code.attr("trace_annotation"); @@ -16,8 +16,8 @@ CacheEntry::CacheEntry(const py::handle& guarded_code, PyObject* backend) } else { this->trace_annotation = "Unknown"; } - this->root_mgr = - torch::dynamo::convert_to_root_guard_manager(this->check_fn.attr("root")); + this->root_mgr = torch::dynamo::convert_to_root_guard_manager( + this->guard_manager.attr("root")); } C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( @@ -25,9 +25,9 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") // NOLINTNEXTLINE(bugprone-exception-escape) CacheEntry::~CacheEntry() { - // prevent check_fn from use-after-free when invalidating - this->check_fn.attr("cache_entry") = py::none(); - this->check_fn.attr("extra_state") = py::none(); + // prevent guard_manager from use-after-free when invalidating + this->guard_manager.attr("cache_entry") = py::none(); + this->guard_manager.attr("extra_state") = py::none(); } C10_DIAGNOSTIC_POP() C10_DIAGNOSTIC_POP() diff --git a/torch/csrc/dynamo/cache_entry.h b/torch/csrc/dynamo/cache_entry.h index 7d1d92084444..9747c0baa421 100644 --- a/torch/csrc/dynamo/cache_entry.h +++ b/torch/csrc/dynamo/cache_entry.h @@ -18,11 +18,12 @@ of the cache is as follows: -> ExtraState -> CacheEntry (list) - -> check_fn + -> guard_manager (a wrapper that contains the actual guard manager at its +attr named root) -> code -> FrameState -CacheEntry is a linked list node containing the check_fn for guards +CacheEntry is a linked list node containing the guard_manager for guards and the optimized code. The FrameState is a PyDict that enables sharing between different frames. This @@ -41,8 +42,8 @@ C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED( C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wdeprecated-copy-dtor") typedef struct VISIBILITY_HIDDEN CacheEntry { // check the guards: lambda: : bool - py::object check_fn; - // modified user bytecode (protected by check_fn's guards) + py::object guard_manager; + // modified user bytecode (protected by guard_manager's guards) py::object code; // CompileId corresponding to this compilation py::object compile_id; diff --git a/torch/csrc/dynamo/extra_state.cpp b/torch/csrc/dynamo/extra_state.cpp index 1c1632b22746..7ee796109655 100644 --- a/torch/csrc/dynamo/extra_state.cpp +++ b/torch/csrc/dynamo/extra_state.cpp @@ -132,7 +132,7 @@ void lookup( if (guard_error_hook) { py::handle guard_error_hook_handle(guard_error_hook); guard_error_hook_handle( - cache_entry.check_fn, + cache_entry.guard_manager, cache_entry.code, locals, index, @@ -168,12 +168,12 @@ CacheEntry* create_cache_entry( auto new_iter = extra_state->cache_entry_list.begin(); new_iter->_owner = extra_state; new_iter->_owner_loc = new_iter; - // Set check_fn references to extra_state and CacheEntry + // Set guard_manager references to extra_state and CacheEntry // Warning: lifetime is controlled by C++! - py::handle check_fn = py::handle(guarded_code).attr("check_fn"); - check_fn.attr("cache_entry") = + py::handle guard_manager = py::handle(guarded_code).attr("guard_manager"); + guard_manager.attr("cache_entry") = py::cast(*new_iter, py::return_value_policy::reference); - check_fn.attr("extra_state") = + guard_manager.attr("extra_state") = py::cast(extra_state, py::return_value_policy::reference); return &*new_iter; } diff --git a/torch/csrc/dynamo/init.cpp b/torch/csrc/dynamo/init.cpp index 5993c25caace..16a3f1e2c973 100644 --- a/torch/csrc/dynamo/init.cpp +++ b/torch/csrc/dynamo/init.cpp @@ -67,7 +67,7 @@ void initDynamoBindings(PyObject* torch) { auto m = py::handle(eval_frame).cast(); py::class_(m, "_CacheEntry") - .def_readonly("check_fn", &CacheEntry::check_fn) + .def_readonly("guard_manager", &CacheEntry::guard_manager) .def_readonly("code", &CacheEntry::code) .def_readonly("compile_id", &CacheEntry::compile_id) .def_readonly("trace_annotation", &CacheEntry::trace_annotation)