mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] refactor resume_execution.py to use bytecode templates (#136483)
Use bytecode from template instead of hardcoding bytecode in resume_execution.py. Gets rid of a lot of Python-version dependent bytecode generation. Also makes resume_execution.py easier to support in future Python version updates. Pull Request resolved: https://github.com/pytorch/pytorch/pull/136483 Approved by: https://github.com/jansel, https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
36f0e61166
commit
ae80bce496
@ -933,6 +933,32 @@ def strip_extended_args(instructions: List[Instruction]) -> None:
|
||||
instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
|
||||
|
||||
|
||||
# Overwrites old_inst with a sequence of new instructions.
|
||||
# This is necessary in order to preserve jump targets to the old
|
||||
# instruction, exception table entries, and positions.
|
||||
# Returns the modified sequence of instructions (including the modified
|
||||
# old instruction!) that can be manipulated elsewhere.
|
||||
def overwrite_instruction(old_inst, new_insts):
|
||||
# update old_inst.exnt_tab_entry.end if necessary
|
||||
if (
|
||||
old_inst.exn_tab_entry
|
||||
and old_inst.exn_tab_entry.end is old_inst
|
||||
and len(new_insts) > 1
|
||||
):
|
||||
old_inst.exn_tab_entry.end = new_insts[-1]
|
||||
# preserve exception table entries and positions
|
||||
for inst in new_insts[1:]:
|
||||
inst.exn_tab_entry = copy.copy(old_inst.exn_tab_entry)
|
||||
inst.positions = old_inst.positions
|
||||
# modify old_inst in-place to preserve jump target
|
||||
old_inst.opcode = new_insts[0].opcode
|
||||
old_inst.opname = new_insts[0].opname
|
||||
old_inst.arg = new_insts[0].arg
|
||||
old_inst.argval = new_insts[0].argval
|
||||
old_inst.target = new_insts[0].target
|
||||
return [old_inst] + new_insts[1:]
|
||||
|
||||
|
||||
def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
|
||||
"""LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
|
||||
assert sys.version_info < (3, 11)
|
||||
@ -947,11 +973,11 @@ def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction
|
||||
def remove_jump_if_none(instructions: List[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst in instructions:
|
||||
new_insts.append(inst)
|
||||
if "_NONE" in inst.opname:
|
||||
is_op = create_instruction("IS_OP", arg=int("NOT" in inst.opname))
|
||||
# need both argval and arg set correctly now (not later)
|
||||
is_op.argval = is_op.arg
|
||||
is_op.positions = inst.positions
|
||||
|
||||
if sys.version_info < (3, 12):
|
||||
jump_op = create_instruction(
|
||||
"POP_JUMP_FORWARD_IF_TRUE"
|
||||
@ -961,19 +987,15 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
|
||||
)
|
||||
else:
|
||||
jump_op = create_instruction("POP_JUMP_IF_TRUE", target=inst.target)
|
||||
jump_op.positions = inst.positions
|
||||
# update inst.exn_tab_entry.end if necessary
|
||||
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
|
||||
inst.exn_tab_entry.end = jump_op
|
||||
# preserve exception table entries
|
||||
is_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
|
||||
jump_op.exn_tab_entry = copy.copy(inst.exn_tab_entry)
|
||||
# modify inst in-place to preserve jump target
|
||||
inst.opcode = dis.opmap["LOAD_CONST"]
|
||||
inst.opname = "LOAD_CONST"
|
||||
inst.arg = None
|
||||
inst.argval = None
|
||||
new_insts.extend([is_op, jump_op])
|
||||
|
||||
replace_insts = [
|
||||
create_instruction("LOAD_CONST", argval=None),
|
||||
is_op,
|
||||
jump_op,
|
||||
]
|
||||
new_insts.extend(overwrite_instruction(inst, replace_insts))
|
||||
else:
|
||||
new_insts.append(inst)
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
@ -1007,24 +1029,17 @@ FUSED_INSTS = {
|
||||
def remove_fused_load_store(instructions: List[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst in instructions:
|
||||
new_insts.append(inst)
|
||||
if inst.opname in FUSED_INSTS:
|
||||
inst0, inst1 = FUSED_INSTS[inst.opname]
|
||||
argval0, argval1 = inst.argval
|
||||
|
||||
# modify inst in-place to preserve jump target
|
||||
inst.opcode = dis.opmap[inst0]
|
||||
inst.opname = inst0
|
||||
inst.argval = argval0
|
||||
|
||||
new_inst = create_instruction(inst1, argval=argval1)
|
||||
# update inst.exn_tab_entry.end if necessary
|
||||
if inst.exn_tab_entry and inst.exn_tab_entry.end is inst:
|
||||
inst.exn_tab_entry.end = new_inst
|
||||
# preserve exception table entries
|
||||
new_inst.exn_tab_entry = copy.copy(inst.exn_tab_entry)
|
||||
|
||||
new_insts.append(new_inst)
|
||||
replace_insts = [
|
||||
create_instruction(inst0, argval=argval0),
|
||||
create_instruction(inst1, argval=argval1),
|
||||
]
|
||||
new_insts.append(overwrite_instruction(inst, replace_insts))
|
||||
else:
|
||||
new_insts.append(inst)
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
@ -1435,7 +1450,9 @@ def bytecode_from_template(fn, varname_map=None, noreturn=True, noprefix=True):
|
||||
For example, local variables in `fn` can be replaced with
|
||||
new names that are generated by `OutputGraph.new_var`.
|
||||
noreturn: remove all RETURN_* bytecodes and replace them with a jump
|
||||
to the end of the bytecode.
|
||||
to the end of the bytecode. NOTE: any items pushed to the stack
|
||||
for return WILL remain on the stack! Append a POP_TOP if you don't want
|
||||
that item to be present.
|
||||
noprefix: remove prefix bytecodes (all bytecode before the first RESUME, inclusive).
|
||||
"""
|
||||
insts = cleaned_instructions(fn.__code__)
|
||||
|
Reference in New Issue
Block a user