[dynamo] fix resume_execution.py KeyError in Python 3.11+ (#162318)

Fixes https://github.com/pytorch/pytorch/issues/162313

Differential Revision: [D81938289](https://our.internmc.facebook.com/intern/diff/D81938289)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162318
Approved by: https://github.com/Lucaskabela, https://github.com/mlazos, https://github.com/anijain2305
This commit is contained in:
William Wen
2025-09-08 10:27:32 -07:00
committed by PyTorch MergeBot
parent 8f114650eb
commit 26a1b9cce2
2 changed files with 37 additions and 6 deletions

View File

@ -7168,6 +7168,30 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
fn(torch.ones(3)), torch.compile(fn, backend="eager")(torch.ones(3))
)
def test_311_resume_block_keyerror(self):
# https://github.com/pytorch/pytorch/issues/162313
flag = True
def fn(x):
x = x + 1
torch._dynamo.graph_break()
x = x + 2
if flag:
with torch.no_grad():
torch._dynamo.graph_break()
x = x + 4
else:
with torch.no_grad():
torch._dynamo.graph_break()
x = x + 8
return x + 16
inp = torch.ones(3)
opt_fn = torch.compile(fn, backend="eager")
self.assertEqual(fn(inp), opt_fn(inp))
flag = False
self.assertEqual(fn(inp), opt_fn(inp))
def test_unbind_copy_out(self):
def f(eye, out):
torch.unbind_copy(eye, out=out)

View File

@ -249,8 +249,10 @@ class ResumeFunctionMetadata:
prefix_block_target_offset_remap: list[int] = dataclasses.field(
default_factory=list
)
# map from new block target offsets to original block target offsets
block_target_offset_remap: Optional[dict[int, int]] = None
# per-offset map from new block target offsets to original block target offsets
block_target_offset_remap: dict[int, dict[int, int]] = dataclasses.field(
default_factory=dict
)
def _filter_iter(
@ -588,7 +590,7 @@ class ContinueExecutionCache:
meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[
code
]
new_offset = None
new_offset = -1
def find_new_offset(
instructions: list[Instruction], code_options: dict[str, Any]
@ -602,17 +604,21 @@ class ContinueExecutionCache:
if i1 is target
)
assert target.opcode == new_target.opcode
assert new_target.offset is not None
new_offset = new_target.offset
transform_code_object(code, find_new_offset)
assert new_offset >= 0
if sys.version_info >= (3, 11):
# setup_fn_target_offsets currently contains the target offset of
# each setup_fn, based on `code`. When we codegen the resume function
# based on the original code object, `meta.code`, the offsets in
# setup_fn_target_offsets must be based on `meta.code` instead.
if not meta.block_target_offset_remap:
block_target_offset_remap = meta.block_target_offset_remap = {}
if new_offset not in meta.block_target_offset_remap:
block_target_offset_remap = meta.block_target_offset_remap[
new_offset
] = {}
def remap_block_offsets(
instructions: list[Instruction], code_options: dict[str, Any]
@ -660,7 +666,8 @@ class ContinueExecutionCache:
# if offset is not in setup_fn_target_offsets, it is an error
setup_fn_target_offsets = tuple(
meta.block_target_offset_remap[n] for n in setup_fn_target_offsets
meta.block_target_offset_remap[new_offset][n]
for n in setup_fn_target_offsets
)
return ContinueExecutionCache.lookup(
meta.code, lineno, new_offset, setup_fn_target_offsets, *args