[dynamo, nested graph breaks] move cell codegen before side effects codegen (#160601)

This is needed because if we codegen cells for nested frames AFTER side effects, then reconstruction could get messed up. From below:

>The added test case demonstrates the reconstruction failure if we kept cell codegen at the original place (only happens with nested graph breaks since we reconstruct nested frame cells from VariableTracker rather than directly using LOAD_CLOSURE).

>At a high level, what happened before this change was that side_effects was pruning the cells (I don't recall exactly why this happens), and because cells were codegen'd after the side effects were applied, we were unable to properly reconstruct the cell. The error I was seeing was a list/tuple IndexError.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160601
Approved by: https://github.com/mlazos
This commit is contained in:
William Wen
2025-10-08 10:50:04 -07:00
committed by PyTorch MergeBot
parent 8f83b3e71c
commit 486b4d2414
8 changed files with 181 additions and 74 deletions

View File

@ -363,6 +363,31 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 13)
def test_cells_double_graph_break(self):
def f1(x1):
cell1 = x1 + 1
def f2(x2):
nonlocal cell1
cell1 += 2
torch._dynamo.graph_break()
torch._dynamo.graph_break()
return x2 + cell1
return f2(x1 + 4), cell1
def outer(x):
return f1(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(backend=cnts)(outer)
x = torch.zeros(3)
res = outer(x)
ref = opt_fn(x)
self.assertEqual(ref, res)
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 4)
def test_side_effects_cells(self):
cell1, cell2, cell3, cell4 = (torch.zeros(3),) * 4

View File

@ -397,19 +397,37 @@ def create_call_function(nargs: int, push_null: bool) -> list[Instruction]:
return [create_instruction("CALL_FUNCTION", arg=nargs)]
def create_call_function_ex(has_kwargs: bool) -> list[Instruction]:
def create_call_function_ex(
has_kwargs: bool, push_null: bool, ignore_314_kwargs_push: bool = False
) -> list[Instruction]:
"""
Assumes that in 3.14+, if has_kwargs=False, there is NOT a NULL
on the TOS for the kwargs. This utility function will add a PUSH_NULL.
If the caller has already pushed a NULL, then do not call this function -
just use create_instruction("CALL_FUNCTION_EX", arg=...).
If the caller has already pushed a NULL for the kwargs, then set ignore_314_kwargs_push=True
so we don't push another NULL for the kwargs.
"""
insts = []
if sys.version_info >= (3, 14) and not has_kwargs:
insts.append(create_instruction("PUSH_NULL"))
insts.append(create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs)))
return insts
if sys.version_info >= (3, 11):
output = []
if (
sys.version_info >= (3, 14)
and not has_kwargs
and not ignore_314_kwargs_push
):
output.append(create_instruction("PUSH_NULL"))
if push_null:
output.append(create_instruction("PUSH_NULL"))
# 3.13 swapped NULL and callable
# if flags == 1, 2 values popped - otherwise if flags == 0, 1 value
rots = (
int(has_kwargs) + 2
if sys.version_info >= (3, 13)
else int(has_kwargs) + 3
)
output.extend(create_rot_n(rots))
output.append(create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs)))
return output
return [create_instruction("CALL_FUNCTION_EX", arg=int(has_kwargs))]
def create_call_method(nargs: int) -> list[Instruction]:

View File

