[dynamo, nested graph breaks] support very simple nested graph breaks (#159329)

e.g. this graph breaks once now:
```python
import torch

torch._dynamo.config.nested_graph_breaks = True

def inner(x):
    x = x + 1
    torch._dynamo.graph_break()
    return x + 2

@torch.compile(backend="eager")
def outer(x):
    return inner(x)

print(outer(torch.ones(3)))
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159329
Approved by: https://github.com/anijain2305
ghstack dependencies: #157971, #159281, #144516
This commit is contained in:
William Wen
2025-08-25 13:27:42 -07:00
committed by PyTorch MergeBot
parent 9a756c2d71
commit 8dab6d4c41
5 changed files with 226 additions and 108 deletions

View File

@ -97,8 +97,11 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
super().tearDown()
torch._dynamo.config.nested_graph_breaks = False
@unittest.expectedFailure
def test_single_graph_break(self):
# NOTE marking f1, f2, f3 as global
# prevents them from being freevars
global f1, f2, f3
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
@ -118,8 +121,9 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_single_graph_break_repeat(self):
global f1, f2, f3
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
@ -141,8 +145,9 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
@unittest.expectedFailure
def test_doubly_nested_graph_break(self):
global f1, f2, f3
def f1(x1):
x1 = x1 + 1
torch._dynamo.graph_break()
@ -164,8 +169,9 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 3)
@unittest.expectedFailure
def test_differing_arg_nums(self):
global f1, f2, f3, f4
def f1(x1, x2):
x = x1 + x2
torch._dynamo.graph_break()
@ -188,8 +194,9 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_differing_locals_nums(self):
global f1, f2, f3
def f1(x1):
loc1 = x1 + 1
torch._dynamo.graph_break()
@ -324,8 +331,8 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_side_effects_globals(self):
global f1, f2, f3
global global1, global2, global3, global4
def f1():
@ -361,8 +368,8 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
@unittest.expectedFailure
def test_side_effects_globals_different_module(self):
global f1, f2, _test_nested_graph_breaks_helper
try:
from . import _test_nested_graph_breaks_helper
except ImportError:

View File

@ -212,6 +212,10 @@ def create_jump_absolute(target: Instruction) -> Instruction:
return create_instruction(inst, target=target)
def is_jump_absolute(target: Instruction) -> bool:
return target.opname in ("JUMP_FORWARD", "JUMP_ABSOLUTE")
def create_load_const(val: Any, checked: bool = True) -> Instruction:
"""
In general we should only create `LOAD_CONST` for immutable objects, but
@ -504,15 +508,6 @@ def create_binary_slice(
]
def create_reverse(n: int) -> list[Instruction]:
# Reverse the top n values on the stack
# UNPACK_SEQUENCE reverses the sequence
return [
create_instruction("BUILD_TUPLE", arg=n),
create_instruction("UNPACK_SEQUENCE", arg=n),
]
def lnotab_writer(
lineno: int, byteno: int = 0
) -> tuple[list[int], Callable[[int, int], None]]:

View File

@ -1456,15 +1456,7 @@ def _compile(
e, compile_id
)
tracer_output = getattr(e, "_torch_dynamo_tracer_output", None)
if tracer_output and tracer_output.is_tracing_resume_prologue:
# Do not allow any errors to be suppressed if tracer is currently tracing
# through resume function.
raise ResumePrologueTracingError(
"Error while tracing through a Dynamo-generated resume function prologue. "
"Errors are not allowed when tracing resume function prologues.\n"
f"{type(e).__qualname__}: {str(e)}"
).with_traceback(e.__traceback__) from None
elif isinstance(
if isinstance(
e,
(
Unsupported,
@ -1478,6 +1470,7 @@ def _compile(
BisectValidationException,
ShortenTraceback,
PackageError,
ResumePrologueTracingError,
),
):
raise

View File

@ -22,8 +22,10 @@ from contextlib import AbstractContextManager
from typing import Any, Callable, cast, Optional
from .bytecode_transformation import (
add_push_null,
bytecode_from_template,
create_call_function,
create_dup_top,
create_instruction,
create_jump_absolute,
create_load_const,
@ -310,6 +312,7 @@ class ContinueExecutionCache:
stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
null_idxes: tuple[int, ...],
has_nested: bool,
) -> types.CodeType:
assert offset is not None
assert not (
@ -330,6 +333,7 @@ class ContinueExecutionCache:
stack_ctx_vars,
argnames_ctx_vars,
null_idxes,
has_nested,
)
is_py311_plus = sys.version_info >= (3, 11)
@ -340,7 +344,7 @@ class ContinueExecutionCache:
) -> None:
meta.instructions = copy.deepcopy(instructions)
args = ["__nested_frame_values"]
args = ["__nested_resume_fns", "__nested_frame_values"]
args += [f"___stack{i}" for i in range(nstack)]
args.extend(v for v in argnames if v not in args)
freevars = tuple(code_options["co_cellvars"] or []) + tuple(
@ -462,15 +466,74 @@ class ContinueExecutionCache:
]
)
# Set is_tracing_resume_prologue back to allow graph breaks.
prefix.extend(
[
create_instruction("LOAD_CONST", argval=False),
create_instruction(
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
]
)
# Call nested resume function
if has_nested:
prefix.extend(
[
# set up __nested_resume_fns[-1] call
*add_push_null(
[
create_instruction(
"LOAD_FAST", argval="__nested_resume_fns"
),
create_instruction("LOAD_CONST", argval=-1),
create_instruction("BINARY_SUBSCR"),
]
),
# del __nested_resume_fns[-1]
create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
create_instruction("LOAD_CONST", argval=-1),
create_instruction("DELETE_SUBSCR"),
# load [__nested_resume_fns, __nested_frame_values]
create_instruction("LOAD_FAST", argval="__nested_resume_fns"),
create_instruction("LOAD_FAST", argval="__nested_frame_values"),
create_instruction("BUILD_LIST", arg=2),
# load __nested_frame_values[-1]
create_instruction("LOAD_FAST", argval="__nested_frame_values"),
create_instruction("LOAD_CONST", argval=-1),
create_instruction("BINARY_SUBSCR"),
# create [
# __nested_resume_fns,
# __nested_frame_values,
# *__nested_frame_values[-1][0],
# *__nested_frame_values[-1][1]],
# ]
create_dup_top(),
create_instruction("LOAD_CONST", argval=0),
create_instruction("BINARY_SUBSCR"),
create_instruction("LIST_EXTEND", arg=2),
create_instruction("LOAD_CONST", argval=1),
create_instruction("BINARY_SUBSCR"),
create_instruction("LIST_EXTEND", arg=1),
# del __nested_frame_values[-1]
create_instruction("LOAD_FAST", argval="__nested_frame_values"),
create_instruction("LOAD_CONST", argval=-1),
create_instruction("DELETE_SUBSCR"),
# delete __nested values
create_instruction("DELETE_FAST", argval="__nested_resume_fns"),
create_instruction(
"DELETE_FAST", argval="__nested_frame_values"
),
# Set is_tracing_resume_prologue back to allow graph breaks
# in the nested resume
create_instruction("LOAD_CONST", argval=False),
create_instruction(
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
# finish the call
create_instruction("CALL_FUNCTION_EX", arg=0),
]
)
else:
# Set is_tracing_resume_prologue back to allow graph breaks after the jump
prefix.extend(
[
create_instruction("LOAD_CONST", argval=False),
create_instruction(
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
]
)
prefix.append(create_jump_absolute(target))

View File

@ -80,6 +80,7 @@ from .bytecode_transformation import (
get_code_keys,
Instruction,
is_generator,
is_jump_absolute,
unique_id,
)
from .code_context import code_context
@ -90,6 +91,7 @@ from .exc import (
collapse_resume_frames,
format_graph_break_message,
get_stack_above_dynamo,
ResumePrologueTracingError,
unimplemented_v2,
Unsupported,
)
@ -1461,8 +1463,17 @@ class InstructionTranslatorBase(
try:
self.output.push_tx(self)
self.start_point = self.instruction_pointer
while self.step():
pass
try:
while self.step():
pass
except Exception as e:
if self.is_tracing_resume_prologue:
raise ResumePrologueTracingError(
"Error while tracing through a Dynamo-generated resume function prologue. "
"Errors are not allowed when tracing resume function prologues.\n"
f"{type(e).__qualname__}: {str(e)}"
).with_traceback(e.__traceback__) from None
raise
except TensorifyScalarRestartAnalysis:
raise
except BackendCompilerFailed:
@ -1546,7 +1557,7 @@ class InstructionTranslatorBase(
)
# for continuation functions
if name.startswith("__stack") or name == "__nested_frame_values":
if name.startswith("__stack"):
self.symbolic_locals.pop(name)
def LOAD_DEREF(self, inst: Instruction) -> None:
@ -2474,7 +2485,7 @@ class InstructionTranslatorBase(
elif inst.opname == "RETURN_CONST":
return [create_instruction("RETURN_CONST", argval=inst.argval)]
cg = PyCodegen(self)
cg = PyCodegen(self.output.root_tx)
# current frame state
# [
@ -2525,6 +2536,7 @@ class InstructionTranslatorBase(
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
argnames: tuple[str, ...] = ()
for i, meta in enumerate(all_stack_locals_metadata):
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
@ -2562,76 +2574,118 @@ class InstructionTranslatorBase(
]
)
name = unique_id(f"__resume_at_{inst.offset}")
# build the resume function for each frame
resume_names = []
resume_codes = []
for i, meta in enumerate(all_stack_locals_metadata):
cur_tx = txes[i]
if cur_tx is self:
resume_inst = inst
else:
resume_inst = cur_tx.next_instruction
# If the resume instruction is a jump absolute, then resume
# at the target instead. This handles the case where we
# graph break again in a nested function before jump-resuming
# this frame.
if is_jump_absolute(resume_inst):
assert resume_inst.target
resume_inst = resume_inst.target
name = unique_id(f"__resume_at_{resume_inst.offset}")
resume_names.append(name)
assert not config.nested_graph_breaks, "NYI"
# more locals may have been pruned after the unsupported instruction (e.g. branch)
reads = livevars_analysis(self.instructions, inst)
all_argnames = tuple(
k
for k in self.symbolic_locals.keys()
if k in reads and k not in self.cell_and_freevars()
)
argnames_null_set = set(all_stack_locals_metadata[-1].locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(self.stack) - len(
all_stack_locals_metadata[-1].stack_null_idxes
)
new_code: types.CodeType = ContinueExecutionCache.lookup(
self.f_code,
self.lineno,
inst.offset,
tuple(b.target.offset for b in self.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
tuple(all_stack_locals_metadata[-1].stack_ctx_args),
tuple(all_stack_locals_metadata[-1].locals_ctx_args),
tuple(all_stack_locals_metadata[-1].stack_null_idxes),
)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
# more locals may have been pruned after the unsupported instruction (e.g. branch)
reads = livevars_analysis(cur_tx.instructions, resume_inst)
all_argnames = tuple(
k
for k in cur_tx.symbolic_locals.keys()
if k in reads and k not in cur_tx.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
if new_code.co_freevars:
# expose code object for debugging purposes
self.output.install_global_unsafe(name, new_code)
cg.make_function_with_closure(name, new_code, True, 1)
package_name = None
new_code: types.CodeType = ContinueExecutionCache.lookup(
cur_tx.f_code,
cur_tx.lineno,
resume_inst.offset,
tuple(b.target.offset for b in cur_tx.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in cur_tx.block_stack),
tuple(meta.stack_ctx_args),
tuple(meta.locals_ctx_args),
tuple(meta.stack_null_idxes),
self is not cur_tx,
)
resume_codes.append(new_code)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(cur_tx.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
# add resume function to the global scope
if new_code.co_freevars:
# expose code object for debugging purposes
cur_tx.output.install_global_unsafe(name, new_code)
package_name = None
else:
# This is safe: we pre-generate a unique name
cur_tx.output.install_global_unsafe(
name, types.FunctionType(new_code, cur_tx.f_globals, name)
)
package_name = name
if cur_tx.package is not None:
cur_tx.package.add_resume_function(
new_code, cur_tx.f_globals["__name__"], package_name
)
# load first resume function (to be called this frame)
if resume_codes[-1].co_freevars:
cg.make_function_with_closure(resume_names[-1], resume_codes[-1], True, 1)
else:
# This is safe: we pre-generate a unique name
self.output.install_global_unsafe(
name, types.FunctionType(new_code, self.f_globals, name)
)
cg.extend_output(cg.load_function_name(name, True, 1))
package_name = name
cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
if self.package is not None:
self.package.add_resume_function(
new_code, self.f_globals["__name__"], package_name
)
# load all other resume functions (to be called later)
resume_names.pop()
resume_codes.pop()
for name, code in zip(resume_names, resume_codes):
if code.co_freevars:
assert not config.nested_graph_breaks, "NYI"
cg.make_function_with_closure(name, code, False, 0)
else:
cg.extend_output(cg.load_function_name(name, False, 0))
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=len(resume_codes)),
*create_swap(2),
]
)
# resume 1 (+ NULL), [resume N, ..., resume 2], frames
# load top level-frame; final stack state should be:
# first resume function (+ NULL),
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# ], frame 1 stack + frame 1 non-cell locals
# [resume N, ..., resume 2],
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# ], *(frame 1 stack + frame 1 non-cell locals)
# ]
cg.extend_output(
[
create_dup_top(),
@ -2655,7 +2709,7 @@ class InstructionTranslatorBase(
]
)
# frames, frames[-1][0], frames[-1][1]
# resumes, frames, frames[-1][0], frames[-1][1]
for name in argnames:
cg.extend_output(
[
@ -2667,22 +2721,24 @@ class InstructionTranslatorBase(
*create_swap(2),
],
)
# frames, frames[-1][0], *(live locals), frames[-1][1]
# resumes, frames, frames[-1][0], *(live locals), frames[-1][1]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(3),
# live_locals, frames[-1][0], frames
*create_swap(4),
# live_locals, frames, frames[-1][0], resumes
create_instruction("BUILD_LIST", arg=1),
*create_swap(2),
# live_locals, [frames], frames[-1][0]
*create_swap(3),
# live_locals, [resumes], frames[-1][0], frames
create_instruction("LIST_APPEND", arg=2),
create_instruction("LIST_EXTEND", arg=1),
# live_locals, [resumes, frames, *stack]
*create_swap(2),
create_instruction("LIST_EXTEND", arg=1),
]
)
# [frames, *(stack + live locals)]
# [resumes, frames, *(stack + live locals)]
cg.extend_output(
[
@ -4208,6 +4264,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
finally:
parent.error_on_graph_break = self.error_on_graph_break
if self.output.should_exit:
# graph break
return ConstantVariable.create(None) # return dummy variable
assert self.symbolic_result is not None
if self.f_globals is parent.f_globals: