From b71e813bceca1e3862b0094e4d58b51a621630c6 Mon Sep 17 00:00:00 2001 From: William Wen Date: Fri, 1 Nov 2024 13:52:46 -0700 Subject: [PATCH] [dynamo, 3.13] fix bytecode nop tests (#139323) Pull Request resolved: https://github.com/pytorch/pytorch/pull/139323 Approved by: https://github.com/jansel --- torch/_dynamo/bytecode_transformation.py | 30 +++++++++++++++++------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 8202d32dcd1b..73054dfb740b 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -980,9 +980,11 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None: if sys.version_info < (3, 12): jump_op = create_instruction( - "POP_JUMP_FORWARD_IF_TRUE" - if "FORWARD" in inst.opname - else "POP_JUMP_BACKWARD_IF_TRUE", + ( + "POP_JUMP_FORWARD_IF_TRUE" + if "FORWARD" in inst.opname + else "POP_JUMP_BACKWARD_IF_TRUE" + ), target=inst.target, ) else: @@ -1244,6 +1246,15 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N + (cast(int, instructions[i].arg) % 2) + 2 ) + elif instructions[i].opname in FUSED_INSTS: + assert sys.version_info >= (3, 13) + assert isinstance(instructions[i].argval, tuple) + assert len(instructions[i].argval) == 2 + arg_tuple = tuple( + varnames[name] if name in varnames else freenames[name] + for name in instructions[i].argval + ) + instructions[i].arg = (arg_tuple[0] << 4) + (arg_tuple[1] & 15) elif instructions[i].opcode in HAS_LOCAL: if should_compute_arg(): if ( @@ -1385,6 +1396,8 @@ def populate_kw_names_argval(instructions, consts): inst.argval = consts[inst.arg] +# If safe=True, we do not make any bytecode modifications. +# Mainly used for debugging bytecode_transformation (see debug_checks) def cleaned_instructions(code, safe=False) -> List[Instruction]: instructions = list(map(convert_instruction, dis.get_instructions(code))) check_offsets(instructions) @@ -1398,12 +1411,13 @@ def cleaned_instructions(code, safe=False) -> List[Instruction]: remove_load_call_method(instructions) if sys.version_info < (3, 12): explicit_super(code, instructions) + if sys.version_info >= (3, 11): + remove_jump_if_none(instructions) + if sys.version_info >= (3, 12): + remove_binary_store_slice(instructions) + if sys.version_info >= (3, 13): + remove_fused_load_store(instructions) if sys.version_info >= (3, 11): - remove_jump_if_none(instructions) - if sys.version_info >= (3, 12): - remove_binary_store_slice(instructions) - if sys.version_info >= (3, 13): - remove_fused_load_store(instructions) update_offsets(instructions) devirtualize_jumps(instructions) return instructions