@ -519,7 +519,7 @@ class PyCodegen:
create_build_tuple(n),
self.create_load_const_unchecked(rot_n_helper(n)),
*create_rot_n(2),
*create_call_function_ex(False),
*create_call_function_ex(False, False),
create_instruction("UNPACK_SEQUENCE", arg=n),
]
@ -540,51 +540,33 @@ class PyCodegen:
def make_function_with_closure(
self,
tx: "InstructionTranslatorBase",
fn_name: str,
code: types.CodeType,
push_null: bool,
num_on_stack: int = 0,
) -> None:
freevars = code.co_freevars
assert freevars
"""Creates a closure with code object `code`.
Expects the TOS to be the tuple of cells to use for this closure.
TOS will be popped to create the closure.
Args:
- fn_name: name of the function
- code: code object of the function
(does not include the tuple of cells on the TOS)
"""
output = self._output
def gen_fn() -> None:
self.clear_tos()
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `CellVariable.local_name`.
for var in freevars:
if tx is self.tx: # root frame
assert var in self.cell_and_freevars()
output.append(self.create_load_closure(var))
else: # nested frame
assert var in tx.cell_and_freevars()
assert tx.post_prune_cell_and_freevars
self(tx.post_prune_cell_and_freevars[var])
output.append(create_build_tuple(len(freevars)))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
if sys.version_info >= (3, 13):
output.extend(
[
create_instruction("MAKE_FUNCTION"),
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
]
)
else:
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
if push_null and sys.version_info >= (3, 11):
self.add_push_null(gen_fn)
output.extend(self.rot_n(num_on_stack + 2))
output.extend(self.rot_n(num_on_stack + 2))
output.append(self.create_load_const(code))
if sys.version_info < (3, 11):
output.append(self.create_load_const(fn_name))
if sys.version_info >= (3, 13):
output.extend(
[
create_instruction("MAKE_FUNCTION"),
create_instruction("SET_FUNCTION_ATTRIBUTE", arg=0x08),
]
)
else:
gen_fn()
output.extend(self.rot_n(num_on_stack + 1))
output.append(create_instruction("MAKE_FUNCTION", arg=0x08))
self.clear_tos()
def create_load_python_module(self, mod: types.ModuleType) -> Instruction:

View File

