Compare commits

...

1 Commits

Author SHA1 Message Date
39e77ce851 [dynamo] Add most recent bytecode to graph break with developer initiation
ghstack-source-id: 8b538f2e1ac703a4538468a758f08db0c89b91a7
Pull Request resolved: https://github.com/pytorch/pytorch/pull/163720

Add most recent bytecode to dynamo graph break called by user

Fix other user-initiated graph break and issues

Fix linter
2025-10-01 17:21:03 -07:00
3 changed files with 82 additions and 6 deletions

View File

@ -773,6 +773,8 @@ from user code:
return f
def post_munge(s):
s = re.sub(r"TRACE.*\n", "", s, flags=re.MULTILINE)
s = re.sub(r"\nTRACE.*", "", s)
return re.sub(r"0x[0-9A-Fa-f]+", "0xmem_addr", s)
torch.compile(fn, backend="eager")()
@ -795,7 +797,7 @@ User code traceback:
torch.compile(fn, backend="eager")()
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
""",
Most recent bytecode instructions traced (max 20):""",
)
self.assertExpectedInline(
@ -1015,6 +1017,7 @@ Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especiall
"<Internal traceback>\n",
msg,
)
msg = re.sub(r"TRACE.*\n", "", msg, flags=re.MULTILINE)
self.assertExpectedInline(
msg,
"""\
@ -1051,9 +1054,12 @@ from user code:
torch.compile(fn, backend="eager")(torch.randn(3))
# check the log for the 2nd torch._dynamo.graph_break()
def post_munge(s):
s = re.sub(r"TRACE.*\n", "", s, flags=re.MULTILINE)
return re.sub(r"TRACE.*", "", s)
self.assertExpectedInline(
munge_exc(records[-1].getMessage(), skip=0),
post_munge(munge_exc(records[-1].getMessage(), skip=0)),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
@ -1072,6 +1078,54 @@ User code traceback:
hn(x + 1)
File "test_error_messages.py", line N, in hn
torch._dynamo.graph_break() # 1
Most recent bytecode instructions traced (max 20):
""",
)
@torch._dynamo.config.patch(verbose=True)
@make_logging_test(graph_breaks=True) # , bytecode=True)
def test_variable_tracker_bytecode_to_graph_break(self, records):
@torch.compile(backend="eager")
def fn(x):
y = x + 1
z = x + y
torch._dynamo.graph_break()
return z
fn(torch.ones(3))
pattern = r"TRACE.*"
s = munge_exc(records[-1].getMessage(), skip=0)
matches = re.findall(pattern, s)
self.assertIn(len(matches), [13, 20])
def post_munge(s):
s = re.sub(r"TRACE.*\n", "", s, flags=re.MULTILINE)
return re.sub(r"TRACE.*", "", s)
# check the log for the 2nd torch._dynamo.graph_break()
self.assertExpectedInline(
post_munge(s),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
Explanation: User-inserted graph break. Message: None
Hint: Remove the `torch._dynamo.graph_break()` call.
Developer debug context: Called `torch._dynamo.graph_break()` with args `[]`, kwargs `{}`
For more details about this graph break, please visit: https://meta-pytorch.github.io/compile-graph-break-site/gb/gb0025.html
User code traceback:
File "test_error_messages.py", line N, in test_variable_tracker_bytecode_to_graph_break
fn(torch.ones(3))
========== most recent `torch.compile` tracing attempt started here ==========
File "test_error_messages.py", line N, in fn
torch._dynamo.graph_break()
NOTE: the most recent `torch.compile` tracing attempt might not be where you applied `torch.compile`! This is due to how graph breaks are implemented - the optimized code object returned by Dynamo will call another Dynamo-generated resume function and tracing is re-enabled by calling the resume function as a normal Python function, which Dynamo intercepts as a top-level frame.
Most recent bytecode instructions traced (max 20):
""",
)
@ -1166,8 +1220,13 @@ NOTE: the most recent `torch.compile` tracing attempt might not be where you app
f1(torch.randn(3))
def post_munge(s):
s = re.sub(r"TRACE.*\n", "", s, flags=re.MULTILINE)
s = re.sub(r"\nTRACE.*", "", s)
return s
self.assertExpectedInline(
munge_exc(records[-1].getMessage(), skip=0),
post_munge(munge_exc(records[-1].getMessage(), skip=0)),
"""\
Graph break in user code at test_error_messages.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
@ -1186,7 +1245,7 @@ User code traceback:
f3(x)
File "test_error_messages.py", line N, in f3
torch._dynamo.graph_break() # correct
""",
Most recent bytecode instructions traced (max 20):""",
)
@make_logging_test(dynamo=logging.DEBUG)

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: dynamo"]
import logging
import re
import unittest
import torch
@ -170,11 +171,13 @@ from user code:
torch.compile(fn001, backend="eager")(torch.randn(1))
record = self.getRecord(records, "Graph break in user code")
msg = re.sub(r"TRACE.*\n", "", record.getMessage(), flags=re.MULTILINE)
# msg =
# TODO: This should also report the enclosing frames; need to plumb
# frame object to it
self.assertExpectedInline(
munge_exc(record.getMessage()),
munge_exc(re.sub(r"TRACE.*", "", msg)),
"""\
Graph break in user code at test_exc.py:N
Graph Break Reason: Call to `torch._dynamo.graph_break()`
@ -191,6 +194,7 @@ User code traceback:
return fn002(x)
File "test_exc.py", line N, in fn002
torch._dynamo.graph_break()
Most recent bytecode instructions traced (max 20):
""", # noqa: B950
)

View File

@ -43,6 +43,7 @@ import threading
import traceback
import types
import weakref
from collections import deque
from traceback import StackSummary
from typing import Any, Callable, cast, NoReturn, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias, TypeIs
@ -206,6 +207,9 @@ compare_op_handlers["in"] = lambda tx, args, _: handle_contains(
compare_op_handlers["not in"] = lambda tx, args, _: handle_not(
tx, [handle_contains(tx, [*reversed(args)], {})], {}
)
latest_bytecode_queue: deque[str] = deque(
maxlen=20
) # Store the latest bytecode before graph_break() call by user
PT2_ISSUE_TRACKER_URL = "https://github.com/pytorch/pytorch/issues/new?&labels=oncall%3A+pt2&projects=&template=pt2-bug-report.yml"
@ -547,6 +551,7 @@ def log_graph_break(
reason: str = "",
exc_info: bool = False,
user_stack: Optional[StackSummary] = None,
latest_bytecode_log: Optional[str] = None,
) -> None:
if user_stack is None:
user_stack = torch._guards.TracingContext.extract_stack()
@ -608,6 +613,10 @@ def log_graph_break(
# This log line MUST contain the string "Graph break in user code",
# This log line is exercised from
# python test/dynamo/test_exc.py -k test_graph_break_log
if latest_bytecode_log:
user_stack_trace += "Most recent bytecode instructions traced (max 20):\n"
user_stack_trace += latest_bytecode_log
graph_break_log.debug(
user_stack_trace,
)
@ -929,6 +938,7 @@ def break_graph_if_unsupported(
exc_info=True,
reason=str(excp),
user_stack=excp.real_stack,
latest_bytecode_log="\n".join(latest_bytecode_queue),
)
if self.maybe_has_backedge():
@ -1342,6 +1352,9 @@ class InstructionTranslatorBase(
"TRACE %s %s %s", inst.opname, inst.argval, self.stack
)
# Store the latest 20 bytecode execution for the process
latest_bytecode_queue.append(f"TRACE {inst.opname} {inst.argval} {self.stack}")
self.update_block_stack(inst)
try: