Compare commits

...

30 Commits

Author SHA1 Message Date
c4d369369f Add error handling for self.stack when byte exceeding limit 2025-10-10 02:48:22 +00:00
a13f24980e Fix CI on the max length conversion 2025-10-10 02:48:22 +00:00
6869487ca4 Fix more byte output 2025-10-10 02:48:21 +00:00
5d9105f2ca Add support for byte in loggin stream 2025-10-10 02:48:21 +00:00
191e6bb367 Fix comment and CI again 2025-10-10 02:48:21 +00:00
a15a08725b Add linter 2025-10-10 02:48:21 +00:00
756ea14378 Fix linter thank you 2025-10-10 02:48:21 +00:00
d7c5ea03df Fix linter 2025-10-10 02:48:21 +00:00
d11e253ee3 Add linter 2025-10-10 02:48:21 +00:00
01d5211679 Fix more comment and CI 2025-10-10 02:48:21 +00:00
b496a04735 Fix comment and more CI 2025-10-10 02:48:21 +00:00
03be8d227b Fix comment 2025-10-10 02:48:21 +00:00
df1b8c3e41 Fix more CI 2025-10-10 02:48:21 +00:00
94f39d5749 Fix CI 2025-10-10 02:48:21 +00:00
2eb8b70d1b Fix more comments and the case where verbose is true 2025-10-10 02:48:21 +00:00
29680dd928 Fix comments and errors 2025-10-10 02:48:21 +00:00
69bcc97937 Add linter 2025-10-10 02:48:21 +00:00
babac1d561 Fix bytecode log to graph break with queue initialization with new tx 2025-10-10 02:48:21 +00:00
8594b98b0a Add user called graph break python version specific test 2025-10-10 02:48:21 +00:00
b3fc84229e Add user called graph break test on full graph true mode 2025-10-10 02:48:21 +00:00
e409e84a7a Add fullgraph testing for dynamo 2025-10-10 02:48:21 +00:00
9c3742e7a7 Add todo for the logging output of bytecode 2025-10-10 02:48:21 +00:00
664a137dbb Fix comments from github 2025-10-10 02:48:21 +00:00
4f5a0deb83 Revert "Update torch/_dynamo/symbolic_convert.py"
This reverts commit d3d658ba65c1d627076b79bbdbebfdb9fa0ad37c.
2025-10-10 02:48:21 +00:00
4752d8fec9 Revert "Update test/dynamo/test_exc.py"
This reverts commit 7996380dc95141bf855a30b5f9b7e2b21c384f88.
2025-10-10 02:48:21 +00:00
715f0a26d7 Revert "Update test/dynamo/test_error_messages.py"
This reverts commit 1b185d792048e875f48d0a3e0bc67d47a618e5a2.
2025-10-10 02:48:21 +00:00
e9e2553603 Update test/dynamo/test_error_messages.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-10 02:48:21 +00:00
43fac7f55d Update test/dynamo/test_exc.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-10 02:48:21 +00:00
a875f27482 Update torch/_dynamo/symbolic_convert.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-10-10 02:48:21 +00:00
f34e0a941a [dynamo] Add most recent bytecode to graph break with developer initiation
ghstack-source-id: 8b538f2e1ac703a4538468a758f08db0c89b91a7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163720

Add most recent bytecode to dynamo graph break called by user

Fix other user-initiated graph break and issues

Fix linter
2025-10-10 02:48:21 +00:00
3 changed files with 127 additions and 2 deletions

View File

