[dynamo][logging] Add most recent bytecode to graph break with torch._dynamo.graph_break() and verbose (#164422)

https://github.com/pytorch/pytorch/issues/162858 The issue described the feature implemented.

This adds to the existing graph break log with the latest 20 (or viable user frame) bytecode instructions. The scenario is when the graph_break happens without errors. It happens during the case when user calling torch._dynamo.graph_break().

Meanwhile, in the testing, one can find that the generated frame based on step() is not deterministic as sometimes it reached the maximum amount, sometimes it generated the less than that. The bytecode generation is python version dependent. Thus, the testing plan excludes the bytecode output but generated the total bytecode line count.

This is a helpful process to understand bytecode transformation, symbolic convert, and convert frame. It is a helpful task to provide hands-on experience with dynamo workflow.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164422
Approved by: https://github.com/williamwen42, https://github.com/mlazos

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Xiao Fu
2025-10-10 17:33:04 +00:00
committed by PyTorch MergeBot
parent f975bd58af
commit dec9a59992
3 changed files with 127 additions and 2 deletions

View File

@ -14,7 +14,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch.utils._pytree as python_pytree
from torch._dynamo.exc import ResumePrologueTracingError, Unsupported
from torch._dynamo.testing import skipIfNotPy312
from torch._dynamo.testing import skipIfNotPy312, skipIfOnlyNotPy312
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
IS_FBCODE,
@ -1015,6 +1015,7 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
"<Internal traceback>\n",
msg,
)
self.assertExpectedInline(
msg,
"""\
@ -1051,7 +1052,6 @@ from user code:
torch.compile(fn, backend="eager")(torch.randn(3))
# check the log for the 2nd torch._dynamo.graph_break()
self.assertExpectedInline(
munge_exc(records[-1].getMessage(), skip=0),
"""\
@ -1075,6 +1075,104 @@ User code traceback:
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break_fullgraph(self, records):
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(3)),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@skipIfOnlyNotPy312
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break_python_versioning(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
s = munge_exc(records[0].getMessage(), skip=0)
self.assertExpectedInline(
s,
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "test_error_messages.py", line N, in test_latest_bytecode_to_graph_break_python_versioning
fn(torch.ones(3))
========== most recent `torch.compile` tracing attempt started here ==========
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
Most recent bytecode instructions traced (max 20):
TRACE RESUME 0 []
TRACE LOAD_FAST 'x' []
TRACE LOAD_CONST 1 [LazyVariableTracker()]
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
TRACE STORE_FAST 'y' [TensorVariable()]
TRACE LOAD_FAST 'x' []
TRACE LOAD_FAST 'y' [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE STORE_FAST 'z' [TensorVariable()]
TRACE LOAD_GLOBAL 'torch' []
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker()]
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker()]
TRACE CALL 0 [NullVariable, LazyVariableTracker()]""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
pattern = r"TRACE.*"
s = munge_exc(records[0].getMessage(), skip=0)
matches = re.findall(pattern, s)
self.assertEqual((len(matches) > 10), True)
self.assertEqual((len(matches) <= 20), True)
self.assertIn("Most recent bytecode instructions traced (max 20):", s)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_graph_break_traceback_above_dynamo_shows_user_code(self, records):

View File

@ -43,6 +43,7 @@ import threading
import traceback
import types
import weakref
from collections import deque
from traceback import StackSummary
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias, TypeIs
@ -544,6 +545,7 @@ def log_graph_break(
reason: str = "",
exc_info: bool = False,
user_stack: Optional[StackSummary] = None,
latest_bytecode_log: Optional[str] = None,
) -> None:
if user_stack is None:
user_stack = torch._guards.TracingContext.extract_stack()
@ -606,6 +608,10 @@ def log_graph_break(
# This log line MUST contain the string "Graph break in user code",
# This log line is exercised from
# python test/dynamo/test_exc.py -k test_graph_break_log
if latest_bytecode_log and config.verbose:
user_stack_trace += "Most recent bytecode instructions traced (max 20):\n"
user_stack_trace += latest_bytecode_log
graph_break_log.debug(
user_stack_trace,
)
@ -933,6 +939,7 @@ def break_graph_if_unsupported(
exc_info=True,
reason=str(excp),
user_stack=excp.real_stack,
latest_bytecode_log="\n".join(self.latest_bytecode_queue),
)
if self.maybe_has_backedge():
@ -1184,6 +1191,8 @@ class InstructionTranslatorBase(
parent: Optional[InstructionTranslatorBase]
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
package: Optional[CompilePackage]
latest_bytecode_queue: deque[str]
# Store the latest bytecode before graph_break() call by user
def mark_inconsistent_side_effects(self) -> None:
"""
@ -1351,6 +1360,17 @@ class InstructionTranslatorBase(
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
)
# Store the latest 20 bytecode execution for the process,
# Used repr for byte processing and limiting the length to 2048
try:
stack_repr = repr(self.stack)
except ValueError:
# Handle large integers that exceed sys.int_info.str_digits_check_threshold
stack_repr = "<self.stack repr truncated due to large integer>"
self.latest_bytecode_queue.append(
f"TRACE {inst.opname} {repr(inst.argval)} {stack_repr}"
)
self.update_block_stack(inst)
try:
@ -4083,6 +4103,7 @@ class InstructionTranslatorBase(
self.accept_prefix_inst = True
self.prefix_insts = []
self.exn_vt_stack = exn_vt_stack
self.latest_bytecode_queue = deque(maxlen=20)
# Properties of the input/output code
self.instructions: list[Instruction] = instructions

View File

@ -506,6 +506,12 @@ def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
return unittest.skip("Requires Python 3.12+")(fn)
def skipIfOnlyNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 13) or sys.version_info < (3, 12):
return unittest.skip("Requires Python 3.12")(fn)
return fn
def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 12):
return unittest.expectedFailure(fn)