[dynamo, nested graph breaks] fix nested step graph break related issues (#162737)

Turns out codegen'ing a nested step graph break is significantly more complicated than first thought. The optimized function should actually do:
- call graph/load values/do side effects etc.
- call into the leaf's resume function, but skipped (this essentially step graph break function for just the leaf function)
- call into all the other resume functions, traced.

This PR also adds `torch._dynamo.step_unsupported()`, which can be used for internal testing purposes to better test step graph break handling.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162737
Approved by: https://github.com/Lucaskabela
ghstack dependencies: #160601
This commit is contained in:
William Wen
2025-10-08 10:50:05 -07:00
committed by PyTorch MergeBot
parent 486b4d2414
commit af4c29fea8
11 changed files with 542 additions and 242 deletions

View File

@ -893,6 +893,29 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
self.assertEqual(gn(inp), inp + 3)
self.assertEqual(cnts.frame_count, 1)
def test_step_unsupported(self):
cnts = torch._dynamo.testing.CompileCounter()
@torch.compile(backend=cnts)
def fn(x):
x = x + 1 + 2
torch._dynamo.step_unsupported()
return x + 4
inp = torch.ones(3)
self.assertEqual(fn(inp), inp + 7)
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 2)
def test_step_unsupported_empty_checkpoint(self):
@torch.compile(backend="eager")
def fn(x):
torch._dynamo.step_unsupported()
return x + 1
inp = torch.ones(3)
self.assertEqual(fn(inp), inp + 1)
@skipIfWindows(
msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows."
)

View File

