mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo, 3.12] fix positions and offsets of added instructions when we clean (#123991)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123991 Approved by: https://github.com/jansel ghstack dependencies: #123978
This commit is contained in:
committed by
PyTorch MergeBot
parent
88a7159493
commit
0dfe72c63b
@ -10109,7 +10109,7 @@ fn
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
opt_fn(inp)
|
||||
|
||||
def test_312_binary_slice_with_graph_break(self):
|
||||
def test_312_binary_slice_with_graph_break1(self):
|
||||
l1 = torch.nn.Linear(5, 5)
|
||||
l2 = torch.nn.Linear(5, 5)
|
||||
|
||||
@ -10122,6 +10122,31 @@ fn
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
opt_fn(torch.randn(5, 5))
|
||||
|
||||
def test_312_binary_slice_with_graph_break2(self):
|
||||
class Foo:
|
||||
def __setitem__(self, key, val):
|
||||
pass
|
||||
|
||||
def __getitem__(self, key):
|
||||
torch._dynamo.graph_break()
|
||||
return 1
|
||||
|
||||
foo = Foo()
|
||||
|
||||
def fn(x):
|
||||
# graph break in a STORE_SLICE instruction
|
||||
foo[:] = x
|
||||
# graph break in BINARY_SLICE with has_backedge check
|
||||
x = x + foo[:]
|
||||
if x is None:
|
||||
x = x + 1
|
||||
else:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager")
|
||||
opt_fn(torch.randn(5, 5))
|
||||
|
||||
def test_super_after_graph_break(self):
|
||||
class Foo(torch.nn.Sequential):
|
||||
def __init__(self, layers):
|
||||
|
@ -804,6 +804,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
|
||||
if "_NONE" in inst.opname:
|
||||
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
|
||||
is_op.argval = is_op.arg
|
||||
is_op.positions = inst.positions
|
||||
if sys.version_info < (3, 12):
|
||||
jump_op = create_instruction(
|
||||
"POP_JUMP_FORWARD_IF_TRUE"
|
||||
@ -813,6 +814,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
|
||||
)
|
||||
else:
|
||||
jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
|
||||
jump_op.positions = inst.positions
|
||||
# update inst.exn_tab_entry.end if necessary
|
||||
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
|
||||
inst.exn_tab_entry.end = jump_op
|
||||
@ -838,6 +840,7 @@ def remove_binary_store_slice(instructions: List[Instruction]) -> None:
|
||||
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
|
||||
inst.exn_tab_entry.end = subscr_inst
|
||||
subscr_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
|
||||
subscr_inst.positions = inst.positions
|
||||
# modify inst in-place to preserve jump target
|
||||
inst.opcode = dis.opmap["BUILD_SLICE"]
|
||||
inst.opname = "BUILD_SLICE"
|
||||
@ -1176,10 +1179,10 @@ def cleaned_instructions(code, safe=False) -> List[Instruction]:
|
||||
explicit_super(code, instructions)
|
||||
if sys.version_info >= (3, 11):
|
||||
remove_jump_if_none(instructions)
|
||||
if sys.version_info >= (3, 12):
|
||||
remove_binary_store_slice(instructions)
|
||||
update_offsets(instructions)
|
||||
devirtualize_jumps(instructions)
|
||||
if sys.version_info >= (3, 12):
|
||||
remove_binary_store_slice(instructions)
|
||||
return instructions
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user