@ -79,6 +79,7 @@ from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import (
create_binary_slice,
create_binary_subscr,
create_build_tuple,
create_call_function,
create_dup_top,
create_instruction,
@ -1534,8 +1535,9 @@ class OutputGraph(OutputGraphCommon):
# Codegen stack convention before the unsupported instruction
# NOTE: in these comment blocks, "locals" EXCLUDE free and cell vars.
# NOTE: stack and locals must be codegen'd BEFORE the unsupported instruction, since the latter
# NOTE: stack/locals/cells must be codegen'd BEFORE the unsupported instruction, since the latter
# can arbitrarily mutate the former.
# [frame N cells, .., frame 1 cells],
# [
# frame N locals,
# frame N-1 stack + locals,
@ -1545,7 +1547,7 @@ class OutputGraph(OutputGraphCommon):
# see symbolic_convert.py for
# codegen stack convention after the unsupported instruction
# NOTE: cells are loaded into continuation functions directly
# NOTE: cells will be loaded into continuation functions directly by symbolic_convert
# this determines the order that values are codegen'd to the stack
stack_values_flat = [val for vals in all_stack_values for val in vals]
@ -1577,12 +1579,19 @@ class OutputGraph(OutputGraphCommon):
and not all_stack_locals_metas[-1].locals_null_keys
):
# optimization to generate better code in a common case
# codegen cells
# no side effects, so no new cells created - no need to call side_effects.codegen_save_tempvars
cell_cg = PyCodegen(self.root_tx)
self.codegen_cells(tx, cell_cg)
self.add_output_instructions(
[
# load in reverse since UNPACK_SEQUENCE will reverse
*self.compile_and_call_fx_graph(
tx, list(reversed(stack_values_flat)), root
),
*cell_cg.get_instructions(),
*create_swap(2),
create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
]
)
@ -1684,6 +1693,7 @@ class OutputGraph(OutputGraphCommon):
# store all stack and locals for each frame
# current state of the stack:
# all cells,
# *(frame N stack), *(frame N locals),
# ...,
# *(frame 1 stack), *(frame 1 locals)
@ -1698,6 +1708,7 @@ class OutputGraph(OutputGraphCommon):
)
# current state of the stack:
# all cells,
# *(frame N stack), [
# *(frame N locals),
# *(frame N-1 stack), *(frame N-1 locals),
@ -1758,7 +1769,8 @@ class OutputGraph(OutputGraphCommon):
# *(frame N stack), metas[0] stack + locals, ..., metas[i] stack + locals, stack_values_flat
# current state of the stack:
# *(frame N stack)
# all cells,
# *(frame N stack),
# frame N locals,
# frame N-1 stack, frame N-1 locals,
# ...
@ -1775,6 +1787,7 @@ class OutputGraph(OutputGraphCommon):
)
# final state of the stack before running the unsupported bytecode:
# all cells,
# [
# [frame N locals],
# [frame N-1 stack + locals],
@ -1831,6 +1844,31 @@ class OutputGraph(OutputGraphCommon):
return all_stack_locals_metas
def codegen_cells(self, tx: "InstructionTranslatorBase", cg: PyCodegen) -> None:
# no need to codegen if reason.graph_break is False (since we won't resume)
if self.compile_subgraph_reason.graph_break:
tx_cnt = 0
cur_tx: Optional[InstructionTranslatorBase] = tx
while cur_tx is not None:
# NOTE: we generate cells in the same order as resume_execution.py: sorted freevars + cellvars
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `CellVariable.local_name`.
freevars = tuple(sorted(cur_tx.cell_and_freevars()))
for cell in freevars:
if cur_tx is self.root_tx: # root frame
cg.append_output(cg.create_load_closure(cell))
else: # nested frame
assert cur_tx.post_prune_cell_and_freevars
cg(cur_tx.post_prune_cell_and_freevars[cell])
cg.append_output(create_build_tuple(len(freevars)))
cur_tx = cur_tx.parent
tx_cnt += 1
cg.append_output(create_instruction("BUILD_LIST", arg=tx_cnt))
else:
cg.append_output(create_instruction("BUILD_LIST", arg=0))
def codegen_suffix(
self,
tx: "InstructionTranslatorBase",
@ -1850,6 +1888,7 @@ class OutputGraph(OutputGraphCommon):
cg.store_attr(name)
self.side_effects.codegen_hooks(cg)
# TODO get debug_locals working for nested graph breaks
# Return variables used for logging at the end
for debug_var, args in tx.debug_locals:
cg.add_push_null(lambda: cg(debug_var))
@ -1858,6 +1897,9 @@ class OutputGraph(OutputGraphCommon):
cg.extend_output(create_call_function(len(args), False))
cg.extend_output([create_instruction("POP_TOP")])
# codegen cells before we apply side effects
self.codegen_cells(tx, cg)
cg.restore_stack(stack_values, value_from_source=not tx.export)
self.side_effects.codegen_update_mutated(cg)

View File

@ -524,7 +524,7 @@ class ContinueExecutionCache:
"STORE_FAST", argval=IS_TRACING_RESUME_PROLOGUE_VARNAME
),
# finish the call
*create_call_function_ex(False),
*create_call_function_ex(False, False),
]
)
else:

View File