@ -536,6 +536,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 5)
# 4 additions from f5+f4, 2 x 4 additions from f2+f1 (i == 5, i != 5)
self.assertEqual(cnts.op_count, 12)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 6)
def test_nested_graph_break_in_try_block(self):
# NOTE: this also tests nested step_graph_break
@ -576,13 +577,40 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
x = torch.zeros(3)
res = f5(x)
ref = opt_fn(x)
print(ref, res)
self.assertEqual(ref, res)
# skip frame due to graph break in try block
# 2 frames from f5+f4+(first part of f3), 2 frames from f2+f1
self.assertEqual(cnts.frame_count, 4)
# 5 additions from f5+f4+(first part of f3), 4 additions from f2+f1
self.assertEqual(cnts.op_count, 9)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 4)
def test_nested_step_unsupported(self):
global f1, f2, f3
def f1(x):
return x + 1
def f2(x):
x = x + 2
torch._dynamo.step_unsupported()
return f1(x) + 4
def f3(x):
x = x + 8
return f2(x) + 16
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
x = torch.zeros(3)
res = f3(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
# 1 frame from start of f3 + start of f2, 1 frame from f1, 1 frame from the end of f3
self.assertEqual(cnts.frame_count, 3)
# all ops except + 4
self.assertEqual(cnts.op_count, 4)
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
if __name__ == "__main__":

View File

@ -7256,6 +7256,26 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
flag = False
self.assertEqual(fn(inp), opt_fn(inp))
def test_cells_unsupported_step_exception(self):
# This error happened because:
# - we were generating cells into a list on the stack
# - we encountered an unsupported step, resulting in a step graph break
# - we encounter an exception, which pops the stack until it reaches a certain length;
# the presence of the list of cells then messes things up.
cell = 0
@torch.compile(backend="eager")
def fn(x):
x = x + 1 + 2
torch._dynamo.step_unsupported()
with contextlib.nullcontext():
print(cell)
raise AssertionError
with self.assertRaises(AssertionError):
fn(torch.ones(3))
def test_unbind_copy_out(self):
def f(eye, out):
torch.unbind_copy(eye, out=out)

View File

@ -40,6 +40,7 @@ from .decorators import (
run,
set_stance,
skip_frame,
step_unsupported,
substitute_in_graph,
)
from .eval_frame import (
@ -102,6 +103,7 @@ __all__ = [
"error_on_graph_break",
"set_stance",
"skip_frame",
"step_unsupported",
"substitute_in_graph",
]

View File

@ -533,6 +533,8 @@ def create_binary_slice(
def create_copy(i: int) -> list[Instruction]:
if sys.version_info >= (3, 11):
return [create_instruction("COPY", arg=i)]
if i == 1:
return [create_instruction("DUP_TOP")]
# COPY 4
# 0 1 2 3
# 3 1 2 0

View File

@ -296,6 +296,14 @@ def skip_frame(msg: str = "") -> None:
"""Force a skipped frame"""
@_disallow_in_graph_helper(throw_if_not_allowed=False)
def step_unsupported(msg: str = "") -> None:
"""Force a step unsupported graph break, which results in compiling
the traced FX graph so far, then skipping the rest of the frame.
In order to get expected behavior, there should be at least 2 ops
and a part of the code not contained in any try/with blocks."""
def forbid_in_graph(fn: Any) -> Any:
"""
Customize which functions TorchDynamo will assert are not present while tracing.

View File

@ -263,6 +263,11 @@ class RecompileLimitExceeded(Unsupported):
pass
# debug exception thrown when tracing torch._dynamo.step_unsupported()
class StepUnsupported(TorchDynamoException):
pass
class UnsafeScriptObjectError(TorchDynamoException):
pass

View File

@ -2763,5 +2763,18 @@
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
],
"GB0275": [
{
"Gb_type": "torch._dynamo.step_unsupported() with empty checkpoint",
"Context": "",
"Explanation": "traced torch._dynamo.step_unsupported(), but there is no checkpoint to step_graph_break from. This graph break is used for debugging only.",
"Hints": [
"Remove the torch._dynamo.step_unsupported() call.",
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some ",
"line of code that is not in a try/with block, and has an empty Python stack.",
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
]
}
]
}

View File

@ -318,6 +318,7 @@ class ContinueExecutionCache:
argnames: tuple[str, ...],
argnames_null: tuple[str, ...],
setup_fns: tuple[ReenterWith, ...],
handle_inactive_ctx: bool,
stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
null_idxes: tuple[int, ...],
@ -341,6 +342,7 @@ class ContinueExecutionCache:
argnames,
argnames_null,
setup_fns,
handle_inactive_ctx,
stack_ctx_vars,
argnames_ctx_vars,
null_idxes,
@ -432,7 +434,7 @@ class ContinueExecutionCache:
prefix.append(
create_instruction("LOAD_FAST", argval=f"___stack{stack_i}")
)
if stack_i in stack_ctx_vars_d:
if handle_inactive_ctx and stack_i in stack_ctx_vars_d:
# NOTE: we assume that current stack var is a context manager CLASS!
# Load args for context variable and construct it
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[stack_i]))
@ -459,10 +461,11 @@ class ContinueExecutionCache:
# NOTE: we assume that local var is a context manager CLASS!
# initialize inactive context vars in argnames
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))
prefix.extend(_load_tuple_and_call(vals))
prefix.append(create_instruction("STORE_FAST", argval=name))
if handle_inactive_ctx:
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))
prefix.extend(_load_tuple_and_call(vals))
prefix.append(create_instruction("STORE_FAST", argval=name))
# 3.12+: store NULL into variables that were NULL
if argnames_null:

View File

@ -79,6 +79,7 @@ from .bytecode_transformation import (
create_dup_top,
create_instruction,
create_jump_absolute,
create_load_const,
create_rot_n,
create_swap,
get_code_keys,
@ -96,6 +97,7 @@ from .exc import (
format_graph_break_message,
get_stack_above_dynamo,
ResumePrologueTracingError,
StepUnsupported,
unimplemented_v2,
Unsupported,
)
@ -669,14 +671,20 @@ def generic_jump(
)
self.pop()
if_next = self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata, False
if_next = self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
) + self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
)
if push:
self.push(value)
assert inst.target is not None
if_jump = self.create_call_resume_at(
inst.target, all_stack_locals_metadata, False
if_jump = self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], inst.target
) + self.create_call_resume_at(
inst.target,
all_stack_locals_metadata,
)
if sys.version_info >= (3, 13):
@ -960,7 +968,7 @@ def break_graph_if_unsupported(
all_stack_locals_metadata = self.output.compile_subgraph(
self, reason=reason, stack_pops=push - stack_effect
)
cg = PyCodegen(self)
cg = PyCodegen(self.output.root_tx)
cleanup: list[Instruction] = []
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
@ -1009,8 +1017,12 @@ def break_graph_if_unsupported(
for _ in range(push):
self.push(UnknownVariable())
self.output.add_output_instructions(
self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata, False
self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
)
+ self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
)
)
@ -1351,9 +1363,22 @@ class InstructionTranslatorBase(
return True
except (ReturnValueOp, YieldValueOp):
return False
except Unsupported:
except (Unsupported, StepUnsupported) as e:
if self.current_speculation is None:
log.debug("empty checkpoint")
if isinstance(e, StepUnsupported):
unimplemented_v2(
gb_type="torch._dynamo.step_unsupported() with empty checkpoint",
context="",
explanation="traced torch._dynamo.step_unsupported(), but there is no checkpoint "
"to step_graph_break from. This graph break is used for debugging only.",
hints=[
"Remove the torch._dynamo.step_unsupported() call.",
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some "
"line of code that is not in a try/with block, and has an empty Python stack.",
*graph_break_hints.DYNAMO_BUG,
],
)
raise
log.debug("step triggered compile", exc_info=True)
@ -1427,12 +1452,98 @@ class InstructionTranslatorBase(
partial_convert=True,
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
)
# current frame state
# cells,
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ],
if self.parent:
from .eval_frame import skip_code
# nested graph break
assert config.nested_graph_breaks
cg = PyCodegen(self.output.root_tx)
# codegen cells and frame values only for frame N
cg.extend_output(
[
*create_copy(2),
cg.create_load_const(0),
cg.create_binary_subscr(),
create_instruction("BUILD_LIST", arg=1),
*create_copy(2),
cg.create_load_const(0),
cg.create_binary_subscr(),
create_instruction("BUILD_LIST", arg=1),
]
)
# No need to fix stack, since stack is assumed to be empty here.
# Do NOT handle_inactive_ctx because we will be skipping this resume code.
leaf_resume_code, leaf_resume_name = self.create_resume(
0, continue_inst, all_stack_locals_metadata[0], [], cg, True, False
)
skip_code(leaf_resume_code)
# current frame state
# cells,
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ], [frame N cells], [frame N locals],
self.codegen_call_resume([leaf_resume_code], [leaf_resume_name], cg)
# current frame state
# cells,
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ], leaf_resume result
# add the leaf_resume result to frame N-1 stack
num_stack = all_stack_locals_metadata[1].num_stack
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=1),
*create_copy(2),
cg.create_load_const(1),
cg.create_binary_subscr(),
*create_binary_slice(num_stack, num_stack, True),
]
)
# pop frame N cells and locals
cg.extend_output(
[
*create_copy(1),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
*create_copy(2),
cg.create_load_const(0),
create_instruction("DELETE_SUBSCR"),
]
)
# call the remaining resume functions
# current frame state
# [frame N-1 cells, ..., frame 1 cells],
# [
# frame N-1 stack (including leaf_resume result) + locals,
# ...,
# frame 1 stack + locals,
# ],
self.parent.push(UnknownVariable())
all_stack_locals_metadata[1].num_stack += 1
self.output.add_output_instructions(
self.create_call_resume_at(
continue_inst, all_stack_locals_metadata, True
cg.get_instructions()
+ self.parent.create_call_resume_at(
self.parent.next_instruction, all_stack_locals_metadata[1:]
)
)
else:
@ -1444,14 +1555,7 @@ class InstructionTranslatorBase(
]
)
# load locals from frame values
# current frame state
# [
# frame N locals,
# frame N-1 stack + locals,
# ...,
# frame 1 stack + locals,
# ],
cg = PyCodegen(self)
cg = PyCodegen(self.output.root_tx)
self.output.add_output_instructions(
[
cg.create_load_const(-1),
@ -2516,8 +2620,12 @@ class InstructionTranslatorBase(
self.output.add_output_instructions([copy.copy(inst)])
self.popn(2)
self.output.add_output_instructions(
self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata, False
self.codegen_fix_leaf_stack(
all_stack_locals_metadata[0], self.next_instruction
)
+ self.create_call_resume_at(
self.next_instruction,
all_stack_locals_metadata,
)
)
@ -2530,15 +2638,26 @@ class InstructionTranslatorBase(
)
@staticmethod
def codegen_return_after_compile_subgraph(
inst: Instruction, meta: StackLocalsMetadata
def codegen_return_with_pops(
inst: Instruction, num_stack: int
) -> list[Instruction]:
"""
Debug CPython expects the stack to be empty after the return.
Calling compile_subgraph will push cells and frame values to TOS.
This function will pop those 2 values from the stack before actually returning.
Expects the stack to be:
cells, frame values, current frame stack (0 or 1 values)
Pops cells and frame values, leaving the current frame stack as TOS.
A return instruction is included.
"""
insts = []
# NOTE: Debug CPython expects the stack to be empty after the return.
# Expect the current stack to be in the state
# cells, frame values, current frame stack (0 or 1 values)
assert meta.num_stack <= 1
if meta.num_stack == 1:
assert num_stack <= 1
if num_stack == 1:
insts.extend(create_swap(3))
return_inst = (
create_instruction("RETURN_VALUE")
@ -2550,31 +2669,261 @@ class InstructionTranslatorBase(
)
return insts
def create_call_resume_at(
self,
inst: Instruction,
all_stack_locals_metadata: Any,
disable_current_frame_resume: bool,
def codegen_fix_leaf_stack(
self, meta: StackLocalsMetadata, resume_inst: Instruction
) -> list[Instruction]:
"""
Codegen resume function(s) and call it.
Assumes that the unsupported instruction has already been run.
Fixes the stack values of the current/leaf frame (self).
Expects the stack to be in the state:
[frame N cells, ..., frame 1 cells],
Expects the TOS to be:
[
frame N locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
], frame N stack (post-instruction)
], *(frame N stack (post-unsupported instruction))
Rearranges the TOS to become:
[
frame N stack + locals,
...,
frame 1 stack + locals
]
Args:
- meta: metadata for the leaf frame returned from OutputGraph.compile_subgraph
- resume_inst: if the resume instruction is a return instruction, then don't return any instructions
"""
if resume_inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
return []
# move frame N stack to the frame values list
current_num_stack = len(self.stack) - len(meta.stack_null_idxes)
meta.num_stack = current_num_stack
return [
create_instruction("BUILD_LIST", arg=current_num_stack),
*create_copy(2),
# frame_values, frame N stack, frame_values
create_load_const(0),
create_instruction("BINARY_SUBSCR"),
*create_binary_slice(0, 0, True),
# frame_values[0][0:0] = frame N stack
# frame_values left on top of stack
]
def create_resume(
self,
idx: int,
resume_inst: Instruction,
meta: StackLocalsMetadata,
resume_codes: list[types.CodeType],
cg: PyCodegen,
is_leaf: bool,
handle_inactive_ctx: bool,
) -> tuple[types.CodeType, str]:
"""
Creates the resume function for the frame corresponding to `self`.
Expects the TOS to be:
[frame N cells, ..., frame 1 cells],
[
frame N stack + locals,
...,
frame 1 stack + locals
]
Some additional codegen may happen to prepare the frame stack + locals values for the generated resume function:
- inactive context variables in the stack and locals will be replaced by their types
- if the frame is a leaf frame, prune dead locals
Regardless of codegen, the stack will be left in the same state as before.
Args:
- idx: depth of this frame: 0 corresponds to the leaf frame (frame N), N-1 to the root frame (frame 1).
- resume_inst: the instruction that this frame should resume at
- meta: metadata for this frame returned from OutputGraph.compile_subgraph
- resume_codes: nested resume code objects generated from previous create_resume calls.
- cg: codegen object to output to
- is_leaf: True if `self` corresponds to the leaf frame.
- handle_inactive_ctx: If True, handles inactive context variables as described above. This is necessary
iff the resume function is traced
"""
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
if handle_inactive_ctx:
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
ctx = cast(ContextWrappingVariable, self.stack[j_orig])
# frames[idx][j] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(idx),
cg.create_binary_subscr(),
cg.create_load_const(j),
create_instruction("STORE_SUBSCR"),
]
)
for name, _ in meta.locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
# frames[idx][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(idx),
cg.create_binary_subscr(),
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
create_instruction("STORE_SUBSCR"),
]
)
# If the resume instruction is a jump absolute, then resume
# at the target instead. This handles the case where we
# graph break again in a nested function before jump-resuming
# this frame.
if is_jump_absolute(resume_inst):
assert resume_inst.target
resume_inst = resume_inst.target
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
# More locals may have been pruned in the current/leaf frame
# after the unsupported instruction (e.g. branch).
# There should not be any pruning in the other frames since
# the current instruction there should be a CALL.
if is_leaf:
reads = livevars_analysis(self.instructions, resume_inst)
all_argnames = tuple(
k
for k in self.symbolic_locals.keys()
if k in reads and k not in self.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
# codegen filter for current frame's locals
# current stack state: frames
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(idx),
cg.create_binary_subscr(),
create_dup_top(),
]
)
for arg in argnames:
# current stack state: frames, frames[i], *(prev locals), frames[i]
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(meta.num_stack + meta.locals_names[arg]),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(2),
# frames, frames i live locals, frames[i]
*create_binary_slice(meta.num_stack, None, True),
# frames[i][num_stack:] = frame i live locals
]
)
# current stack state: frames
else:
argnames = tuple(meta.locals_names.keys())
argnames_null = tuple(meta.locals_null_keys)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(self.stack) - len(meta.stack_null_idxes)
new_code: types.CodeType = ContinueExecutionCache.lookup(
self.f_code,
self.lineno,
resume_inst.offset,
tuple(b.target.offset for b in self.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
handle_inactive_ctx,
tuple(meta.stack_ctx_args),
tuple(meta.locals_ctx_args),
tuple(meta.stack_null_idxes),
tuple(resume_codes),
)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
# add resume function to the global scope
if new_code.co_freevars:
# expose code object for debugging purposes
self.output.install_global_unsafe(resume_name, new_code)
package_name = None
else:
# This is safe: we pre-generate a unique name
self.output.install_global_unsafe(
resume_name,
types.FunctionType(new_code, self.f_globals, resume_name),
)
package_name = resume_name
if self.package is not None:
self.package.add_resume_function(
new_code, self.f_globals["__name__"], package_name
)
return new_code, resume_name
def create_call_resume_at(
self,
inst: Instruction,
all_stack_locals_metadata: list[StackLocalsMetadata],
) -> list[Instruction]:
"""
Codegen all resume function(s) from the frame stack starting at `self` and call them.
Assumes that the unsupported instruction has already been run.
Expects the stack to be in the state:
[frame N cells, ..., frame 1 cells],
[
frame N stack + locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
]
Pops the cells and frame values list from the stack.
Also includes a return instruction (stack expected to be empty after return).
Args:
- inst: the instruction of the current (deepest) frame to resume at
- all_stack_locals_metadata: metadata returned from OutputGraph.compile_subgraph - contains
metadata such as local names, NULL positions, stack length, etc.
- disable_current_frame_resume: If True, disable tracing on the current frame's resume function.
Used for implementing nested step_graph_break.
"""
self.instruction_pointer = None
@ -2585,212 +2934,63 @@ class InstructionTranslatorBase(
all_stack_locals_metadata[0].num_stack = current_num_stack
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
return self.codegen_return_after_compile_subgraph(
inst, all_stack_locals_metadata[0]
return self.codegen_return_with_pops(
inst, all_stack_locals_metadata[0].num_stack
)
cg = PyCodegen(self.output.root_tx)
# move frame N stack to the frame values list
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=current_num_stack),
*create_copy(2),
# frame_values, frame N stack, frame_values
cg.create_load_const(0),
cg.create_binary_subscr(),
*create_binary_slice(0, 0, True),
# frame_values[0][0:0] = frame N stack
# frame_values left on top of stack
]
)
# current frame state
# all cells
# [
# [frame N stack (fixed) + locals]
# ...,
# [frame 1 stack + locals]
# ],
#
txes = []
cur_tx: Optional[InstructionTranslatorBase] = self
while cur_tx is not None:
txes.append(cur_tx)
cur_tx = cur_tx.parent
assert len(txes) == len(all_stack_locals_metadata)
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
for i, meta in enumerate(all_stack_locals_metadata):
if i == 0 and disable_current_frame_resume:
continue
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
# frames[i][j] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(j),
create_instruction("STORE_SUBSCR"),
]
)
for name, _ in meta.locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
# frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
create_instruction("STORE_SUBSCR"),
]
)
# build the resume function for each frame
resume_names = []
idx = 0
resume_codes: list[types.CodeType] = []
for i, meta in enumerate(all_stack_locals_metadata):
cur_tx = txes[i]
resume_names = []
while cur_tx is not None:
if cur_tx is self:
resume_inst = inst
else:
resume_inst = cur_tx.next_instruction
# If the resume instruction is a jump absolute, then resume
# at the target instead. This handles the case where we
# graph break again in a nested function before jump-resuming
# this frame.
if is_jump_absolute(resume_inst):
assert resume_inst.target
resume_inst = resume_inst.target
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
resume_code, resume_name = cur_tx.create_resume(
idx,
resume_inst,
all_stack_locals_metadata[idx],
resume_codes,
cg,
cur_tx is self,
True,
)
resume_codes.append(resume_code)
resume_names.append(resume_name)
# More locals may have been pruned in the current frame
# after the unsupported instruction (e.g. branch).
# There should not be any pruning in the other frames since
# the current instruction is a CALL.
if cur_tx is self:
reads = livevars_analysis(cur_tx.instructions, resume_inst)
all_argnames = tuple(
k
for k in cur_tx.symbolic_locals.keys()
if k in reads and k not in cur_tx.cell_and_freevars()
)
argnames_null_set = set(meta.locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
cur_tx = cur_tx.parent
idx += 1
# codegen filter for current frame's locals
# current stack state: frames
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(i),
cg.create_binary_subscr(),
create_dup_top(),
]
)
for arg in argnames:
# current stack state: frames, frames[i], *(prev locals), frames[i]
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(
meta.num_stack + meta.locals_names[arg]
),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
*create_swap(2),
# frames, frames i live locals, frames[i]
*create_binary_slice(meta.num_stack, None, True),
# frames[i][num_stack:] = frame i live locals
]
)
# current stack state: frames
else:
argnames = tuple(meta.locals_names.keys())
argnames_null = tuple(meta.locals_null_keys)
self.codegen_call_resume(resume_codes, resume_names, cg)
return cg.get_instructions() + [create_instruction("RETURN_VALUE")]
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
@staticmethod
def codegen_call_resume(
resume_codes: list[types.CodeType], resume_names: list[str], cg: PyCodegen
) -> None:
"""
Calls the provided resume functions.
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
Expects the TOS to be in the state:
[frame N cells, ..., frame 1 cells],
[
frame N stack + locals,
frame N-1 stack + locals,
...,
frame 1 stack + locals
]
new_code: types.CodeType = ContinueExecutionCache.lookup(
cur_tx.f_code,
cur_tx.lineno,
resume_inst.offset,
tuple(b.target.offset for b in cur_tx.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in cur_tx.block_stack),
tuple(meta.stack_ctx_args),
tuple(meta.locals_ctx_args),
tuple(meta.stack_null_idxes),
tuple(resume_codes),
)
resume_codes.append(new_code)
Pops the cells and frame values, leaving the result of calling the resume functions on TOS.
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(cur_tx.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
# add resume function to the global scope
if new_code.co_freevars:
# expose code object for debugging purposes
cur_tx.output.install_global_unsafe(resume_name, new_code)
package_name = None
else:
# This is safe: we pre-generate a unique name
cur_tx.output.install_global_unsafe(
resume_name,
types.FunctionType(new_code, cur_tx.f_globals, resume_name),
)
package_name = resume_name
if cur_tx.package is not None:
cur_tx.package.add_resume_function(
new_code, cur_tx.f_globals["__name__"], package_name
)
if disable_current_frame_resume:
from .eval_frame import skip_code
skip_code(resume_codes[0])
# load cells as we load resume functions
Args:
- resume_codes: list of resume function code objects to call
- resume_names: list of the corresponding names of the resume functions
- cg: PyCodegen object to output instructions to
"""
# NOTE: We will load cells as we load resume functions
# load resume functions except the root's
cg.extend_output(create_copy(2))
@ -2884,10 +3084,8 @@ class InstructionTranslatorBase(
cg.extend_output(
[
*create_call_function_ex(False, True),
create_instruction("RETURN_VALUE"),
]
)
return cg.get_instructions()
def should_compile_partial_graph(self) -> bool:
if sys.version_info >= (3, 11):
@ -3530,7 +3728,7 @@ class InstructionTranslatorBase(
self.active_generic_context_managers.append(ctx)
if sys.version_info >= (3, 11):
# See create_call_resume_at for block stack details.
# See update_block_stack/create_resume for block stack details.
# Only push a block if the current instruction's block is a
# with block that is not nested in a try block - that is, the current
# instruction's block target is the same as the top block's target.
@ -4230,9 +4428,7 @@ class InstructionTranslator(InstructionTranslatorBase):
assert len(all_stack_locals_metadata) == 1
assert not all_stack_locals_metadata[0].stack_null_idxes
self.output.add_output_instructions(
self.codegen_return_after_compile_subgraph(
inst, all_stack_locals_metadata[0]
)
self.codegen_return_with_pops(inst, all_stack_locals_metadata[0].num_stack)
)
raise ReturnValueOp
@ -4617,13 +4813,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
def create_call_resume_at(
self,
inst: Instruction,
all_stack_locals_metadata: Any,
disable_current_frame_resume: bool,
all_stack_locals_metadata: list[StackLocalsMetadata],
) -> list[Instruction]:
if config.nested_graph_breaks:
return super().create_call_resume_at(
inst, all_stack_locals_metadata, disable_current_frame_resume
)
return super().create_call_resume_at(inst, all_stack_locals_metadata)
unimplemented_v2(
gb_type="Graph break in inlined function",
context="",

View File

@ -52,6 +52,7 @@ from ..exc import (
ObservedUserStopIteration,
raise_observed_exception,
SkipFrame,
StepUnsupported,
unimplemented_v2,
Unsupported,
)
@ -1527,6 +1528,8 @@ class SkipFunctionVariable(VariableTracker):
raise SkipFrame(
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
)
elif self.value is torch._dynamo.step_unsupported:
raise StepUnsupported
else:
if config.dont_skip_tracing:
from .builder import SourcelessBuilder