mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_dynamo (#145105)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145105 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
c95efc37ba
commit
a79100ab11
@ -6,7 +6,8 @@ import functools
|
||||
import itertools
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Callable, cast, Dict, Iterator, List, Optional, Sequence, Union
|
||||
from collections.abc import Iterator, Sequence
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
from ..utils._backport_slots import dataclass_slots
|
||||
from .bytecode_analysis import (
|
||||
@ -198,7 +199,7 @@ def create_dup_top() -> Instruction:
|
||||
return create_instruction("DUP_TOP")
|
||||
|
||||
|
||||
def create_rot_n(n) -> List[Instruction]:
|
||||
def create_rot_n(n) -> list[Instruction]:
|
||||
"""
|
||||
Returns a "simple" sequence of instructions that rotates TOS to the n-th
|
||||
position in the stack. For Python < 3.11, returns a single ROT_*
|
||||
@ -228,8 +229,8 @@ def create_rot_n(n) -> List[Instruction]:
|
||||
|
||||
|
||||
def add_push_null(
|
||||
inst_or_insts: Union[Instruction, List[Instruction]],
|
||||
) -> List[Instruction]:
|
||||
inst_or_insts: Union[Instruction, list[Instruction]],
|
||||
) -> list[Instruction]:
|
||||
"""
|
||||
Appends or prepends a PUSH_NULL instruction to `inst_or_insts`,
|
||||
depending on Python version. Used when you know that
|
||||
@ -286,8 +287,8 @@ def add_push_null(
|
||||
|
||||
|
||||
def add_push_null_call_function_ex(
|
||||
inst_or_insts: Union[Instruction, List[Instruction]],
|
||||
) -> List[Instruction]:
|
||||
inst_or_insts: Union[Instruction, list[Instruction]],
|
||||
) -> list[Instruction]:
|
||||
"""Like add_push_null, but the low bit of LOAD_ATTR/LOAD_SUPER_ATTR
|
||||
is not set, due to an expected CALL_FUNCTION_EX instruction.
|
||||
"""
|
||||
@ -314,7 +315,7 @@ def add_push_null_call_function_ex(
|
||||
return insts
|
||||
|
||||
|
||||
def create_call_function(nargs, push_null) -> List[Instruction]:
|
||||
def create_call_function(nargs, push_null) -> list[Instruction]:
|
||||
"""
|
||||
Creates a sequence of instructions that makes a function call.
|
||||
|
||||
@ -369,7 +370,7 @@ def create_call_function(nargs, push_null) -> List[Instruction]:
|
||||
return [create_instruction("CALL_FUNCTION", arg=nargs)]
|
||||
|
||||
|
||||
def create_call_method(nargs) -> List[Instruction]:
|
||||
def create_call_method(nargs) -> list[Instruction]:
|
||||
if sys.version_info >= (3, 12):
|
||||
return [create_instruction("CALL", arg=nargs)]
|
||||
if sys.version_info >= (3, 11):
|
||||
@ -392,7 +393,7 @@ def create_setup_with(target) -> Instruction:
|
||||
return create_instruction(opname, target=target)
|
||||
|
||||
|
||||
def create_swap(n) -> List[Instruction]:
|
||||
def create_swap(n) -> list[Instruction]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return [create_instruction("SWAP", arg=n)]
|
||||
# in Python < 3.11, SWAP is a macro that expands to multiple instructions
|
||||
@ -436,14 +437,14 @@ def create_swap(n) -> List[Instruction]:
|
||||
|
||||
def lnotab_writer(
|
||||
lineno: int, byteno: int = 0
|
||||
) -> tuple[List[int], Callable[[int, int], None]]:
|
||||
) -> tuple[list[int], Callable[[int, int], None]]:
|
||||
"""
|
||||
Used to create typing.CodeType.co_lnotab
|
||||
See https://github.com/python/cpython/blob/main/Objects/lnotab_notes.txt
|
||||
This is the internal format of the line number table if Python < 3.10
|
||||
"""
|
||||
assert sys.version_info < (3, 10)
|
||||
lnotab: List[int] = []
|
||||
lnotab: list[int] = []
|
||||
|
||||
def update(lineno_new, byteno_new):
|
||||
nonlocal byteno, lineno
|
||||
@ -465,7 +466,7 @@ def linetable_310_writer(first_lineno):
|
||||
This is the internal format of the line number table for Python 3.10
|
||||
"""
|
||||
assert sys.version_info >= (3, 10) and sys.version_info < (3, 11)
|
||||
linetable: List[int] = []
|
||||
linetable: list[int] = []
|
||||
lineno = first_lineno
|
||||
lineno_delta = 0
|
||||
byteno = 0
|
||||
@ -493,7 +494,7 @@ def linetable_310_writer(first_lineno):
|
||||
return linetable, update, end
|
||||
|
||||
|
||||
def encode_varint(n: int) -> List[int]:
|
||||
def encode_varint(n: int) -> list[int]:
|
||||
"""
|
||||
6-bit chunk encoding of an unsigned integer
|
||||
See https://github.com/python/cpython/blob/3.11/Objects/locations.md
|
||||
@ -577,7 +578,7 @@ class ExceptionTableEntry:
|
||||
lasti: bool
|
||||
|
||||
|
||||
def encode_exception_table_varint(n: int) -> List[int]:
|
||||
def encode_exception_table_varint(n: int) -> list[int]:
|
||||
"""
|
||||
Similar to `encode_varint`, but the 6-bit chunks are ordered in reverse.
|
||||
"""
|
||||
@ -606,7 +607,7 @@ def decode_exception_table_varint(bytes_iter: Iterator[int]) -> int:
|
||||
return val
|
||||
|
||||
|
||||
def check_exception_table(tab: List[ExceptionTableEntry]) -> None:
|
||||
def check_exception_table(tab: list[ExceptionTableEntry]) -> None:
|
||||
"""
|
||||
Verifies that a list of ExceptionTableEntries will make a well-formed
|
||||
jump table: entries are non-empty, sorted, and do not overlap.
|
||||
@ -619,7 +620,7 @@ def check_exception_table(tab: List[ExceptionTableEntry]) -> None:
|
||||
)
|
||||
|
||||
|
||||
def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]:
|
||||
def parse_exception_table(exntab: bytes) -> list[ExceptionTableEntry]:
|
||||
"""
|
||||
Parse the exception table according to
|
||||
https://github.com/python/cpython/blob/3.11/Objects/exception_handling_notes.txt
|
||||
@ -641,7 +642,7 @@ def parse_exception_table(exntab: bytes) -> List[ExceptionTableEntry]:
|
||||
return tab
|
||||
|
||||
|
||||
def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
|
||||
def assemble_exception_table(tab: list[ExceptionTableEntry]) -> bytes:
|
||||
"""
|
||||
Inverse of parse_exception_table - encodes list of exception
|
||||
table entries into bytes.
|
||||
@ -659,9 +660,9 @@ def assemble_exception_table(tab: List[ExceptionTableEntry]) -> bytes:
|
||||
return bytes(b)
|
||||
|
||||
|
||||
def assemble(instructions: List[Instruction], firstlineno: int) -> tuple[bytes, bytes]:
|
||||
def assemble(instructions: list[Instruction], firstlineno: int) -> tuple[bytes, bytes]:
|
||||
"""Do the opposite of dis.get_instructions()"""
|
||||
code: List[int] = []
|
||||
code: list[int] = []
|
||||
if sys.version_info >= (3, 11):
|
||||
lnotab, update_lineno = linetable_311_writer(firstlineno)
|
||||
num_ext = 0
|
||||
@ -701,7 +702,7 @@ def assemble(instructions: List[Instruction], firstlineno: int) -> tuple[bytes,
|
||||
return bytes(code), bytes(lnotab)
|
||||
|
||||
|
||||
def _get_instruction_by_offset(offset_to_inst: Dict[int, Instruction], offset: int):
|
||||
def _get_instruction_by_offset(offset_to_inst: dict[int, Instruction], offset: int):
|
||||
"""
|
||||
Get the instruction located at a given offset, accounting for EXTENDED_ARGs
|
||||
"""
|
||||
@ -736,7 +737,7 @@ def flip_jump_direction(instruction: Instruction) -> None:
|
||||
assert instruction.opcode in _REL_JUMPS
|
||||
|
||||
|
||||
def _get_instruction_front(instructions: List[Instruction], idx: int):
|
||||
def _get_instruction_front(instructions: list[Instruction], idx: int):
|
||||
"""
|
||||
i.e. get the first EXTENDED_ARG instruction (if any) when targeting
|
||||
instructions[idx] with a jump.
|
||||
@ -798,7 +799,7 @@ def devirtualize_jumps(instructions):
|
||||
inst.argrepr = f"to {target.offset}"
|
||||
|
||||
|
||||
def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruction]):
|
||||
def virtualize_exception_table(exn_tab_bytes: bytes, instructions: list[Instruction]):
|
||||
"""Replace exception table entries with pointers to make editing easier"""
|
||||
exn_tab = parse_exception_table(exn_tab_bytes)
|
||||
offset_to_inst = {cast(int, inst.offset): inst for inst in instructions}
|
||||
@ -840,10 +841,10 @@ def virtualize_exception_table(exn_tab_bytes: bytes, instructions: List[Instruct
|
||||
|
||||
|
||||
def compute_exception_table(
|
||||
instructions: List[Instruction],
|
||||
) -> List[ExceptionTableEntry]:
|
||||
instructions: list[Instruction],
|
||||
) -> list[ExceptionTableEntry]:
|
||||
"""Compute exception table in list format from instructions with exn_tab_entries"""
|
||||
exn_dict: Dict[tuple[int, int], tuple[int, int, bool]] = {}
|
||||
exn_dict: dict[tuple[int, int], tuple[int, int, bool]] = {}
|
||||
indexof = get_indexof(instructions)
|
||||
|
||||
for inst in instructions:
|
||||
@ -877,8 +878,8 @@ def compute_exception_table(
|
||||
# smallest byte that the next exception table entry can start at
|
||||
nexti = 0
|
||||
# stack of current nested keys
|
||||
key_stack: List[tuple[int, int]] = []
|
||||
exn_tab: List[ExceptionTableEntry] = []
|
||||
key_stack: list[tuple[int, int]] = []
|
||||
exn_tab: list[ExceptionTableEntry] = []
|
||||
|
||||
def pop():
|
||||
"""
|
||||
@ -914,7 +915,7 @@ def compute_exception_table(
|
||||
|
||||
|
||||
def check_inst_exn_tab_entries_nested(
|
||||
tab: List[InstructionExnTabEntry], indexof
|
||||
tab: list[InstructionExnTabEntry], indexof
|
||||
) -> None:
|
||||
"""
|
||||
Checks `tab` is a properly sorted list of nested InstructionExnTabEntry's,
|
||||
@ -922,7 +923,7 @@ def check_inst_exn_tab_entries_nested(
|
||||
"Properly sorted" means entries are sorted by increasing starts, then
|
||||
decreasing ends.
|
||||
"""
|
||||
entry_stack: List[tuple[int, int]] = []
|
||||
entry_stack: list[tuple[int, int]] = []
|
||||
for entry in tab:
|
||||
key = (indexof[entry.start], indexof[entry.end])
|
||||
while entry_stack and entry_stack[-1][1] < key[0]:
|
||||
@ -932,13 +933,13 @@ def check_inst_exn_tab_entries_nested(
|
||||
entry_stack.append(key)
|
||||
|
||||
|
||||
def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
|
||||
def propagate_inst_exn_table_entries(instructions: list[Instruction]) -> None:
|
||||
"""
|
||||
Copies exception table entries to all instructions in an entry's range.
|
||||
Supports nested exception table entries.
|
||||
"""
|
||||
indexof = get_indexof(instructions)
|
||||
entries: Dict[tuple[int, int], InstructionExnTabEntry] = {}
|
||||
entries: dict[tuple[int, int], InstructionExnTabEntry] = {}
|
||||
for inst in instructions:
|
||||
if inst.exn_tab_entry:
|
||||
key = (
|
||||
@ -959,7 +960,7 @@ def propagate_inst_exn_table_entries(instructions: List[Instruction]) -> None:
|
||||
instructions[i].exn_tab_entry = copy.copy(entry)
|
||||
|
||||
|
||||
def check_inst_exn_tab_entries_valid(instructions: List[Instruction]):
|
||||
def check_inst_exn_tab_entries_valid(instructions: list[Instruction]):
|
||||
"""
|
||||
Checks that exn_tab_entries of instructions are valid.
|
||||
An entry's start, end, and target must be in instructions.
|
||||
@ -983,7 +984,7 @@ def check_inst_exn_tab_entries_valid(instructions: List[Instruction]):
|
||||
assert indexof[entry.start] <= i <= indexof[entry.end]
|
||||
|
||||
|
||||
def strip_extended_args(instructions: List[Instruction]) -> None:
|
||||
def strip_extended_args(instructions: list[Instruction]) -> None:
|
||||
instructions[:] = [i for i in instructions if i.opcode != dis.EXTENDED_ARG]
|
||||
|
||||
|
||||
@ -1013,7 +1014,7 @@ def overwrite_instruction(old_inst, new_insts):
|
||||
return [old_inst] + new_insts[1:]
|
||||
|
||||
|
||||
def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction]:
|
||||
def remove_load_call_method(instructions: list[Instruction]) -> list[Instruction]:
|
||||
"""LOAD_METHOD puts a NULL on the stack which causes issues, so remove it"""
|
||||
assert sys.version_info < (3, 11)
|
||||
rewrites = {"LOAD_METHOD": "LOAD_ATTR", "CALL_METHOD": "CALL_FUNCTION"}
|
||||
@ -1024,7 +1025,7 @@ def remove_load_call_method(instructions: List[Instruction]) -> List[Instruction
|
||||
return instructions
|
||||
|
||||
|
||||
def remove_jump_if_none(instructions: List[Instruction]) -> None:
|
||||
def remove_jump_if_none(instructions: list[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst in instructions:
|
||||
if "_NONE" in inst.opname:
|
||||
@ -1055,7 +1056,7 @@ def remove_jump_if_none(instructions: List[Instruction]) -> None:
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
def remove_binary_store_slice(instructions: List[Instruction]) -> None:
|
||||
def remove_binary_store_slice(instructions: list[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst in instructions:
|
||||
new_insts.append(inst)
|
||||
@ -1082,7 +1083,7 @@ FUSED_INSTS = {
|
||||
}
|
||||
|
||||
|
||||
def remove_fused_load_store(instructions: List[Instruction]) -> None:
|
||||
def remove_fused_load_store(instructions: list[Instruction]) -> None:
|
||||
new_insts = []
|
||||
for inst in instructions:
|
||||
if inst.opname in FUSED_INSTS:
|
||||
@ -1099,7 +1100,7 @@ def remove_fused_load_store(instructions: List[Instruction]) -> None:
|
||||
instructions[:] = new_insts
|
||||
|
||||
|
||||
def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> None:
|
||||
def explicit_super(code: types.CodeType, instructions: list[Instruction]) -> None:
|
||||
"""convert super() with no args into explicit arg form"""
|
||||
cell_and_free = (code.co_cellvars or ()) + (code.co_freevars or ())
|
||||
if not len(code.co_varnames):
|
||||
@ -1137,9 +1138,9 @@ def explicit_super(code: types.CodeType, instructions: List[Instruction]) -> Non
|
||||
instructions[:] = output
|
||||
|
||||
|
||||
def fix_extended_args(instructions: List[Instruction]) -> int:
|
||||
def fix_extended_args(instructions: list[Instruction]) -> int:
|
||||
"""Fill in correct argvals for EXTENDED_ARG ops"""
|
||||
output: List[Instruction] = []
|
||||
output: list[Instruction] = []
|
||||
|
||||
def maybe_pop_n(n):
|
||||
for _ in range(n):
|
||||
@ -1229,7 +1230,7 @@ def get_const_index(code_options, val) -> int:
|
||||
return len(code_options["co_consts"]) - 1
|
||||
|
||||
|
||||
def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None):
|
||||
def fix_vars(instructions: list[Instruction], code_options, varname_from_oparg=None):
|
||||
# compute instruction arg from argval if arg is not provided
|
||||
names = {name: idx for idx, name in enumerate(code_options["co_names"])}
|
||||
|
||||
@ -1353,7 +1354,7 @@ def clear_instruction_args(instructions):
|
||||
inst.arg = None
|
||||
|
||||
|
||||
def get_code_keys() -> List[str]:
|
||||
def get_code_keys() -> list[str]:
|
||||
# Python 3.11 changes to code keys are not fully documented.
|
||||
# See https://github.com/python/cpython/blob/3.11/Objects/clinic/codeobject.c.h#L24
|
||||
# for new format.
|
||||
@ -1405,8 +1406,8 @@ def transform_code_object(code, transformations, safe=False) -> types.CodeType:
|
||||
|
||||
|
||||
def clean_and_assemble_instructions(
|
||||
instructions: List[Instruction], keys: List[str], code_options: Dict[str, Any]
|
||||
) -> tuple[List[Instruction], types.CodeType]:
|
||||
instructions: list[Instruction], keys: list[str], code_options: dict[str, Any]
|
||||
) -> tuple[list[Instruction], types.CodeType]:
|
||||
# also implicitly checks for no duplicate instructions
|
||||
check_inst_exn_tab_entries_valid(instructions)
|
||||
|
||||
@ -1453,7 +1454,7 @@ def populate_kw_names_argval(instructions, consts):
|
||||
|
||||
# If safe=True, we do not make any bytecode modifications.
|
||||
# Mainly used for debugging bytecode_transformation (see debug_checks)
|
||||
def cleaned_instructions(code, safe=False) -> List[Instruction]:
|
||||
def cleaned_instructions(code, safe=False) -> list[Instruction]:
|
||||
instructions = _cached_cleaned_instructions(code, safe)
|
||||
# We have a lot of code that implicitly mutates the instruction array. We
|
||||
# could do better here by making the copies explicit when necessary.
|
||||
|
Reference in New Issue
Block a user