mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo, nested graph breaks] implement new resume frame stack/locals/cell layout convention (#157971)
The comments/conventions are not exactly correct here, as the implementation at this PR is partial. They will be fixed in #160138. No tests added, since there shouldn't be any overall semantic changes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/157971 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
4e19c1906a
commit
2df9b437e3
@ -246,9 +246,21 @@ def create_rot_n(n: int) -> list[Instruction]:
|
||||
# e.g. rotate 3 is equivalent to swap 3, swap 2
|
||||
return [create_instruction("SWAP", arg=i) for i in range(n, 1, -1)]
|
||||
|
||||
# ensure desired rotate function exists
|
||||
# ROT_N does not exist in Python <= 3.9, but we can simulate it
|
||||
if sys.version_info < (3, 10) and n >= 5:
|
||||
raise AttributeError(f"rotate {n} not supported for Python < 3.10")
|
||||
"""
|
||||
0 1 2 3 4
|
||||
[0 1 2 3 4]
|
||||
4 3 2 1 0
|
||||
4 [3 2 1 0]
|
||||
4 0 1 2 3
|
||||
"""
|
||||
return [
|
||||
create_instruction("BUILD_TUPLE", arg=n),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=n),
|
||||
create_instruction("BUILD_TUPLE", arg=n - 1),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=n - 1),
|
||||
]
|
||||
|
||||
if n <= 4:
|
||||
return [create_instruction("ROT_" + ["TWO", "THREE", "FOUR"][n - 2])]
|
||||
@ -428,6 +440,10 @@ def create_swap(n: int) -> list[Instruction]:
|
||||
# in Python < 3.11, SWAP is a macro that expands to multiple instructions
|
||||
if n == 1:
|
||||
return []
|
||||
elif n == 2:
|
||||
return [create_instruction("ROT_TWO")]
|
||||
elif n == 3:
|
||||
return [create_instruction("ROT_THREE"), create_instruction("ROT_TWO")]
|
||||
"""
|
||||
e.g. swap "a" and "b" in this stack:
|
||||
0 a 1 2 3 b
|
||||
@ -464,6 +480,38 @@ def create_swap(n: int) -> list[Instruction]:
|
||||
]
|
||||
|
||||
|
||||
def create_binary_slice(
|
||||
start: Optional[int], end: Optional[int], store: bool = False
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
BINARY_SLICE and STORE_SLICE (if `set` is True) for all Python versions
|
||||
"""
|
||||
if sys.version_info >= (3, 12):
|
||||
inst_name = "STORE_SLICE" if store else "BINARY_SLICE"
|
||||
return [
|
||||
create_load_const(start),
|
||||
create_load_const(end),
|
||||
create_instruction(inst_name),
|
||||
]
|
||||
else:
|
||||
inst_name = "STORE_SUBSCR" if store else "BINARY_SUBSCR"
|
||||
return [
|
||||
create_load_const(start),
|
||||
create_load_const(end),
|
||||
create_instruction("BUILD_SLICE", arg=2),
|
||||
create_instruction(inst_name),
|
||||
]
|
||||
|
||||
|
||||
def create_reverse(n: int) -> list[Instruction]:
|
||||
# Reverse the top n values on the stack
|
||||
# UNPACK_SEQUENCE reverses the sequence
|
||||
return [
|
||||
create_instruction("BUILD_TUPLE", arg=n),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=n),
|
||||
]
|
||||
|
||||
|
||||
def lnotab_writer(
|
||||
lineno: int, byteno: int = 0
|
||||
) -> tuple[list[int], Callable[[int, int], None]]:
|
||||
|
@ -448,6 +448,10 @@ inline_inbuilt_nn_modules = Config( # type: ignore[var-annotated]
|
||||
justknob="pytorch/compiler:inline_inbuilt_nn_modules",
|
||||
)
|
||||
|
||||
# Resume tracing in nested frames if a nested graph break occurs
|
||||
# Old behavior is to bubble up the graph break to the top level frame.
|
||||
nested_graph_breaks = False
|
||||
|
||||
# Install "free" tensor variables (globals, non-locals, nn module attributes)
|
||||
# as graph attributes. This is useful for export, as it
|
||||
# produces a consistent number of inputs to the graph.
|
||||
|
@ -77,9 +77,13 @@ from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from . import config, exc, logging as torchdynamo_logging, variables
|
||||
from .backends.registry import CompiledFn, CompilerFn
|
||||
from .bytecode_transformation import (
|
||||
create_binary_slice,
|
||||
create_call_function,
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_load_const,
|
||||
create_rot_n,
|
||||
create_swap,
|
||||
Instruction,
|
||||
unique_id,
|
||||
)
|
||||
@ -146,7 +150,7 @@ from .variables.builder import (
|
||||
)
|
||||
from .variables.ctx_manager import ContextWrappingVariable
|
||||
from .variables.lists import BaseListVariable
|
||||
from .variables.misc import CellVariable, NullVariable
|
||||
from .variables.misc import NullVariable
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
NumpyNdarrayVariable,
|
||||
@ -348,6 +352,11 @@ class StackLocalsMetadata:
|
||||
Stores metadata for a frame's stack and locals for the purposes of building resume functions
|
||||
"""
|
||||
|
||||
num_stack: int = 0 # number of stack elements, minus removed NULLs
|
||||
locals_names: dict[str, int] = dc_field(
|
||||
default_factory=dict
|
||||
) # order of locals codegen'd to the stack
|
||||
cell_and_freevars: dict[str, int] = dc_field(default_factory=dict)
|
||||
stack_null_idxes: list[int] = dc_field(default_factory=list)
|
||||
locals_null_keys: list[str] = dc_field(default_factory=list)
|
||||
stack_ctx_args: list[tuple[int, tuple[Any, ...]]] = dc_field(default_factory=list)
|
||||
@ -1186,7 +1195,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
def _get_stack_values_to_restore(
|
||||
self, tx: "InstructionTranslatorBase", stack_pops: int
|
||||
) -> tuple[list[VariableTracker], list[str], StackLocalsMetadata]:
|
||||
) -> tuple[list[VariableTracker], StackLocalsMetadata]:
|
||||
"""
|
||||
Gets the stack + locals values belonging to tx that need to be restored.
|
||||
|
||||
@ -1198,7 +1207,6 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
Returns:
|
||||
- stack_values: stack and locals values that need to be restored
|
||||
- restore_vars: names of locals corresponding to the locals part of `stack_values`
|
||||
- meta: locations of NULLs and ContextWrappingVariables in the stack/locals
|
||||
(ignores the top `stack_pops` values on the stack)
|
||||
"""
|
||||
@ -1227,9 +1235,13 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
meta.stack_ctx_args.append((len(stack_values) - 1, target_values))
|
||||
meta.stack_ctx_idxes_orig.append(i)
|
||||
|
||||
# Add all the local vars to the "stack" so restore at the end
|
||||
restore_vars: list[str] = []
|
||||
val_to_names: dict[VariableTracker, list[str]] = {}
|
||||
meta.num_stack = len(stack_values)
|
||||
|
||||
cell_and_freevars = dict.fromkeys(tx.cellvars() + tx.freevars())
|
||||
meta.cell_and_freevars = {
|
||||
name: i for i, name in enumerate(cell_and_freevars.keys())
|
||||
}
|
||||
|
||||
# NB: Typically (i.e., for graph compile from RETURN_VALUE),
|
||||
# symbolic_locals will be empty at this point, as prune_dead_locals
|
||||
# will clear out all of symbolic_locals because RETURN_VALUE is the
|
||||
@ -1244,12 +1256,19 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# This will in turn result in spurious variables showing up in the graph.
|
||||
# This was very tricky to debug. For an example, dump the graph at call_user_compiler
|
||||
# while running test_subgraphs.py
|
||||
if isinstance(v.source, LocalSource) and v.source.local_name == k:
|
||||
continue # no need to restore initial state
|
||||
if isinstance(v, CellVariable) and v.local_name == k:
|
||||
continue # no need to restore initial state
|
||||
# Do not load unmodified locals (load them at a later time) from the top frame
|
||||
if (
|
||||
isinstance(v.source, LocalSource)
|
||||
and v.source.local_name == k
|
||||
and tx is self.root_tx
|
||||
):
|
||||
continue
|
||||
# Do not load cell/free vars
|
||||
if k in meta.cell_and_freevars:
|
||||
continue
|
||||
# Do not load variable if it is NULL.
|
||||
if sys.version_info >= (3, 12):
|
||||
# NOTE: do not use isinstance, since it realizes lazy VT's
|
||||
# Continuation function will load the NULL for v.
|
||||
if type.__instancecheck__(NullVariable, v):
|
||||
meta.locals_null_keys.append(k)
|
||||
@ -1257,19 +1276,15 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
else:
|
||||
# A variable should never be NULL in < 3.12
|
||||
assert not type.__instancecheck__(NullVariable, v)
|
||||
meta.locals_names[k] = len(meta.locals_names)
|
||||
if isinstance(v, ContextWrappingVariable):
|
||||
target_values = (
|
||||
() if v.target_values is None else tuple(v.target_values)
|
||||
)
|
||||
meta.locals_ctx_args.append((k, target_values))
|
||||
if v not in val_to_names:
|
||||
val_to_names[v] = []
|
||||
val_to_names[v].append(k)
|
||||
for v in val_to_names.keys():
|
||||
restore_vars.extend(val_to_names[v])
|
||||
stack_values.extend([v] * len(val_to_names[v]))
|
||||
stack_values.append(v)
|
||||
|
||||
return stack_values, restore_vars, meta
|
||||
return stack_values, meta
|
||||
|
||||
def compile_subgraph(
|
||||
self,
|
||||
@ -1295,8 +1310,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
assert self.root_tx is not None
|
||||
|
||||
# FIXME temporary assert to make sure we're not accidentally compiling nested graph breaks
|
||||
# before we're done the full implementation
|
||||
if not config.nested_graph_breaks:
|
||||
# expect to only compile 1 frame
|
||||
assert self.root_tx is tx
|
||||
|
||||
# bytecode tracing has finished. Pop the context manager for dynamo_timed
|
||||
@ -1311,12 +1326,8 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
# prefix instructions (Python 3.11+)
|
||||
prefix_insts: list[Instruction] = []
|
||||
if sys.version_info >= (3, 11):
|
||||
for inst in tx.prefix_insts:
|
||||
if inst.opname == "MAKE_CELL":
|
||||
prefix_insts.append(
|
||||
create_instruction("MAKE_CELL", argval=inst.argval)
|
||||
)
|
||||
elif inst.opname == "COPY_FREE_VARS":
|
||||
for inst in self.root_tx.prefix_insts:
|
||||
if inst.opname == "COPY_FREE_VARS":
|
||||
prefix_insts.append(
|
||||
create_instruction(
|
||||
"COPY_FREE_VARS", arg=len(tx.code_options["co_freevars"])
|
||||
@ -1324,6 +1335,26 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
)
|
||||
else:
|
||||
prefix_insts.append(copy.copy(inst))
|
||||
|
||||
# stack values and restore vars for each frame are pushed in reverse order
|
||||
# i.e. last element corresponds to root frame, first element corresponds to current frame
|
||||
all_stack_values = []
|
||||
all_stack_locals_metas = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = tx
|
||||
while True:
|
||||
assert cur_tx is not None
|
||||
# this should have been checked by the caller
|
||||
assert all(block.can_restore() for block in cur_tx.block_stack)
|
||||
|
||||
stack_values, meta = self._get_stack_values_to_restore(
|
||||
cur_tx, stack_pops if cur_tx is tx else 0
|
||||
)
|
||||
all_stack_values.append(stack_values)
|
||||
all_stack_locals_metas.append(meta)
|
||||
if cur_tx is self.root_tx:
|
||||
break
|
||||
cur_tx = cur_tx.parent
|
||||
|
||||
self.add_output_instructions(prefix_insts)
|
||||
|
||||
assert not (self.pregraph_bytecode and self.export), (
|
||||
@ -1342,26 +1373,6 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
|
||||
self.cleanup_graph()
|
||||
|
||||
# stack values and restore vars for each frame are pushed in reverse order
|
||||
# i.e. last element corresponds to root frame, first element corresponds to current frame
|
||||
all_stack_values = []
|
||||
all_restore_vars = []
|
||||
all_stack_locals_metas = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = tx
|
||||
while True:
|
||||
assert cur_tx is not None
|
||||
# this should have been checked by the caller
|
||||
assert all(block.can_restore() for block in cur_tx.block_stack)
|
||||
stack_values, restore_vars, meta = self._get_stack_values_to_restore(
|
||||
cur_tx, stack_pops
|
||||
)
|
||||
all_stack_values.append(stack_values)
|
||||
all_restore_vars.append(restore_vars)
|
||||
all_stack_locals_metas.append(meta)
|
||||
if cur_tx is self.root_tx:
|
||||
break
|
||||
cur_tx = tx.parent
|
||||
|
||||
# Use nn.Module "proxies" in the constructed GraphModule so that
|
||||
# the resulting GM does not hold additional strong references to the original modules.
|
||||
# This prevents a strong ref cycle where Dynamo created code holds on to references
|
||||
@ -1396,13 +1407,44 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
)
|
||||
self.add_output_instructions(random_calls_instructions)
|
||||
|
||||
# call compiled fx graph
|
||||
graph_output_var = None
|
||||
# FIXME: right now not dealing with cells because they're difficult to deal with
|
||||
# codegen stack convention before the unsupported instruction
|
||||
# NOTE: in this comment block, "cell" refers to a Python cell object - i.e. free and cell vars
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ], top stack_pops values of frame N
|
||||
|
||||
# codegen stack convention after the unsupported instruction
|
||||
# before calling resume function
|
||||
# NOTE: need to push result of unsupported instruction to frame N stack
|
||||
# [
|
||||
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
|
||||
# ], frame 1 stack + frame 1 non-cell locals
|
||||
|
||||
# (frame 1 cells should be loaded into the continuation function directly
|
||||
# as part of the closure)
|
||||
|
||||
# NOTE: move the top stack_pops values from frame N to the beginning of the flat list.
|
||||
# This is to prevent packing NULLs into a list.
|
||||
|
||||
cur_num_stack = all_stack_locals_metas[0].num_stack
|
||||
stack_values_flat = (
|
||||
all_stack_values[0][cur_num_stack - stack_pops : cur_num_stack]
|
||||
+ all_stack_values[0][: cur_num_stack - stack_pops]
|
||||
+ all_stack_values[0][cur_num_stack:]
|
||||
+ [val for vals in all_stack_values[1:] for val in vals]
|
||||
)
|
||||
stored_graph_output_var = False
|
||||
root_stack_values = all_stack_values[-1]
|
||||
graph_output_var = None
|
||||
|
||||
# call compiled fx graph and codegen everything - stack, locals, cells
|
||||
if (
|
||||
self.root_tx is tx
|
||||
and root_stack_values
|
||||
self.root_tx is tx # single frame
|
||||
and stack_values_flat
|
||||
and all(
|
||||
not isinstance(
|
||||
v,
|
||||
@ -1413,10 +1455,10 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
),
|
||||
)
|
||||
and not (isinstance(v, SymNodeVariable) and v.python_type() is float)
|
||||
for v in root_stack_values
|
||||
for v in stack_values_flat
|
||||
)
|
||||
and all(isinstance(x, TensorVariable) for x in root_stack_values)
|
||||
and len(set(root_stack_values)) == len(root_stack_values)
|
||||
and all(isinstance(x, TensorVariable) for x in stack_values_flat)
|
||||
and len(set(stack_values_flat)) == len(stack_values_flat)
|
||||
and self.side_effects.is_empty()
|
||||
and not tx.debug_locals
|
||||
and not self.backward_state
|
||||
@ -1425,17 +1467,19 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
):
|
||||
# optimization to generate better code in a common case
|
||||
self.add_output_instructions(
|
||||
self.compile_and_call_fx_graph(
|
||||
tx, list(reversed(root_stack_values)), root
|
||||
)
|
||||
+ [create_instruction("UNPACK_SEQUENCE", arg=len(root_stack_values))]
|
||||
[
|
||||
# load in reverse since UNPACK_SEQUENCE will reverse
|
||||
*self.compile_and_call_fx_graph(
|
||||
tx, list(reversed(stack_values_flat)), root
|
||||
),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=len(stack_values_flat)),
|
||||
]
|
||||
)
|
||||
# function output will be moved to the correct places below
|
||||
else:
|
||||
graph_output_var = self.new_var("graph_out")
|
||||
# load stack values in a flat manner for now - will likely change later.
|
||||
stack_values_flat = [
|
||||
val for vals in reversed(all_stack_values) for val in vals
|
||||
]
|
||||
# load stack values in a flat manner - we will codegen bytecode to place them correctly
|
||||
# according to our convention above
|
||||
pass1 = PyCodegen(
|
||||
self.root_tx,
|
||||
root,
|
||||
@ -1479,21 +1523,115 @@ class OutputGraph(OutputGraphGuardsState):
|
||||
self.run_compiler_collective()
|
||||
self.add_output_instructions(output + pass2.get_instructions())
|
||||
|
||||
# restore all the live local vars of the root
|
||||
local_restore_cg = PyCodegen(
|
||||
self.root_tx, overridden_sources=overridden_sources
|
||||
)
|
||||
# TODO this local restoration should be removed when fully implementing nested graph breaks
|
||||
# store all stack, locals, cells for each frame
|
||||
# current state of the stack:
|
||||
# *(top stack_pops values), *(remaining stack_values_flat)
|
||||
|
||||
self.add_output_instructions(
|
||||
[
|
||||
local_restore_cg.create_store(var)
|
||||
for var in reversed(all_restore_vars[-1])
|
||||
create_instruction(
|
||||
"BUILD_LIST", arg=len(stack_values_flat) - stack_pops
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# iterate current frame to root frame
|
||||
# sliding window over frame stack/locals/cells
|
||||
start_idx = 0
|
||||
end_idx = 0
|
||||
for i, meta in enumerate(all_stack_locals_metas):
|
||||
# stack, locals, cells
|
||||
# account for removed stack_pops values in current frame
|
||||
num_stack = meta.num_stack - stack_pops if i == 0 else meta.num_stack
|
||||
counts = (
|
||||
num_stack,
|
||||
len(meta.locals_names),
|
||||
# len(meta.cell_and_freevars),
|
||||
)
|
||||
self.add_output_instructions([create_dup_top()])
|
||||
# values, values
|
||||
for j, cnt in enumerate(counts):
|
||||
end_idx += cnt
|
||||
if start_idx == end_idx:
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=0),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# [], values
|
||||
else:
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_dup_top(),
|
||||
*create_binary_slice(start_idx, end_idx),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# values[x:y], values
|
||||
# add root frame's unmodified locals here
|
||||
if i == len(all_stack_locals_metas) - 1 and j == 1:
|
||||
root_cg = PyCodegen(self.root_tx)
|
||||
unmodified_locals_names: dict[str, int] = {}
|
||||
for k, v in self.root_tx.symbolic_locals.items():
|
||||
if (
|
||||
isinstance(v.source, LocalSource)
|
||||
and v.source.local_name == k
|
||||
):
|
||||
root_cg.append_output(root_cg.create_load(k))
|
||||
unmodified_locals_names[k] = len(meta.locals_names) + len(
|
||||
unmodified_locals_names
|
||||
)
|
||||
self.add_output_instructions(
|
||||
root_cg.get_instructions()
|
||||
+ [
|
||||
create_instruction(
|
||||
"BUILD_LIST", arg=len(unmodified_locals_names)
|
||||
),
|
||||
# arg=2 because we already swapped the locals list back
|
||||
create_instruction("LIST_EXTEND", arg=2),
|
||||
]
|
||||
)
|
||||
meta.locals_names.update(unmodified_locals_names)
|
||||
start_idx += cnt
|
||||
|
||||
# pack stack, locals, cells together
|
||||
# values, stack, locals, cells, values
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_TUPLE", arg=2),
|
||||
*create_swap(2),
|
||||
]
|
||||
)
|
||||
# (stack, locals, cells), values
|
||||
|
||||
# current state of the stack:
|
||||
# *(top stack_pops values),
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# stack_values_flat
|
||||
#
|
||||
|
||||
self.add_output_instructions(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(all_stack_locals_metas)),
|
||||
*create_rot_n(stack_pops + 1),
|
||||
]
|
||||
)
|
||||
|
||||
# final state of the stack before running the unsupported bytecode:
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ], *(top stack_pops values of frame N)
|
||||
|
||||
if graph_output_var and stored_graph_output_var:
|
||||
self.add_output_instructions(
|
||||
[local_restore_cg.create_delete(graph_output_var)]
|
||||
[create_instruction("DELETE_FAST", argval=graph_output_var)]
|
||||
)
|
||||
|
||||
if self.export:
|
||||
|
@ -340,7 +340,8 @@ class ContinueExecutionCache:
|
||||
) -> None:
|
||||
meta.instructions = copy.deepcopy(instructions)
|
||||
|
||||
args = [f"___stack{i}" for i in range(nstack)]
|
||||
args = ["__nested_frame_values"]
|
||||
args += [f"___stack{i}" for i in range(nstack)]
|
||||
args.extend(v for v in argnames if v not in args)
|
||||
freevars = tuple(code_options["co_cellvars"] or []) + tuple(
|
||||
code_options["co_freevars"] or []
|
||||
|
@ -73,8 +73,10 @@ from .bytecode_analysis import (
|
||||
from .bytecode_transformation import (
|
||||
cleaned_instructions,
|
||||
create_call_function,
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_reverse,
|
||||
create_swap,
|
||||
get_code_keys,
|
||||
Instruction,
|
||||
@ -668,12 +670,14 @@ def generic_jump(
|
||||
self.pop()
|
||||
|
||||
if_next = self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata
|
||||
self.next_instruction, 0, all_stack_locals_metadata
|
||||
)
|
||||
if push:
|
||||
self.push(value)
|
||||
assert inst.target is not None
|
||||
if_jump = self.create_call_resume_at(inst.target, all_stack_locals_metadata)
|
||||
if_jump = self.create_call_resume_at(
|
||||
inst.target, int(push), all_stack_locals_metadata
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
# 3.13 requires stack[-1] to be bool type
|
||||
@ -1006,7 +1010,7 @@ def break_graph_if_unsupported(
|
||||
self.push(UnknownVariable())
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata
|
||||
self.next_instruction, push, all_stack_locals_metadata
|
||||
)
|
||||
)
|
||||
|
||||
@ -1404,13 +1408,45 @@ class InstructionTranslatorBase(
|
||||
# where we call step_graph_break right now is when the stack is empty,
|
||||
# so let's enforce that for now.
|
||||
assert not self.stack
|
||||
self.output.compile_subgraph(
|
||||
# NOTE: if we support non-empty self.stack in the future, the `stack_pops` argument
|
||||
# below should be set to the stack length to ensure that the stack is codegen'd
|
||||
# for the rest of the function.
|
||||
all_stack_locals_metadata = self.output.compile_subgraph(
|
||||
self,
|
||||
partial_convert=True,
|
||||
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
|
||||
)
|
||||
# load locals from frame values
|
||||
# current frame state
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ],
|
||||
cg = PyCodegen(self)
|
||||
self.output.add_output_instructions(
|
||||
[create_jump_absolute(continue_inst)] + self.instructions
|
||||
[
|
||||
cg.create_load_const(-1),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
]
|
||||
)
|
||||
for local, idx in all_stack_locals_metadata[-1].locals_names.items():
|
||||
self.output.add_output_instructions(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_store(local),
|
||||
]
|
||||
)
|
||||
self.output.add_output_instructions(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_jump_absolute(continue_inst),
|
||||
*self.instructions,
|
||||
]
|
||||
)
|
||||
|
||||
def run_ctx_mgr(self) -> Any:
|
||||
@ -1510,7 +1546,7 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
|
||||
# for continuation functions
|
||||
if name.startswith("__stack"):
|
||||
if name.startswith("__stack") or name == "__nested_frame_values":
|
||||
self.symbolic_locals.pop(name)
|
||||
|
||||
def LOAD_DEREF(self, inst: Instruction) -> None:
|
||||
@ -2415,7 +2451,9 @@ class InstructionTranslatorBase(
|
||||
self.output.add_output_instructions([copy.copy(inst)])
|
||||
self.popn(2)
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(self.next_instruction, all_stack_locals_metadata)
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, 0, all_stack_locals_metadata
|
||||
)
|
||||
)
|
||||
|
||||
def DELETE_ATTR(self, inst: Instruction) -> None:
|
||||
@ -2427,15 +2465,240 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
|
||||
def create_call_resume_at(
|
||||
self, offset: Instruction, all_stack_locals_metadata: Any
|
||||
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
|
||||
) -> list[Instruction]:
|
||||
raise AssertionError(
|
||||
f"create_call_resume_at not overridden by subclass {type(self)}"
|
||||
self.instruction_pointer = None
|
||||
|
||||
if inst.opname == "RETURN_VALUE":
|
||||
return [create_instruction("RETURN_VALUE")]
|
||||
elif inst.opname == "RETURN_CONST":
|
||||
return [create_instruction("RETURN_CONST", argval=inst.argval)]
|
||||
|
||||
cg = PyCodegen(self)
|
||||
|
||||
# current frame state
|
||||
# [
|
||||
# (frame N stack (minus top stack_pops values), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ], `push` values from running the unsupported instruction
|
||||
|
||||
# move the `push` stack values to the frame N stack
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=push),
|
||||
# frames_list, push_values_list
|
||||
*create_swap(2),
|
||||
create_dup_top(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
# push_values_list, frames_list, frames_list[0][0]
|
||||
*create_swap(3),
|
||||
# frames_list[0][0] += push_values_list
|
||||
create_instruction("LIST_EXTEND", arg=2),
|
||||
*create_swap(2),
|
||||
# frames_list, frames_list[0][0]
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
)
|
||||
|
||||
# current frame state
|
||||
# [
|
||||
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
|
||||
# (frame 1 stack, frame 1 non-cell locals, frame 1 cells),
|
||||
# ],
|
||||
|
||||
#
|
||||
txes = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = self
|
||||
while cur_tx is not None:
|
||||
txes.append(cur_tx)
|
||||
cur_tx = cur_tx.parent
|
||||
assert len(txes) == len(all_stack_locals_metadata)
|
||||
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
for i, meta in enumerate(all_stack_locals_metadata):
|
||||
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
|
||||
# Replace the stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
|
||||
# frames[i][0][j] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(j),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
for name, _ in meta.locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
|
||||
# frames[i][1][meta.locals_names[name]] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(meta.locals_names[name]),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
name = unique_id(f"__resume_at_{inst.offset}")
|
||||
|
||||
assert not config.nested_graph_breaks, "NYI"
|
||||
|
||||
# more locals may have been pruned after the unsupported instruction (e.g. branch)
|
||||
reads = livevars_analysis(self.instructions, inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in self.symbolic_locals.keys()
|
||||
if k in reads and k not in self.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(all_stack_locals_metadata[-1].locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(self.stack) - len(
|
||||
all_stack_locals_metadata[-1].stack_null_idxes
|
||||
)
|
||||
nargs = stack_len + len(argnames)
|
||||
|
||||
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
||||
self.f_code,
|
||||
self.lineno,
|
||||
inst.offset,
|
||||
tuple(b.target.offset for b in self.block_stack),
|
||||
stack_len,
|
||||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in self.block_stack),
|
||||
tuple(all_stack_locals_metadata[-1].stack_ctx_args),
|
||||
tuple(all_stack_locals_metadata[-1].locals_ctx_args),
|
||||
tuple(all_stack_locals_metadata[-1].stack_null_idxes),
|
||||
)
|
||||
|
||||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
self.output.install_global_unsafe(name, new_code)
|
||||
cg.make_function_with_closure(name, new_code, True, 1)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
self.output.install_global_unsafe(
|
||||
name, types.FunctionType(new_code, self.f_globals, name)
|
||||
)
|
||||
cg.extend_output(cg.load_function_name(name, True, 1))
|
||||
package_name = name
|
||||
|
||||
if self.package is not None:
|
||||
self.package.add_resume_function(
|
||||
new_code, self.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
# load top level-frame; final stack state should be:
|
||||
# [
|
||||
# (frame N stack (fixed), frame N non-cell locals, frame N cells),
|
||||
# ...,
|
||||
# (frame 2 stack, frame 2 non-cell locals, frame 2 cells),
|
||||
# ], frame 1 stack + frame 1 non-cell locals
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
create_dup_top(),
|
||||
# frames, frames, frames
|
||||
cg.create_load_const(-1),
|
||||
cg.create_binary_subscr(),
|
||||
# frames, frames, frames[-1]
|
||||
*create_swap(2),
|
||||
# frames, frames[-1], frames
|
||||
cg.create_load_const(-1),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
# del frames[-1]; stack: frames, frames[-1]
|
||||
create_dup_top(),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
# frames, frames[-1], frames[-1][0]
|
||||
*create_swap(2),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
]
|
||||
)
|
||||
|
||||
# frames, frames[-1][0], frames[-1][1]
|
||||
for name in argnames:
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(
|
||||
all_stack_locals_metadata[-1].locals_names[name]
|
||||
),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# frames, frames[-1][0], *(live locals), frames[-1][1]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
create_instruction("LIST_EXTEND", arg=1),
|
||||
# UNPACK_SEQUENCE reverses elements
|
||||
create_instruction("UNPACK_SEQUENCE", arg=nargs),
|
||||
*create_reverse(nargs),
|
||||
]
|
||||
)
|
||||
# frames, *(stack + live locals)
|
||||
|
||||
cg.extend_output(create_call_function(nargs + 1, False))
|
||||
cg.append_output(create_instruction("RETURN_VALUE"))
|
||||
return cg.get_instructions()
|
||||
|
||||
def should_compile_partial_graph(self) -> bool:
|
||||
raise AssertionError(
|
||||
f"should_compile_partial_graph not overridden by subclass {type(self)}"
|
||||
if sys.version_info >= (3, 11):
|
||||
# Do not compile if current instruction's block is not the top with block
|
||||
entry = self.current_instruction.exn_tab_entry
|
||||
if entry and (
|
||||
not self.block_stack or entry.target is not self.block_stack[-1].target
|
||||
):
|
||||
return False
|
||||
return (
|
||||
all(b.can_restore() for b in self.block_stack)
|
||||
and not self.one_graph
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
and not self.active_generic_context_managers
|
||||
)
|
||||
|
||||
@break_graph_if_unsupported(push=0)
|
||||
@ -3612,125 +3875,6 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
return self.f_globals[source.global_name]
|
||||
raise KeyError
|
||||
|
||||
def run(self) -> None:
|
||||
super().run()
|
||||
|
||||
def should_compile_partial_graph(self) -> bool:
|
||||
if sys.version_info >= (3, 11):
|
||||
# Do not compile if current instruction's block is not the top with block
|
||||
entry = self.current_instruction.exn_tab_entry
|
||||
if entry and (
|
||||
not self.block_stack or entry.target is not self.block_stack[-1].target
|
||||
):
|
||||
return False
|
||||
return (
|
||||
all(b.can_restore() for b in self.block_stack)
|
||||
and not self.one_graph
|
||||
and not self.error_on_graph_break
|
||||
and not self.is_tracing_resume_prologue
|
||||
and not self.active_generic_context_managers
|
||||
)
|
||||
|
||||
def create_call_resume_at(
|
||||
self, inst: Instruction, all_stack_locals_metadata: Any
|
||||
) -> list[Instruction]:
|
||||
self.instruction_pointer = None
|
||||
|
||||
if inst.opname == "RETURN_VALUE":
|
||||
return [create_instruction("RETURN_VALUE")]
|
||||
elif inst.opname == "RETURN_CONST":
|
||||
return [create_instruction("RETURN_CONST", argval=inst.argval)]
|
||||
|
||||
reads = livevars_analysis(self.instructions, inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in self.symbolic_locals.keys()
|
||||
if k in reads and k not in self.cell_and_freevars()
|
||||
)
|
||||
# NOTE: do not use isinstance, since it realizes lazy VT's
|
||||
argnames_null_set = set(all_stack_locals_metadata[0].locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(self.stack) - len(all_stack_locals_metadata[0].stack_null_idxes)
|
||||
nargs = stack_len + len(argnames)
|
||||
|
||||
cg = PyCodegen(self)
|
||||
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
for (i, _), i_orig in zip(
|
||||
all_stack_locals_metadata[0].stack_ctx_args,
|
||||
all_stack_locals_metadata[0].stack_ctx_idxes_orig,
|
||||
):
|
||||
# Replace the current stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, self.stack[i_orig])
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(create_swap(stack_len - i + 1))
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
|
||||
for name, _ in all_stack_locals_metadata[0].locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
||||
|
||||
name = unique_id(f"__resume_at_{inst.offset}", with_uuid=True)
|
||||
|
||||
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
||||
self.f_code,
|
||||
self.lineno,
|
||||
inst.offset,
|
||||
tuple(b.target.offset for b in self.block_stack),
|
||||
stack_len,
|
||||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in self.block_stack),
|
||||
tuple(all_stack_locals_metadata[0].stack_ctx_args),
|
||||
tuple(all_stack_locals_metadata[0].locals_ctx_args),
|
||||
tuple(all_stack_locals_metadata[0].stack_null_idxes),
|
||||
)
|
||||
|
||||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
self.output.install_global_unsafe(name, new_code)
|
||||
cg.make_function_with_closure(name, new_code, True, stack_len)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
self.output.install_global_unsafe(
|
||||
name, types.FunctionType(new_code, self.f_globals, name)
|
||||
)
|
||||
cg.extend_output(cg.load_function_name(name, True, stack_len))
|
||||
package_name = name
|
||||
|
||||
if self.package is not None:
|
||||
if self.output.package is not None:
|
||||
self.package.add_resume_function(
|
||||
new_code, self.f_globals["__name__"], function_name=package_name
|
||||
)
|
||||
|
||||
cg.extend_output([cg.create_load(k) for k in argnames])
|
||||
cg.extend_output(create_call_function(nargs, False))
|
||||
cg.append_output(create_instruction("RETURN_VALUE"))
|
||||
return cg.get_instructions()
|
||||
|
||||
def symbolic_locals_contain_module_class(self) -> bool:
|
||||
for v in self.symbolic_locals.values():
|
||||
if isinstance(v, UserDefinedClassVariable) and issubclass(
|
||||
@ -3781,6 +3925,8 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
reason=GraphCompileReason(
|
||||
"return_value", [self.frame_summary()], graph_break=False
|
||||
),
|
||||
# the value to be returned
|
||||
stack_pops=1 if inst.opname == "RETURN_VALUE" else 0,
|
||||
)
|
||||
# check that our stack/locals meta are correct:
|
||||
# we should only be tracing 1 frame, and there should not be any NULLs on the stack
|
||||
@ -3791,6 +3937,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
if inst.opname == "RETURN_VALUE"
|
||||
else create_instruction("RETURN_CONST", argval=inst.argval)
|
||||
)
|
||||
# NOTE: does the stack need to be empty after the return?
|
||||
self.output.add_output_instructions([return_inst])
|
||||
raise ReturnValueOp
|
||||
|
||||
@ -4147,11 +4294,17 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
return TracingContext.current_frame(self.parent.frame_summary())
|
||||
|
||||
def should_compile_partial_graph(self) -> bool:
|
||||
if config.nested_graph_breaks:
|
||||
if not self.parent.should_compile_partial_graph():
|
||||
return False
|
||||
return super().should_compile_partial_graph()
|
||||
return False # inlining functions is all-or-nothing
|
||||
|
||||
def create_call_resume_at(
|
||||
self, inst: Instruction, all_stack_locals_metadata: Any
|
||||
) -> NoReturn:
|
||||
self, inst: Instruction, push: int, all_stack_locals_metadata: Any
|
||||
) -> list[Instruction]:
|
||||
if config.nested_graph_breaks:
|
||||
return super().create_call_resume_at(inst, push, all_stack_locals_metadata)
|
||||
unimplemented_v2(
|
||||
gb_type="Graph break in inlined function",
|
||||
context="",
|
||||
|
Reference in New Issue
Block a user