diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index ee74fba237f7..007c56e6a26e 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -142,7 +142,7 @@ class TestUtils(TestCase): compilation_events = [arg[0][0] for arg in log_event.call_args_list] self.assertEqual(compilation_events[-1].num_graph_breaks, 2) - def test_frame_traced_hook(self): + def test_traced_code_query(self): try: from .utils import add, break_it except ImportError: @@ -150,31 +150,33 @@ class TestUtils(TestCase): traced_code_lists = [] - def get_traced_code(s): - nonlocal traced_code_lists - traced_code_lists.append(s) - def get_filenames(traced_code_lists): return [ [code.co_filename for code in code_list] for code_list in traced_code_lists ] + def my_backend(gm, example_inputs): + from torch._dynamo.utils import get_traced_code + + nonlocal traced_code_lists + traced_code_lists.append(get_traced_code()) + return gm.forward + utils_path = os.path.join(os.path.dirname(__file__), "utils.py") # === no inlining === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): return x * 2 x = torch.randn(3) traced_code_lists = [] fn(x) - # expect hook to be called once with this file self.assertEqual(get_filenames(traced_code_lists), [[__file__]]) # === successful inlining === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): return add(x) * 2 @@ -182,30 +184,28 @@ class TestUtils(TestCase): traced_code_lists = [] fn(x) utils_path = os.path.join(os.path.dirname(__file__), "utils.py") - # expect hook to be called once with both this file and file of inlined func - self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]]) + self.assertEqual(get_filenames(traced_code_lists), [[__file__, utils_path]]) # === graph break occurs during inlining === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): - y = break_it(x) + z = x + 1 + y = break_it(z) return y * 2 x = torch.randn(3) traced_code_lists = [] fn(x) - # expect hook to be called twice; once for this file one for file of inlined func self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]]) # === empty graph === - @torch.compile(options={"frame_traced_fn": get_traced_code}) + @torch.compile(backend=my_backend) def fn(x): return x x = torch.randn(3) traced_code_lists = [] fn(x) - # hook is not expected to be called at all for an empty graph self.assertEqual(traced_code_lists, []) diff --git a/torch/__init__.py b/torch/__init__.py index bb1735531f5b..c0b69f05ca26 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -2630,10 +2630,6 @@ def compile( if options and isinstance(options, dict): guard_filter_fn = options.pop("guard_filter_fn", None) - frame_traced_fn = None - if options and isinstance(options, dict): - frame_traced_fn = options.pop("frame_traced_fn", None) - if backend == "inductor": backend = _TorchCompileInductorWrapper(mode, options, dynamic) else: @@ -2645,7 +2641,6 @@ def compile( dynamic=dynamic, disable=disable, guard_filter_fn=guard_filter_fn, - frame_traced_fn=frame_traced_fn, )(model) # type: ignore[return-value] diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index b0465b14dd29..df0d04557707 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -947,17 +947,13 @@ def _compile( annotation_str, ) - if not output.is_empty_graph(): - if hooks.guard_export_fn is not None: - # We should not run the guard_export_fn when Dynamo does not - # generate any graph. This can happen in export when TorchDynamo - # generated bytecode has some reconstruction logic for mutated - # variables which can trigger TorchDynamo on the children frames but - # they are benign and do not generate any new graphs. - hooks.guard_export_fn(output.guards) - if hooks.frame_traced_fn is not None: - output.tracing_context.traced_code.append(output.f_code) - hooks.frame_traced_fn(output.tracing_context.traced_code) + if not output.is_empty_graph() and hooks.guard_export_fn is not None: + # We should not run the guard_export_fn when Dynamo does not + # generate any graph. This can happen in export when TorchDynamo + # generated bytecode has some reconstruction logic for mutated + # variables which can trigger TorchDynamo on the children frames but + # they are benign and do not generate any new graphs. + hooks.guard_export_fn(output.guards) return wrap_guarded_code(guarded_code) diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 62c0ccffb7c7..291b175fb031 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -1025,7 +1025,6 @@ def _optimize( guard_export_fn=None, guard_fail_fn=None, guard_filter_fn=None, - frame_traced_fn=None, disable=False, dynamic=None, ) -> Union[OptimizeContext, _NullDecorator]: @@ -1065,7 +1064,6 @@ def _optimize( guard_export_fn=guard_export_fn, guard_fail_fn=guard_fail_fn, guard_filter_fn=guard_filter_fn, - frame_traced_fn=frame_traced_fn, ) torch._C._log_api_usage_once("torch._dynamo.optimize") if ( diff --git a/torch/_dynamo/hooks.py b/torch/_dynamo/hooks.py index c362f04ebe70..e180ad6dedf0 100644 --- a/torch/_dynamo/hooks.py +++ b/torch/_dynamo/hooks.py @@ -6,13 +6,10 @@ guard-related operations. The Hooks class manages two types of hook functions: - guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input - guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input -- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph. - This hook will be passed the set of filenames containing traced code These hooks enable customization of guard export and failure handling behaviors. """ import dataclasses -from types import CodeType from typing import Callable, Optional from torch._guards import GuardsSet @@ -25,4 +22,3 @@ class Hooks: guard_export_fn: Optional[Callable[[GuardsSet], None]] = None guard_fail_fn: Optional[Callable[[GuardFail], None]] = None guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None - frame_traced_fn: Optional[Callable[[list[CodeType]], None]] = None diff --git a/torch/_dynamo/output_graph.py b/torch/_dynamo/output_graph.py index 65b0baba9a20..21d88b7f1b0b 100644 --- a/torch/_dynamo/output_graph.py +++ b/torch/_dynamo/output_graph.py @@ -392,8 +392,8 @@ class OutputGraph(OutputGraphGuardsState): # Set of globals installed via install_global* APIs self.installed_globals: set[str] = set() - self.f_code = f_code - # TODO: maybe should only store the entire f_code + # TODO: maybe should just pass the entire f_code in here? Not + # sure... self.co_fields = { "co_name": f_code.co_name, "co_filename": f_code.co_filename, @@ -437,6 +437,7 @@ class OutputGraph(OutputGraphGuardsState): export=self.export, ) self.tracing_context: TracingContext = TracingContext(fake_mode) + self.tracing_context.traced_code.append(f_code) self.dynamo_compile_id: Optional[CompileId] = ( CompileContext.current_compile_id() ) diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 117a180ccb0e..1f35afeb90b2 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -50,7 +50,7 @@ from collections import Counter, OrderedDict from contextlib import AbstractContextManager, contextmanager from dataclasses import is_dataclass from functools import lru_cache -from types import MethodWrapperType +from types import CodeType, MethodWrapperType from typing import ( Any, Callable, @@ -4672,3 +4672,11 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]: def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None: cm.__exit__(None, None, None) + + +# Returns a set of code objects present traced in the current TracingContext, or None +# if there is no current TracingContext. +def get_traced_code() -> list[CodeType]: + from torch._guards import TracingContext + + return TracingContext.get_traced_code() diff --git a/torch/_guards.py b/torch/_guards.py index 818696c1f3e7..28becfac5865 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -986,6 +986,13 @@ class TracingContext: # framesummary. TracingContext.get().loc_in_frame = (filename, lineno, frame_name) + @staticmethod + def get_traced_code(): + tc = TracingContext.try_get() + if tc is None: + return None + return tc.traced_code + @contextmanager def compile_context(context: Optional[CompileContext]):