PEP585 update - torch/_dynamo (#145105)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145105
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-18 08:56:06 -08:00
committed by PyTorch MergeBot
parent c95efc37ba
commit a79100ab11
69 changed files with 847 additions and 864 deletions

View File

@ -6,7 +6,8 @@ import functools
import itertools
import sys
import types
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Sequence, Union
from collections.abc import Iterator, Sequence
from typing import Any, Callable, cast, Optional, Union
from ..utils._backport_slots import dataclass_slots
from .bytecode_analysis import (
@ -198,7 +199,7 @@ def create_dup_top() -> Instruction:
return create_instruction("DUP_TOP")
def create_rot_n(n) -> List[Instruction]:
def create_rot_n(n) -> list[Instruction]:
"""
Returns a "simple" sequence of instructions that rotates TOS to the n-th
position in the stack. For Python < 3.11, returns a single ROT_*
@ -228,8 +229,8 @@ def create_rot_n(n) -> List[Instruction]:
def add_push_null(
inst_or_insts: Union[Instruction, List[Instruction]],
) -> List[Instruction]:
inst_or_insts: Union[Instruction, list[Instruction]],
) -> list[Instruction]:
"""
Appends or prepends a PUSH_NULL instruction to `inst_or_insts`,
depending on Python version. Used when you know that
@ -286,8 +287,8 @@ def add_push_null(
def add_push_null_call_function_ex(
inst_or_insts: Union[Instruction, List[Instruction]],
) -> List[Instruction]:
inst_or_insts: Union[Instruction, list[Instruction]],
) -> list[Instruction]:
"""Like add_push_null, but the low bit of LOAD_ATTR/LOAD_SUPER_ATTR
is not set, due to an expected CALL_FUNCTION_EX instruction.
"""
@ -314,7 +315,7 @@ def add_push_null_call_function_ex(
return insts
def create_call_function(nargs, push_null) -> List[Instruction]:
def create_call_function(nargs, push_null) -> list[Instruction]:
"""
Creates a sequence of instructions that makes a function call.
@ -369,7 +370,7 @@ def create_call_function(nargs, push_null) -> List[Instruction]:
return [create_instruction("CALL_FUNCTION", arg=nargs)]
def create_call_method(nargs) -> List[Instruction]:
def create_call_method(nargs) -> list[Instruction]:
if sys.version_info >= (3, 12):
return [create_instruction("CALL", arg=nargs)]
if sys.version_info >= (3, 11):
@ -392,7 +393,7 @@ def create_setup_with(target) -> Instruction:
return create_instruction(opname, target=target)
def create_swap(n) -> List[Instruction]:
def create_swap(n) -> list[Instruction]:
if sys.version_info >= (3, 11):
return [create_instruction("SWAP", arg=n)]
# in Python < 3.11, SWAP is a macro that expands to multiple instructions
@ -436,14 +437,14 @@ def create_swap(n) -> List[Instruction]:
def lnotab_writer(
lineno: int, byteno: int = 0
) -> tuple[List[int], Callable[[int, int], None]]:
) -> tuple[list[int], Callable[[int, int], None]]:
"""
Used to create typing.CodeType.co_lnotab
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
This is the internal format of the line number table if Python < 3.10
"""
assert sys.version_info < (3, 10)
lnotab: List[int] = []
lnotab: list[int] = []
def update(lineno_new, byteno_new):
nonlocal byteno, lineno
@ -465,7 +466,7 @@ def linetable_310_writer(first_lineno):
This is the internal format of the line number table for Python 3.10
"""
assert sys.version_info >= (3, 10) and sys.version_info < (3, 11)
linetable: List[int] = []
linetable: list[int] = []
lineno = first_lineno
lineno_delta = 0
byteno = 0
@ -493,7 +494,7 @@ def linetable_310_writer(first_lineno):
return linetable, update, end
def encode_varint(n: int) -> List[int]:
def encode_varint(n: int) -> list[int]:
"""
6-bit chunk encoding of an unsigned integer
See https://github.com/python/cpython/blob/3.11/Objects/locations.md
@ -577,7 +578,7 @@ class ExceptionTableEntry:
lasti: bool
def encode_exception_table_varint(n: int) -> List[int]:
def encode_exception_table_varint(n: int) -> list[int]:
"""
Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse.
"""
@ -606,7 +607,7 @@ def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int:
return val
def check_exception_table(tab: List[ExceptionTableEntry]) -> None:
def check_exception_table(tab: list[ExceptionTableEntry]) -> None:
"""
Verifies that a list of ExceptionTableEntries will make a well-formed
jump table: entries are non-empty, sorted, and do not overlap.
@ -619,7 +620,7 @@ def check_exception_table(tab: List[ExceptionTableEntry]) -> None:
)
def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]:
def parse_exception_table(exntab: bytes) -> list[ExceptionTableEntry]:
"""
Parse the exception table according to
https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
@ -641,7 +642,7 @@ def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]:
return tab
def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
def assemble_exception_table(tab: list[ExceptionTableEntry]) -> bytes:
"""
Inverse of parse_exception_table - encodes list of exception
table entries into bytes.
@ -659,9 +660,9 @@ def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
return bytes(b)
def assemble(instructions: List[Instruction], firstlineno: int) -> tuple[bytes, bytes]:
def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, bytes]:
"""Do the opposite of dis.get_instructions()"""
code: List[int] = []
code: list[int] = []
if sys.version_info >= (3, 11):
lnotab, update_lineno = linetable_311_writer(firstlineno)
num_ext = 0
@ -701,7 +702,7 @@ def assemble(instructions: List[Instruction], firstlineno: int) -> tuple[bytes,
return bytes(code), bytes(lnotab)
def _get_instruction_by_offset(offset_to_inst: Dict[int, Instruction], offset: int):
def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: int):
"""
Get the instruction located at a given offset, accounting for EXTENDED_ARGs
"""
@ -736,7 +737,7 @@ def flip_jump_direction(instruction: Instruction) -> None:
assert instruction.opcode in _REL_JUMPS
def _get_instruction_front(instructions: List[Instruction], idx: int):
def _get_instruction_front(instructions: list[Instruction], idx: int):
"""
i.e. get the first EXTENDED_ARG instruction (if any) when targeting
instructions[idx] with a jump.
@ -798,7 +799,7 @@ def devirtualize_jumps(instructions):
inst.argrepr = f"to {target.offset}"
def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruction]):
def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruction]):
"""Replace exception table entries with pointers to make editing easier"""
exn_tab = parse_exception_table(exn_tab_bytes)
offset_to_inst = {cast(int, inst.offset): inst for inst in instructions}
@ -840,10 +841,10 @@ def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruct
def compute_exception_table(
instructions: List[Instruction],
) -> List[ExceptionTableEntry]:
instructions: list[Instruction],
) -> list[ExceptionTableEntry]:
"""Compute exception table in list format from instructions with exn_tab_entries"""
exn_dict: Dict[tuple[int, int], tuple[int, int, bool]] = {}
exn_dict: dict[tuple[int, int], tuple[int, int, bool]] = {}
indexof = get_indexof(instructions)
for inst in instructions:
@ -877,8 +878,8 @@ def compute_exception_table(
# smallest byte that the next exception table entry can start at
nexti = 0
# stack of current nested keys
key_stack: List[tuple[int, int]] = []
exn_tab: List[ExceptionTableEntry] = []
key_stack: list[tuple[int, int]] = []
exn_tab: list[ExceptionTableEntry] = []
def pop():
"""
@ -914,7 +915,7 @@ def compute_exception_table(
def check_inst_exn_tab_entries_nested(
tab: List[InstructionExnTabEntry], indexof
tab: list[InstructionExnTabEntry], indexof
) -> None:
"""
Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's,
@ -922,7 +923,7 @@ def check_inst_exn_tab_entries_nested(
"Properly sorted" means entries are sorted by increasing starts, then
decreasing ends.
"""
entry_stack: List[tuple[int, int]] = []
entry_stack: list[tuple[int, int]] = []
for entry in tab:
key = (indexof[entry.start], indexof[entry.end])
while entry_stack and entry_stack[-1][1] < key[0]:
@ -932,13 +933,13 @@ def check_inst_exn_tab_entries_nested(
entry_stack.append(key)
def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
def propagate_inst_exn_table_entries(instructions: list[Instruction]) -> None:
"""
Copies exception table entries to all instructions in an entry's range.
Supports nested exception table entries.
"""
indexof = get_indexof(instructions)
entries: Dict[tuple[int, int], InstructionExnTabEntry] = {}
entries: dict[tuple[int, int], InstructionExnTabEntry] = {}
for inst in instructions:
if inst.exn_tab_entry:
key = (
@ -959,7 +960,7 @@ def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
instructions[i].exn_tab_entry = copy.copy(entry)
def check_inst_exn_tab_entries_valid(instructions: List[Instruction]):
def check_inst_exn_tab_entries_valid(instructions: list[Instruction]):
"""
Checks that exn_tab_entries of instructions are valid.
An entry's start, end, and target must be in instructions.
@ -983,7 +984,7 @@ def check_inst_exn_tab_entries_valid(instructions: List[Instruction]):
assert indexof[entry.start] <= i <= indexof[entry.end]
def strip_extended_args(instructions: List[Instruction]) -> None:
def strip_extended_args(instructions: list[Instruction]) -> None:
instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
@ -1013,7 +1014,7 @@ def overwrite_instruction(old_inst, new_insts):
return [old_inst] + new_insts[1:]
def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
def remove_load_call_method(instructions: list[Instruction]) -> list[Instruction]:
"""LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
assert sys.version_info < (3, 11)
rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
@ -1024,7 +1025,7 @@ def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction
return instructions
def remove_jump_if_none(instructions: List[Instruction]) -> None:
def remove_jump_if_none(instructions: list[Instruction]) -> None:
new_insts = []
for inst in instructions:
if "_NONE" in inst.opname:
@ -1055,7 +1056,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
instructions[:] = new_insts
def remove_binary_store_slice(instructions: List[Instruction]) -> None:
def remove_binary_store_slice(instructions: list[Instruction]) -> None:
new_insts = []
for inst in instructions:
new_insts.append(inst)
@ -1082,7 +1083,7 @@ FUSED_INSTS = {
}
def remove_fused_load_store(instructions: List[Instruction]) -> None:
def remove_fused_load_store(instructions: list[Instruction]) -> None:
new_insts = []
for inst in instructions:
if inst.opname in FUSED_INSTS:
@ -1099,7 +1100,7 @@ def remove_fused_load_store(instructions: List[Instruction]) -> None:
instructions[:] = new_insts
def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> None:
def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None:
"""convert super() with no args into explicit arg form"""
cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
if not len(code.co_varnames):
@ -1137,9 +1138,9 @@ def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> Non
instructions[:] = output
def fix_extended_args(instructions: List[Instruction]) -> int:
def fix_extended_args(instructions: list[Instruction]) -> int:
"""Fill in correct argvals for EXTENDED_ARG ops"""
output: List[Instruction] = []
output: list[Instruction] = []
def maybe_pop_n(n):
for _ in range(n):
@ -1229,7 +1230,7 @@ def get_const_index(code_options, val) -> int:
return len(code_options["co_consts"]) - 1
def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None):
def fix_vars(instructions: list[Instruction], code_options, varname_from_oparg=None):
# compute instruction arg from argval if arg is not provided
names = {name: idx for idx, name in enumerate(code_options["co_names"])}
@ -1353,7 +1354,7 @@ def clear_instruction_args(instructions):
inst.arg = None
def get_code_keys() -> List[str]:
def get_code_keys() -> list[str]:
# Python 3.11 changes to code keys are not fully documented.
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
# for new format.
@ -1405,8 +1406,8 @@ def transform_code_object(code, transformations, safe=False) -> types.CodeType:
def clean_and_assemble_instructions(
instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any]
) -> tuple[List[Instruction], types.CodeType]:
instructions: list[Instruction], keys: list[str], code_options: dict[str, Any]
) -> tuple[list[Instruction], types.CodeType]:
# also implicitly checks for no duplicate instructions
check_inst_exn_tab_entries_valid(instructions)
@ -1453,7 +1454,7 @@ def populate_kw_names_argval(instructions, consts):
# If safe=True, we do not make any bytecode modifications.
# Mainly used for debugging bytecode_transformation (see debug_checks)
def cleaned_instructions(code, safe=False) -> List[Instruction]:
def cleaned_instructions(code, safe=False) -> list[Instruction]:
instructions = _cached_cleaned_instructions(code, safe)
# We have a lot of code that implicitly mutates the instruction array. We
# could do better here by making the copies explicit when necessary.