@ -14,7 +14,7 @@ import torch._dynamo.config
import torch._dynamo.test_case
import torch.utils._pytree as python_pytree
from torch._dynamo.exc import ResumePrologueTracingError, Unsupported
from torch._dynamo.testing import skipIfNotPy312
from torch._dynamo.testing import skipIfNotPy312, skipIfOnlyNotPy312
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
IS_FBCODE,
@ -1015,6 +1015,7 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
"<Internal traceback>\n",
msg,
)
self.assertExpectedInline(
msg,
"""\
@ -1051,7 +1052,6 @@ from user code:
torch.compile(fn, backend="eager")(torch.randn(3))
# check the log for the 2nd torch._dynamo.graph_break()
self.assertExpectedInline(
munge_exc(records[-1].getMessage(), skip=0),
"""\
@ -1075,6 +1075,104 @@ User code traceback:
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break_fullgraph(self, records):
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
self.assertExpectedInlineMunged(
Unsupported,
lambda: torch.compile(fn, backend="eager", fullgraph=True)(torch.randn(3)),
"""\
Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
from user code:
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
)
@skipIfOnlyNotPy312
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break_python_versioning(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
s = munge_exc(records[0].getMessage(), skip=0)
self.assertExpectedInline(
s,
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "test_error_messages.py", line N, in test_latest_bytecode_to_graph_break_python_versioning
fn(torch.ones(3))
========== most recent `torch.compile` tracing attempt started here ==========
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
Most recent bytecode instructions traced (max 20):
TRACE RESUME 0 []
TRACE LOAD_FAST 'x' []
TRACE LOAD_CONST 1 [LazyVariableTracker()]
TRACE BINARY_OP 0 [LazyVariableTracker(), ConstantVariable(int: 1)]
TRACE STORE_FAST 'y' [TensorVariable()]
TRACE LOAD_FAST 'x' []
TRACE LOAD_FAST 'y' [TensorVariable()]
TRACE BINARY_OP 0 [TensorVariable(), TensorVariable()]
TRACE STORE_FAST 'z' [TensorVariable()]
TRACE LOAD_GLOBAL 'torch' []
TRACE LOAD_ATTR '_dynamo' [LazyVariableTracker()]
TRACE LOAD_ATTR 'graph_break' [LazyVariableTracker()]
TRACE CALL 0 [NullVariable, LazyVariableTracker()]""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_latest_bytecode_to_graph_break(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
pattern = r"TRACE.*"
s = munge_exc(records[0].getMessage(), skip=0)
matches = re.findall(pattern, s)
self.assertEqual((len(matches) > 10), True)
self.assertEqual((len(matches) <= 20), True)
self.assertIn("Most recent bytecode instructions traced (max 20):", s)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True)
def test_graph_break_traceback_above_dynamo_shows_user_code(self, records):

View File

@ -43,6 +43,7 @@ import threading
import traceback
import types
import weakref
from collections import deque
from traceback import StackSummary
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias, TypeIs
@ -544,6 +545,7 @@ def log_graph_break(
reason: str = "",
exc_info: bool = False,
user_stack: Optional[StackSummary] = None,
latest_bytecode_log: Optional[str] = None,
) -> None:
if user_stack is None:
user_stack = torch._guards.TracingContext.extract_stack()
@ -606,6 +608,10 @@ def log_graph_break(
# This log line MUST contain the string "Graph break in user code",
# This log line is exercised from
# python test/dynamo/test_exc.py -k test_graph_break_log
if latest_bytecode_log and config.verbose:
user_stack_trace += "Most recent bytecode instructions traced (max 20):\n"
user_stack_trace += latest_bytecode_log
graph_break_log.debug(
user_stack_trace,
)
@ -933,6 +939,7 @@ def break_graph_if_unsupported(
exc_info=True,
reason=str(excp),
user_stack=excp.real_stack,
latest_bytecode_log="\n".join(self.latest_bytecode_queue),
)
if self.maybe_has_backedge():
@ -1184,6 +1191,8 @@ class InstructionTranslatorBase(
parent: Optional[InstructionTranslatorBase]
debug_locals: list[tuple[VariableTracker, list[VariableTracker]]]
package: Optional[CompilePackage]
latest_bytecode_queue: deque[str]
# Store the latest bytecode before graph_break() call by user
def mark_inconsistent_side_effects(self) -> None:
"""
@ -1351,6 +1360,17 @@ class InstructionTranslatorBase(
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
)
# Store the latest 20 bytecode execution for the process,
# Used repr for byte processing and limiting the length to 2048
try:
stack_repr = repr(self.stack)
except ValueError:
# Handle large integers that exceed sys.int_info.str_digits_check_threshold
stack_repr = "<self.stack repr truncated due to large integer>"
self.latest_bytecode_queue.append(
f"TRACE {inst.opname} {repr(inst.argval)} {stack_repr}"
)
self.update_block_stack(inst)
try:
@ -4083,6 +4103,7 @@ class InstructionTranslatorBase(
self.accept_prefix_inst = True
self.prefix_insts = []
self.exn_vt_stack = exn_vt_stack
self.latest_bytecode_queue = deque(maxlen=20)
# Properties of the input/output code
self.instructions: list[Instruction] = instructions

View File

@ -506,6 +506,12 @@ def skipIfNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
return unittest.skip("Requires Python 3.12+")(fn)
def skipIfOnlyNotPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 13) or sys.version_info < (3, 12):
return unittest.skip("Requires Python 3.12")(fn)
return fn
def xfailIfPy312(fn: Callable[_P, _T]) -> Callable[_P, _T]:
if sys.version_info >= (3, 12):
return unittest.expectedFailure(fn)