mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[reland][dynamo] Record the pre-graph bytecode using fast record function event (#154974)
reland of https://github.com/pytorch/pytorch/pull/154769 @diff-train-skip-merge Pull Request resolved: https://github.com/pytorch/pytorch/pull/154974 Approved by: https://github.com/Lucaskabela, https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
9656251bb1
commit
271ca679a8
@ -103,10 +103,15 @@ def reset():
|
||||
|
||||
class TestCompiledAutograd(TestCase):
|
||||
def setUp(self) -> None:
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
self.exit_stack.enter_context(
|
||||
config.patch("record_pre_graph_bytecode_in_traces", False)
|
||||
)
|
||||
super().setUp()
|
||||
reset()
|
||||
|
||||
def tearDown(self) -> None:
|
||||
self.exit_stack.close()
|
||||
super().tearDown()
|
||||
reset()
|
||||
|
||||
|
@ -23,7 +23,7 @@ from typing import Optional, TYPE_CHECKING, Union
|
||||
import torch.nn
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from . import graph_break_hints, utils
|
||||
from . import config, graph_break_hints, utils
|
||||
from .bytecode_transformation import (
|
||||
add_push_null,
|
||||
add_push_null_call_function_ex,
|
||||
@ -613,6 +613,18 @@ class PyCodegen:
|
||||
if arg.source is not None:
|
||||
collect_temp_source(arg.source)
|
||||
|
||||
cm_var = None
|
||||
if config.record_pre_graph_bytecode_in_traces:
|
||||
# Record the pregraph bytecode start
|
||||
self.add_push_null(
|
||||
lambda: self.load_import_from(
|
||||
utils.__name__, "record_pregraph_bytecode_enter"
|
||||
)
|
||||
)
|
||||
self.extend_output(create_call_function(0, False))
|
||||
cm_var = self.new_var()
|
||||
self.store(cm_var)
|
||||
|
||||
for arg in graphargs:
|
||||
if arg.pass_arg_as_tensor:
|
||||
self.add_push_null(
|
||||
@ -628,6 +640,18 @@ class PyCodegen:
|
||||
else:
|
||||
self.call_reconstruct(arg)
|
||||
|
||||
if config.record_pre_graph_bytecode_in_traces:
|
||||
# Record the pregraph bytecode end
|
||||
self.add_push_null(
|
||||
lambda: self.load_import_from(
|
||||
utils.__name__, "record_pregraph_bytecode_exit"
|
||||
)
|
||||
)
|
||||
assert cm_var is not None
|
||||
self.extend_output([self.create_load(cm_var)])
|
||||
self.extend_output(create_call_function(1, False))
|
||||
self.pop_top()
|
||||
|
||||
self.extend_output(create_call_function(len(graphargs), False))
|
||||
|
||||
def load_import_from(self, module_name, object_name) -> None:
|
||||
|
@ -615,6 +615,9 @@ run_gc_after_compile = Config( # type: ignore[var-annotated]
|
||||
# wrapper. This ensures that nn.module hooks are also compiled in the same frame.
|
||||
wrap_top_frame = False
|
||||
|
||||
# record pre-graph bytecode in profile traces
|
||||
record_pre_graph_bytecode_in_traces = True
|
||||
|
||||
# HACK: this is for testing custom ops profiling only
|
||||
_custom_ops_profile: Optional[Any] = None
|
||||
|
||||
|
@ -47,7 +47,7 @@ import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import Counter, OrderedDict
|
||||
from contextlib import contextmanager
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from dataclasses import is_dataclass
|
||||
from functools import lru_cache
|
||||
from types import MethodWrapperType
|
||||
@ -4674,3 +4674,17 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
|
||||
|
||||
def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
|
||||
return node is None or "example_value" in node.meta or "val" in node.meta
|
||||
|
||||
|
||||
def record_pregraph_bytecode_enter() -> AbstractContextManager[None]:
|
||||
cm: AbstractContextManager[None] = (
|
||||
torch._C._profiler._RecordFunctionFast("Pregraph bytecode")
|
||||
if torch.autograd.profiler._is_profiler_enabled
|
||||
else contextlib.nullcontext()
|
||||
)
|
||||
cm.__enter__()
|
||||
return cm
|
||||
|
||||
|
||||
def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
|
||||
cm.__exit__(None, None, None)
|
||||
|
Reference in New Issue
Block a user