mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo, nested graph breaks] fix nested step graph break related issues (#162737)
Turns out codegen'ing a nested step graph break is significantly more complicated than first thought. The optimized function should actually do: - call graph/load values/do side effects etc. - call into the leaf's resume function, but skipped (this essentially step graph break function for just the leaf function) - call into all the other resume functions, traced. This PR also adds `torch._dynamo.step_unsupported()`, which can be used for internal testing purposes to better test step graph break handling. Pull Request resolved: https://github.com/pytorch/pytorch/pull/162737 Approved by: https://github.com/Lucaskabela ghstack dependencies: #160601
This commit is contained in:
committed by
PyTorch MergeBot
parent
486b4d2414
commit
af4c29fea8
@ -893,6 +893,29 @@ class DecoratorTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(gn(inp), inp + 3)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
def test_step_unsupported(self):
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnts)
|
||||
def fn(x):
|
||||
x = x + 1 + 2
|
||||
torch._dynamo.step_unsupported()
|
||||
return x + 4
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(fn(inp), inp + 7)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
self.assertEqual(cnts.op_count, 2)
|
||||
|
||||
def test_step_unsupported_empty_checkpoint(self):
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
torch._dynamo.step_unsupported()
|
||||
return x + 1
|
||||
|
||||
inp = torch.ones(3)
|
||||
self.assertEqual(fn(inp), inp + 1)
|
||||
|
||||
@skipIfWindows(
|
||||
msg="TODO: (xuhancn), confirm if torch.compiler.disable work on Windows."
|
||||
)
|
||||
|
@ -536,6 +536,7 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
self.assertEqual(cnts.frame_count, 5)
|
||||
# 4 additions from f5+f4, 2 x 4 additions from f2+f1 (i == 5, i != 5)
|
||||
self.assertEqual(cnts.op_count, 12)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 6)
|
||||
|
||||
def test_nested_graph_break_in_try_block(self):
|
||||
# NOTE: this also tests nested step_graph_break
|
||||
@ -576,13 +577,40 @@ class NestedGraphBreakTests(torch._dynamo.test_case.TestCase):
|
||||
x = torch.zeros(3)
|
||||
res = f5(x)
|
||||
ref = opt_fn(x)
|
||||
print(ref, res)
|
||||
self.assertEqual(ref, res)
|
||||
# skip frame due to graph break in try block
|
||||
# 2 frames from f5+f4+(first part of f3), 2 frames from f2+f1
|
||||
self.assertEqual(cnts.frame_count, 4)
|
||||
# 5 additions from f5+f4+(first part of f3), 4 additions from f2+f1
|
||||
self.assertEqual(cnts.op_count, 9)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 4)
|
||||
|
||||
def test_nested_step_unsupported(self):
|
||||
global f1, f2, f3
|
||||
|
||||
def f1(x):
|
||||
return x + 1
|
||||
|
||||
def f2(x):
|
||||
x = x + 2
|
||||
torch._dynamo.step_unsupported()
|
||||
return f1(x) + 4
|
||||
|
||||
def f3(x):
|
||||
x = x + 8
|
||||
return f2(x) + 16
|
||||
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch._dynamo.optimize(backend=cnts)(f3)
|
||||
x = torch.zeros(3)
|
||||
res = f3(x)
|
||||
ref = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
# 1 frame from start of f3 + start of f2, 1 frame from f1, 1 frame from the end of f3
|
||||
self.assertEqual(cnts.frame_count, 3)
|
||||
# all ops except + 4
|
||||
self.assertEqual(cnts.op_count, 4)
|
||||
self.assertEqual(torch._dynamo.utils.counters["frames"]["total"], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -7256,6 +7256,26 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
flag = False
|
||||
self.assertEqual(fn(inp), opt_fn(inp))
|
||||
|
||||
def test_cells_unsupported_step_exception(self):
|
||||
# This error happened because:
|
||||
# - we were generating cells into a list on the stack
|
||||
# - we encountered an unsupported step, resulting in a step graph break
|
||||
# - we encounter an exception, which pops the stack until it reaches a certain length;
|
||||
# the presence of the list of cells then messes things up.
|
||||
|
||||
cell = 0
|
||||
|
||||
@torch.compile(backend="eager")
|
||||
def fn(x):
|
||||
x = x + 1 + 2
|
||||
torch._dynamo.step_unsupported()
|
||||
with contextlib.nullcontext():
|
||||
print(cell)
|
||||
raise AssertionError
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
fn(torch.ones(3))
|
||||
|
||||
def test_unbind_copy_out(self):
|
||||
def f(eye, out):
|
||||
torch.unbind_copy(eye, out=out)
|
||||
|
@ -40,6 +40,7 @@ from .decorators import (
|
||||
run,
|
||||
set_stance,
|
||||
skip_frame,
|
||||
step_unsupported,
|
||||
substitute_in_graph,
|
||||
)
|
||||
from .eval_frame import (
|
||||
@ -102,6 +103,7 @@ __all__ = [
|
||||
"error_on_graph_break",
|
||||
"set_stance",
|
||||
"skip_frame",
|
||||
"step_unsupported",
|
||||
"substitute_in_graph",
|
||||
]
|
||||
|
||||
|
@ -533,6 +533,8 @@ def create_binary_slice(
|
||||
def create_copy(i: int) -> list[Instruction]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return [create_instruction("COPY", arg=i)]
|
||||
if i == 1:
|
||||
return [create_instruction("DUP_TOP")]
|
||||
# COPY 4
|
||||
# 0 1 2 3
|
||||
# 3 1 2 0
|
||||
|
@ -296,6 +296,14 @@ def skip_frame(msg: str = "") -> None:
|
||||
"""Force a skipped frame"""
|
||||
|
||||
|
||||
@_disallow_in_graph_helper(throw_if_not_allowed=False)
|
||||
def step_unsupported(msg: str = "") -> None:
|
||||
"""Force a step unsupported graph break, which results in compiling
|
||||
the traced FX graph so far, then skipping the rest of the frame.
|
||||
In order to get expected behavior, there should be at least 2 ops
|
||||
and a part of the code not contained in any try/with blocks."""
|
||||
|
||||
|
||||
def forbid_in_graph(fn: Any) -> Any:
|
||||
"""
|
||||
Customize which functions TorchDynamo will assert are not present while tracing.
|
||||
|
@ -263,6 +263,11 @@ class RecompileLimitExceeded(Unsupported):
|
||||
pass
|
||||
|
||||
|
||||
# debug exception thrown when tracing torch._dynamo.step_unsupported()
|
||||
class StepUnsupported(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
||||
class UnsafeScriptObjectError(TorchDynamoException):
|
||||
pass
|
||||
|
||||
|
@ -2763,5 +2763,18 @@
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0275": [
|
||||
{
|
||||
"Gb_type": "torch._dynamo.step_unsupported() with empty checkpoint",
|
||||
"Context": "",
|
||||
"Explanation": "traced torch._dynamo.step_unsupported(), but there is no checkpoint to step_graph_break from. This graph break is used for debugging only.",
|
||||
"Hints": [
|
||||
"Remove the torch._dynamo.step_unsupported() call.",
|
||||
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some ",
|
||||
"line of code that is not in a try/with block, and has an empty Python stack.",
|
||||
"This is likely to be a Dynamo bug. Please report an issue to PyTorch."
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
@ -318,6 +318,7 @@ class ContinueExecutionCache:
|
||||
argnames: tuple[str, ...],
|
||||
argnames_null: tuple[str, ...],
|
||||
setup_fns: tuple[ReenterWith, ...],
|
||||
handle_inactive_ctx: bool,
|
||||
stack_ctx_vars: tuple[tuple[int, tuple[Any, ...]], ...],
|
||||
argnames_ctx_vars: tuple[tuple[str, tuple[Any, ...]], ...],
|
||||
null_idxes: tuple[int, ...],
|
||||
@ -341,6 +342,7 @@ class ContinueExecutionCache:
|
||||
argnames,
|
||||
argnames_null,
|
||||
setup_fns,
|
||||
handle_inactive_ctx,
|
||||
stack_ctx_vars,
|
||||
argnames_ctx_vars,
|
||||
null_idxes,
|
||||
@ -432,7 +434,7 @@ class ContinueExecutionCache:
|
||||
prefix.append(
|
||||
create_instruction("LOAD_FAST", argval=f"___stack{stack_i}")
|
||||
)
|
||||
if stack_i in stack_ctx_vars_d:
|
||||
if handle_inactive_ctx and stack_i in stack_ctx_vars_d:
|
||||
# NOTE: we assume that current stack var is a context manager CLASS!
|
||||
# Load args for context variable and construct it
|
||||
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[stack_i]))
|
||||
@ -459,10 +461,11 @@ class ContinueExecutionCache:
|
||||
|
||||
# NOTE: we assume that local var is a context manager CLASS!
|
||||
# initialize inactive context vars in argnames
|
||||
for name, vals in argnames_ctx_vars:
|
||||
prefix.append(create_instruction("LOAD_FAST", argval=name))
|
||||
prefix.extend(_load_tuple_and_call(vals))
|
||||
prefix.append(create_instruction("STORE_FAST", argval=name))
|
||||
if handle_inactive_ctx:
|
||||
for name, vals in argnames_ctx_vars:
|
||||
prefix.append(create_instruction("LOAD_FAST", argval=name))
|
||||
prefix.extend(_load_tuple_and_call(vals))
|
||||
prefix.append(create_instruction("STORE_FAST", argval=name))
|
||||
|
||||
# 3.12+: store NULL into variables that were NULL
|
||||
if argnames_null:
|
||||
|
@ -79,6 +79,7 @@ from .bytecode_transformation import (
|
||||
create_dup_top,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_load_const,
|
||||
create_rot_n,
|
||||
create_swap,
|
||||
get_code_keys,
|
||||
@ -96,6 +97,7 @@ from .exc import (
|
||||
format_graph_break_message,
|
||||
get_stack_above_dynamo,
|
||||
ResumePrologueTracingError,
|
||||
StepUnsupported,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
@ -669,14 +671,20 @@ def generic_jump(
|
||||
)
|
||||
self.pop()
|
||||
|
||||
if_next = self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata, False
|
||||
if_next = self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
) + self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
if push:
|
||||
self.push(value)
|
||||
assert inst.target is not None
|
||||
if_jump = self.create_call_resume_at(
|
||||
inst.target, all_stack_locals_metadata, False
|
||||
if_jump = self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], inst.target
|
||||
) + self.create_call_resume_at(
|
||||
inst.target,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@ -960,7 +968,7 @@ def break_graph_if_unsupported(
|
||||
all_stack_locals_metadata = self.output.compile_subgraph(
|
||||
self, reason=reason, stack_pops=push - stack_effect
|
||||
)
|
||||
cg = PyCodegen(self)
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
cleanup: list[Instruction] = []
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
@ -1009,8 +1017,12 @@ def break_graph_if_unsupported(
|
||||
for _ in range(push):
|
||||
self.push(UnknownVariable())
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata, False
|
||||
self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
)
|
||||
+ self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -1351,9 +1363,22 @@ class InstructionTranslatorBase(
|
||||
return True
|
||||
except (ReturnValueOp, YieldValueOp):
|
||||
return False
|
||||
except Unsupported:
|
||||
except (Unsupported, StepUnsupported) as e:
|
||||
if self.current_speculation is None:
|
||||
log.debug("empty checkpoint")
|
||||
if isinstance(e, StepUnsupported):
|
||||
unimplemented_v2(
|
||||
gb_type="torch._dynamo.step_unsupported() with empty checkpoint",
|
||||
context="",
|
||||
explanation="traced torch._dynamo.step_unsupported(), but there is no checkpoint "
|
||||
"to step_graph_break from. This graph break is used for debugging only.",
|
||||
hints=[
|
||||
"Remove the torch._dynamo.step_unsupported() call.",
|
||||
"Include at least one checkpoint: (1) include at least 2 ops and (2) make sure there is some "
|
||||
"line of code that is not in a try/with block, and has an empty Python stack.",
|
||||
*graph_break_hints.DYNAMO_BUG,
|
||||
],
|
||||
)
|
||||
raise
|
||||
log.debug("step triggered compile", exc_info=True)
|
||||
|
||||
@ -1427,12 +1452,98 @@ class InstructionTranslatorBase(
|
||||
partial_convert=True,
|
||||
reason=GraphCompileReason("step_unsupported", [self.frame_summary()]),
|
||||
)
|
||||
# current frame state
|
||||
# cells,
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
if self.parent:
|
||||
from .eval_frame import skip_code
|
||||
|
||||
# nested graph break
|
||||
assert config.nested_graph_breaks
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
|
||||
# codegen cells and frame values only for frame N
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
]
|
||||
)
|
||||
# No need to fix stack, since stack is assumed to be empty here.
|
||||
# Do NOT handle_inactive_ctx because we will be skipping this resume code.
|
||||
leaf_resume_code, leaf_resume_name = self.create_resume(
|
||||
0, continue_inst, all_stack_locals_metadata[0], [], cg, True, False
|
||||
)
|
||||
skip_code(leaf_resume_code)
|
||||
|
||||
# current frame state
|
||||
# cells,
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ], [frame N cells], [frame N locals],
|
||||
self.codegen_call_resume([leaf_resume_code], [leaf_resume_name], cg)
|
||||
|
||||
# current frame state
|
||||
# cells,
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ], leaf_resume result
|
||||
|
||||
# add the leaf_resume result to frame N-1 stack
|
||||
num_stack = all_stack_locals_metadata[1].num_stack
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=1),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(1),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(num_stack, num_stack, True),
|
||||
]
|
||||
)
|
||||
|
||||
# pop frame N cells and locals
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_copy(1),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
*create_copy(2),
|
||||
cg.create_load_const(0),
|
||||
create_instruction("DELETE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# call the remaining resume functions
|
||||
# current frame state
|
||||
# [frame N-1 cells, ..., frame 1 cells],
|
||||
# [
|
||||
# frame N-1 stack (including leaf_resume result) + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
self.parent.push(UnknownVariable())
|
||||
all_stack_locals_metadata[1].num_stack += 1
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(
|
||||
continue_inst, all_stack_locals_metadata, True
|
||||
cg.get_instructions()
|
||||
+ self.parent.create_call_resume_at(
|
||||
self.parent.next_instruction, all_stack_locals_metadata[1:]
|
||||
)
|
||||
)
|
||||
else:
|
||||
@ -1444,14 +1555,7 @@ class InstructionTranslatorBase(
|
||||
]
|
||||
)
|
||||
# load locals from frame values
|
||||
# current frame state
|
||||
# [
|
||||
# frame N locals,
|
||||
# frame N-1 stack + locals,
|
||||
# ...,
|
||||
# frame 1 stack + locals,
|
||||
# ],
|
||||
cg = PyCodegen(self)
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
self.output.add_output_instructions(
|
||||
[
|
||||
cg.create_load_const(-1),
|
||||
@ -2516,8 +2620,12 @@ class InstructionTranslatorBase(
|
||||
self.output.add_output_instructions([copy.copy(inst)])
|
||||
self.popn(2)
|
||||
self.output.add_output_instructions(
|
||||
self.create_call_resume_at(
|
||||
self.next_instruction, all_stack_locals_metadata, False
|
||||
self.codegen_fix_leaf_stack(
|
||||
all_stack_locals_metadata[0], self.next_instruction
|
||||
)
|
||||
+ self.create_call_resume_at(
|
||||
self.next_instruction,
|
||||
all_stack_locals_metadata,
|
||||
)
|
||||
)
|
||||
|
||||
@ -2530,15 +2638,26 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def codegen_return_after_compile_subgraph(
|
||||
inst: Instruction, meta: StackLocalsMetadata
|
||||
def codegen_return_with_pops(
|
||||
inst: Instruction, num_stack: int
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Debug CPython expects the stack to be empty after the return.
|
||||
Calling compile_subgraph will push cells and frame values to TOS.
|
||||
This function will pop those 2 values from the stack before actually returning.
|
||||
|
||||
Expects the stack to be:
|
||||
cells, frame values, current frame stack (0 or 1 values)
|
||||
|
||||
Pops cells and frame values, leaving the current frame stack as TOS.
|
||||
A return instruction is included.
|
||||
"""
|
||||
insts = []
|
||||
# NOTE: Debug CPython expects the stack to be empty after the return.
|
||||
# Expect the current stack to be in the state
|
||||
# cells, frame values, current frame stack (0 or 1 values)
|
||||
assert meta.num_stack <= 1
|
||||
if meta.num_stack == 1:
|
||||
assert num_stack <= 1
|
||||
if num_stack == 1:
|
||||
insts.extend(create_swap(3))
|
||||
return_inst = (
|
||||
create_instruction("RETURN_VALUE")
|
||||
@ -2550,31 +2669,261 @@ class InstructionTranslatorBase(
|
||||
)
|
||||
return insts
|
||||
|
||||
def create_call_resume_at(
|
||||
self,
|
||||
inst: Instruction,
|
||||
all_stack_locals_metadata: Any,
|
||||
disable_current_frame_resume: bool,
|
||||
def codegen_fix_leaf_stack(
|
||||
self, meta: StackLocalsMetadata, resume_inst: Instruction
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Codegen resume function(s) and call it.
|
||||
Assumes that the unsupported instruction has already been run.
|
||||
Fixes the stack values of the current/leaf frame (self).
|
||||
|
||||
Expects the stack to be in the state:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
Expects the TOS to be:
|
||||
[
|
||||
frame N locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
], frame N stack (post-instruction)
|
||||
], *(frame N stack (post-unsupported instruction))
|
||||
|
||||
Rearranges the TOS to become:
|
||||
[
|
||||
frame N stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Args:
|
||||
- meta: metadata for the leaf frame returned from OutputGraph.compile_subgraph
|
||||
- resume_inst: if the resume instruction is a return instruction, then don't return any instructions
|
||||
"""
|
||||
if resume_inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
||||
return []
|
||||
# move frame N stack to the frame values list
|
||||
current_num_stack = len(self.stack) - len(meta.stack_null_idxes)
|
||||
meta.num_stack = current_num_stack
|
||||
return [
|
||||
create_instruction("BUILD_LIST", arg=current_num_stack),
|
||||
*create_copy(2),
|
||||
# frame_values, frame N stack, frame_values
|
||||
create_load_const(0),
|
||||
create_instruction("BINARY_SUBSCR"),
|
||||
*create_binary_slice(0, 0, True),
|
||||
# frame_values[0][0:0] = frame N stack
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
|
||||
def create_resume(
|
||||
self,
|
||||
idx: int,
|
||||
resume_inst: Instruction,
|
||||
meta: StackLocalsMetadata,
|
||||
resume_codes: list[types.CodeType],
|
||||
cg: PyCodegen,
|
||||
is_leaf: bool,
|
||||
handle_inactive_ctx: bool,
|
||||
) -> tuple[types.CodeType, str]:
|
||||
"""
|
||||
Creates the resume function for the frame corresponding to `self`.
|
||||
|
||||
Expects the TOS to be:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
[
|
||||
frame N stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Some additional codegen may happen to prepare the frame stack + locals values for the generated resume function:
|
||||
- inactive context variables in the stack and locals will be replaced by their types
|
||||
- if the frame is a leaf frame, prune dead locals
|
||||
|
||||
Regardless of codegen, the stack will be left in the same state as before.
|
||||
|
||||
Args:
|
||||
- idx: depth of this frame: 0 corresponds to the leaf frame (frame N), N-1 to the root frame (frame 1).
|
||||
- resume_inst: the instruction that this frame should resume at
|
||||
- meta: metadata for this frame returned from OutputGraph.compile_subgraph
|
||||
- resume_codes: nested resume code objects generated from previous create_resume calls.
|
||||
- cg: codegen object to output to
|
||||
- is_leaf: True if `self` corresponds to the leaf frame.
|
||||
- handle_inactive_ctx: If True, handles inactive context variables as described above. This is necessary
|
||||
iff the resume function is traced
|
||||
"""
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
if handle_inactive_ctx:
|
||||
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
|
||||
# Replace the stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, self.stack[j_orig])
|
||||
# frames[idx][j] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(j),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
for name, _ in meta.locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, self.symbolic_locals[name])
|
||||
# frames[idx][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# If the resume instruction is a jump absolute, then resume
|
||||
# at the target instead. This handles the case where we
|
||||
# graph break again in a nested function before jump-resuming
|
||||
# this frame.
|
||||
if is_jump_absolute(resume_inst):
|
||||
assert resume_inst.target
|
||||
resume_inst = resume_inst.target
|
||||
|
||||
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
|
||||
|
||||
# More locals may have been pruned in the current/leaf frame
|
||||
# after the unsupported instruction (e.g. branch).
|
||||
# There should not be any pruning in the other frames since
|
||||
# the current instruction there should be a CALL.
|
||||
if is_leaf:
|
||||
reads = livevars_analysis(self.instructions, resume_inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in self.symbolic_locals.keys()
|
||||
if k in reads and k not in self.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(meta.locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
|
||||
# codegen filter for current frame's locals
|
||||
# current stack state: frames
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(idx),
|
||||
cg.create_binary_subscr(),
|
||||
create_dup_top(),
|
||||
]
|
||||
)
|
||||
for arg in argnames:
|
||||
# current stack state: frames, frames[i], *(prev locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[arg]),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
*create_swap(2),
|
||||
# frames, frames i live locals, frames[i]
|
||||
*create_binary_slice(meta.num_stack, None, True),
|
||||
# frames[i][num_stack:] = frame i live locals
|
||||
]
|
||||
)
|
||||
# current stack state: frames
|
||||
else:
|
||||
argnames = tuple(meta.locals_names.keys())
|
||||
argnames_null = tuple(meta.locals_null_keys)
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(self.stack) - len(meta.stack_null_idxes)
|
||||
|
||||
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
||||
self.f_code,
|
||||
self.lineno,
|
||||
resume_inst.offset,
|
||||
tuple(b.target.offset for b in self.block_stack),
|
||||
stack_len,
|
||||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in self.block_stack),
|
||||
handle_inactive_ctx,
|
||||
tuple(meta.stack_ctx_args),
|
||||
tuple(meta.locals_ctx_args),
|
||||
tuple(meta.stack_null_idxes),
|
||||
tuple(resume_codes),
|
||||
)
|
||||
|
||||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(self.f_code).get(
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
# add resume function to the global scope
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
self.output.install_global_unsafe(resume_name, new_code)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
self.output.install_global_unsafe(
|
||||
resume_name,
|
||||
types.FunctionType(new_code, self.f_globals, resume_name),
|
||||
)
|
||||
package_name = resume_name
|
||||
|
||||
if self.package is not None:
|
||||
self.package.add_resume_function(
|
||||
new_code, self.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
return new_code, resume_name
|
||||
|
||||
def create_call_resume_at(
|
||||
self,
|
||||
inst: Instruction,
|
||||
all_stack_locals_metadata: list[StackLocalsMetadata],
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Codegen all resume function(s) from the frame stack starting at `self` and call them.
|
||||
Assumes that the unsupported instruction has already been run.
|
||||
|
||||
Expects the stack to be in the state:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
[
|
||||
frame N stack + locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
Pops the cells and frame values list from the stack.
|
||||
Also includes a return instruction (stack expected to be empty after return).
|
||||
|
||||
Args:
|
||||
- inst: the instruction of the current (deepest) frame to resume at
|
||||
- all_stack_locals_metadata: metadata returned from OutputGraph.compile_subgraph - contains
|
||||
metadata such as local names, NULL positions, stack length, etc.
|
||||
- disable_current_frame_resume: If True, disable tracing on the current frame's resume function.
|
||||
Used for implementing nested step_graph_break.
|
||||
"""
|
||||
|
||||
self.instruction_pointer = None
|
||||
@ -2585,212 +2934,63 @@ class InstructionTranslatorBase(
|
||||
all_stack_locals_metadata[0].num_stack = current_num_stack
|
||||
|
||||
if inst.opname in ("RETURN_VALUE", "RETURN_CONST"):
|
||||
return self.codegen_return_after_compile_subgraph(
|
||||
inst, all_stack_locals_metadata[0]
|
||||
return self.codegen_return_with_pops(
|
||||
inst, all_stack_locals_metadata[0].num_stack
|
||||
)
|
||||
|
||||
cg = PyCodegen(self.output.root_tx)
|
||||
|
||||
# move frame N stack to the frame values list
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=current_num_stack),
|
||||
*create_copy(2),
|
||||
# frame_values, frame N stack, frame_values
|
||||
cg.create_load_const(0),
|
||||
cg.create_binary_subscr(),
|
||||
*create_binary_slice(0, 0, True),
|
||||
# frame_values[0][0:0] = frame N stack
|
||||
# frame_values left on top of stack
|
||||
]
|
||||
)
|
||||
|
||||
# current frame state
|
||||
# all cells
|
||||
# [
|
||||
# [frame N stack (fixed) + locals]
|
||||
# ...,
|
||||
# [frame 1 stack + locals]
|
||||
# ],
|
||||
|
||||
#
|
||||
txes = []
|
||||
cur_tx: Optional[InstructionTranslatorBase] = self
|
||||
while cur_tx is not None:
|
||||
txes.append(cur_tx)
|
||||
cur_tx = cur_tx.parent
|
||||
assert len(txes) == len(all_stack_locals_metadata)
|
||||
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
# NOTE: if the unsupported instruction modifies the inactive context variable, it may
|
||||
# result in silent incorrectness!
|
||||
for i, meta in enumerate(all_stack_locals_metadata):
|
||||
if i == 0 and disable_current_frame_resume:
|
||||
continue
|
||||
|
||||
for (j, _), j_orig in zip(meta.stack_ctx_args, meta.stack_ctx_idxes_orig):
|
||||
# Replace the stack var with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].stack[j_orig])
|
||||
# frames[i][j] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(j),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
for name, _ in meta.locals_ctx_args:
|
||||
# Replace the local with the context class
|
||||
ctx = cast(ContextWrappingVariable, txes[i].symbolic_locals[name])
|
||||
# frames[i][meta.num_stack +meta.locals_names[name]] = reconstructed_ctx
|
||||
cg.append_output(create_dup_top())
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_swap(2),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
cg.create_load_const(meta.num_stack + meta.locals_names[name]),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
]
|
||||
)
|
||||
|
||||
# build the resume function for each frame
|
||||
resume_names = []
|
||||
idx = 0
|
||||
resume_codes: list[types.CodeType] = []
|
||||
for i, meta in enumerate(all_stack_locals_metadata):
|
||||
cur_tx = txes[i]
|
||||
resume_names = []
|
||||
while cur_tx is not None:
|
||||
if cur_tx is self:
|
||||
resume_inst = inst
|
||||
else:
|
||||
resume_inst = cur_tx.next_instruction
|
||||
# If the resume instruction is a jump absolute, then resume
|
||||
# at the target instead. This handles the case where we
|
||||
# graph break again in a nested function before jump-resuming
|
||||
# this frame.
|
||||
if is_jump_absolute(resume_inst):
|
||||
assert resume_inst.target
|
||||
resume_inst = resume_inst.target
|
||||
resume_name = unique_id(f"__resume_at_{resume_inst.offset}")
|
||||
resume_code, resume_name = cur_tx.create_resume(
|
||||
idx,
|
||||
resume_inst,
|
||||
all_stack_locals_metadata[idx],
|
||||
resume_codes,
|
||||
cg,
|
||||
cur_tx is self,
|
||||
True,
|
||||
)
|
||||
resume_codes.append(resume_code)
|
||||
resume_names.append(resume_name)
|
||||
|
||||
# More locals may have been pruned in the current frame
|
||||
# after the unsupported instruction (e.g. branch).
|
||||
# There should not be any pruning in the other frames since
|
||||
# the current instruction is a CALL.
|
||||
if cur_tx is self:
|
||||
reads = livevars_analysis(cur_tx.instructions, resume_inst)
|
||||
all_argnames = tuple(
|
||||
k
|
||||
for k in cur_tx.symbolic_locals.keys()
|
||||
if k in reads and k not in cur_tx.cell_and_freevars()
|
||||
)
|
||||
argnames_null_set = set(meta.locals_null_keys)
|
||||
argnames = tuple(k for k in all_argnames if k not in argnames_null_set)
|
||||
argnames_null = tuple(k for k in all_argnames if k in argnames_null_set)
|
||||
cur_tx = cur_tx.parent
|
||||
idx += 1
|
||||
|
||||
# codegen filter for current frame's locals
|
||||
# current stack state: frames
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(i),
|
||||
cg.create_binary_subscr(),
|
||||
create_dup_top(),
|
||||
]
|
||||
)
|
||||
for arg in argnames:
|
||||
# current stack state: frames, frames[i], *(prev locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_dup_top(),
|
||||
cg.create_load_const(
|
||||
meta.num_stack + meta.locals_names[arg]
|
||||
),
|
||||
cg.create_binary_subscr(),
|
||||
*create_swap(2),
|
||||
],
|
||||
)
|
||||
# current stack state: frames, frames[i], *(frame i live locals), frames[i]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("BUILD_LIST", arg=len(argnames)),
|
||||
*create_swap(2),
|
||||
# frames, frames i live locals, frames[i]
|
||||
*create_binary_slice(meta.num_stack, None, True),
|
||||
# frames[i][num_stack:] = frame i live locals
|
||||
]
|
||||
)
|
||||
# current stack state: frames
|
||||
else:
|
||||
argnames = tuple(meta.locals_names.keys())
|
||||
argnames_null = tuple(meta.locals_null_keys)
|
||||
self.codegen_call_resume(resume_codes, resume_names, cg)
|
||||
return cg.get_instructions() + [create_instruction("RETURN_VALUE")]
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
@staticmethod
|
||||
def codegen_call_resume(
|
||||
resume_codes: list[types.CodeType], resume_names: list[str], cg: PyCodegen
|
||||
) -> None:
|
||||
"""
|
||||
Calls the provided resume functions.
|
||||
|
||||
# compile_subgraph did not codegen any NULLs,
|
||||
# so we should not count NullVariables
|
||||
stack_len = len(cur_tx.stack) - len(meta.stack_null_idxes)
|
||||
Expects the TOS to be in the state:
|
||||
[frame N cells, ..., frame 1 cells],
|
||||
[
|
||||
frame N stack + locals,
|
||||
frame N-1 stack + locals,
|
||||
...,
|
||||
frame 1 stack + locals
|
||||
]
|
||||
|
||||
new_code: types.CodeType = ContinueExecutionCache.lookup(
|
||||
cur_tx.f_code,
|
||||
cur_tx.lineno,
|
||||
resume_inst.offset,
|
||||
tuple(b.target.offset for b in cur_tx.block_stack),
|
||||
stack_len,
|
||||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in cur_tx.block_stack),
|
||||
tuple(meta.stack_ctx_args),
|
||||
tuple(meta.locals_ctx_args),
|
||||
tuple(meta.stack_null_idxes),
|
||||
tuple(resume_codes),
|
||||
)
|
||||
resume_codes.append(new_code)
|
||||
Pops the cells and frame values, leaving the result of calling the resume functions on TOS.
|
||||
|
||||
# Add original GraphModule context to the resume function to handle
|
||||
# the case of a graph break while tracing a GraphModule
|
||||
orig_graphmodule_maybe = code_context.get_context(cur_tx.f_code).get(
|
||||
"orig_graphmodule", lambda: None
|
||||
)()
|
||||
if orig_graphmodule_maybe is not None:
|
||||
code_context.get_context(new_code)["orig_graphmodule"] = weakref.ref(
|
||||
orig_graphmodule_maybe
|
||||
)
|
||||
|
||||
# add resume function to the global scope
|
||||
if new_code.co_freevars:
|
||||
# expose code object for debugging purposes
|
||||
cur_tx.output.install_global_unsafe(resume_name, new_code)
|
||||
package_name = None
|
||||
else:
|
||||
# This is safe: we pre-generate a unique name
|
||||
cur_tx.output.install_global_unsafe(
|
||||
resume_name,
|
||||
types.FunctionType(new_code, cur_tx.f_globals, resume_name),
|
||||
)
|
||||
package_name = resume_name
|
||||
|
||||
if cur_tx.package is not None:
|
||||
cur_tx.package.add_resume_function(
|
||||
new_code, cur_tx.f_globals["__name__"], package_name
|
||||
)
|
||||
|
||||
if disable_current_frame_resume:
|
||||
from .eval_frame import skip_code
|
||||
|
||||
skip_code(resume_codes[0])
|
||||
|
||||
# load cells as we load resume functions
|
||||
Args:
|
||||
- resume_codes: list of resume function code objects to call
|
||||
- resume_names: list of the corresponding names of the resume functions
|
||||
- cg: PyCodegen object to output instructions to
|
||||
"""
|
||||
# NOTE: We will load cells as we load resume functions
|
||||
|
||||
# load resume functions except the root's
|
||||
cg.extend_output(create_copy(2))
|
||||
@ -2884,10 +3084,8 @@ class InstructionTranslatorBase(
|
||||
cg.extend_output(
|
||||
[
|
||||
*create_call_function_ex(False, True),
|
||||
create_instruction("RETURN_VALUE"),
|
||||
]
|
||||
)
|
||||
return cg.get_instructions()
|
||||
|
||||
def should_compile_partial_graph(self) -> bool:
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -3530,7 +3728,7 @@ class InstructionTranslatorBase(
|
||||
self.active_generic_context_managers.append(ctx)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
# See create_call_resume_at for block stack details.
|
||||
# See update_block_stack/create_resume for block stack details.
|
||||
# Only push a block if the current instruction's block is a
|
||||
# with block that is not nested in a try block - that is, the current
|
||||
# instruction's block target is the same as the top block's target.
|
||||
@ -4230,9 +4428,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||
assert len(all_stack_locals_metadata) == 1
|
||||
assert not all_stack_locals_metadata[0].stack_null_idxes
|
||||
self.output.add_output_instructions(
|
||||
self.codegen_return_after_compile_subgraph(
|
||||
inst, all_stack_locals_metadata[0]
|
||||
)
|
||||
self.codegen_return_with_pops(inst, all_stack_locals_metadata[0].num_stack)
|
||||
)
|
||||
raise ReturnValueOp
|
||||
|
||||
@ -4617,13 +4813,10 @@ class InliningInstructionTranslator(InstructionTranslatorBase):
|
||||
def create_call_resume_at(
|
||||
self,
|
||||
inst: Instruction,
|
||||
all_stack_locals_metadata: Any,
|
||||
disable_current_frame_resume: bool,
|
||||
all_stack_locals_metadata: list[StackLocalsMetadata],
|
||||
) -> list[Instruction]:
|
||||
if config.nested_graph_breaks:
|
||||
return super().create_call_resume_at(
|
||||
inst, all_stack_locals_metadata, disable_current_frame_resume
|
||||
)
|
||||
return super().create_call_resume_at(inst, all_stack_locals_metadata)
|
||||
unimplemented_v2(
|
||||
gb_type="Graph break in inlined function",
|
||||
context="",
|
||||
|
@ -52,6 +52,7 @@ from ..exc import (
|
||||
ObservedUserStopIteration,
|
||||
raise_observed_exception,
|
||||
SkipFrame,
|
||||
StepUnsupported,
|
||||
unimplemented_v2,
|
||||
Unsupported,
|
||||
)
|
||||
@ -1527,6 +1528,8 @@ class SkipFunctionVariable(VariableTracker):
|
||||
raise SkipFrame(
|
||||
f"Skip frame due to `torch._dynamo.skip_frame()`. Message: {skip_frame_msg}"
|
||||
)
|
||||
elif self.value is torch._dynamo.step_unsupported:
|
||||
raise StepUnsupported
|
||||
else:
|
||||
if config.dont_skip_tracing:
|
||||
from .builder import SourcelessBuilder
|
||||
|
Reference in New Issue
Block a user