[dynamo, nested graph breaks] implement new resume frame stack/locals/cell layout convention (#157971)

The comments/conventions are not exactly correct here, as the implementation at this PR is partial. They will be fixed in #160138.

No tests added, since there shouldn't be any overall semantic changes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157971
Approved by: https://github.com/anijain2305
This commit is contained in:
William Wen
2025-08-25 13:27:40 -07:00
committed by PyTorch MergeBot
parent 4e19c1906a
commit 2df9b437e3
5 changed files with 550 additions and 206 deletions

View File

@ -246,9 +246,21 @@ def create_rot_n(n: int) -> list[Instruction]:
# e.g. rotate 3 is equivalent to swap 3, swap 2
return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)]
# ensure desired rotate function exists
# ROT_N does not exist in Python <= 3.9, but we can simulate it
if sys.version_info < (3, 10) and n >= 5:
raise AttributeError(f"rotate {n} not supported for Python < 3.10")
"""
0 1 2 3 4
[0 1 2 3 4]
4 3 2 1 0
4 [3 2 1 0]
4 0 1 2 3
"""
return [
create_instruction("BUILD_TUPLE", arg=n),
create_instruction("UNPACK_SEQUENCE", arg=n),
create_instruction("BUILD_TUPLE", arg=n - 1),
create_instruction("UNPACK_SEQUENCE", arg=n - 1),
]
if n <= 4:
return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])]
@ -428,6 +440,10 @@ def create_swap(n: int) -> list[Instruction]:
# in Python < 3.11, SWAP is a macro that expands to multiple instructions
if n == 1:
return []
elif n == 2:
return [create_instruction("ROT_TWO")]
elif n == 3:
return [create_instruction("ROT_THREE"), create_instruction("ROT_TWO")]
"""
e.g. swap "a" and "b" in this stack:
0 a 1 2 3 b
@ -464,6 +480,38 @@ def create_swap(n: int) -> list[Instruction]:
]
def create_binary_slice(
start: Optional[int], end: Optional[int], store: bool = False
) -> list[Instruction]:
"""
BINARY_SLICE and STORE_SLICE (if `set` is True) for all Python versions
"""
if sys.version_info >= (3, 12):
inst_name = "STORE_SLICE" if store else "BINARY_SLICE"
return [
create_load_const(start),
create_load_const(end),
create_instruction(inst_name),
]
else:
inst_name = "STORE_SUBSCR" if store else "BINARY_SUBSCR"
return [
create_load_const(start),
create_load_const(end),
create_instruction("BUILD_SLICE", arg=2),
create_instruction(inst_name),
]
def create_reverse(n: int) -> list[Instruction]:
# Reverse the top n values on the stack
# UNPACK_SEQUENCE reverses the sequence
return [
create_instruction("BUILD_TUPLE", arg=n),
create_instruction("UNPACK_SEQUENCE", arg=n),
]
def lnotab_writer(
lineno: int, byteno: int = 0
) -> tuple[list[int], Callable[[int, int], None]]:

View File

@ -448,6 +448,10 @@ inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated]
justknob="pytorch/compiler:inline_inbuilt_nn_modules",
)
# Resume tracing in nested frames if a nested graph break occurs
# Old behavior is to bubble up the graph break to the top level frame.
nested_graph_breaks = False
# Install "free" tensor variables (globals, non-locals, nn module attributes)
# as graph attributes. This is useful for export, as it
# produces a consistent number of inputs to the graph.

View File

@ -77,9 +77,13 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import config, exc, logging as torchdynamo_logging, variables
from .backends.registry import CompiledFn, CompilerFn
from .bytecode_transformation import (
create_binary_slice,
create_call_function,
create_dup_top,
create_instruction,
create_load_const,
create_rot_n,
create_swap,
Instruction,
unique_id,
)
@ -146,7 +150,7 @@ from .variables.builder import (
)
from .variables.ctx_manager import ContextWrappingVariable
from .variables.lists import BaseListVariable
from .variables.misc import CellVariable, NullVariable
from .variables.misc import NullVariable
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
@ -348,6 +352,11 @@ class StackLocalsMetadata:
Stores metadata for a frame's stack and locals for the purposes of building resume functions
"""
num_stack: int = 0 # number of stack elements, minus removed NULLs
locals_names: dict[str, int] = dc_field(
default_factory=dict
) # order of locals codegen'd to the stack
cell_and_freevars: dict[str, int] = dc_field(default_factory=dict)
stack_null_idxes: list[int] = dc_field(default_factory=list)
locals_null_keys: list[str] = dc_field(default_factory=list)
stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
@ -1186,7 +1195,7 @@ class OutputGraph(OutputGraphGuardsState):
def _get_stack_values_to_restore(
self, tx: "InstructionTranslatorBase", stack_pops: int
) -> tuple[list[VariableTracker], list[str], StackLocalsMetadata]:
) -> tuple[list[VariableTracker], StackLocalsMetadata]:
"""
Gets the stack + locals values belonging to tx that need to be restored.
@ -1198,7 +1207,6 @@ class OutputGraph(OutputGraphGuardsState):
Returns:
- stack_values: stack and locals values that need to be restored
- restore_vars: names of locals corresponding to the locals part of `stack_values`
- meta: locations of NULLs and ContextWrappingVariables in the stack/locals
(ignores the top `stack_pops` values on the stack)
"""
@ -1227,9 +1235,13 @@ class OutputGraph(OutputGraphGuardsState):
meta.stack_ctx_args.append((len(stack_values) - 1, target_values))
meta.stack_ctx_idxes_orig.append(i)
# Add all the local vars to the "stack" so restore at the end
restore_vars: list[str] = []
val_to_names: dict[VariableTracker, list[str]] = {}
meta.num_stack = len(stack_values)
cell_and_freevars = dict.fromkeys(tx.cellvars() + tx.freevars())
meta.cell_and_freevars = {
name: i for i, name in enumerate(cell_and_freevars.keys())
}
# NB: Typically (i.e., for graph compile from RETURN_VALUE),
# symbolic_locals will be empty at this point, as prune_dead_locals
# will clear out all of symbolic_locals because RETURN_VALUE is the
@ -1244,12 +1256,19 @@ class OutputGraph(OutputGraphGuardsState):
# This will in turn result in spurious variables showing up in the graph.
# This was very tricky to debug. For an example, dump the graph at call_user_compiler
# while running test_subgraphs.py
if isinstance(v.source, LocalSource) and v.source.local_name == k:
continue # no need to restore initial state
if isinstance(v, CellVariable) and v.local_name == k:
continue # no need to restore initial state
# Do not load unmodified locals (load them at a later time) from the top frame
if (
isinstance(v.source, LocalSource)
and v.source.local_name == k
and tx is self.root_tx
):
continue
# Do not load cell/free vars
if k in meta.cell_and_freevars:
continue
# Do not load variable if it is NULL.
if sys.version_info >= (3, 12):
# NOTE: do not use isinstance, since it realizes lazy VT's
# Continuation function will load the NULL for v.
if type.__instancecheck__(NullVariable, v):
meta.locals_null_keys.append(k)
@ -1257,19 +1276,15 @@ class OutputGraph(OutputGraphGuardsState):
else:
# A variable should never be NULL in < 3.12
assert not type.__instancecheck__(NullVariable, v)
meta.locals_names[k] = len(meta.locals_names)
if isinstance(v, ContextWrappingVariable):
target_values = (
() if v.target_values is None else tuple(v.target_values)
)
meta.locals_ctx_args.append((k, target_values))
if v not in val_to_names:
val_to_names[v] = []
val_to_names[v].append(k)
for v in val_to_names.keys():
restore_vars.extend(val_to_names[v])
stack_values.extend([v] * len(val_to_names[v]))
stack_values.append(v)
return stack_values, restore_vars, meta
return stack_values, meta
def compile_subgraph(
self,
@ -1295,8 +1310,8 @@ class OutputGraph(OutputGraphGuardsState):
assert self.root_tx is not None
# FIXME temporary assert to make sure we're not accidentally compiling nested graph breaks
# before we're done the full implementation
if not config.nested_graph_breaks:
# expect to only compile 1 frame
assert self.root_tx is tx
# bytecode tracing has finished. Pop the context manager for dynamo_timed
@ -1311,12 +1326,8 @@ class OutputGraph(OutputGraphGuardsState):
# prefix instructions (Python 3.11+)
prefix_insts: list[Instruction] = []
if sys.version_info >= (3, 11):
for inst in tx.prefix_insts:
if inst.opname == "MAKE_CELL":
prefix_insts.append(
create_instruction("MAKE_CELL", argval=inst.argval)
)
elif inst.opname == "COPY_FREE_VARS":
for inst in self.root_tx.prefix_insts:
if inst.opname == "COPY_FREE_VARS":
prefix_insts.append(
create_instruction(
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
@ -1324,6 +1335,26 @@ class OutputGraph(OutputGraphGuardsState):
)
else:
prefix_insts.append(copy.copy(inst))
# stack values and restore vars for each frame are pushed in reverse order
# i.e. last element corresponds to root frame, first element corresponds to current frame
all_stack_values = []
all_stack_locals_metas = []
cur_tx: Optional[InstructionTranslatorBase] = tx
while True:
assert cur_tx is not None
# this should have been checked by the caller
assert all(block.can_restore() for block in cur_tx.block_stack)
stack_values, meta = self._get_stack_values_to_restore(
cur_tx, stack_pops if cur_tx is tx else 0
)
all_stack_values.append(stack_values)
all_stack_locals_metas.append(meta)
if cur_tx is self.root_tx:
break
cur_tx = cur_tx.parent
self.add_output_instructions(prefix_insts)
assert not (self.pregraph_bytecode and self.export), (
@ -1342,26 +1373,6 @@ class OutputGraph(OutputGraphGuardsState):
self.cleanup_graph()
# stack values and restore vars for each frame are pushed in reverse order
# i.e. last element corresponds to root frame, first element corresponds to current frame
all_stack_values = []
all_restore_vars = []
all_stack_locals_metas = []
cur_tx: Optional[InstructionTranslatorBase] = tx
while True:
assert cur_tx is not None
# this should have been checked by the caller
assert all(block.can_restore() for block in cur_tx.block_stack)
stack_values, restore_vars, meta = self._get_stack_values_to_restore(
cur_tx, stack_pops
)
all_stack_values.append(stack_values)
all_restore_vars.append(restore_vars)
all_stack_locals_metas.append(meta)
if cur_tx is self.root_tx:
break
cur_tx = tx.parent
# Use nn.Module "proxies" in the constructed GraphModule so that
# the resulting GM does not hold additional strong references to the original modules.
# This prevents a strong ref cycle where Dynamo created code holds on to references
@ -1396,13 +1407,44 @@ class OutputGraph(OutputGraphGuardsState):
)
self.add_output_instructions(random_calls_instructions)
# call compiled fx graph
graph_output_var = None
# FIXME: right now not dealing with cells because they're difficult to deal with
# codegen stack convention before the unsupported instruction
# NOTE: in this comment block, "cell" refers to a Python cell object - i.e. free and cell vars
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ], top stack_pops values of frame N
# codegen stack convention after the unsupported instruction
# before calling resume function
# NOTE: need to push result of unsupported instruction to frame N stack
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# ], frame 1 stack + frame 1 non-cell locals
# (frame 1 cells should be loaded into the continuation function directly
# as part of the closure)
# NOTE: move the top stack_pops values from frame N to the beginning of the flat list.
# This is to prevent packing NULLs into a list.
cur_num_stack = all_stack_locals_metas[0].num_stack
stack_values_flat = (
all_stack_values[0][cur_num_stack - stack_pops : cur_num_stack]
+ all_stack_values[0][: cur_num_stack - stack_pops]
+ all_stack_values[0][cur_num_stack:]
+ [val for vals in all_stack_values[1:] for val in vals]
)
stored_graph_output_var = False
root_stack_values = all_stack_values[-1]
graph_output_var = None
# call compiled fx graph and codegen everything - stack, locals, cells
if (
self.root_tx is tx
and root_stack_values
self.root_tx is tx # single frame
and stack_values_flat
and all(
not isinstance(
v,
@ -1413,10 +1455,10 @@ class OutputGraph(OutputGraphGuardsState):
),
)
and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
for v in root_stack_values
for v in stack_values_flat
)
and all(isinstance(x, TensorVariable) for x in root_stack_values)
and len(set(root_stack_values)) == len(root_stack_values)
and all(isinstance(x, TensorVariable) for x in stack_values_flat)
and len(set(stack_values_flat)) == len(stack_values_flat)
and self.side_effects.is_empty()
and not tx.debug_locals
and not self.backward_state
@ -1425,17 +1467,19 @@ class OutputGraph(OutputGraphGuardsState):
):
# optimization to generate better code in a common case
self.add_output_instructions(
self.compile_and_call_fx_graph(
tx, list(reversed(root_stack_values)), root
)
+ [create_instruction("UNPACK_SEQUENCE", arg=len(root_stack_values))]
[
# load in reverse since UNPACK_SEQUENCE will reverse
*self.compile_and_call_fx_graph(
tx, list(reversed(stack_values_flat)), root
),
create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
]
)
# function output will be moved to the correct places below
else:
graph_output_var = self.new_var("graph_out")
# load stack values in a flat manner for now - will likely change later.
stack_values_flat = [
val for vals in reversed(all_stack_values) for val in vals
]
# load stack values in a flat manner - we will codegen bytecode to place them correctly
# according to our convention above
pass1 = PyCodegen(
self.root_tx,
root,
@ -1479,21 +1523,115 @@ class OutputGraph(OutputGraphGuardsState):
self.run_compiler_collective()
self.add_output_instructions(output + pass2.get_instructions())
# restore all the live local vars of the root
local_restore_cg = PyCodegen(
self.root_tx, overridden_sources=overridden_sources
)
# TODO this local restoration should be removed when fully implementing nested graph breaks
# store all stack, locals, cells for each frame
# current state of the stack:
# *(top stack_pops values), *(remaining stack_values_flat)
self.add_output_instructions(
[
local_restore_cg.create_store(var)
for var in reversed(all_restore_vars[-1])
create_instruction(
"BUILD_LIST", arg=len(stack_values_flat) - stack_pops
),
]
)
# iterate current frame to root frame
# sliding window over frame stack/locals/cells
start_idx = 0
end_idx = 0
for i, meta in enumerate(all_stack_locals_metas):
# stack, locals, cells
# account for removed stack_pops values in current frame
num_stack = meta.num_stack - stack_pops if i == 0 else meta.num_stack
counts = (
num_stack,
len(meta.locals_names),
# len(meta.cell_and_freevars),
)
self.add_output_instructions([create_dup_top()])
# values, values
for j, cnt in enumerate(counts):
end_idx += cnt
if start_idx == end_idx:
self.add_output_instructions(
[
create_instruction("BUILD_LIST", arg=0),
*create_swap(2),
]
)
# [], values
else:
self.add_output_instructions(
[
create_dup_top(),
*create_binary_slice(start_idx, end_idx),
*create_swap(2),
]
)
# values[x:y], values
# add root frame's unmodified locals here
if i == len(all_stack_locals_metas) - 1 and j == 1:
root_cg = PyCodegen(self.root_tx)
unmodified_locals_names: dict[str, int] = {}
for k, v in self.root_tx.symbolic_locals.items():
if (
isinstance(v.source, LocalSource)
and v.source.local_name == k
):
root_cg.append_output(root_cg.create_load(k))
unmodified_locals_names[k] = len(meta.locals_names) + len(
unmodified_locals_names
)
self.add_output_instructions(
root_cg.get_instructions()
+ [
create_instruction(
"BUILD_LIST", arg=len(unmodified_locals_names)
),
# arg=2 because we already swapped the locals list back
create_instruction("LIST_EXTEND", arg=2),
]
)
meta.locals_names.update(unmodified_locals_names)
start_idx += cnt
# pack stack, locals, cells together
# values, stack, locals, cells, values
self.add_output_instructions(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_TUPLE", arg=2),
*create_swap(2),
]
)
# (stack, locals, cells), values
# current state of the stack:
# *(top stack_pops values),
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# stack_values_flat
#
self.add_output_instructions(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
*create_rot_n(stack_pops + 1),
]
)
# final state of the stack before running the unsupported bytecode:
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ], *(top stack_pops values of frame N)
if graph_output_var and stored_graph_output_var:
self.add_output_instructions(
[local_restore_cg.create_delete(graph_output_var)]
[create_instruction("DELETE_FAST", argval=graph_output_var)]
)
if self.export:

View File

@ -340,7 +340,8 @@ class ContinueExecutionCache:
) -> None:
meta.instructions = copy.deepcopy(instructions)
args = [f"___stack{i}" for i in range(nstack)]
args = ["__nested_frame_values"]
args += [f"___stack{i}" for i in range(nstack)]
args.extend(v for v in argnames if v not in args)
freevars = tuple(code_options["co_cellvars"] or []) + tuple(
code_options["co_freevars"] or []

View File

@ -73,8 +73,10 @@ from .bytecode_analysis import (
from .bytecode_transformation import (
cleaned_instructions,
create_call_function,
create_dup_top,
create_instruction,
create_jump_absolute,
create_reverse,
create_swap,
get_code_keys,
Instruction,
@ -668,12 +670,14 @@ def generic_jump(
self.pop()
if_next = self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata
self.next_instruction, 0, all_stack_locals_metadata
)
if push:
self.push(value)
assert inst.target is not None
if_jump = self.create_call_resume_at(inst.target, all_stack_locals_metadata)
if_jump = self.create_call_resume_at(
inst.target, int(push), all_stack_locals_metadata
)
if sys.version_info >= (3, 13):
# 3.13 requires stack[-1] to be bool type
@ -1006,7 +1010,7 @@ def break_graph_if_unsupported(
self.push(UnknownVariable())
self.output.add_output_instructions(
self.create_call_resume_at(
self.next_instruction, all_stack_locals_metadata
self.next_instruction, push, all_stack_locals_metadata
)
)
@ -1404,13 +1408,45 @@ class InstructionTranslatorBase(
# where we call step_graph_break right now is when the stack is empty,
# so let's enforce that for now.
assert not self.stack
self.output.compile_subgraph(
# NOTE: if we support non-empty self.stack in the future, the `stack_pops` argument
# below should be set to the stack length to ensure that the stack is codegen'd
# for the rest of the function.
all_stack_locals_metadata = self.output.compile_subgraph(
self,
partial_convert=True,
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
)
# load locals from frame values
# current frame state
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ],
cg = PyCodegen(self)
self.output.add_output_instructions(
[create_jump_absolute(continue_inst)] + self.instructions
[
cg.create_load_const(-1),
cg.create_binary_subscr(),
cg.create_load_const(1),
cg.create_binary_subscr(),
]
)
for local, idx in all_stack_locals_metadata[-1].locals_names.items():
self.output.add_output_instructions(
[
create_dup_top(),
cg.create_load_const(idx),
cg.create_binary_subscr(),
cg.create_store(local),
]
)
self.output.add_output_instructions(
[
create_instruction("POP_TOP"),
create_jump_absolute(continue_inst),
*self.instructions,
]
)
def run_ctx_mgr(self) -> Any:
@ -1510,7 +1546,7 @@ class InstructionTranslatorBase(
)
# for continuation functions
if name.startswith("__stack"):
if name.startswith("__stack") or name == "__nested_frame_values":
self.symbolic_locals.pop(name)
def LOAD_DEREF(self, inst: Instruction) -> None:
@ -2415,7 +2451,9 @@ class InstructionTranslatorBase(
self.output.add_output_instructions([copy.copy(inst)])
self.popn(2)
self.output.add_output_instructions(
self.create_call_resume_at(self.next_instruction, all_stack_locals_metadata)
self.create_call_resume_at(
self.next_instruction, 0, all_stack_locals_metadata
)
)
def DELETE_ATTR(self, inst: Instruction) -> None:
@ -2427,15 +2465,240 @@ class InstructionTranslatorBase(
)
def create_call_resume_at(
self, offset: Instruction, all_stack_locals_metadata: Any
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
) -> list[Instruction]:
raise AssertionError(
f"create_call_resume_at not overridden by subclass {type(self)}"
self.instruction_pointer = None
if inst.opname == "RETURN_VALUE":
return [create_instruction("RETURN_VALUE")]
elif inst.opname == "RETURN_CONST":
return [create_instruction("RETURN_CONST", argval=inst.argval)]
cg = PyCodegen(self)
# current frame state
# [
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
# ...,
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ], `push` values from running the unsupported instruction
# move the `push` stack values to the frame N stack
cg.extend_output(
[
create_instruction("BUILD_LIST", arg=push),
# frames_list, push_values_list
*create_swap(2),
create_dup_top(),
cg.create_load_const(0),
cg.create_binary_subscr(),
cg.create_load_const(0),
cg.create_binary_subscr(),
# push_values_list, frames_list, frames_list[0][0]
*create_swap(3),
# frames_list[0][0] += push_values_list
create_instruction("LIST_EXTEND", arg=2),
*create_swap(2),
# frames_list, frames_list[0][0]
create_instruction("POP_TOP"),
]
)
# current frame state
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
# ],
#
txes = []
cur_tx: Optional[InstructionTranslatorBase] = self
while cur_tx is not None:
txes.append(cur_tx)
cur_tx = cur_tx.parent
assert len(txes) == len(all_stack_locals_metadata)
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
for i, meta in enumerate(all_stack_locals_metadata):
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
# Replace the stack var with the context class
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
# frames[i][0][j] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(0),
cg.create_binary_subscr(),
cg.create_load_const(j),
create_instruction("STORE_SUBSCR"),
]
)
for name, _ in meta.locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
# frames[i][1][meta.locals_names[name]] = reconstructed_ctx
cg.append_output(create_dup_top())
ctx.reconstruct_type(cg)
cg.extend_output(
[
*create_swap(2),
cg.create_load_const(i),
cg.create_binary_subscr(),
cg.create_load_const(1),
cg.create_binary_subscr(),
cg.create_load_const(meta.locals_names[name]),
create_instruction("STORE_SUBSCR"),
]
)
name = unique_id(f"__resume_at_{inst.offset}")
assert not config.nested_graph_breaks, "NYI"
# more locals may have been pruned after the unsupported instruction (e.g. branch)
reads = livevars_analysis(self.instructions, inst)
all_argnames = tuple(
k
for k in self.symbolic_locals.keys()
if k in reads and k not in self.cell_and_freevars()
)
argnames_null_set = set(all_stack_locals_metadata[-1].locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(self.stack) - len(
all_stack_locals_metadata[-1].stack_null_idxes
)
nargs = stack_len + len(argnames)
new_code: types.CodeType = ContinueExecutionCache.lookup(
self.f_code,
self.lineno,
inst.offset,
tuple(b.target.offset for b in self.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
tuple(all_stack_locals_metadata[-1].stack_ctx_args),
tuple(all_stack_locals_metadata[-1].locals_ctx_args),
tuple(all_stack_locals_metadata[-1].stack_null_idxes),
)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
if new_code.co_freevars:
# expose code object for debugging purposes
self.output.install_global_unsafe(name, new_code)
cg.make_function_with_closure(name, new_code, True, 1)
package_name = None
else:
# This is safe: we pre-generate a unique name
self.output.install_global_unsafe(
name, types.FunctionType(new_code, self.f_globals, name)
)
cg.extend_output(cg.load_function_name(name, True, 1))
package_name = name
if self.package is not None:
self.package.add_resume_function(
new_code, self.f_globals["__name__"], package_name
)
# load top level-frame; final stack state should be:
# [
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
# ...,
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
# ], frame 1 stack + frame 1 non-cell locals
cg.extend_output(
[
create_dup_top(),
create_dup_top(),
# frames, frames, frames
cg.create_load_const(-1),
cg.create_binary_subscr(),
# frames, frames, frames[-1]
*create_swap(2),
# frames, frames[-1], frames
cg.create_load_const(-1),
create_instruction("DELETE_SUBSCR"),
# del frames[-1]; stack: frames, frames[-1]
create_dup_top(),
cg.create_load_const(0),
cg.create_binary_subscr(),
# frames, frames[-1], frames[-1][0]
*create_swap(2),
cg.create_load_const(1),
cg.create_binary_subscr(),
]
)
# frames, frames[-1][0], frames[-1][1]
for name in argnames:
cg.extend_output(
[
create_dup_top(),
cg.create_load_const(
all_stack_locals_metadata[-1].locals_names[name]
),
cg.create_binary_subscr(),
*create_swap(2),
],
)
# frames, frames[-1][0], *(live locals), frames[-1][1]
cg.extend_output(
[
create_instruction("POP_TOP"),
create_instruction("BUILD_LIST", arg=len(argnames)),
create_instruction("LIST_EXTEND", arg=1),
# UNPACK_SEQUENCE reverses elements
create_instruction("UNPACK_SEQUENCE", arg=nargs),
*create_reverse(nargs),
]
)
# frames, *(stack + live locals)
cg.extend_output(create_call_function(nargs + 1, False))
cg.append_output(create_instruction("RETURN_VALUE"))
return cg.get_instructions()
def should_compile_partial_graph(self) -> bool:
raise AssertionError(
f"should_compile_partial_graph not overridden by subclass {type(self)}"
if sys.version_info >= (3, 11):
# Do not compile if current instruction's block is not the top with block
entry = self.current_instruction.exn_tab_entry
if entry and (
not self.block_stack or entry.target is not self.block_stack[-1].target
):
return False
return (
all(b.can_restore() for b in self.block_stack)
and not self.one_graph
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
and not self.active_generic_context_managers
)
@break_graph_if_unsupported(push=0)
@ -3612,125 +3875,6 @@ class InstructionTranslator(InstructionTranslatorBase):
return self.f_globals[source.global_name]
raise KeyError
def run(self) -> None:
super().run()
def should_compile_partial_graph(self) -> bool:
if sys.version_info >= (3, 11):
# Do not compile if current instruction's block is not the top with block
entry = self.current_instruction.exn_tab_entry
if entry and (
not self.block_stack or entry.target is not self.block_stack[-1].target
):
return False
return (
all(b.can_restore() for b in self.block_stack)
and not self.one_graph
and not self.error_on_graph_break
and not self.is_tracing_resume_prologue
and not self.active_generic_context_managers
)
def create_call_resume_at(
self, inst: Instruction, all_stack_locals_metadata: Any
) -> list[Instruction]:
self.instruction_pointer = None
if inst.opname == "RETURN_VALUE":
return [create_instruction("RETURN_VALUE")]
elif inst.opname == "RETURN_CONST":
return [create_instruction("RETURN_CONST", argval=inst.argval)]
reads = livevars_analysis(self.instructions, inst)
all_argnames = tuple(
k
for k in self.symbolic_locals.keys()
if k in reads and k not in self.cell_and_freevars()
)
# NOTE: do not use isinstance, since it realizes lazy VT's
argnames_null_set = set(all_stack_locals_metadata[0].locals_null_keys)
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# compile_subgraph did not codegen any NULLs,
# so we should not count NullVariables
stack_len = len(self.stack) - len(all_stack_locals_metadata[0].stack_null_idxes)
nargs = stack_len + len(argnames)
cg = PyCodegen(self)
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
# result in silent incorrectness!
for (i, _), i_orig in zip(
all_stack_locals_metadata[0].stack_ctx_args,
all_stack_locals_metadata[0].stack_ctx_idxes_orig,
):
# Replace the current stack var with the context class
ctx = cast(ContextWrappingVariable, self.stack[i_orig])
ctx.reconstruct_type(cg)
cg.extend_output(create_swap(stack_len - i + 1))
cg.append_output(create_instruction("POP_TOP"))
for name, _ in all_stack_locals_metadata[0].locals_ctx_args:
# Replace the local with the context class
ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
ctx.reconstruct_type(cg)
cg.append_output(create_instruction("STORE_FAST", argval=name))
name = unique_id(f"__resume_at_{inst.offset}", with_uuid=True)
new_code: types.CodeType = ContinueExecutionCache.lookup(
self.f_code,
self.lineno,
inst.offset,
tuple(b.target.offset for b in self.block_stack),
stack_len,
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
tuple(all_stack_locals_metadata[0].stack_ctx_args),
tuple(all_stack_locals_metadata[0].locals_ctx_args),
tuple(all_stack_locals_metadata[0].stack_null_idxes),
)
# Add original GraphModule context to the resume function to handle
# the case of a graph break while tracing a GraphModule
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
"orig_graphmodule", lambda: None
)()
if orig_graphmodule_maybe is not None:
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
orig_graphmodule_maybe
)
if new_code.co_freevars:
# expose code object for debugging purposes
self.output.install_global_unsafe(name, new_code)
cg.make_function_with_closure(name, new_code, True, stack_len)
package_name = None
else:
# This is safe: we pre-generate a unique name
self.output.install_global_unsafe(
name, types.FunctionType(new_code, self.f_globals, name)
)
cg.extend_output(cg.load_function_name(name, True, stack_len))
package_name = name
if self.package is not None:
if self.output.package is not None:
self.package.add_resume_function(
new_code, self.f_globals["__name__"], function_name=package_name
)
cg.extend_output([cg.create_load(k) for k in argnames])
cg.extend_output(create_call_function(nargs, False))
cg.append_output(create_instruction("RETURN_VALUE"))
return cg.get_instructions()
def symbolic_locals_contain_module_class(self) -> bool:
for v in self.symbolic_locals.values():
if isinstance(v, UserDefinedClassVariable) and issubclass(
@ -3781,6 +3925,8 @@ class InstructionTranslator(InstructionTranslatorBase):
reason=GraphCompileReason(
"return_value", [self.frame_summary()], graph_break=False
),
# the value to be returned
stack_pops=1 if inst.opname == "RETURN_VALUE" else 0,
)
# check that our stack/locals meta are correct:
# we should only be tracing 1 frame, and there should not be any NULLs on the stack
@ -3791,6 +3937,7 @@ class InstructionTranslator(InstructionTranslatorBase):
if inst.opname == "RETURN_VALUE"
else create_instruction("RETURN_CONST", argval=inst.argval)
)
# NOTE: does the stack need to be empty after the return?
self.output.add_output_instructions([return_inst])
raise ReturnValueOp
@ -4147,11 +4294,17 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
return TracingContext.current_frame(self.parent.frame_summary())
def should_compile_partial_graph(self) -> bool:
if config.nested_graph_breaks:
if not self.parent.should_compile_partial_graph():
return False
return super().should_compile_partial_graph()
return False # inlining functions is all-or-nothing
def create_call_resume_at(
self, inst: Instruction, all_stack_locals_metadata: Any
) -> NoReturn:
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
) -> list[Instruction]:
if config.nested_graph_breaks:
return super().create_call_resume_at(inst, push, all_stack_locals_metadata)
unimplemented_v2(
gb_type="Graph break in inlined function",
context="",