mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Replace frame_traced_fn hook with get_traced_code() util (#155249)
#153622 introduced a hook for getting the relevant code objects after frame tracing. The idea is to have vLLM use this instead of monkey-patching `inline_call_()` to determine the source code files to hash. Unfortunately, the hook runs too late; the vLLM backend needs access to the set of source code filenames while it's running. This PR replaces the newly-added hook with a utility function that a backend can call to get this information. I've made the change in vLLM and can verify that this allows the information to be queried at the right time. Pull Request resolved: https://github.com/pytorch/pytorch/pull/155249 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
8892b782a8
commit
c4b93e6579
@ -142,7 +142,7 @@ class TestUtils(TestCase):
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
|
||||
|
||||
def test_frame_traced_hook(self):
|
||||
def test_traced_code_query(self):
|
||||
try:
|
||||
from .utils import add, break_it
|
||||
except ImportError:
|
||||
@ -150,31 +150,33 @@ class TestUtils(TestCase):
|
||||
|
||||
traced_code_lists = []
|
||||
|
||||
def get_traced_code(s):
|
||||
nonlocal traced_code_lists
|
||||
traced_code_lists.append(s)
|
||||
|
||||
def get_filenames(traced_code_lists):
|
||||
return [
|
||||
[code.co_filename for code in code_list]
|
||||
for code_list in traced_code_lists
|
||||
]
|
||||
|
||||
def my_backend(gm, example_inputs):
|
||||
from torch._dynamo.utils import get_traced_code
|
||||
|
||||
nonlocal traced_code_lists
|
||||
traced_code_lists.append(get_traced_code())
|
||||
return gm.forward
|
||||
|
||||
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
|
||||
|
||||
# === no inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
@torch.compile(backend=my_backend)
|
||||
def fn(x):
|
||||
return x * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# expect hook to be called once with this file
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__]])
|
||||
|
||||
# === successful inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
@torch.compile(backend=my_backend)
|
||||
def fn(x):
|
||||
return add(x) * 2
|
||||
|
||||
@ -182,30 +184,28 @@ class TestUtils(TestCase):
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
|
||||
# expect hook to be called once with both this file and file of inlined func
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]])
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__, utils_path]])
|
||||
|
||||
# === graph break occurs during inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
@torch.compile(backend=my_backend)
|
||||
def fn(x):
|
||||
y = break_it(x)
|
||||
z = x + 1
|
||||
y = break_it(z)
|
||||
return y * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# expect hook to be called twice; once for this file one for file of inlined func
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]])
|
||||
|
||||
# === empty graph ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
@torch.compile(backend=my_backend)
|
||||
def fn(x):
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# hook is not expected to be called at all for an empty graph
|
||||
self.assertEqual(traced_code_lists, [])
|
||||
|
||||
|
||||
|
@ -2630,10 +2630,6 @@ def compile(
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
|
||||
frame_traced_fn = None
|
||||
if options and isinstance(options, dict):
|
||||
frame_traced_fn = options.pop("frame_traced_fn", None)
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
@ -2645,7 +2641,6 @@ def compile(
|
||||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
frame_traced_fn=frame_traced_fn,
|
||||
)(model) # type: ignore[return-value]
|
||||
|
||||
|
||||
|
@ -947,17 +947,13 @@ def _compile(
|
||||
annotation_str,
|
||||
)
|
||||
|
||||
if not output.is_empty_graph():
|
||||
if hooks.guard_export_fn is not None:
|
||||
# We should not run the guard_export_fn when Dynamo does not
|
||||
# generate any graph. This can happen in export when TorchDynamo
|
||||
# generated bytecode has some reconstruction logic for mutated
|
||||
# variables which can trigger TorchDynamo on the children frames but
|
||||
# they are benign and do not generate any new graphs.
|
||||
hooks.guard_export_fn(output.guards)
|
||||
if hooks.frame_traced_fn is not None:
|
||||
output.tracing_context.traced_code.append(output.f_code)
|
||||
hooks.frame_traced_fn(output.tracing_context.traced_code)
|
||||
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
|
||||
# We should not run the guard_export_fn when Dynamo does not
|
||||
# generate any graph. This can happen in export when TorchDynamo
|
||||
# generated bytecode has some reconstruction logic for mutated
|
||||
# variables which can trigger TorchDynamo on the children frames but
|
||||
# they are benign and do not generate any new graphs.
|
||||
hooks.guard_export_fn(output.guards)
|
||||
|
||||
return wrap_guarded_code(guarded_code)
|
||||
|
||||
|
@ -1025,7 +1025,6 @@ def _optimize(
|
||||
guard_export_fn=None,
|
||||
guard_fail_fn=None,
|
||||
guard_filter_fn=None,
|
||||
frame_traced_fn=None,
|
||||
disable=False,
|
||||
dynamic=None,
|
||||
) -> Union[OptimizeContext, _NullDecorator]:
|
||||
@ -1065,7 +1064,6 @@ def _optimize(
|
||||
guard_export_fn=guard_export_fn,
|
||||
guard_fail_fn=guard_fail_fn,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
frame_traced_fn=frame_traced_fn,
|
||||
)
|
||||
torch._C._log_api_usage_once("torch._dynamo.optimize")
|
||||
if (
|
||||
|
@ -6,13 +6,10 @@ guard-related operations.
|
||||
The Hooks class manages two types of hook functions:
|
||||
- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input
|
||||
- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input
|
||||
- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph.
|
||||
This hook will be passed the set of filenames containing traced code
|
||||
These hooks enable customization of guard export and failure handling behaviors.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from types import CodeType
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch._guards import GuardsSet
|
||||
@ -25,4 +22,3 @@ class Hooks:
|
||||
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
|
||||
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
|
||||
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None
|
||||
frame_traced_fn: Optional[Callable[[list[CodeType]], None]] = None
|
||||
|
@ -392,8 +392,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# Set of globals installed via install_global* APIs
|
||||
self.installed_globals: set[str] = set()
|
||||
|
||||
self.f_code = f_code
|
||||
# TODO: maybe should only store the entire f_code
|
||||
# TODO: maybe should just pass the entire f_code in here? Not
|
||||
# sure...
|
||||
self.co_fields = {
|
||||
"co_name": f_code.co_name,
|
||||
"co_filename": f_code.co_filename,
|
||||
@ -437,6 +437,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
export=self.export,
|
||||
)
|
||||
self.tracing_context: TracingContext = TracingContext(fake_mode)
|
||||
self.tracing_context.traced_code.append(f_code)
|
||||
self.dynamo_compile_id: Optional[CompileId] = (
|
||||
CompileContext.current_compile_id()
|
||||
)
|
||||
|
@ -50,7 +50,7 @@ from collections import Counter, OrderedDict
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from dataclasses import is_dataclass
|
||||
from functools import lru_cache
|
||||
from types import MethodWrapperType
|
||||
from types import CodeType, MethodWrapperType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
@ -4672,3 +4672,11 @@ def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
|
||||
|
||||
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
|
||||
cm.__exit__(None, None, None)
|
||||
|
||||
|
||||
# Returns a set of code objects present traced in the current TracingContext, or None
|
||||
# if there is no current TracingContext.
|
||||
def get_traced_code() -> list[CodeType]:
|
||||
from torch._guards import TracingContext
|
||||
|
||||
return TracingContext.get_traced_code()
|
||||
|
@ -986,6 +986,13 @@ class TracingContext:
|
||||
# framesummary.
|
||||
TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
|
||||
|
||||
@staticmethod
|
||||
def get_traced_code():
|
||||
tc = TracingContext.try_get()
|
||||
if tc is None:
|
||||
return None
|
||||
return tc.traced_code
|
||||
|
||||
|
||||
@contextmanager
|
||||
def compile_context(context: Optional[CompileContext]):
|
||||
|
Reference in New Issue
Block a user