[Dynamo] Introduce hook receiving list of traced code objects (#153622)

This PR:
* Expands `Hooks` with a new, optional `frame_traced_fn` field. It should be a callable receiving the list of traced code objects
* Maintains a list of `traced_code` objects in the `TracingContext` of an `OutputGraph`
    *  Whenever an `inline_call()` is encountered, the corresponding code object is added to this set
    * `OutputGraph`'s associated `f_code` is added to the list just before the hook is called

I believe use of this hook should enable the source code hashing that vLLM does in a better way than monkey-patching `inline_call()`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153622
Approved by: https://github.com/jansel
This commit is contained in:
Joel Schlosser
2025-05-19 10:32:49 -04:00
committed by PyTorch MergeBot
parent 476e0a643a
commit 9db7bcb3fe
9 changed files with 97 additions and 10 deletions

View File

@ -1,5 +1,6 @@
# Owner(s): ["module: dynamo"]
import dataclasses
import os
import pprint
import sys
from unittest import mock
@ -141,6 +142,69 @@ class TestUtils(TestCase):
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
def test_frame_traced_hook(self):
from utils import add, break_it
traced_code_lists = []
def get_traced_code(s):
nonlocal traced_code_lists
traced_code_lists.append(s)
def get_filenames(traced_code_lists):
return [
[code.co_filename for code in code_list]
for code_list in traced_code_lists
]
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# === no inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
return x * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
# expect hook to be called once with this file
self.assertEqual(get_filenames(traced_code_lists), [[__file__]])
# === successful inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
return add(x) * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
# expect hook to be called once with both this file and file of inlined func
self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]])
# === graph break occurs during inlining ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
y = break_it(x)
return y * 2
x = torch.randn(3)
traced_code_lists = []
fn(x)
# expect hook to be called twice; once for this file one for file of inlined func
self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]])
# === empty graph ===
@torch.compile(options={"frame_traced_fn": get_traced_code})
def fn(x):
return x
x = torch.randn(3)
traced_code_lists = []
fn(x)
# hook is not expected to be called at all for an empty graph
self.assertEqual(traced_code_lists, [])
class TestModel(torch.nn.Module):
def __init__(self):

View File

@ -39,6 +39,10 @@ def add(x):
return x + 1
def break_it(x):
return x.sum().item()
def create_dummy_module_and_function():
module = types.ModuleType("dummy_module")
module.__spec__ = importlib.machinery.ModuleSpec(

View File

@ -2597,6 +2597,10 @@ def compile(
if options and isinstance(options, dict):
guard_filter_fn = options.pop("guard_filter_fn", None)
frame_traced_fn = None
if options and isinstance(options, dict):
frame_traced_fn = options.pop("frame_traced_fn", None)
if backend == "inductor":
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
else:
@ -2608,6 +2612,7 @@ def compile(
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
frame_traced_fn=frame_traced_fn,
)(model) # type: ignore[return-value]

View File

@ -941,13 +941,17 @@ def _compile(
annotation_str,
)
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not
# generate any graph. This can happen in export when TorchDynamo
# generated bytecode has some reconstruction logic for mutated
# variables which can trigger TorchDynamo on the children frames but
# they are benign and do not generate any new graphs.
hooks.guard_export_fn(output.guards)
if not output.is_empty_graph():
if hooks.guard_export_fn is not None:
# We should not run the guard_export_fn when Dynamo does not
# generate any graph. This can happen in export when TorchDynamo
# generated bytecode has some reconstruction logic for mutated
# variables which can trigger TorchDynamo on the children frames but
# they are benign and do not generate any new graphs.
hooks.guard_export_fn(output.guards)
if hooks.frame_traced_fn is not None:
output.tracing_context.traced_code.append(output.f_code)
hooks.frame_traced_fn(output.tracing_context.traced_code)
return wrap_guarded_code(guarded_code)

View File

@ -1015,6 +1015,7 @@ def _optimize(
guard_export_fn=None,
guard_fail_fn=None,
guard_filter_fn=None,
frame_traced_fn=None,
disable=False,
dynamic=None,
) -> Union[OptimizeContext, _NullDecorator]:
@ -1054,6 +1055,7 @@ def _optimize(
guard_export_fn=guard_export_fn,
guard_fail_fn=guard_fail_fn,
guard_filter_fn=guard_filter_fn,
frame_traced_fn=frame_traced_fn,
)
torch._C._log_api_usage_once("torch._dynamo.optimize")
if (

View File

@ -6,11 +6,13 @@ guard-related operations.
The Hooks class manages two types of hook functions:
- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input
- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input
- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph.
This hook will be passed the set of filenames containing traced code
These hooks enable customization of guard export and failure handling behaviors.
"""
import dataclasses
from types import CodeType
from typing import Callable, Optional
from torch._guards import GuardsSet
@ -23,3 +25,4 @@ class Hooks:
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None
frame_traced_fn: Optional[Callable[[list[CodeType]], None]] = None

View File

@ -372,8 +372,8 @@ class OutputGraph(OutputGraphGuardsState):
# Set of globals installed via install_global* APIs
self.installed_globals: set[str] = set()
# TODO: maybe should just pass the entire f_code in here? Not
# sure...
self.f_code = f_code
# TODO: maybe should only store the entire f_code
self.co_fields = {
"co_name": f_code.co_name,
"co_filename": f_code.co_filename,

View File

@ -3956,6 +3956,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
parent.inconsistent_side_effects |= self.inconsistent_side_effects
log.debug("DONE INLINING %s", code)
self.output.tracing_context.traced_code.append(code)
if config.enable_faithful_generator_behavior or (
isinstance(self, InliningGeneratorInstructionTranslator)

View File

@ -37,6 +37,8 @@ log = logging.getLogger(__name__)
if TYPE_CHECKING:
from types import CodeType
import sympy
@ -867,6 +869,8 @@ class TracingContext:
# see note: [Returning Fake Tensors on First AOT Autograd Call]
self.fakify_first_call = False
self.hop_dispatch_set_cache = HopDispatchSetCache()
# list of code objects for inlined functions
self.traced_code: list[CodeType] = []
def clear(self):
# Look at the note in output_graph.py in function `save_global_state`