mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] fix resume_execution.py KeyError in Python 3.11+ (#162318)
Fixes https://github.com/pytorch/pytorch/issues/162313 Differential Revision: [D81938289](https://our.internmc.facebook.com/intern/diff/D81938289) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162318 Approved by: https://github.com/Lucaskabela, https://github.com/mlazos, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
8f114650eb
commit
26a1b9cce2
@ -7168,6 +7168,30 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
fn(torch.ones(3)), torch.compile(fn, backend="eager")(torch.ones(3))
|
||||
)
|
||||
|
||||
def test_311_resume_block_keyerror(self):
|
||||
# https://github.com/pytorch/pytorch/issues/162313
|
||||
flag = True
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
torch._dynamo.graph_break()
|
||||
x = x + 2
|
||||
if flag:
|
||||
with torch.no_grad():
|
||||
torch._dynamo.graph_break()
|
||||
x = x + 4
|
||||
else:
|
||||
with torch.no_grad():
|
||||
torch._dynamo.graph_break()
|
||||
x = x + 8
|
||||
return x + 16
|
||||
|
||||
inp = torch.ones(3)
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
self.assertEqual(fn(inp), opt_fn(inp))
|
||||
flag = False
|
||||
self.assertEqual(fn(inp), opt_fn(inp))
|
||||
|
||||
def test_unbind_copy_out(self):
|
||||
def f(eye, out):
|
||||
torch.unbind_copy(eye, out=out)
|
||||
|
@ -249,8 +249,10 @@ class ResumeFunctionMetadata:
|
||||
prefix_block_target_offset_remap: list[int] = dataclasses.field(
|
||||
default_factory=list
|
||||
)
|
||||
# map from new block target offsets to original block target offsets
|
||||
block_target_offset_remap: Optional[dict[int, int]] = None
|
||||
# per-offset map from new block target offsets to original block target offsets
|
||||
block_target_offset_remap: dict[int, dict[int, int]] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
|
||||
|
||||
def _filter_iter(
|
||||
@ -588,7 +590,7 @@ class ContinueExecutionCache:
|
||||
meta: ResumeFunctionMetadata = ContinueExecutionCache.generated_code_metadata[
|
||||
code
|
||||
]
|
||||
new_offset = None
|
||||
new_offset = -1
|
||||
|
||||
def find_new_offset(
|
||||
instructions: list[Instruction], code_options: dict[str, Any]
|
||||
@ -602,17 +604,21 @@ class ContinueExecutionCache:
|
||||
if i1 is target
|
||||
)
|
||||
assert target.opcode == new_target.opcode
|
||||
assert new_target.offset is not None
|
||||
new_offset = new_target.offset
|
||||
|
||||
transform_code_object(code, find_new_offset)
|
||||
assert new_offset >= 0
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
# setup_fn_target_offsets currently contains the target offset of
|
||||
# each setup_fn, based on `code`. When we codegen the resume function
|
||||
# based on the original code object, `meta.code`, the offsets in
|
||||
# setup_fn_target_offsets must be based on `meta.code` instead.
|
||||
if not meta.block_target_offset_remap:
|
||||
block_target_offset_remap = meta.block_target_offset_remap = {}
|
||||
if new_offset not in meta.block_target_offset_remap:
|
||||
block_target_offset_remap = meta.block_target_offset_remap[
|
||||
new_offset
|
||||
] = {}
|
||||
|
||||
def remap_block_offsets(
|
||||
instructions: list[Instruction], code_options: dict[str, Any]
|
||||
@ -660,7 +666,8 @@ class ContinueExecutionCache:
|
||||
|
||||
# if offset is not in setup_fn_target_offsets, it is an error
|
||||
setup_fn_target_offsets = tuple(
|
||||
meta.block_target_offset_remap[n] for n in setup_fn_target_offsets
|
||||
meta.block_target_offset_remap[new_offset][n]
|
||||
for n in setup_fn_target_offsets
|
||||
)
|
||||
return ContinueExecutionCache.lookup(
|
||||
meta.code, lineno, new_offset, setup_fn_target_offsets, *args
|
||||
|
Reference in New Issue
Block a user