[reland][dynamo] Record the pre-graph bytecode using fast record function event (#154974)

reland of https://github.com/pytorch/pytorch/pull/154769

@diff-train-skip-merge
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154974
Approved by: https://github.com/Lucaskabela, https://github.com/jansel
This commit is contained in:
Animesh Jain
2025-06-06 13:11:03 +00:00
committed by PyTorch MergeBot
parent 9656251bb1
commit 271ca679a8
4 changed files with 48 additions and 2 deletions

View File

@ -103,10 +103,15 @@ def reset():
class TestCompiledAutograd(TestCase):
def setUp(self) -> None:
self.exit_stack = contextlib.ExitStack()
self.exit_stack.enter_context(
config.patch("record_pre_graph_bytecode_in_traces", False)
)
super().setUp()
reset()
def tearDown(self) -> None:
self.exit_stack.close()
super().tearDown()
reset()

View File

@ -23,7 +23,7 @@ from typing import Optional, TYPE_CHECKING, Union
import torch.nn
from torch.utils._ordered_set import OrderedSet
from . import graph_break_hints, utils
from . import config, graph_break_hints, utils
from .bytecode_transformation import (
add_push_null,
add_push_null_call_function_ex,
@ -613,6 +613,18 @@ class PyCodegen:
if arg.source is not None:
collect_temp_source(arg.source)
cm_var = None
if config.record_pre_graph_bytecode_in_traces:
# Record the pregraph bytecode start
self.add_push_null(
lambda: self.load_import_from(
utils.__name__, "record_pregraph_bytecode_enter"
)
)
self.extend_output(create_call_function(0, False))
cm_var = self.new_var()
self.store(cm_var)
for arg in graphargs:
if arg.pass_arg_as_tensor:
self.add_push_null(
@ -628,6 +640,18 @@ class PyCodegen:
else:
self.call_reconstruct(arg)
if config.record_pre_graph_bytecode_in_traces:
# Record the pregraph bytecode end
self.add_push_null(
lambda: self.load_import_from(
utils.__name__, "record_pregraph_bytecode_exit"
)
)
assert cm_var is not None
self.extend_output([self.create_load(cm_var)])
self.extend_output(create_call_function(1, False))
self.pop_top()
self.extend_output(create_call_function(len(graphargs), False))
def load_import_from(self, module_name, object_name) -> None:

View File

@ -615,6 +615,9 @@ run_gc_after_compile = Config( # type: ignore[var-annotated]
# wrapper. This ensures that nn.module hooks are also compiled in the same frame.
wrap_top_frame = False
# record pre-graph bytecode in profile traces
record_pre_graph_bytecode_in_traces = True
# HACK: this is for testing custom ops profiling only
_custom_ops_profile: Optional[Any] = None

View File

@ -47,7 +47,7 @@ import uuid
import warnings
import weakref
from collections import Counter, OrderedDict
from contextlib import contextmanager
from contextlib import AbstractContextManager, contextmanager
from dataclasses import is_dataclass
from functools import lru_cache
from types import MethodWrapperType
@ -4674,3 +4674,17 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
return node is None or "example_value" in node.meta or "val" in node.meta
def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
cm: AbstractContextManager[None] = (
torch._C._profiler._RecordFunctionFast("Pregraph bytecode")
if torch.autograd.profiler._is_profiler_enabled
else contextlib.nullcontext()
)
cm.__enter__()
return cm
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
cm.__exit__(None, None, None)