[dynamo, 3.12] fix positions and offsets of added instructions when we clean (#123991)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123991
Approved by: https://github.com/jansel
ghstack dependencies: #123978
This commit is contained in:
William Wen
2024-04-13 21:32:15 +00:00
committed by PyTorch MergeBot
parent 88a7159493
commit 0dfe72c63b
2 changed files with 31 additions and 3 deletions

View File

@ -10109,7 +10109,7 @@ fn
opt_fn = torch.compile(fn, backend="eager")
opt_fn(inp)
def test_312_binary_slice_with_graph_break(self):
def test_312_binary_slice_with_graph_break1(self):
l1 = torch.nn.Linear(5, 5)
l2 = torch.nn.Linear(5, 5)
@ -10122,6 +10122,31 @@ fn
opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(5, 5))
def test_312_binary_slice_with_graph_break2(self):
class Foo:
def __setitem__(self, key, val):
pass
def __getitem__(self, key):
torch._dynamo.graph_break()
return 1
foo = Foo()
def fn(x):
# graph break in a STORE_SLICE instruction
foo[:] = x
# graph break in BINARY_SLICE with has_backedge check
x = x + foo[:]
if x is None:
x = x + 1
else:
x = x + 1
return x
opt_fn = torch.compile(fn, backend="eager")
opt_fn(torch.randn(5, 5))
def test_super_after_graph_break(self):
class Foo(torch.nn.Sequential):
def __init__(self, layers):

View File

@ -804,6 +804,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
if "_NONE" in inst.opname:
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
is_op.argval = is_op.arg
is_op.positions = inst.positions
if sys.version_info < (3, 12):
jump_op = create_instruction(
"POP_JUMP_FORWARD_IF_TRUE"
@ -813,6 +814,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
)
else:
jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
jump_op.positions = inst.positions
# update inst.exn_tab_entry.end if necessary
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = jump_op
@ -838,6 +840,7 @@ def remove_binary_store_slice(instructions: List[Instruction]) -> None:
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = subscr_inst
subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
subscr_inst.positions = inst.positions
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap["BUILD_SLICE"]
inst.opname = "BUILD_SLICE"
@ -1176,10 +1179,10 @@ def cleaned_instructions(code, safe=False) -> List[Instruction]:
explicit_super(code, instructions)
if sys.version_info >= (3, 11):
remove_jump_if_none(instructions)
if sys.version_info >= (3, 12):
remove_binary_store_slice(instructions)
update_offsets(instructions)
devirtualize_jumps(instructions)
if sys.version_info >= (3, 12):
remove_binary_store_slice(instructions)
return instructions