@ -1436,6 +1436,13 @@ class InstructionTranslatorBase(
)
)
else:
# pop cells
self.output.add_output_instructions(
[
*create_swap(2),
create_instruction("POP_TOP"),
]
)
# load locals from frame values
# current frame state
# [
@ -2529,16 +2536,18 @@ class InstructionTranslatorBase(
insts = []
# NOTE: Debug CPython expects the stack to be empty after the return.
# Expect the current stack to be in the state
# [[]] (empty frame values), current frame stack (0 or 1 values)
# cells, frame values, current frame stack (0 or 1 values)
assert meta.num_stack <= 1
if meta.num_stack == 1:
insts.extend(create_swap(2))
insts.extend(create_swap(3))
return_inst = (
create_instruction("RETURN_VALUE")
if inst.opname == "RETURN_VALUE"
else create_instruction("RETURN_CONST", argval=inst.argval)
)
insts.extend([create_instruction("POP_TOP"), return_inst])
insts.extend(
[create_instruction("POP_TOP"), create_instruction("POP_TOP"), return_inst]
)
return insts
def create_call_resume_at(
@ -2552,6 +2561,7 @@ class InstructionTranslatorBase(
Assumes that the unsupported instruction has already been run.
Expects the stack to be in the state:
[frame N cells, ..., frame 1 cells],
[
frame N locals,
frame N-1 stack + locals,
@ -2596,6 +2606,7 @@ class InstructionTranslatorBase(
)
# current frame state
# all cells
# [
# [frame N stack (fixed) + locals]
# ...,
@ -2779,30 +2790,59 @@ class InstructionTranslatorBase(
skip_code(resume_codes[0])
# load first resume function (to be called this frame)
if resume_codes[-1].co_freevars:
cg.make_function_with_closure(
txes[-1], resume_names[-1], resume_codes[-1], True, 1
)
else:
cg.extend_output(cg.load_function_name(resume_names[-1], True, 1))
# load cells as we load resume functions
# load all other resume functions (to be called later)
resume_names.pop()
resume_codes.pop()
for tx, name, code in zip(txes, resume_names, resume_codes):
# load resume functions except the root's
cg.extend_output(create_copy(2))
for i, (name, code) in enumerate(zip(resume_names, resume_codes)):
if i == len(resume_names) - 1:
break
# stack: cells, frames, *(resume 1, ...), cells
if code.co_freevars:
cg.make_function_with_closure(tx, name, code, False, 0)
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(i),
cg.create_binary_subscr(),
]
)
cg.make_function_with_closure(name, code)
else:
cg.extend_output(cg.load_function_name(name, False, 0))
cg.extend_output(create_swap(2))
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=len(resume_codes)),
*create_swap(2),
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(resume_codes) - 1),
]
)
# resume 1 (+ NULL), [resume N, ..., resume 2], frames
# stack: cells, frames, [resume 1, ..., resume N - 1]
# load root resume function
cg.extend_output(create_swap(3))
if resume_codes[-1].co_freevars:
cg.extend_output(
[
cg.create_load_const(-1),
cg.create_binary_subscr(),
]
)
cg.make_function_with_closure(resume_names[-1], resume_codes[-1])
cg.extend_output(
[
*create_rot_n(3),
]
)
else:
cg.extend_output(
[
create_instruction("POP_TOP"),
*cg.load_function_name(resume_names[-1], False),
*create_rot_n(3),
]
)
# resume 1, [resume N, ..., resume 2], frames
# load top level-frame; final stack state should be:
# first resume function (+ NULL),
@ -2843,7 +2883,7 @@ class InstructionTranslatorBase(
# TOS: [resumes, frames, *(frame 1 stack + locals)]
cg.extend_output(
[
*create_call_function_ex(False),
*create_call_function_ex(False, True),
create_instruction("RETURN_VALUE"),
]
)

View File

@ -449,7 +449,7 @@ class ZipVariable(IteratorVariable):
codegen.create_load_const("strict"),
codegen.create_load_const(self.strict),
create_instruction("BUILD_MAP", arg=1),
*create_call_function_ex(True),
*create_call_function_ex(True, False),
]
)
@ -487,7 +487,7 @@ class MapVariable(ZipVariable):
codegen.extend_output(
[
create_build_tuple(len(self.iterables) + 1),
*create_call_function_ex(False),
*create_call_function_ex(False, False),
]
)

View File

@ -1579,7 +1579,7 @@ class StringFormatVariable(VariableTracker):
variables.ConstantVariable.create(k): v for k, v in self.sym_kwargs.items()
}
codegen(variables.ConstDictVariable(kwargs))
codegen.extend_output(create_call_function_ex(True))
codegen.extend_output(create_call_function_ex(True, False))
class DebuggingVariable(VariableTracker):