mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[Dynamo] Introduce hook receiving list of traced code objects (#153622)
This PR: * Expands `Hooks` with a new, optional `frame_traced_fn` field. It should be a callable receiving the list of traced code objects * Maintains a list of `traced_code` objects in the `TracingContext` of an `OutputGraph` * Whenever an `inline_call()` is encountered, the corresponding code object is added to this set * `OutputGraph`'s associated `f_code` is added to the list just before the hook is called I believe use of this hook should enable the source code hashing that vLLM does in a better way than monkey-patching `inline_call()`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153622 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
476e0a643a
commit
9db7bcb3fe
@ -1,5 +1,6 @@
|
||||
# Owner(s): ["module: dynamo"]
|
||||
import dataclasses
|
||||
import os
|
||||
import pprint
|
||||
import sys
|
||||
from unittest import mock
|
||||
@ -141,6 +142,69 @@ class TestUtils(TestCase):
|
||||
compilation_events = [arg[0][0] for arg in log_event.call_args_list]
|
||||
self.assertEqual(compilation_events[-1].num_graph_breaks, 2)
|
||||
|
||||
def test_frame_traced_hook(self):
|
||||
from utils import add, break_it
|
||||
|
||||
traced_code_lists = []
|
||||
|
||||
def get_traced_code(s):
|
||||
nonlocal traced_code_lists
|
||||
traced_code_lists.append(s)
|
||||
|
||||
def get_filenames(traced_code_lists):
|
||||
return [
|
||||
[code.co_filename for code in code_list]
|
||||
for code_list in traced_code_lists
|
||||
]
|
||||
|
||||
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
|
||||
|
||||
# === no inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
return x * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# expect hook to be called once with this file
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__]])
|
||||
|
||||
# === successful inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
return add(x) * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
utils_path = os.path.join(os.path.dirname(__file__), "utils.py")
|
||||
# expect hook to be called once with both this file and file of inlined func
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[utils_path, __file__]])
|
||||
|
||||
# === graph break occurs during inlining ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
y = break_it(x)
|
||||
return y * 2
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# expect hook to be called twice; once for this file one for file of inlined func
|
||||
self.assertEqual(get_filenames(traced_code_lists), [[__file__], [utils_path]])
|
||||
|
||||
# === empty graph ===
|
||||
@torch.compile(options={"frame_traced_fn": get_traced_code})
|
||||
def fn(x):
|
||||
return x
|
||||
|
||||
x = torch.randn(3)
|
||||
traced_code_lists = []
|
||||
fn(x)
|
||||
# hook is not expected to be called at all for an empty graph
|
||||
self.assertEqual(traced_code_lists, [])
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
|
@ -39,6 +39,10 @@ def add(x):
|
||||
return x + 1
|
||||
|
||||
|
||||
def break_it(x):
|
||||
return x.sum().item()
|
||||
|
||||
|
||||
def create_dummy_module_and_function():
|
||||
module = types.ModuleType("dummy_module")
|
||||
module.__spec__ = importlib.machinery.ModuleSpec(
|
||||
|
@ -2597,6 +2597,10 @@ def compile(
|
||||
if options and isinstance(options, dict):
|
||||
guard_filter_fn = options.pop("guard_filter_fn", None)
|
||||
|
||||
frame_traced_fn = None
|
||||
if options and isinstance(options, dict):
|
||||
frame_traced_fn = options.pop("frame_traced_fn", None)
|
||||
|
||||
if backend == "inductor":
|
||||
backend = _TorchCompileInductorWrapper(mode, options, dynamic)
|
||||
else:
|
||||
@ -2608,6 +2612,7 @@ def compile(
|
||||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
frame_traced_fn=frame_traced_fn,
|
||||
)(model) # type: ignore[return-value]
|
||||
|
||||
|
||||
|
@ -941,13 +941,17 @@ def _compile(
|
||||
annotation_str,
|
||||
)
|
||||
|
||||
if not output.is_empty_graph() and hooks.guard_export_fn is not None:
|
||||
# We should not run the guard_export_fn when Dynamo does not
|
||||
# generate any graph. This can happen in export when TorchDynamo
|
||||
# generated bytecode has some reconstruction logic for mutated
|
||||
# variables which can trigger TorchDynamo on the children frames but
|
||||
# they are benign and do not generate any new graphs.
|
||||
hooks.guard_export_fn(output.guards)
|
||||
if not output.is_empty_graph():
|
||||
if hooks.guard_export_fn is not None:
|
||||
# We should not run the guard_export_fn when Dynamo does not
|
||||
# generate any graph. This can happen in export when TorchDynamo
|
||||
# generated bytecode has some reconstruction logic for mutated
|
||||
# variables which can trigger TorchDynamo on the children frames but
|
||||
# they are benign and do not generate any new graphs.
|
||||
hooks.guard_export_fn(output.guards)
|
||||
if hooks.frame_traced_fn is not None:
|
||||
output.tracing_context.traced_code.append(output.f_code)
|
||||
hooks.frame_traced_fn(output.tracing_context.traced_code)
|
||||
|
||||
return wrap_guarded_code(guarded_code)
|
||||
|
||||
|
@ -1015,6 +1015,7 @@ def _optimize(
|
||||
guard_export_fn=None,
|
||||
guard_fail_fn=None,
|
||||
guard_filter_fn=None,
|
||||
frame_traced_fn=None,
|
||||
disable=False,
|
||||
dynamic=None,
|
||||
) -> Union[OptimizeContext, _NullDecorator]:
|
||||
@ -1054,6 +1055,7 @@ def _optimize(
|
||||
guard_export_fn=guard_export_fn,
|
||||
guard_fail_fn=guard_fail_fn,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
frame_traced_fn=frame_traced_fn,
|
||||
)
|
||||
torch._C._log_api_usage_once("torch._dynamo.optimize")
|
||||
if (
|
||||
|
@ -6,11 +6,13 @@ guard-related operations.
|
||||
The Hooks class manages two types of hook functions:
|
||||
- guard_export_fn: Called when guards need to be exported, taking a GuardsSet as input
|
||||
- guard_fail_fn: Called when a guard check fails, taking a GuardFail object as input
|
||||
|
||||
- frame_traced_fn: Called when a frame has finished tracing, resulting in a non-empty graph.
|
||||
This hook will be passed the set of filenames containing traced code
|
||||
These hooks enable customization of guard export and failure handling behaviors.
|
||||
"""
|
||||
|
||||
import dataclasses
|
||||
from types import CodeType
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch._guards import GuardsSet
|
||||
@ -23,3 +25,4 @@ class Hooks:
|
||||
guard_export_fn: Optional[Callable[[GuardsSet], None]] = None
|
||||
guard_fail_fn: Optional[Callable[[GuardFail], None]] = None
|
||||
guard_filter_fn: Optional[Callable[[list[GuardFilterEntry]], list[bool]]] = None
|
||||
frame_traced_fn: Optional[Callable[[list[CodeType]], None]] = None
|
||||
|
@ -372,8 +372,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# Set of globals installed via install_global* APIs
|
||||
self.installed_globals: set[str] = set()
|
||||
|
||||
# TODO: maybe should just pass the entire f_code in here? Not
|
||||
# sure...
|
||||
self.f_code = f_code
|
||||
# TODO: maybe should only store the entire f_code
|
||||
self.co_fields = {
|
||||
"co_name": f_code.co_name,
|
||||
"co_filename": f_code.co_filename,
|
||||
|
@ -3956,6 +3956,7 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
parent.inconsistent_side_effects |= self.inconsistent_side_effects
|
||||
|
||||
log.debug("DONE INLINING %s", code)
|
||||
self.output.tracing_context.traced_code.append(code)
|
||||
|
||||
if config.enable_faithful_generator_behavior or (
|
||||
isinstance(self, InliningGeneratorInstructionTranslator)
|
||||
|
@ -37,6 +37,8 @@ log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import CodeType
|
||||
|
||||
import sympy
|
||||
|
||||
|
||||
@ -867,6 +869,8 @@ class TracingContext:
|
||||
# see note: [Returning Fake Tensors on First AOT Autograd Call]
|
||||
self.fakify_first_call = False
|
||||
self.hop_dispatch_set_cache = HopDispatchSetCache()
|
||||
# list of code objects for inlined functions
|
||||
self.traced_code: list[CodeType] = []
|
||||
|
||||
def clear(self):
|
||||
# Look at the note in output_graph.py in function `save_global_state`
|
||||
|
Reference in New Issue
Block a user