[dynamo] refactor resume_execution.py to use bytecode templates (#136483)

Use bytecode from template instead of hardcoding bytecode in resume_execution.py. Gets rid of a lot of Python-version dependent bytecode generation. Also makes resume_execution.py easier to support in future Python version updates.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/136483
Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
William Wen
2024-09-24 17:08:37 +00:00
committed by PyTorch MergeBot
parent 36f0e61166
commit ae80bce496
3 changed files with 149 additions and 325 deletions

View File

@ -933,6 +933,32 @@ def strip_extended_args(instructions: List[Instruction]) -> None:
instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
# Overwrites old_inst with a sequence of new instructions.
# This is necessary in order to preserve jump targets to the old
# instruction, exception table entries, and positions.
# Returns the modified sequence of instructions (including the modified
# old instruction!) that can be manipulated elsewhere.
def overwrite_instruction(old_inst, new_insts):
# update old_inst.exnt_tab_entry.end if necessary
if (
old_inst.exn_tab_entry
and old_inst.exn_tab_entry.end is old_inst
and len(new_insts) > 1
):
old_inst.exn_tab_entry.end = new_insts[-1]
# preserve exception table entries and positions
for inst in new_insts[1:]:
inst.exn_tab_entry = copy.copy(old_inst.exn_tab_entry)
inst.positions = old_inst.positions
# modify old_inst in-place to preserve jump target
old_inst.opcode = new_insts[0].opcode
old_inst.opname = new_insts[0].opname
old_inst.arg = new_insts[0].arg
old_inst.argval = new_insts[0].argval
old_inst.target = new_insts[0].target
return [old_inst] + new_insts[1:]
def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
"""LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
assert sys.version_info < (3, 11)
@ -947,11 +973,11 @@ def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction
def remove_jump_if_none(instructions: List[Instruction]) -> None:
new_insts = []
for inst in instructions:
new_insts.append(inst)
if "_NONE" in inst.opname:
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
# need both argval and arg set correctly now (not later)
is_op.argval = is_op.arg
is_op.positions = inst.positions
if sys.version_info < (3, 12):
jump_op = create_instruction(
"POP_JUMP_FORWARD_IF_TRUE"
@ -961,19 +987,15 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
)
else:
jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
jump_op.positions = inst.positions
# update inst.exn_tab_entry.end if necessary
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = jump_op
# preserve exception table entries
is_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
jump_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap["LOAD_CONST"]
inst.opname = "LOAD_CONST"
inst.arg = None
inst.argval = None
new_insts.extend([is_op, jump_op])
replace_insts = [
create_instruction("LOAD_CONST", argval=None),
is_op,
jump_op,
]
new_insts.extend(overwrite_instruction(inst, replace_insts))
else:
new_insts.append(inst)
instructions[:] = new_insts
@ -1007,24 +1029,17 @@ FUSED_INSTS = {
def remove_fused_load_store(instructions: List[Instruction]) -> None:
new_insts = []
for inst in instructions:
new_insts.append(inst)
if inst.opname in FUSED_INSTS:
inst0, inst1 = FUSED_INSTS[inst.opname]
argval0, argval1 = inst.argval
# modify inst in-place to preserve jump target
inst.opcode = dis.opmap[inst0]
inst.opname = inst0
inst.argval = argval0
new_inst = create_instruction(inst1, argval=argval1)
# update inst.exn_tab_entry.end if necessary
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
inst.exn_tab_entry.end = new_inst
# preserve exception table entries
new_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
new_insts.append(new_inst)
replace_insts = [
create_instruction(inst0, argval=argval0),
create_instruction(inst1, argval=argval1),
]
new_insts.append(overwrite_instruction(inst, replace_insts))
else:
new_insts.append(inst)
instructions[:] = new_insts
@ -1435,7 +1450,9 @@ def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
For example, local variables in `fn` can be replaced with
new names that are generated by `OutputGraph.new_var`.
noreturn: remove all RETURN_* bytecodes and replace them with a jump
to the end of the bytecode.
to the end of the bytecode. NOTE: any items pushed to the stack
for return WILL remain on the stack! Append a POP_TOP if you don't want
that item to be present.
noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
"""
insts = cleaned_instructions(fn.__code__)