mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] raise hard error if error is encountered while tracing resume function prologue (#154564)
This should prevent bad resume function prologues from slipping by. In particular, graph breaks in resume function prologues will now hard error. Implementation details: - The resume function prologue is surrounded by `LOAD_CONST arg, STORE_FAST __is_tracing_resume_prologue` instructions. The first sequence has `arg=True` and the second sequence has `arg=False`. - InstructionTranslator will know when it is tracing a resume function prologue when it detects `STORE_FAST __is_tracing_resume_prologue`. The top of stack will be True to mark the start of the prologue, False to mark the end. - When `convert_frame.py` detects that an error occurred while the InstructionTranslator was tracing a resume function prologue, we will wrap the exception and hard error Pull Request resolved: https://github.com/pytorch/pytorch/pull/154564 Approved by: https://github.com/jansel ghstack dependencies: #154283, #154289, #154782, #155166
This commit is contained in:
committed by
PyTorch MergeBot
parent
24dc33b37b
commit
0aed855b2b
@ -1,8 +1,8 @@
|
||||
add_loop_eager,compile_time_instruction_count,2937000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,3179757906,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4300194436,0.025
|
||||
add_loop_eager_dynamic,compile_time_instruction_count,4510896405,0.025
|
||||
|
||||
|
||||
|
||||
@ -18,7 +18,7 @@ add_loop_inductor_gpu,compile_time_instruction_count,25900000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,942514329,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,967638254,0.015
|
||||
|
||||
|
||||
|
||||
@ -38,11 +38,11 @@ update_hint_regression,compile_time_instruction_count,1661000000,0.02
|
||||
|
||||
|
||||
|
||||
sum_floordiv_regression,compile_time_instruction_count,984411080,0.015
|
||||
sum_floordiv_regression,compile_time_instruction_count,1009252114,0.015
|
||||
|
||||
|
||||
|
||||
symint_sum,compile_time_instruction_count,3252000000,0.015
|
||||
symint_sum,compile_time_instruction_count,3162643519,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -3,6 +3,7 @@
|
||||
import re
|
||||
import traceback
|
||||
import unittest
|
||||
import unittest.mock
|
||||
import warnings
|
||||
|
||||
import torch
|
||||
@ -10,7 +11,7 @@ import torch._dynamo
|
||||
import torch._dynamo.config
|
||||
import torch._dynamo.test_case
|
||||
import torch.utils._pytree as python_pytree
|
||||
from torch._dynamo.exc import Unsupported
|
||||
from torch._dynamo.exc import ResumePrologueTracingError, Unsupported
|
||||
from torch._dynamo.testing import skipIfNotPy312
|
||||
from torch._dynamo.utils import counters
|
||||
from torch.testing._internal.common_utils import (
|
||||
@ -1257,6 +1258,48 @@ from user code:
|
||||
post_munge=post_munge,
|
||||
)
|
||||
|
||||
# Test that errors while tracing resume function prologues do not get suppressed
|
||||
def test_graph_break_in_buggy_resume_prologue(self):
|
||||
import torch._dynamo.bytecode_transformation as bt
|
||||
import torch._dynamo.resume_execution as rex
|
||||
|
||||
# NOTE: do not define non_global as a global in this file!
|
||||
@torch.compile(backend="eager")
|
||||
def fn(non_global):
|
||||
non_global = non_global + 1
|
||||
torch._dynamo.graph_break()
|
||||
return non_global + 1
|
||||
|
||||
orig_clean_and_assemble_instructions = bt.clean_and_assemble_instructions
|
||||
|
||||
def bad_clean_and_assemble_instructions(instructions, *args):
|
||||
# Inject an invalid LOAD_GLOBAL after the first STORE_FAST IS_TRACING_RESUME_PROLOGUE_VARNAME
|
||||
for i, inst in enumerate(instructions):
|
||||
if (
|
||||
inst.opname == "STORE_FAST"
|
||||
and inst.argval == rex.IS_TRACING_RESUME_PROLOGUE_VARNAME
|
||||
):
|
||||
instructions[:] = (
|
||||
instructions[: i + 1]
|
||||
+ [
|
||||
# this should cause a graph break
|
||||
bt.create_instruction("LOAD_GLOBAL", argval="non_global"),
|
||||
]
|
||||
+ instructions[i + 1 :]
|
||||
)
|
||||
break
|
||||
return orig_clean_and_assemble_instructions(instructions, *args)
|
||||
|
||||
with unittest.mock.patch(
|
||||
"torch._dynamo.bytecode_transformation.clean_and_assemble_instructions",
|
||||
bad_clean_and_assemble_instructions,
|
||||
):
|
||||
with self.assertRaisesRegex(
|
||||
ResumePrologueTracingError,
|
||||
"Error while tracing through a Dynamo-generated resume function prologue.",
|
||||
):
|
||||
fn(torch.randn(3))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
@ -103,6 +103,7 @@ from .exc import (
|
||||
InternalTorchDynamoError,
|
||||
PackageError,
|
||||
RecompileLimitExceeded,
|
||||
ResumePrologueTracingError,
|
||||
ShortenTraceback,
|
||||
SkipCodeRecursiveException,
|
||||
TorchRuntimeError,
|
||||
@ -478,6 +479,12 @@ class ConvertFrameBox:
|
||||
error_on_graph_break: Optional[bool] = None
|
||||
|
||||
|
||||
def _is_error_on_graph_break(tx: Optional[InstructionTranslator]) -> bool:
|
||||
if tx is None:
|
||||
return config.error_on_graph_break
|
||||
return tx.error_on_graph_break
|
||||
|
||||
|
||||
class ConvertFrameAssert:
|
||||
def __init__(
|
||||
self,
|
||||
@ -873,12 +880,7 @@ def _compile(
|
||||
code.co_filename,
|
||||
code.co_firstlineno,
|
||||
)
|
||||
error_on_graph_break = (
|
||||
tracer.error_on_graph_break
|
||||
if tracer
|
||||
else config.error_on_graph_break
|
||||
)
|
||||
if one_graph or error_on_graph_break:
|
||||
if one_graph or _is_error_on_graph_break(tracer):
|
||||
log.debug(
|
||||
"No graph captured with one_graph=True or torch._dynamo.config.error_on_graph_break=True"
|
||||
)
|
||||
@ -1043,14 +1045,11 @@ def _compile(
|
||||
recompile_reason,
|
||||
troubleshooting_url,
|
||||
)
|
||||
error_on_graph_break = (
|
||||
tracer.error_on_graph_break if tracer else config.error_on_graph_break
|
||||
)
|
||||
if config.fail_on_recompile_limit_hit:
|
||||
raise FailOnRecompileLimitHit(
|
||||
f"{limit_type} reached, because fail_on_recompile_limit_hit = True this is a HARD failure"
|
||||
)
|
||||
elif one_graph or error_on_graph_break:
|
||||
elif one_graph or _is_error_on_graph_break(tracer):
|
||||
raise FailOnRecompileLimitHit(
|
||||
f"{limit_type} reached with one_graph=True or torch._dynamo.config.error_on_graph_break=True. "
|
||||
"Excessive recompilations can degrade "
|
||||
@ -1163,7 +1162,15 @@ def _compile(
|
||||
fail_user_frame_filename, fail_user_frame_lineno = exc.get_exc_message(
|
||||
e, compile_id
|
||||
)
|
||||
if isinstance(
|
||||
if tracer and tracer.is_tracing_resume_prologue:
|
||||
# Do not allow any errors to be suppressed if tracer is currently tracing
|
||||
# through resume function.
|
||||
raise ResumePrologueTracingError(
|
||||
"Error while tracing through a Dynamo-generated resume function prologue. "
|
||||
"Errors are not allowed when tracing resume function prologues.\n"
|
||||
f"{type(e).__qualname__}: {str(e)}"
|
||||
).with_traceback(e.__traceback__) from None
|
||||
elif isinstance(
|
||||
e,
|
||||
(
|
||||
Unsupported,
|
||||
@ -1310,6 +1317,10 @@ class ConvertFrame:
|
||||
counters["frames"]["ok"] += 1
|
||||
return result
|
||||
except Exception as e:
|
||||
# Do not allow errors to be suppressed if we're tracing a resume function prologue
|
||||
if isinstance(e, ResumePrologueTracingError):
|
||||
raise
|
||||
|
||||
error_on_graph_break = (
|
||||
self._inner_convert._box.error_on_graph_break is not None
|
||||
)
|
||||
|
@ -70,6 +70,10 @@ class InternalTorchDynamoError(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
||||
class ResumePrologueTracingError(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
||||
class RestartAnalysis(TorchDynamoException):
|
||||
restart_reason: Optional[str]
|
||||
|
||||
|
@ -49,6 +49,7 @@ CO_ASYNC_GENERATOR = 0x0200
|
||||
|
||||
# trace_rules.py import this constant for consistency
|
||||
TORCH_DYNAMO_RESUME_IN_PREFIX = "torch_dynamo_resume_in"
|
||||
IS_TRACING_RESUME_PROLOGUE_VARNAME = "__is_tracing_resume_prologue"
|
||||
|
||||
|
||||
def _initial_push_null(insts):
|
||||
@ -356,6 +357,7 @@ class ContinueExecutionCache:
|
||||
for v in code_options["co_varnames"]
|
||||
if v not in args and v not in freevars
|
||||
]
|
||||
+ [IS_TRACING_RESUME_PROLOGUE_VARNAME]
|
||||
)
|
||||
code_options["co_flags"] = code_options["co_flags"] & ~(
|
||||
CO_VARARGS | CO_VARKEYWORDS
|
||||
@ -370,6 +372,18 @@ class ContinueExecutionCache:
|
||||
)
|
||||
prefix.append(create_instruction("RESUME", arg=0))
|
||||
|
||||
# Set is_tracing_resume_prologue to prevent graph breaks.
|
||||
# This doesn't really do anything at runtime, but dynamo will trace this
|
||||
# and will know that we're in a resume function prologue.
|
||||
prefix.extend(
|
||||
[
|
||||
create_instruction("LOAD_CONST", argval=True),
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
cleanup: list[Instruction] = []
|
||||
hooks = {fn.stack_index: fn for fn in setup_fns}
|
||||
hook_target_offsets = {
|
||||
@ -431,6 +445,16 @@ class ContinueExecutionCache:
|
||||
]
|
||||
)
|
||||
|
||||
# Set is_tracing_resume_prologue back to allow graph breaks.
|
||||
prefix.extend(
|
||||
[
|
||||
create_instruction("LOAD_CONST", argval=False),
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
prefix.append(create_jump_absolute(target))
|
||||
|
||||
# because the line number table monotonically increases from co_firstlineno
|
||||
|
@ -95,7 +95,11 @@ from .funcname_cache import get_funcname
|
||||
from .guards import GuardBuilder, install_guard
|
||||
from .output_graph import GraphCompileReason, OutputGraph
|
||||
from .replay_record import DummyModule, ExecutionRecorder
|
||||
from .resume_execution import ContinueExecutionCache, ReenterWith
|
||||
from .resume_execution import (
|
||||
ContinueExecutionCache,
|
||||
IS_TRACING_RESUME_PROLOGUE_VARNAME,
|
||||
ReenterWith,
|
||||
)
|
||||
from .source import (
|
||||
AttrSource,
|
||||
DictGetItemSource,
|
||||
@ -1473,6 +1477,10 @@ class InstructionTranslatorBase(
|
||||
loaded_vt = self.pop()
|
||||
loaded_vt.set_name_hint(name)
|
||||
self.symbolic_locals[name] = loaded_vt
|
||||
if name == IS_TRACING_RESUME_PROLOGUE_VARNAME:
|
||||
val = loaded_vt.as_python_constant()
|
||||
assert type(val) is bool
|
||||
self.is_tracing_resume_prologue = val
|
||||
|
||||
def DELETE_FAST(self, inst):
|
||||
del self.symbolic_locals[inst.argval]
|
||||
@ -3262,6 +3270,8 @@ class InstructionTranslatorBase(
|
||||
# the same instruction.
|
||||
self.one_graph = False
|
||||
self.error_on_graph_break = False
|
||||
# Also do not graph break when tracing resume function prologues
|
||||
self.is_tracing_resume_prologue = False
|
||||
|
||||
self.current_speculation = None
|
||||
|
||||
@ -3526,6 +3536,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
all(b.can_restore() for b in self.block_stack)
|
||||
and not self.one_graph
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
and not self.active_generic_context_managers
|
||||
)
|
||||
|
||||
@ -3661,6 +3672,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
and not self.export
|
||||
and not self.one_graph
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
):
|
||||
raise exc.SkipFrame("because no content in function call")
|
||||
|
||||
|
Reference in New Issue
Block a user