Replace frame_traced_fn hook with get_traced_code() util (#155249)

#153622 introduced a hook for getting the relevant code objects after frame tracing. The idea is to have vLLM use this instead of monkey-patching `inline_call_()` to determine the source code files to hash. Unfortunately, the hook runs too late; the vLLM backend needs access to the set of source code filenames while it's running.

This PR replaces the newly-added hook with a utility function that a backend can call to get this information. I've made the change in vLLM and can verify that this allows the information to be queried at the right time.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/155249
Approved by: https://github.com/zou3519
This commit is contained in:
Joel Schlosser
2025-06-10 13:24:34 -04:00
committed by PyTorch MergeBot
parent 8892b782a8
commit c4b93e6579
8 changed files with 41 additions and 40 deletions

View File

@ -142,7 +142,7 @@ class TestUtils(TestCase):
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
def test_frame_traced_hook(self):
def test_traced_code_query(self):
try:
from .utils import add, break_it
except ImportError:
@ -150,31 +150,33 @@ class TestUtils(TestCase):
traced_code_lists = []
def get_traced_code(s):
nonlocal traced_code_lists
traced_code_lists.append(s)
def get_filenames(traced_code_lists):
return [
[code.co_filename for code in code_list]
for code_list in traced_code_lists
]
def my_backend(gm, example_inputs):
from torch._dynamo.utils import get_traced_code
nonlocal traced_code_lists
traced_code_lists.append(get_traced_code())
return gm.forward
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# === no inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
@torch.compile(backend=my_backend)
def fn(x):
return x * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
# expect hook to be called once with this file
self.assertEqual(get_filenames(traced_code_lists), [[__file__]])
# === successful inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
@torch.compile(backend=my_backend)
def fn(x):
return add(x) * 2
@ -182,30 +184,28 @@ class TestUtils(TestCase):
traced_code_lists = []
fn(x)
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# expect hook to be called once with both this file and file of inlined func
self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]])
self.assertEqual(get_filenames(traced_code_lists), [[__file__, utils_path]])
# === graph break occurs during inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
@torch.compile(backend=my_backend)
def fn(x):
y = break_it(x)
z = x + 1
y = break_it(z)
return y * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
# expect hook to be called twice; once for this file one for file of inlined func
self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]])
# === empty graph ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
@torch.compile(backend=my_backend)
def fn(x):
return x
x = torch.randn(3)
traced_code_lists = []
fn(x)
# hook is not expected to be called at all for an empty graph
self.assertEqual(traced_code_lists, [])

View File

@ -2630,10 +2630,6 @@ def compile(
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
frame_traced_fn = None
if options and isinstance(options, dict):
frame_traced_fn = options.pop("frame_traced_fn", None)
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
@ -2645,7 +2641,6 @@ def compile(
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
frame_traced_fn=frame_traced_fn,
)(model) # type: ignore[return-value]

View File

@ -947,17 +947,13 @@ def _compile(
annotation_str,
)
if not output.is_empty_graph():
if hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not
# generate any graph. This can happen in export when TorchDynamo
# generated bytecode has some reconstruction logic for mutated
# variables which can trigger TorchDynamo on the children frames but
# they are benign and do not generate any new graphs.
hooks.guard_export_fn(output.guards)
if hooks.frame_traced_fn is not None:
output.tracing_context.traced_code.append(output.f_code)
hooks.frame_traced_fn(output.tracing_context.traced_code)
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not
# generate any graph. This can happen in export when TorchDynamo
# generated bytecode has some reconstruction logic for mutated
# variables which can trigger TorchDynamo on the children frames but
# they are benign and do not generate any new graphs.
hooks.guard_export_fn(output.guards)
return wrap_guarded_code(guarded_code)

View File

@ -1025,7 +1025,6 @@ def _optimize(
guard_export_fn=None,
guard_fail_fn=None,
guard_filter_fn=None,
frame_traced_fn=None,
disable=False,
dynamic=None,
) -> Union[OptimizeContext, _NullDecorator]:
@ -1065,7 +1064,6 @@ def _optimize(
guard_export_fn=guard_export_fn,
guard_fail_fn=guard_fail_fn,
guard_filter_fn=guard_filter_fn,
frame_traced_fn=frame_traced_fn,
)
torch._C._log_api_usage_once("torch._dynamo.optimize")
if (

View File

@ -6,13 +6,10 @@ guard-related operations.
The Hooks class manages two types of hook functions:
- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input
- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input
- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph.
This hook will be passed the set of filenames containing traced code
These hooks enable customization of guard export and failure handling behaviors.
"""
import dataclasses
from types import CodeType
from typing import Callable, Optional
from torch._guards import GuardsSet
@ -25,4 +22,3 @@ class Hooks:
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None
frame_traced_fn: Optional[Callable[[list[CodeType]], None]] = None

View File

@ -392,8 +392,8 @@ class OutputGraph(OutputGraphGuardsState):
# Set of globals installed via install_global* APIs
self.installed_globals: set[str] = set()
self.f_code = f_code
# TODO: maybe should only store the entire f_code
# TODO: maybe should just pass the entire f_code in here? Not
# sure...
self.co_fields = {
"co_name": f_code.co_name,
"co_filename": f_code.co_filename,
@ -437,6 +437,7 @@ class OutputGraph(OutputGraphGuardsState):
export=self.export,
)
self.tracing_context: TracingContext = TracingContext(fake_mode)
self.tracing_context.traced_code.append(f_code)
self.dynamo_compile_id: Optional[CompileId] = (
CompileContext.current_compile_id()
)

View File

@ -50,7 +50,7 @@ from collections import Counter, OrderedDict
from contextlib import AbstractContextManager, contextmanager
from dataclasses import is_dataclass
from functools import lru_cache
from types import MethodWrapperType
from types import CodeType, MethodWrapperType
from typing import (
Any,
Callable,
@ -4672,3 +4672,11 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
cm.__exit__(None, None, None)
# Returns a set of code objects present traced in the current TracingContext, or None
# if there is no current TracingContext.
def get_traced_code() -> list[CodeType]:
from torch._guards import TracingContext
return TracingContext.get_traced_code()

View File

@ -986,6 +986,13 @@ class TracingContext:
# framesummary.
TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
@staticmethod
def get_traced_code():
tc = TracingContext.try_get()
if tc is None:
return None
return tc.traced_code
@contextmanager
def compile_context(context: Optional[CompileContext]):