[dynamo] raise hard error if error is encountered while tracing resume function prologue (#154564)

This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error.

Implementation details:
- The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`.
- InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end.
- When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564
Approved by: https://github.com/jansel
ghstack dependencies: #154283, #154289, #154782, #155166
This commit is contained in:
William Wen
2025-06-18 20:01:28 -07:00
committed by PyTorch MergeBot
parent 24dc33b37b
commit 0aed855b2b
6 changed files with 112 additions and 18 deletions

View File

@ -1,8 +1,8 @@
add_loop_eager,compile_time_instruction_count,2937000000,0.015
add_loop_eager,compile_time_instruction_count,3179757906,0.015
add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
add_loop_eager_dynamic,compile_time_instruction_count,4510896405,0.025
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,942514329,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,967638254,0.015
@ -38,11 +38,11 @@ update_hint_regression,compile_time_instruction_count,1661000000,0.02
sum_floordiv_regression,compile_time_instruction_count,984411080,0.015
sum_floordiv_regression,compile_time_instruction_count,1009252114,0.015
symint_sum,compile_time_instruction_count,3252000000,0.015
symint_sum,compile_time_instruction_count,3162643519,0.015

1 add_loop_eager compile_time_instruction_count 2937000000 3179757906 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 4300194436 4510896405 0.025
3 add_loop_inductor compile_time_instruction_count 29370000000 29370000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 38747844521 38747844521 0.025
5 add_loop_inductor_gpu compile_time_instruction_count 25900000000 25900000000 0.015
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 942514329 967638254 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18390000000 18390000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 16450000000 16450000000 0.015
18 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3838000000 3838000000 0.015
19 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10420000000 10420000000 0.015
20
21
22
23
24
38
39
40
41
42
43
44
45
46
47
48

View File

@ -3,6 +3,7 @@
import re
import traceback
import unittest
import unittest.mock
import warnings
import torch
@ -10,7 +11,7 @@ import torch._dynamo
import torch._dynamo.config
import torch._dynamo.test_case
import torch.utils._pytree as python_pytree
from torch._dynamo.exc import Unsupported
from torch._dynamo.exc import ResumePrologueTracingError, Unsupported
from torch._dynamo.testing import skipIfNotPy312
from torch._dynamo.utils import counters
from torch.testing._internal.common_utils import (
@ -1257,6 +1258,48 @@ from user code:
post_munge=post_munge,
)
# Test that errors while tracing resume function prologues do not get suppressed
def test_graph_break_in_buggy_resume_prologue(self):
import torch._dynamo.bytecode_transformation as bt
import torch._dynamo.resume_execution as rex
# NOTE: do not define non_global as a global in this file!
@torch.compile(backend="eager")
def fn(non_global):
non_global = non_global + 1
torch._dynamo.graph_break()
return non_global + 1
orig_clean_and_assemble_instructions = bt.clean_and_assemble_instructions
def bad_clean_and_assemble_instructions(instructions, *args):
# Inject an invalid LOAD_GLOBAL after the first STORE_FAST IS_TRACING_RESUME_PROLOGUE_VARNAME
for i, inst in enumerate(instructions):
if (
inst.opname == "STORE_FAST"
and inst.argval == rex.IS_TRACING_RESUME_PROLOGUE_VARNAME
):
instructions[:] = (
instructions[: i + 1]
+ [
# this should cause a graph break
bt.create_instruction("LOAD_GLOBAL", argval="non_global"),
]
+ instructions[i + 1 :]
)
break
return orig_clean_and_assemble_instructions(instructions, *args)
with unittest.mock.patch(
"torch._dynamo.bytecode_transformation.clean_and_assemble_instructions",
bad_clean_and_assemble_instructions,
):
with self.assertRaisesRegex(
ResumePrologueTracingError,
"Error while tracing through a Dynamo-generated resume function prologue.",
):
fn(torch.randn(3))
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View File

@ -103,6 +103,7 @@ from .exc import (
InternalTorchDynamoError,
PackageError,
RecompileLimitExceeded,
ResumePrologueTracingError,
ShortenTraceback,
SkipCodeRecursiveException,
TorchRuntimeError,
@ -478,6 +479,12 @@ class ConvertFrameBox:
error_on_graph_break: Optional[bool] = None
def _is_error_on_graph_break(tx: Optional[InstructionTranslator]) -> bool:
if tx is None:
return config.error_on_graph_break
return tx.error_on_graph_break
class ConvertFrameAssert:
def __init__(
self,
@ -873,12 +880,7 @@ def _compile(
code.co_filename,
code.co_firstlineno,
)
error_on_graph_break = (
tracer.error_on_graph_break
if tracer
else config.error_on_graph_break
)
if one_graph or error_on_graph_break:
if one_graph or _is_error_on_graph_break(tracer):
log.debug(
"No graph captured with one_graph=True or torch._dynamo.config.error_on_graph_break=True"
)
@ -1043,14 +1045,11 @@ def _compile(
recompile_reason,
troubleshooting_url,
)
error_on_graph_break = (
tracer.error_on_graph_break if tracer else config.error_on_graph_break
)
if config.fail_on_recompile_limit_hit:
raise FailOnRecompileLimitHit(
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
)
elif one_graph or error_on_graph_break:
elif one_graph or _is_error_on_graph_break(tracer):
raise FailOnRecompileLimitHit(
f"{limit_type} reached with one_graph=True or torch._dynamo.config.error_on_graph_break=True. "
"Excessive recompilations can degrade "
@ -1163,7 +1162,15 @@ def _compile(
fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message(
e, compile_id
)
if isinstance(
if tracer and tracer.is_tracing_resume_prologue:
# Do not allow any errors to be suppressed if tracer is currently tracing
# through resume function.
raise ResumePrologueTracingError(
"Error while tracing through a Dynamo-generated resume function prologue. "
"Errors are not allowed when tracing resume function prologues.\n"
f"{type(e).__qualname__}: {str(e)}"
).with_traceback(e.__traceback__) from None
elif isinstance(
e,
(
Unsupported,
@ -1310,6 +1317,10 @@ class ConvertFrame:
counters["frames"]["ok"] += 1
return result
except Exception as e:
# Do not allow errors to be suppressed if we're tracing a resume function prologue
if isinstance(e, ResumePrologueTracingError):
raise
error_on_graph_break = (
self._inner_convert._box.error_on_graph_break is not None
)

View File

@ -70,6 +70,10 @@ class InternalTorchDynamoError(TorchDynamoException):
pass
class ResumePrologueTracingError(TorchDynamoException):
pass
class RestartAnalysis(TorchDynamoException):
restart_reason: Optional[str]

View File

@ -49,6 +49,7 @@ CO_ASYNC_GENERATOR = 0x0200
# trace_rules.py import this constant for consistency
TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
def _initial_push_null(insts):
@ -356,6 +357,7 @@ class ContinueExecutionCache:
for v in code_options["co_varnames"]
if v not in args and v not in freevars
]
+ [IS_TRACING_RESUME_PROLOGUE_VARNAME]
)
code_options["co_flags"] = code_options["co_flags"] & ~(
CO_VARARGS | CO_VARKEYWORDS
@ -370,6 +372,18 @@ class ContinueExecutionCache:
)
prefix.append(create_instruction("RESUME", arg=0))
# Set is_tracing_resume_prologue to prevent graph breaks.
# This doesn't really do anything at runtime, but dynamo will trace this
# and will know that we're in a resume function prologue.
prefix.extend(
[
create_instruction("LOAD_CONST", argval=True),
create_instruction(
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
]
)
cleanup: list[Instruction] = []
hooks = {fn.stack_index: fn for fn in setup_fns}
hook_target_offsets = {
@ -431,6 +445,16 @@ class ContinueExecutionCache:
]
)
# Set is_tracing_resume_prologue back to allow graph breaks.
prefix.extend(
[
create_instruction("LOAD_CONST", argval=False),
create_instruction(
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
]
)
prefix.append(create_jump_absolute(target))
# because the line number table monotonically increases from co_firstlineno

View File

@ -95,7 +95,11 @@ from .funcname_cache import get_funcname
from .guards import GuardBuilder, install_guard
from .output_graph import GraphCompileReason, OutputGraph
from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
from .resume_execution import (
ContinueExecutionCache,
IS_TRACING_RESUME_PROLOGUE_VARNAME,
ReenterWith,
)
from .source import (
AttrSource,
DictGetItemSource,
@ -1473,6 +1477,10 @@ class InstructionTranslatorBase(
loaded_vt = self.pop()
loaded_vt.set_name_hint(name)
self.symbolic_locals[name] = loaded_vt
if name == IS_TRACING_RESUME_PROLOGUE_VARNAME:
val = loaded_vt.as_python_constant()
assert type(val) is bool
self.is_tracing_resume_prologue = val
def DELETE_FAST(self, inst):
del self.symbolic_locals[inst.argval]
@ -3262,6 +3270,8 @@ class InstructionTranslatorBase(
# the same instruction.
self.one_graph = False
self.error_on_graph_break = False
# Also do not graph break when tracing resume function prologues
self.is_tracing_resume_prologue = False
self.current_speculation = None
@ -3526,6 +3536,7 @@ class InstructionTranslator(InstructionTranslatorBase):
all(b.can_restore() for b in self.block_stack)
and not self.one_graph
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
and not self.active_generic_context_managers
)
@ -3661,6 +3672,7 @@ class InstructionTranslator(InstructionTranslatorBase):
and not self.export
and not self.one_graph
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
):
raise exc.SkipFrame("because no content in function call")