diff --git a/torch/_inductor/jagged_lowerings.py b/torch/_inductor/jagged_lowerings.py index 7036105b9ab2..9b393b36b42e 100644 --- a/torch/_inductor/jagged_lowerings.py +++ b/torch/_inductor/jagged_lowerings.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import List, Optional, Union +from typing import Optional, Union import sympy @@ -113,8 +113,8 @@ def register_jagged_ops(): @register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default) def _jagged_to_padded_dense_forward( jagged_values: TensorBox, - jagged_offsets: List[TensorBox], - max_lengths: List[int], # list of ints/SymInts + jagged_offsets: list[TensorBox], + max_lengths: list[int], # list of ints/SymInts padding_value: float = 0.0, ) -> TensorBox: device = jagged_values.get_device_or_error() @@ -184,7 +184,7 @@ def register_jagged_ops(): def _dense_to_jagged_forward_impl( fallback_op, # pyre-ignore[2] dense: TensorBox, - jagged_offsets: List[TensorBox], + jagged_offsets: list[TensorBox], jagged_len: Optional[int] = None, ) -> TensorBox: device = dense.get_device_or_error() @@ -257,7 +257,7 @@ def register_jagged_ops(): @register_lowering(torch.ops.aten._padded_dense_to_jagged_forward) def _dense_to_jagged_forward( dense: TensorBox, - jagged_offsets: List[TensorBox], + jagged_offsets: list[TensorBox], jagged_len: Optional[int] = None, ) -> TensorBox: return _dense_to_jagged_forward_impl( diff --git a/torch/_inductor/kernel/conv.py b/torch/_inductor/kernel/conv.py index dc2d6037e5a1..0b8e88d41b18 100644 --- a/torch/_inductor/kernel/conv.py +++ b/torch/_inductor/kernel/conv.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging -from typing import cast, Optional, Sequence, TYPE_CHECKING, TypedDict +from typing import cast, Optional, TYPE_CHECKING, TypedDict import torch from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate @@ -33,6 +33,8 @@ from .mm_common import build_rocm_gemm_configs, filtered_configs if TYPE_CHECKING: + from collections.abc import Sequence + from ..ir import TensorBox log = logging.getLogger(__name__) diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index b6145005da21..3a472d0c547d 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -3,9 +3,10 @@ import logging import math +from collections.abc import Sequence from dataclasses import dataclass from enum import auto, Enum -from typing import Any, List, Optional, Sequence, Union +from typing import Any, Optional, Union import sympy @@ -90,7 +91,7 @@ def create_placeholder( return TensorBox.create(input_buffer) -def maybe_realize(args: List[Optional[IRNode]]): +def maybe_realize(args: list[Optional[IRNode]]): """Accepts a list of optional IRNodes and returns a list of realized IRNodes""" return tree_map( lambda x: ( @@ -109,7 +110,7 @@ def get_float32_precision(): return "'tf32'" -def zeros_and_scatter_lowering(shape: List[int], indices, values): +def zeros_and_scatter_lowering(shape: list[int], indices, values): # Always accumulate into fp32 then cast grad = _full(0, values.get_device(), torch.float32, shape) assert isinstance(grad, TensorBox) @@ -153,10 +154,10 @@ def zeros_and_scatter_lowering(shape: List[int], indices, values): return buffer -SubgraphResults = Union[List[Optional[ComputedBuffer]], Optional[ComputedBuffer]] +SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] -def build_subgraph_buffer(args: List[TensorBox], subgraph: Subgraph) -> SubgraphResults: +def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults: """This function's goal is to take in the required args and produce the subgraph buffer The subgraph buffer is a ComputedBuffer that will be inlined into the triton template @@ -870,7 +871,7 @@ def lower_cpu( "torch.compile on current platform is not supported for CPU." ) - fake_buffers: List[Buffer] = [] # noqa: F821 + fake_buffers: list[Buffer] = [] # noqa: F821 placeholder_inps = [ create_placeholder(name, dtype, query.get_device()) for name, dtype in [ @@ -968,7 +969,7 @@ def lower_cpu( [B, Hq, seq_len_q, v_head_dim], stride=[sympy.sympify(s) for s in out_strides], ) - _choices: List[Any] = [] + _choices: list[Any] = [] input_nodes = [query, key, value, kv_num_blocks, kv_indices] if not full_kv_num_blocks: no_full_kv_block = True @@ -1214,8 +1215,8 @@ def flex_attention( "V_HEAD_DIM", V.graph.sizevars.evaluate_static_shape(v_head_dim) ) - choices: List[Any] = [] - configs: List[tuple[int, int, int, int]] = [] + choices: list[Any] = [] + configs: list[tuple[int, int, int, int]] = [] configs.append(_get_default_config_fwd(query)) if config.max_autotune: configs += [ @@ -2071,9 +2072,9 @@ class JointOutputResult: """Results from processing joint outputs.""" grad_input: ComputedBuffer - captured_grads_compute: List[ComputedBuffer] - captured_grads: List[Optional[TensorBox]] - mutated_grads: List[TensorBox] + captured_grads_compute: list[ComputedBuffer] + captured_grads: list[Optional[TensorBox]] + mutated_grads: list[TensorBox] def process_joint_outputs( @@ -2088,7 +2089,7 @@ def process_joint_outputs( Returns: JointOutputResult containing processed buffers and gradients """ - assert isinstance(all_joint_outputs, List) + assert isinstance(all_joint_outputs, list) assert ( all_joint_outputs[0] is not None ), "joint_subgraph_buffer is None this is a bug!" @@ -2307,8 +2308,8 @@ def flex_attention_backward(*args, **kwargs): SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE) - choices: List[Any] = [] - configs: List[tuple[int, int, int, int]] = [] + choices: list[Any] = [] + configs: list[tuple[int, int, int, int]] = [] configs.append(_get_default_config_bwd(query)) if config.max_autotune: num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1] diff --git a/torch/_inductor/kernel/flex_decoding.py b/torch/_inductor/kernel/flex_decoding.py index a0b6351df6ef..933a080683fc 100644 --- a/torch/_inductor/kernel/flex_decoding.py +++ b/torch/_inductor/kernel/flex_decoding.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs """ Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" -from typing import Any, List +from typing import Any import sympy @@ -415,8 +415,8 @@ def create_flex_decoding_kernel(*args, **kwargs): score_mod_other_buffers = maybe_realize(score_mod_other_buffers) mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers) - choices: List[Any] = [] - configs: List[tuple[int, int, int]] = [] + choices: list[Any] = [] + configs: list[tuple[int, int, int]] = [] configs.append(_get_decoding_default_config(key)) # Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops. if config.max_autotune: diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 65424406d749..a0b2f1c42a98 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -1,7 +1,7 @@ # mypy: allow-untyped-defs import functools import logging -from typing import Any, Dict, List, Optional +from typing import Any, Optional import torch from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm @@ -933,7 +933,7 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None): def mul_epilogue(v1, v2): return V.ops.mul(v1, v2) - choices: List[Dict[Any, Any]] = [] + choices: list[dict[Any, Any]] = [] for config in int8_mm_configs( m, n, k, **mm_config_kwargs(ir.get_device_type(mat1)) ): diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index f7472a6ccbd7..8df7458cf40e 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -2,7 +2,8 @@ import functools import itertools import logging -from typing import Any, cast, Dict, Sequence +from collections.abc import Sequence +from typing import Any, cast import sympy @@ -438,7 +439,7 @@ def mm_grid(m, n, meta): return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1) -def persistent_mm_grid(M: int, N: int, meta: Dict[str, Any]): +def persistent_mm_grid(M: int, N: int, meta: dict[str, Any]): """Defines the grid for persistent kernels.""" return ( min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])), diff --git a/torch/_inductor/kernel/mm_scaled.py b/torch/_inductor/kernel/mm_scaled.py index 3bdf91852e01..b9e9cc757594 100644 --- a/torch/_inductor/kernel/mm_scaled.py +++ b/torch/_inductor/kernel/mm_scaled.py @@ -1,5 +1,6 @@ import logging -from typing import Any, Dict, List, Optional, Sequence +from collections.abc import Sequence +from typing import Any, Optional import sympy @@ -428,7 +429,7 @@ def scaled_mm_options_device_tma( # type: ignore[no-untyped-def] scale_b: StorageBox, use_fast_accum: bool, b_prologue_cast_type: Optional[str] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: even_k_symbolic = ( sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] ) @@ -464,7 +465,7 @@ def scaled_mm_options( # type: ignore[no-untyped-def] scale_b: StorageBox, use_fast_accum: bool, b_prologue_cast_type: Optional[str] = None, -) -> Dict[str, Any]: +) -> dict[str, Any]: even_k_symbolic = ( sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"] ) @@ -533,7 +534,7 @@ def tuned_scaled_mm( input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum ) - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] if use_aten_gemm_kernels(): choices.append(aten_choice) diff --git a/torch/_inductor/kernel/unpack_mixed_mm.py b/torch/_inductor/kernel/unpack_mixed_mm.py index 674da97c1655..8d3fe67785d9 100644 --- a/torch/_inductor/kernel/unpack_mixed_mm.py +++ b/torch/_inductor/kernel/unpack_mixed_mm.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import logging -from typing import List, TYPE_CHECKING +from typing import TYPE_CHECKING from ..select_algorithm import autotune_select_algorithm, TritonTemplate from .mm_common import mm_args, mm_configs, mm_grid, mm_options @@ -75,7 +75,7 @@ uint4x2_mixed_mm_template = TritonTemplate( def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype): m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True) - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "") for config in mm_configs(m, n, k): uint4x2_mixed_mm_template.maybe_append_choice( diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 33f5ff1464e7..21f63a11b674 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -6,7 +6,7 @@ import functools import itertools import re from enum import auto, Enum -from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, TypeVar +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, TypeVar import sympy @@ -21,6 +21,10 @@ from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs from .virtualized import ops, V +if TYPE_CHECKING: + from collections.abc import Sequence + + T = TypeVar("T") @@ -83,14 +87,14 @@ class LoopBody: indexing simplifications and makes it easier to analyze loop bodies. """ - indexing_exprs: Dict[str, sympy.Expr] - indexing_exprs_name: Dict[sympy.Expr, str] - submodules: Dict[str, Any] - subblocks: Dict[str, LoopBodyBlock] - indirect_vars: List[sympy.Symbol] - indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] + indexing_exprs: dict[str, sympy.Expr] + indexing_exprs_name: dict[sympy.Expr, str] + submodules: dict[str, Any] + subblocks: dict[str, LoopBodyBlock] + indirect_vars: list[sympy.Symbol] + indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] root_block: LoopBodyBlock - memory_usage: Dict[MemoryUsageType, List[MemoryEntry]] + memory_usage: dict[MemoryUsageType, list[MemoryEntry]] op_counts: collections.Counter[str] def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars): @@ -120,7 +124,7 @@ class LoopBody: self.submodules = {"get_index": self.get_index} self.subblocks = {} self.indirect_vars = [] - self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {} + self.indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] = {} self.memory_usage = {t: [] for t in MemoryUsageType} self.op_counts = collections.Counter() self.root_block = LoopBodyBlock(self, fn, args) # traces @@ -433,7 +437,7 @@ class LoopBodyBlock: operations will manifest as an extra LoopBodyBlock. """ - def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]): + def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]): self.body = body def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs): diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index dd89cecb0fc5..3695d454a243 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -8,8 +8,8 @@ import operator import os import warnings from collections import defaultdict -from collections.abc import Iterable -from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union +from collections.abc import Iterable, Sequence +from typing import Any, Callable, Optional, TypeVar, Union from typing_extensions import ParamSpec from unittest.mock import patch @@ -90,9 +90,9 @@ FALLBACK_ALLOW_LIST = OrderedSet( ) log = logging.getLogger(__name__) -lowerings: Dict[Union[Callable[..., Any], str], Callable[..., Any]] = {} +lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {} # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints -_maybe_layout_constraints: Dict[ +_maybe_layout_constraints: dict[ torch._ops.OpOverload, Optional[Callable[..., Any]] ] = {} fallbacks = OrderedSet[torch._ops.OpOverload]() @@ -106,7 +106,7 @@ foreach_ops = OrderedSet[torch._ops.OpOverload]( # TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload # so why is it in foreach_ops? inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]() -inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} +inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {} quantized_decomposed = torch.ops.quantized_decomposed @@ -313,12 +313,12 @@ def in_namespace(op, namespace): def transform_args( - args: List[Any], - kwargs: Dict[str, Any], + args: list[Any], + kwargs: dict[str, Any], broadcast: bool, type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND], convert_input_to_bool: bool, -) -> tuple[List[Any], Dict[str, Any]]: +) -> tuple[list[Any], dict[str, Any]]: args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)] kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)] # check that there's something to transform @@ -428,8 +428,8 @@ def _register_lowering( @functools.wraps(decomp_fn) def wrapped(*args, **kwargs): - args: List[Any] = list(args) - kwargs: Dict[str, Any] = dict(kwargs) + args: list[Any] = list(args) + kwargs: dict[str, Any] = dict(kwargs) unpacked = False # TODO maybe we need to use pytrees here if len(args) == 1 and isinstance(args[0], (list, tuple)): @@ -654,7 +654,7 @@ def make_pointwise( def make_foreach_pointwise(pw_fn, allow_alpha=False): - def inner(*inputs: List[List[TensorBox]], alpha=1): + def inner(*inputs: list[list[TensorBox]], alpha=1): realize_outputs = ( len(V.graph.current_node.users) == 0 or V.graph.current_node.target in inplace_foreach_ops @@ -682,7 +682,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False): outputs = [None] * len(a_list_input) for (device, use_foreach), group in groups.items(): - operation_list: List[str] = [] + operation_list: list[str] = [] for ( output_ind, args, @@ -749,7 +749,7 @@ def _foreach_map(subgraph, *args, **kwargs): outputs = [None] * len(sub_outputs) for (device, use_foreach), group in groups.items(): - operation_list: List[str] = [] + operation_list: list[str] = [] for ( output_ind, output, @@ -949,7 +949,7 @@ def where(cond, a, b): def broadcast_tensors(*inputs): if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)): return broadcast_tensors(*inputs[0]) - target: List[sympy.Expr] = functools.reduce( + target: list[sympy.Expr] = functools.reduce( broadcast_symbolic_shapes, [x.get_size() for x in inputs], [] ) outputs = [] @@ -1231,7 +1231,7 @@ def as_strided_copy(x, size, stride, storage_offset=None): def pointwise_cat(inputs, dim=0): # (inclusive, exclusive) - inputs_ranges: List[tuple[sympy.Expr, sympy.Expr]] = [] + inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = [] prev_end = 0 for inp in inputs: inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type] @@ -2173,7 +2173,7 @@ def inductor_lookup_seed(seeds, index): @register_lowering(inductor_prims.random, type_promotion_kind=None) -def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0): +def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0): assert not config.fallback_random assert mode in ("rand", "randn") size = [*size] @@ -2202,7 +2202,7 @@ def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int @register_lowering(inductor_prims.randint, type_promotion_kind=None) def inductor_randint( - low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0 + low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0 ): assert not config.fallback_random size = [*size] @@ -2916,7 +2916,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False): else: dtype = dtype or torch.get_default_dtype() - ranges: List[sympy.Expr] = [] + ranges: list[sympy.Expr] = [] if isinstance(data, sympy.Basic): @@ -4041,7 +4041,7 @@ def constant_pad_nd(x, padding, fill_value=0): n = len(sizes) - len(bounds) # if padding is a complicated expression, hoist it - bounds_precomp: List[tuple[sympy.Symbol, Any]] = [] + bounds_precomp: list[tuple[sympy.Symbol, Any]] = [] for l, h in bounds: bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type] diff --git a/torch/_inductor/memory.py b/torch/_inductor/memory.py index f936a1abdd07..bd3011dcbeb5 100644 --- a/torch/_inductor/memory.py +++ b/torch/_inductor/memory.py @@ -4,7 +4,7 @@ import collections import dataclasses import heapq import logging -from typing import Callable, Dict, List, TYPE_CHECKING, TypedDict, Union +from typing import Callable, TYPE_CHECKING, TypedDict, Union from torch._utils_internal import signpost_event from torch.utils._ordered_set import OrderedSet @@ -61,9 +61,9 @@ class FreeableInputBuffer: def get_freeable_input_buf( - nodes: List[BaseSchedulerNode], + nodes: list[BaseSchedulerNode], graph_inputs: OrderedSet[str], -) -> Dict[str, FreeableInputBuffer]: +) -> dict[str, FreeableInputBuffer]: """ Create and keep track of all input buffers that can be freed during the program @@ -87,10 +87,10 @@ def get_freeable_input_buf( # get freeable input buffers' successor nodes and their sizes # note that different deps can have the same name, so we use name as keys - dep_name_to_succ_nodes: Dict[ + dep_name_to_succ_nodes: dict[ str, OrderedSet[BaseSchedulerNode] ] = collections.defaultdict(OrderedSet) - dep_name_to_size: Dict[str, int] = dict() + dep_name_to_size: dict[str, int] = dict() for node in nodes: for dep in node.read_writes.reads: if dep.name in graph_inputs and not dep.name.startswith( @@ -100,7 +100,7 @@ def get_freeable_input_buf( dep_name_to_size[dep.name] = _dep_size_hint(dep) # create FreeableInputBuffer objects and add them to the returned dictionary - name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = dict() + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict() for dep_name, succ_nodes in dep_name_to_succ_nodes.items(): name_to_freeable_input_buf[dep_name] = FreeableInputBuffer( dep_name, @@ -112,8 +112,8 @@ def get_freeable_input_buf( def compute_size_for_scheduler_buffer( - name_to_buf: Dict[str, SchedulerBuffer] -) -> Dict[str, tuple[int, int]]: + name_to_buf: dict[str, SchedulerBuffer] +) -> dict[str, tuple[int, int]]: """ Compute the size of each scheduler buffer, including (1) memory allocated when it is created and (2) memory deallocated when it is freed. @@ -134,7 +134,7 @@ def compute_size_for_scheduler_buffer( from .ir import MultiOutput from .scheduler import OutputNode - sched_buf_to_size: Dict[str, tuple[int, int]] = dict() + sched_buf_to_size: dict[str, tuple[int, int]] = dict() def _compute_and_update_buf_size( sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False @@ -175,8 +175,8 @@ def compute_size_for_scheduler_buffer( def assign_memory_planning_info_for_scheduler_buffers( - nodes: List[BaseSchedulerNode], - name_to_buf: Dict[str, SchedulerBuffer], + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], ) -> None: """ For each SchedulerBuffer, assign its size info and successor nodes. @@ -187,7 +187,7 @@ def assign_memory_planning_info_for_scheduler_buffers( # get buffer's successor nodes # note that different deps can have the same name, so we use name as keys - dep_name_to_succ_nodes: Dict[ + dep_name_to_succ_nodes: dict[ str, OrderedSet[BaseSchedulerNode] ] = collections.defaultdict(OrderedSet) for node in nodes: @@ -205,10 +205,10 @@ def assign_memory_planning_info_for_scheduler_buffers( def assign_memory_planning_info_for_scheduler_nodes( - nodes: List[BaseSchedulerNode], - name_to_fused_node: Dict[str, BaseSchedulerNode], - name_to_buf: Dict[str, SchedulerBuffer], - name_to_freeable_input_buf: Dict[str, FreeableInputBuffer], + nodes: list[BaseSchedulerNode], + name_to_fused_node: dict[str, BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], ) -> None: """ Assign to each scheduler node its predecessor and successor nodes. @@ -243,10 +243,10 @@ def assign_memory_planning_info_for_scheduler_nodes( def estimate_peak_memory( - nodes: List[BaseSchedulerNode], - name_to_freeable_input_buf: Dict[str, FreeableInputBuffer], + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], graph_outputs: OrderedSet[str], -) -> tuple[int, List[int]]: +) -> tuple[int, list[int]]: """ Given a list of nodes in their execution order, estimate the peak memory, by keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers. @@ -267,12 +267,12 @@ def estimate_peak_memory( # get the execution step of each node, this will be used to determine # the end_step of buffers - node_to_step: Dict[BaseSchedulerNode, int] = dict() + node_to_step: dict[BaseSchedulerNode, int] = dict() for step, node in enumerate(nodes): node_to_step[node] = step # get buffers' size and liveliness information - buf_info_list: List[BufferInfo] = [] + buf_info_list: list[BufferInfo] = [] # 1. for freeable input buffers for buf_name, input_buf in name_to_freeable_input_buf.items(): end_step = ( @@ -340,11 +340,11 @@ def estimate_peak_memory( def topological_sort_lpmf( - nodes: List[BaseSchedulerNode], - name_to_freeable_input_buf: Dict[str, FreeableInputBuffer], - name_to_buf: Dict[str, SchedulerBuffer], + nodes: list[BaseSchedulerNode], + name_to_freeable_input_buf: dict[str, FreeableInputBuffer], + name_to_buf: dict[str, SchedulerBuffer], graph_outputs: OrderedSet[str], -) -> List[BaseSchedulerNode]: +) -> list[BaseSchedulerNode]: """ A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First". @@ -372,8 +372,8 @@ def topological_sort_lpmf( class BufferInfo(TypedDict): outdegree: int - node_info: Dict[BaseSchedulerNode, NodeInfo] = dict() - buf_info: Dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict() + node_info: dict[BaseSchedulerNode, NodeInfo] = dict() + buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict() # compute nodes' number of unmet dependencies (for schedulability) # initialize the list of nodes ready to be scheduled @@ -422,7 +422,7 @@ def topological_sort_lpmf( node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free # schedule nodes one at a time - schedule: List[BaseSchedulerNode] = [] + schedule: list[BaseSchedulerNode] = [] num_iters: int = 0 while num_iters < len(nodes) and nodes_to_schedule: # select a node to schedule: @@ -464,7 +464,7 @@ def topological_sort_lpmf( return schedule -def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: +def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: """ A BFS topological sort that selects nodes whose dependencies are executed the earliest. This follows a FIFO idea. Specifically, at every iteration, for each node @@ -478,11 +478,11 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo indegree: int order: int - node_info: Dict[BaseSchedulerNode, NodeInfo] = dict() + node_info: dict[BaseSchedulerNode, NodeInfo] = dict() @dataclasses.dataclass class NodeWithPriority: - priority: List[int] + priority: list[int] node: BaseSchedulerNode def __lt__(self, other: NodeWithPriority) -> bool: @@ -490,7 +490,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo return self.node.mpi_node.index < other.node.mpi_node.index return self.priority < other.priority - def _node_priority(node: BaseSchedulerNode) -> List[int]: + def _node_priority(node: BaseSchedulerNode) -> list[int]: # priority is the order in which predecessor nodes are executed assert node_info[node]["indegree"] == 0 exec_orders = sorted( @@ -502,7 +502,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo # compute nodes' number of unmet dependencies (for schedulability) # initialize the list of nodes ready to be scheduled - nodes_to_schedule: List[NodeWithPriority] = [] + nodes_to_schedule: list[NodeWithPriority] = [] for node in nodes: node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1} if node_info[node]["indegree"] == 0: @@ -511,7 +511,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo ) # schedule nodes one at a time - schedule: List[BaseSchedulerNode] = [] + schedule: list[BaseSchedulerNode] = [] num_iters: int = 0 while num_iters < len(nodes) and nodes_to_schedule: # select a node to schedule @@ -536,7 +536,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo return schedule -def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: +def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: """ This is a DFS topological sort. The setup is similar to `topological_sort_schedule` in scheduler.py. The difference is the order nodes are visited in the outer loop. @@ -546,9 +546,9 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo the nodes in ascending order of this priority. """ seen: OrderedSet[BaseSchedulerNode] = OrderedSet() - name_to_node: Dict[str, BaseSchedulerNode] = dict() - result: List[BaseSchedulerNode] = [] - size_with_reads: Dict[BaseSchedulerNode, int] = dict() + name_to_node: dict[str, BaseSchedulerNode] = dict() + result: list[BaseSchedulerNode] = [] + size_with_reads: dict[BaseSchedulerNode, int] = dict() def visit(n: BaseSchedulerNode) -> None: if n not in seen: @@ -579,17 +579,17 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo def reorder_for_peak_memory( - nodes: List[BaseSchedulerNode], - name_to_buf: Dict[str, SchedulerBuffer], - name_to_fused_node: Dict[str, BaseSchedulerNode], + nodes: list[BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], graph_inputs: OrderedSet[str], graph_outputs: OrderedSet[str], - methods: List[Callable[..., List[BaseSchedulerNode]]] = [ # noqa: B006 + methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006 topological_sort_lpmf, topological_sort_bfs, topological_sort_dfs, ], -) -> List[BaseSchedulerNode]: +) -> list[BaseSchedulerNode]: """ Try a few heuristics based topological sort algorithms, and pick the one whose resulting topological order has the lowest peak memory estimation. @@ -599,13 +599,13 @@ def reorder_for_peak_memory( @dataclasses.dataclass class PeakMemoryResult: - order: List[BaseSchedulerNode] + order: list[BaseSchedulerNode] peak_memory: int method: str # preparation -- as nodes are scheduled one at a time, these help # keep track of when a buffer can be freed, and when a node can be scheduled - name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf( + name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf( nodes, graph_inputs ) assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf) @@ -614,7 +614,7 @@ def reorder_for_peak_memory( ) # keep track of the peak memory estimates of different methods - peak_memory_diff_methods: List[PeakMemoryResult] = [] + peak_memory_diff_methods: list[PeakMemoryResult] = [] # the default estimated_peak_memory, _ = estimate_peak_memory( diff --git a/torch/_inductor/metrics.py b/torch/_inductor/metrics.py index 85dd418a84ad..88f53d277e85 100644 --- a/torch/_inductor/metrics.py +++ b/torch/_inductor/metrics.py @@ -8,7 +8,7 @@ import os import re from dataclasses import dataclass from functools import lru_cache -from typing import Dict, List, TYPE_CHECKING +from typing import TYPE_CHECKING from torch._inductor import config from torch._inductor.utils import get_benchmark_name @@ -23,13 +23,13 @@ if TYPE_CHECKING: generated_kernel_count = 0 generated_cpp_vec_kernel_count = 0 num_bytes_accessed = 0 -nodes_num_elem: List[ +nodes_num_elem: list[ tuple[ BaseSchedulerNode, int, ] ] = [] -node_runtimes: List[tuple[BaseSchedulerNode, float]] = [] +node_runtimes: list[tuple[BaseSchedulerNode, float]] = [] # counters for tracking fusions ir_nodes_pre_fusion = 0 @@ -45,7 +45,7 @@ class CppOuterLoopFusedCount: # The length counts the number of outer loop fusions. -cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = [] +cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = [] num_comprehensive_padding = 0 num_matches_for_scatter_upon_const_tensor = 0 @@ -122,13 +122,13 @@ class CachedMetricsHelper: globals()[metric] += getattr(delta, metric) -REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {} +REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {} @dataclass class MetricTable: table_name: str - column_names: List[str] + column_names: list[str] num_rows_added: int = 0 diff --git a/torch/_inductor/mkldnn_ir.py b/torch/_inductor/mkldnn_ir.py index 237a25805b29..c62b0085b5de 100644 --- a/torch/_inductor/mkldnn_ir.py +++ b/torch/_inductor/mkldnn_ir.py @@ -1,5 +1,5 @@ # mypy: allow-untyped-defs -from typing import Any, List, Optional +from typing import Any, Optional import sympy @@ -31,13 +31,13 @@ def _prepare_convolution_fusion_create( x: "TensorBox", weight: "TensorBox", bias: "TensorBox", - padding: List[int], - stride: List[int], - dilation: List[int], + padding: list[int], + stride: list[int], + dilation: list[int], groups: int, transposed: bool = False, - output_padding: Optional[List[int]] = None, - quantize_args: Optional[List["TensorBox"]] = None, + output_padding: Optional[list[int]] = None, + quantize_args: Optional[list["TensorBox"]] = None, other: Optional["TensorBox"] = None, ): """ @@ -204,7 +204,7 @@ def _prepare_linear_fusion_create( x: "TensorBox", weight: "TensorBox", bias: "TensorBox", - quantize_args: Optional[List["TensorBox"]] = None, + quantize_args: Optional[list["TensorBox"]] = None, other: Optional["TensorBox"] = None, binary_sum: bool = False, ): @@ -252,7 +252,7 @@ def _prepare_linear_fusion_create( output_size, output_stride, ) - constant_args: List[Any] = [] + constant_args: list[Any] = [] if bias is not None: inputs.append(bias) @@ -298,12 +298,12 @@ class ConvolutionUnary(ExternKernelAlloc): x: "TensorBox", weight: "TensorBox", bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], + padding_: list[int], + stride_: list[int], + dilation_: list[int], groups: int, attr, - scalars: Optional[List[Any]], + scalars: Optional[list[Any]], algorithm, ): ( @@ -357,14 +357,14 @@ class ConvolutionBinary(ExternKernelAlloc): other: "TensorBox", weight: "TensorBox", bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], + padding_: list[int], + stride_: list[int], + dilation_: list[int], groups: int, binary_attr: str, binary_alpha: Optional[float], unary_attr: Optional[str], - unary_scalars: Optional[List[Any]], + unary_scalars: Optional[list[Any]], unary_algorithm: Optional[str], ): ( @@ -431,14 +431,14 @@ class ConvolutionBinaryInplace(ExternKernelAlloc): other: "TensorBox", weight: "TensorBox", bias: "TensorBox", - padding_: List[int], - stride_: List[int], - dilation_: List[int], + padding_: list[int], + stride_: list[int], + dilation_: list[int], groups: int, binary_attr: str, binary_alpha: Optional[float], unary_attr: Optional[str], - unary_scalars: Optional[List[Any]], + unary_scalars: Optional[list[Any]], unary_algorithm: Optional[str], ): ( @@ -496,13 +496,13 @@ class ConvolutionTransposeUnary(ExternKernelAlloc): x: "TensorBox", weight: "TensorBox", bias: "TensorBox", - padding_: List[int], - output_padding_: List[int], - stride_: List[int], - dilation_: List[int], + padding_: list[int], + output_padding_: list[int], + stride_: list[int], + dilation_: list[int], groups_: int, attr, - scalars: Optional[List[Any]], + scalars: Optional[list[Any]], algorithm, ): transposed = True @@ -580,9 +580,9 @@ class QConvPointWisePT2E(ExternKernelAlloc): w_scale: "TensorBox", w_zero_point: "TensorBox", bias: "TensorBox", - stride: List[int], - padding: List[int], - dilation: List[int], + stride: list[int], + padding: list[int], + dilation: list[int], groups: int, output_scale: float, output_zero_point: int, @@ -692,9 +692,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc): w_zero_point, qaccum: "TensorBox", bias: "TensorBox", - stride: List[int], - padding: List[int], - dilation: List[int], + stride: list[int], + padding: list[int], + dilation: list[int], groups: int, output_scale: "TensorBox", output_zero_point: "TensorBox", @@ -1139,7 +1139,7 @@ class MkldnnRnnLayer(ExternKernelAlloc): hx: "TensorBox", cx: "TensorBox", reverse: bool, - batch_sizes: List[int], + batch_sizes: list[int], mode: int, hidden_size: int, num_layers: int, diff --git a/torch/_inductor/mkldnn_lowerings.py b/torch/_inductor/mkldnn_lowerings.py index de06a43bc52a..f644b787f8da 100644 --- a/torch/_inductor/mkldnn_lowerings.py +++ b/torch/_inductor/mkldnn_lowerings.py @@ -1,6 +1,6 @@ # mypy: allow-untyped-defs import functools -from typing import List, Optional +from typing import Optional import torch import torch.utils._pytree as pytree @@ -31,8 +31,8 @@ from .virtualized import ops, V def grouped_gemm_lowering( x: TensorBox, - w: List[TensorBox], - b: List[TensorBox], + w: list[TensorBox], + b: list[TensorBox], attr=None, scalars=None, algorithm=None, @@ -47,7 +47,7 @@ def grouped_gemm_lowering( assert use_max_autotune() b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b] - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] *_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout) kwargs = dict( @@ -245,7 +245,7 @@ def register_onednn_fusion_ops(): x = view(x, [-1, x_size[-1]]) if b is not None: b = ir.ExternKernel.realize_input(b) - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] if use_max_autotune(): transposed_w = permute(w, [1, 0]) *_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout) @@ -308,7 +308,7 @@ def register_onednn_fusion_ops(): y = view(y, [-1, y_size[-1]]) if b is not None: b = ir.ExternKernel.realize_input(b) - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] if use_max_autotune(): transposed_w = permute(w, [1, 0]) *_, layout, x, transposed_w, y = mm_args( @@ -397,7 +397,7 @@ def register_onednn_fusion_ops(): hx: TensorBox, cx: TensorBox, reverse: bool, - batch_sizes: List[int], + batch_sizes: list[int], mode: int, hidden_size: int, num_layers: int, @@ -611,7 +611,7 @@ def register_onednn_fusion_ops(): bias_dtype = None if bias is None else bias.get_dtype() - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] if use_max_autotune(): *_, layout, x, packed_weight = mm_args( x, packed_weight, layout=layout, out_dtype=output_dtype @@ -888,7 +888,7 @@ def register_onednn_fusion_ops(): ), "dtype of accum for qlinear post op sum should be the same as output" x2_dtype = x2.get_dtype() bias_dtype = bias.get_dtype() if bias is not None else None - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] if ( use_max_autotune() and binary_attr == "add" ): # Support inplace sum fusion @@ -1131,7 +1131,7 @@ def register_onednn_fusion_ops(): *, layout=None, ): - choices: List[ChoiceCaller] = [] + choices: list[ChoiceCaller] = [] if use_max_autotune(): transposed_w = permute(orig_w, [1, 0]) *_, layout, x, transposed_w = mm_args( diff --git a/torch/_inductor/mock_cache.py b/torch/_inductor/mock_cache.py index b333e347e756..a610ce219ea5 100644 --- a/torch/_inductor/mock_cache.py +++ b/torch/_inductor/mock_cache.py @@ -6,7 +6,7 @@ import contextlib import dataclasses import sys import threading -from typing import Any, Callable, Dict, Optional, Type, TYPE_CHECKING +from typing import Any, Callable, Optional, TYPE_CHECKING from typing_extensions import override, Self from unittest.mock import patch @@ -56,7 +56,7 @@ class Stats: class _GlobalItemStats(Stats): - cache: Dict[str, object] + cache: dict[str, object] def __init__(self) -> None: super().__init__() @@ -266,7 +266,7 @@ class PatchCaches(contextlib.AbstractContextManager): def __exit__( self, - exc_type: Optional[Type[BaseException]], + exc_type: Optional[type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], ) -> None: diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 8466a796553e..5bf7de418c33 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -1,17 +1,6 @@ # mypy: allow-untyped-defs import itertools -from typing import ( - Any, - Callable, - Dict, - Generic, - List, - Literal, - NamedTuple, - Optional, - TypeVar, - Union, -) +from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from typing_extensions import Protocol from unittest.mock import patch @@ -961,7 +950,7 @@ def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: class OpCountResult(NamedTuple): num_ops: int used_ops: OrderedSet[str] - read_buffers: List[str] + read_buffers: list[str] nontrivial_read_count: int @@ -974,7 +963,7 @@ class OpCounterCSE: self.op_count = 0 self.var_names = {} self._used_ops = OrderedSet[str]() - self._read_names: List[str] = [] + self._read_names: list[str] = [] self._nontrivial_read_count = 0 def __getattr__(self, name): @@ -1076,7 +1065,7 @@ class SimpleCSEHandler(WrapperHandler[T]): def __init__(self, inner: OpsHandler[T]): super().__init__(inner) - self.cse_cache: Dict[str, Union[T, tuple[T, ...]]] = {} + self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {} self.mock = MockHandler() def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: diff --git a/torch/_inductor/optimize_indexing.py b/torch/_inductor/optimize_indexing.py index cd7ac7207dd4..67c2a74e886a 100644 --- a/torch/_inductor/optimize_indexing.py +++ b/torch/_inductor/optimize_indexing.py @@ -1,5 +1,5 @@ import math -from typing import Any, Dict, List +from typing import Any import sympy @@ -40,10 +40,10 @@ def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool: def try_to_reduce_precision( node: Any, - bounds: Dict[Any, Any], - indirect_vars: List[Any], - indices: Dict[Any, sympy.Expr], - replacement_vals: Dict[Any, ValueRanges[sympy.Expr]], + bounds: dict[Any, Any], + indirect_vars: list[Any], + indices: dict[Any, sympy.Expr], + replacement_vals: dict[Any, ValueRanges[sympy.Expr]], ) -> None: # if a downstream use of a node explicitly converts to int32, or float16/float32/float64, # then it's precision is set for that chain of uses, and we don't need to consider those diff --git a/torch/_inductor/output_code.py b/torch/_inductor/output_code.py index 06641b7f9840..844be9d8e544 100644 --- a/torch/_inductor/output_code.py +++ b/torch/_inductor/output_code.py @@ -27,17 +27,7 @@ import logging import os import re from pathlib import Path -from typing import ( - Any, - Callable, - Counter, - Dict, - List, - Optional, - Sequence, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Optional, TYPE_CHECKING, Union from typing_extensions import TypeAlias import torch @@ -62,6 +52,9 @@ from .runtime.autotune_cache import AutotuneCacheBundler if TYPE_CHECKING: + from collections import Counter + from collections.abc import Sequence + from torch._inductor import metrics from torch._inductor.graph import GraphLowering @@ -108,13 +101,13 @@ def has_frozen_params(gm: torch.fx.GraphModule) -> bool: # for expanded dimensions (a dimension which used to have size 1 -> ?) # we can select one element from that dimension and write to it # to achieve writing to all values of that dimension of the input tensor -def get_expanded_dims(t: torch.Tensor) -> List[int]: +def get_expanded_dims(t: torch.Tensor) -> list[int]: if not isinstance(t, torch.Tensor): return None return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1] -def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor: +def index_expanded_dims(t: torch.Tensor, expanded_dims: list[int]) -> torch.Tensor: for expanded_dim in expanded_dims: t = torch.ops.aten.slice(t, expanded_dim, 0, 1) return t @@ -146,7 +139,7 @@ def cudagraph_post_compile( example_inputs: Sequence[InputType], compiled_graph: CompiledFxGraph, cudagraphs: BoxedBool, - constants: Dict[str, torch.Tensor], + constants: dict[str, torch.Tensor], ) -> None: """ Checks for any reasons not to run cudagraphs and then @@ -213,7 +206,7 @@ def cudagraph_post_compile( # should already exist from forward assert manager is not None - def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]: + def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]: manager.set_to_running_backward() # type: ignore[union-attr] return compiled_graph_callable(new_inputs) @@ -270,7 +263,7 @@ class CompiledFxGraphConstants: the value of constants directly off of the original saved object. """ - def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]: + def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]: assert g.constants is not None return g.constants @@ -287,7 +280,7 @@ class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants): def __init__(self, gm: torch.fx.GraphModule) -> None: self.gm = gm - def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]: + def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]: if g.allocated_constant_name is not None: return { name: getattr(self.gm, name) @@ -308,7 +301,7 @@ class CompiledFxGraph(OutputCode): current_callable: Optional[Callable[..., Any]] cache_key: str source_code: str = dataclasses.field(repr=False) # Do not display source_code - cache_linemap: Optional[List[tuple[int, str]]] + cache_linemap: Optional[list[tuple[int, str]]] device_types: OrderedSet[str] device_idxs: OrderedSet[int] mutated_inputs: OrderedSet[str] @@ -320,10 +313,10 @@ class CompiledFxGraph(OutputCode): # original name of the attribute in the GraphModule. When we create the module from # the cache entry, we then look up the constants from the current GraphModule. This # scheme allows us to support caching with freezing. - allocated_constant_name: Optional[Dict[str, str]] - constants: Optional[Dict[str, torch.Tensor]] - torchbind_constants: Dict[str, torch._C.ScriptObject] - output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]] + allocated_constant_name: Optional[dict[str, str]] + constants: Optional[dict[str, torch.Tensor]] + torchbind_constants: dict[str, torch._C.ScriptObject] + output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]] disabled_cudagraphs_reason: Optional[str] metrics_deltas: metrics.CachedMetricsDeltas counter_deltas: Counter[str] @@ -340,14 +333,14 @@ class CompiledFxGraph(OutputCode): boxed_forward_device_index: Optional[BoxedDeviceIndex] _boxed_call: Optional[bool] = None - _triton_bundle: Optional[List[TritonKernelArtifacts]] = None + _triton_bundle: Optional[list[TritonKernelArtifacts]] = None def __init__( self, current_callable: Optional[Callable[..., Any]], graph: GraphLowering, gm: torch.fx.GraphModule, - output_strides: List[Optional[tuple[_StrideExprStr, ...]]], + output_strides: list[Optional[tuple[_StrideExprStr, ...]]], disabled_cudagraphs_reason: Optional[str], metrics_deltas: metrics.CachedMetricsDeltas, counter_deltas: Counter[str], @@ -583,7 +576,7 @@ class CompiledAOTI(OutputCode): Class holding an AOTInductor compiled so. """ - filename: Union[str, List[str]] + filename: Union[str, list[str]] def __call__(self, inputs: Sequence[Any]) -> Any: raise NotImplementedError("NYI") diff --git a/torch/_inductor/package/package.py b/torch/_inductor/package/package.py index c09042674d0e..041909464416 100644 --- a/torch/_inductor/package/package.py +++ b/torch/_inductor/package/package.py @@ -7,7 +7,7 @@ import subprocess import tempfile import zipfile from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Optional, Union import torch import torch._inductor @@ -85,7 +85,7 @@ class PT2ArchiveReader: assert self.archive_file is not None self.archive_file.extractall(path) - def get_file_names(self) -> List[str]: + def get_file_names(self) -> list[str]: assert self.archive_file is not None return self.archive_file.namelist() @@ -98,7 +98,7 @@ def _run_command_and_check(cmd: str) -> None: raise exc.CppCompileError(cmd, e.output) from e -def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str: +def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str: def get_aoti_file_with_suffix(suffix: str) -> str: for file in aoti_files: if file.endswith(suffix): @@ -159,7 +159,7 @@ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str: def package_aoti( archive_file: Union[str, io.BytesIO], - aoti_files: Union[List[str], Dict[str, List[str]]], + aoti_files: Union[list[str], dict[str, list[str]]], ) -> Union[str, io.BytesIO]: """ Saves the AOTInductor generated files to the PT2Archive format. @@ -244,12 +244,12 @@ class AOTICompiledModel: flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined] return pytree.tree_unflatten(flat_outputs, out_spec) - def get_metadata(self) -> Dict[str, str]: + def get_metadata(self) -> dict[str, str]: return self.loader.get_metadata() # type: ignore[attr-defined] def load_constants( self, - constants_map: Dict[str, torch.Tensor], + constants_map: dict[str, torch.Tensor], *, check_full_update: bool, ) -> None: @@ -265,7 +265,7 @@ class AOTICompiledModel: """ self.loader.load_constants(constants_map, False, check_full_update) # type: ignore[attr-defined] - def get_constant_fqns(self) -> List[str]: + def get_constant_fqns(self) -> list[str]: return self.loader.get_constant_fqns() # type: ignore[attr-defined] diff --git a/torch/_inductor/pattern_matcher.py b/torch/_inductor/pattern_matcher.py index 926a33f613f9..474d98620f0c 100644 --- a/torch/_inductor/pattern_matcher.py +++ b/torch/_inductor/pattern_matcher.py @@ -50,24 +50,9 @@ import textwrap import typing from abc import ABC, abstractmethod from collections import defaultdict +from collections.abc import Generator, Iterable, Mapping, Sequence from pathlib import Path -from typing import ( - Any, - Callable, - DefaultDict, - Dict, - Generator, - Iterable, - List, - Mapping, - NoReturn, - Optional, - Protocol, - Sequence, - Type, - TypeVar, - Union, -) +from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union from typing_extensions import Self, TypeIs import torch @@ -139,7 +124,7 @@ MULTIPLE = Multiple() def _transfer_meta( - new_meta: Dict[str, Any], old_node: torch.fx.Node, pass_name: str = "" + new_meta: dict[str, Any], old_node: torch.fx.Node, pass_name: str = "" ) -> None: from torch.fx.traceback import NodeSource, NodeSourceAction @@ -177,10 +162,10 @@ class Match: """ pattern: PatternExpr - args: List[Any] - kwargs: Dict[str, Any] - nodes: List[torch.fx.Node] - targets: Dict[_TargetExpr, torch.fx.node.Target] + args: list[Any] + kwargs: dict[str, Any] + nodes: list[torch.fx.Node] + targets: dict[_TargetExpr, torch.fx.node.Target] ctx: MatchContext replacement_graph: Optional[torch.fx.GraphModule] @@ -189,7 +174,7 @@ class Match: ctx: MatchContext, pattern: PatternExpr, args: Optional[Sequence[Any]] = None, - kwargs: Optional[Dict[str, Any]] = None, + kwargs: Optional[dict[str, Any]] = None, ) -> None: super().__init__() self.pattern = pattern @@ -231,7 +216,7 @@ class Match: if not n._erased and not n.users: graph.erase_node(n) - def output_nodes(self) -> List[Optional[torch.fx.Node]]: + def output_nodes(self) -> list[Optional[torch.fx.Node]]: return [ (self.ctx.pattern_to_node[p] if p is not None else None) for p in self.ctx.outputs @@ -338,15 +323,15 @@ class MatchContext: Internal state needed while running PatternExpr._match(). """ - outputs: List[Optional[PatternExpr]] - pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]] + outputs: list[Optional[PatternExpr]] + pattern_to_node: dict[PatternExpr, Optional[torch.fx.Node]] graph: torch.fx.Graph - exclusive_node_set: List[NodeOrConstant] + exclusive_node_set: list[NodeOrConstant] def __init__( self, - outputs: List[Optional[PatternExpr]], - pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None, + outputs: list[Optional[PatternExpr]], + pattern_to_node: Optional[dict[PatternExpr, torch.fx.Node]] = None, *, graph: torch.fx.Graph, ) -> None: @@ -367,7 +352,7 @@ class MatchContext: self.pattern_to_node[pattern] = node if m else None return m - def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]: + def filter_multi_user_patterns(self) -> dict[PatternExpr, torch.fx.Node]: return { pattern: node for pattern, node in self.pattern_to_node.items() @@ -487,7 +472,7 @@ class _TargetExpr(PatternExpr): Base class for filtering match by node.target """ - fns: List[FnsType] + fns: list[FnsType] fns_set: OrderedSet[FnsType] def __init__( @@ -806,7 +791,7 @@ class ListOf(PatternExpr): def __repr__(self) -> str: return f"{self.__class__.__name__}({self.pattern})" - def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] + def _match(self, node: list[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override] if not isinstance(node, (list, tuple)) or len(node) == 0: return FailedMatch("non_list") m = Match(ctx, self) @@ -840,7 +825,7 @@ class ListOf(PatternExpr): class MultiOutputPattern(PatternExpr): - outputs: List[Optional[PatternExpr]] + outputs: list[Optional[PatternExpr]] def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None: super().__init__() @@ -959,8 +944,8 @@ class PatternPrettyPrinter: def __init__(self) -> None: self.namespace = torch.fx.graph._Namespace() - self.memoized_objs_names: Dict[PatternExpr, str] = {} - self.memoized_objs_pp: Dict[PatternExpr, str] = {} + self.memoized_objs_names: dict[PatternExpr, str] = {} + self.memoized_objs_pp: dict[PatternExpr, str] = {} @staticmethod @functools.lru_cache(None) @@ -1006,7 +991,7 @@ class PatternPrettyPrinter: class _PassDictsType(Protocol): - def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]: ... @@ -1069,7 +1054,7 @@ class GraphPatternEntry(PatternEntry): @dataclasses.dataclass class ReplacementPatternEntry(PatternEntry): - normalize_args: Callable[..., List[Any]] + normalize_args: Callable[..., list[Any]] @staticmethod def replace_with_graph( @@ -1253,7 +1238,7 @@ def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None: def check_and_add_duplicate_pattern( pattern: PatternExpr, graph: Optional[torch.fx.Graph], - seen_patterns: Dict[str, List[Optional[str]]], + seen_patterns: dict[str, list[Optional[str]]], skip_duplicates: bool = False, ) -> bool: """ @@ -1299,7 +1284,7 @@ def register_replacement( trace_fn: TraceFn, pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], extra_check: Callable[[Match], bool] = _return_true, - scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, exclusive_arg_names: Sequence[str] = (), search_fn_pattern: Union[PatternExpr, None] = None, skip_duplicates: bool = False, @@ -1339,7 +1324,7 @@ def register_replacement( [match.kwargs[name] for name in argnames], lambda n: n.meta["val"] ) ) - sym_args: List[torch.SymInt] = [] + sym_args: list[torch.SymInt] = [] with torch._dynamo.utils.detect_fake_mode(args): for i, grad in enumerate(requires_grad): if isinstance(args[i], torch.Tensor): @@ -1432,7 +1417,7 @@ def register_replacement( return True return False - def normalize_args(**kwargs: Any) -> List[Any]: + def normalize_args(**kwargs: Any) -> list[Any]: args = [kwargs.pop(name) for name in argnames_static] for i in range(1, len(kwargs) + 1): if f"tangents_{i}" not in kwargs: @@ -1449,7 +1434,7 @@ def register_replacement( # TODO: Revisit the functionalize_rng_ops for lowmem dropout with functorch_config.patch(functionalize_rng_ops=False): - requires_grad: List[bool] = [ + requires_grad: list[bool] = [ isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs ] if search_fn_pattern is None: @@ -1493,7 +1478,7 @@ def _serialize_pattern( search_fn: SearchFn, example_inputs: Sequence[Any], trace_fn: TraceFn, - scalar_workaround: Union[Dict[str, Union[float, int]], None], + scalar_workaround: Union[dict[str, Union[float, int]], None], ) -> PatternExpr: def get_file_template() -> str: auto_generated_msg = textwrap.dedent( @@ -1566,7 +1551,7 @@ SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patt # This is the set of serialized patterns that we've registered. Used by # test_serialized_patterns_up_to_date() to ensure the patterns are up # to date. -_known_precompiled_patterns: List[ +_known_precompiled_patterns: list[ tuple[ Any, Iterable[Any], @@ -1585,7 +1570,7 @@ def gen_register_replacement( trace_fn: TraceFn, pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]], extra_check: Callable[[Match], bool] = _return_true, - scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, exclusive_arg_names: Sequence[str] = (), skip_duplicates: bool = False, ) -> None: @@ -1638,7 +1623,7 @@ def gen_pattern_and_search_gm( search_fn: SearchFn, example_inputs: Sequence[Any], trace_fn: TraceFn, - scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, exclusive_arg_names: Sequence[str] = (), ) -> tuple[PatternExpr, torch.fx.GraphModule]: argnames = [*inspect.signature(search_fn).parameters.keys()] @@ -1672,7 +1657,7 @@ def gen_pattern( search_fn: SearchFn, example_inputs: Sequence[Any], trace_fn: TraceFn, - scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, exclusive_arg_names: Sequence[str] = (), ) -> PatternExpr: return gen_pattern_and_search_gm( @@ -1803,8 +1788,8 @@ class PatternMatcherPass: pass_name: Optional[str] = None, ) -> None: super().__init__() - self.patterns: DefaultDict[ - tuple[str, torch.fx.node.Target], List[PatternEntry] + self.patterns: defaultdict[ + tuple[str, torch.fx.node.Target], list[PatternEntry] ] = defaultdict(list) self.pass_name = pass_name @@ -1812,9 +1797,9 @@ class PatternMatcherPass: # of the graph used to generate them. Because we ignore certain patterns # in searching, but not in matching, use the graph to distinguish if two equivalent # searches are actually different. - self.seen_patterns: Dict[str, List[Optional[str]]] = defaultdict(list) + self.seen_patterns: dict[str, list[Optional[str]]] = defaultdict(list) - def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> List[PatternEntry]: + def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]: return self.patterns[item] def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int: @@ -1888,9 +1873,9 @@ def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn: def fx_to_pattern( gm: Union[torch.fx.GraphModule, torch.fx.Graph], - ignore_types: Sequence[Type[Any]] = (), + ignore_types: Sequence[type[Any]] = (), argnames: Sequence[str] = (), - scalar_workaround: Union[Dict[str, Union[float, int]], None] = None, + scalar_workaround: Union[dict[str, Union[float, int]], None] = None, exclusive_arg_names: Sequence[str] = (), ) -> PatternExpr: """ @@ -1904,7 +1889,7 @@ def fx_to_pattern( assert len(inv_scalar_workaround) == len(scalar_workaround) def process_arg( - x: T, ignore_types_override: Optional[Sequence[Type[Any]]] = None + x: T, ignore_types_override: Optional[Sequence[type[Any]]] = None ) -> Union[T, KeywordArg, Ignored]: current_ignore_types = ( ignore_types_override if ignore_types_override is not None else ignore_types @@ -1950,7 +1935,7 @@ def fx_to_pattern( def process_arg_fn_impl( x: T, - ignore_types_override: Optional[Sequence[Type[Any]]] = tuple( + ignore_types_override: Optional[Sequence[type[Any]]] = tuple( t for t in ignore_types if t is not int ), ) -> Union[T, KeywordArg, Ignored]: @@ -2054,8 +2039,8 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph return gm -def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]: - args: List[torch.fx.node.Argument] = [] +def _args(n: torch.fx.Node) -> list[torch.fx.node.Argument]: + args: list[torch.fx.node.Argument] = [] torch.fx.map_arg((n.args, n.kwargs), args.append) return args @@ -2152,7 +2137,7 @@ def get_arg_value( ) -def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]: +def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]: fns = [fn] if isinstance(fn, torch._ops.OpOverloadPacket): fns.extend([getattr(fn, overload) for overload in fn.overloads()]) diff --git a/torch/_inductor/remote_cache.py b/torch/_inductor/remote_cache.py index fe95a2e90dbe..85100c114136 100644 --- a/torch/_inductor/remote_cache.py +++ b/torch/_inductor/remote_cache.py @@ -10,7 +10,7 @@ import os import sys import typing from abc import abstractmethod -from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Generic, Optional, TypeVar, Union from typing_extensions import override, TypeAlias from torch._dynamo.utils import dynamo_timed @@ -34,7 +34,7 @@ if config.is_fbcode(): Sample: TypeAlias = Sample_ else: - Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef] + Sample: TypeAlias = type[object] # type: ignore[misc,no-redef] _T = TypeVar("_T") @@ -106,7 +106,7 @@ class RemoteCacheSerde(Generic[_T, _U]): JsonDataTy = Optional[ - Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]] + Union[int, float, str, bool, dict[str, "JsonDataTy"], list["JsonDataTy"]] ] @@ -371,7 +371,7 @@ class _CacheStat: class _CacheStats: - _stats: Dict[str, _CacheStat] + _stats: dict[str, _CacheStat] def __init__(self) -> None: self._stats = collections.defaultdict(_CacheStat) diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index 18b95d362808..dabe299fa2cf 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -6,7 +6,7 @@ import logging import os import os.path import re -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Optional, TYPE_CHECKING from typing_extensions import override import torch @@ -29,7 +29,7 @@ if TYPE_CHECKING: log = logging.getLogger(__name__) -_InductorMetaTy = Dict[str, object] +_InductorMetaTy = dict[str, object] def inductor_meta_from_config() -> _InductorMetaTy: @@ -88,7 +88,7 @@ class AutotuneCache: return hashlib.sha256(key.encode("utf-8")).hexdigest() # Read the best config options from the most local cache and return it. - def _read(self) -> Optional[Dict[str, JsonDataTy]]: + def _read(self) -> Optional[dict[str, JsonDataTy]]: if local_cache := self.local_cache: cache, key = local_cache if best_config := cache.get(key): @@ -106,7 +106,7 @@ class AutotuneCache: # Read the best config options from the most local cache and figure out # which `configs` represents that option. def read_best( - self, inductor_meta: _InductorMetaTy, configs: List[Config] + self, inductor_meta: _InductorMetaTy, configs: list[Config] ) -> Optional[Config]: if best := self._read(): return _load_cached_autotuning( @@ -196,7 +196,7 @@ class _AutotuneCacheBundlerImpl: _cache: RemoteCache[JsonDataTy] # All known entries from LocalAutotuneCache.put() - _entries: Dict[str, JsonDataTy] + _entries: dict[str, JsonDataTy] def end_compile(self) -> None: # TODO: Do we need to compute time_taken_ms and encode that somehow? @@ -407,9 +407,9 @@ def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool: def _load_cached_autotuning( - best_config: Dict[str, JsonDataTy], + best_config: dict[str, JsonDataTy], configs_hash: str, - configs: List[Config], + configs: list[Config], inductor_meta: _InductorMetaTy, ) -> Optional[Config]: if best_config is None: diff --git a/torch/_inductor/runtime/benchmarking.py b/torch/_inductor/runtime/benchmarking.py index 40060c96916a..fb50a385e54a 100644 --- a/torch/_inductor/runtime/benchmarking.py +++ b/torch/_inductor/runtime/benchmarking.py @@ -3,7 +3,7 @@ import time from functools import cached_property, wraps from itertools import chain from statistics import median -from typing import Any, Callable, Dict, List, Tuple +from typing import Any, Callable from typing_extensions import Concatenate, ParamSpec, Self, TypeVar import torch @@ -49,8 +49,8 @@ class Benchmarker: def benchmark( self: Self, fn: Callable[..., Any], - fn_args: Tuple[Any, ...], - fn_kwargs: Dict[str, Any], + fn_args: tuple[Any, ...], + fn_kwargs: dict[str, Any], **kwargs: Any, ) -> float: """Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the @@ -114,7 +114,7 @@ class Benchmarker: - The median runtime of `_callable`, in milliseconds. """ - def run_for(ms: int) -> List[float]: + def run_for(ms: int) -> list[float]: timings = [] run_start_t = time.perf_counter() while True: @@ -183,7 +183,7 @@ class InductorBenchmarker(TritonBenchmarker): def get_event_pairs( self: Self, iters: int - ) -> List[Tuple[torch.cuda.Event, torch.cuda.Event]]: + ) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]: """Get `iters` pairs of CUDA events.""" return [ ( @@ -194,7 +194,7 @@ class InductorBenchmarker(TritonBenchmarker): ] def get_event_pairs_min_timing( - self: Self, event_pairs: List[Tuple[torch.cuda.Event, torch.cuda.Event]] + self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]] ) -> float: """Get the minimum timing, in milliseconds, for a group of CUDA event pairs.""" return min( diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 3c947f5c3c0e..8e1f8659adf1 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -6,7 +6,7 @@ import sys import warnings from pathlib import Path from types import ModuleType -from typing import Callable, Dict, TYPE_CHECKING +from typing import Callable, TYPE_CHECKING if TYPE_CHECKING: @@ -67,7 +67,7 @@ def _set_triton_ptxas_path() -> None: def _worker_compile_triton( - load_kernel: Callable[[], CachingAutotuner], extra_env: Dict[str, str] + load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str] ) -> None: _set_triton_ptxas_path() os.environ.update(extra_env) diff --git a/torch/_inductor/runtime/hints.py b/torch/_inductor/runtime/hints.py index 14389c912d1a..227bbd110e24 100644 --- a/torch/_inductor/runtime/hints.py +++ b/torch/_inductor/runtime/hints.py @@ -5,7 +5,7 @@ import collections import functools import typing from enum import auto, Enum -from typing import Dict, List, Optional, Union +from typing import Optional, Union # The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values @@ -167,8 +167,8 @@ class DeviceProperties(typing.NamedTuple): class HalideInputSpec(typing.NamedTuple): ctype: str name: str - shape: Optional[List[str]] = None - stride: Optional[List[str]] = None + shape: Optional[list[str]] = None + stride: Optional[list[str]] = None offset: Optional[str] = None alias_of: Optional[str] = None @@ -192,13 +192,13 @@ class HalideInputSpec(typing.NamedTuple): class HalideMeta(typing.NamedTuple): - argtypes: List[HalideInputSpec] + argtypes: list[HalideInputSpec] target: str scheduler: Optional[str] = None - scheduler_flags: Optional[Dict[str, Union[int, str]]] = None + scheduler_flags: Optional[dict[str, Union[int, str]]] = None cuda_device: Optional[int] = None - def args(self) -> List[str]: + def args(self) -> list[str]: """Command line args to pass to halide generator""" args = [f"target={self.target}"] if self.scheduler: diff --git a/torch/_inductor/runtime/runtime_utils.py b/torch/_inductor/runtime/runtime_utils.py index 17d4a8517d12..05bb3ddf0b97 100644 --- a/torch/_inductor/runtime/runtime_utils.py +++ b/torch/_inductor/runtime/runtime_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations import functools import operator -from typing import Any, Hashable, TYPE_CHECKING +from typing import Any, TYPE_CHECKING import torch from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 @@ -13,6 +13,8 @@ from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401 if TYPE_CHECKING: + from collections.abc import Hashable + from .triton_compat import Config diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 9492b8cfabbe..9afb3453f277 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -16,17 +16,7 @@ import sys import threading import time from collections import namedtuple -from typing import ( - Any, - Callable, - Container, - Dict, - Hashable, - List, - Optional, - Tuple, - TYPE_CHECKING, -) +from typing import Any, Callable, Optional, TYPE_CHECKING import torch from torch.utils._ordered_set import OrderedSet @@ -76,13 +66,15 @@ from .triton_compat import ( if TYPE_CHECKING: + from collections.abc import Container, Hashable + LauncherType = Any log = logging.getLogger(__name__) -def get_total_reduction_numel(numels: Dict[str, int]) -> int: +def get_total_reduction_numel(numels: dict[str, int]) -> int: return conditional_product( *[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)] ) @@ -93,7 +85,7 @@ def autotune_hints_to_configs( size_hints, block_size: int, device_props: DeviceProperties, -) -> List[Config]: +) -> list[Config]: """ AutotuneHints can be attached to the metadata of triton kernels for providing suggestions about what to try for autotuning. One reason to do this is if there are @@ -104,7 +96,7 @@ def autotune_hints_to_configs( configs to try. """ xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...] - configs: List[Config] = [] + configs: list[Config] = [] for hint in hints: if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD: if len(size_hints) == 1: @@ -180,14 +172,14 @@ class CachingAutotuner(KernelInterface): triton_meta, # passed directly to triton configs, save_cache_hook, - mutated_arg_names: List[str], # see [Note: clone mutated buffers] + mutated_arg_names: list[str], # see [Note: clone mutated buffers] optimize_mem, heuristic_type, size_hints=None, inductor_meta=None, # metadata not relevant to triton custom_kernel=False, # whether the kernel is inductor-generated or custom filename: Optional[str] = None, - reset_to_zero_arg_names: Optional[List[str]] = None, + reset_to_zero_arg_names: Optional[list[str]] = None, ): super().__init__() @@ -223,8 +215,8 @@ class CachingAutotuner(KernelInterface): for c in self.configs: log.debug(c) - self.compile_results: List[TritonCompileResult] = [] - self.launchers: List[LauncherType] = [] + self.compile_results: list[TritonCompileResult] = [] + self.launchers: list[LauncherType] = [] self.lock = threading.Lock() if os.getenv("TRITON_CACHE_DIR") is None: os.environ["TRITON_CACHE_DIR"] = triton_cache_dir( @@ -430,7 +422,7 @@ class CachingAutotuner(KernelInterface): self.fn.repr = _ConstRepr(self.fn.repr(self.fn)) self.launchers = [] - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: assert ( not self.launchers ), "pickle should not be called with after make_launchers()" @@ -439,7 +431,7 @@ class CachingAutotuner(KernelInterface): "lock": None, } - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: self.__dict__.update(state) self.lock = threading.Lock() @@ -636,7 +628,7 @@ class CachingAutotuner(KernelInterface): def maybe_clone_args( self, exclude: Container[str], *args, **kwargs - ) -> tuple[List[Any], Dict[str, Any]]: + ) -> tuple[list[Any], dict[str, Any]]: """ Prepare new args and kwargs by cloning any in-place buffers (that are not in the provided exclusion list), to avoid autotune @@ -659,7 +651,7 @@ class CachingAutotuner(KernelInterface): return cloned_args, cloned_kwargs - def clone_args(self, *args, **kwargs) -> tuple[List[Any], Dict[str, Any]]: + def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]: return self.maybe_clone_args(OrderedSet(), *args, **kwargs) def benchmark_all_configs(self, *args, **kwargs): @@ -888,15 +880,15 @@ class TritonCompileResult: @staticmethod @functools.lru_cache(32) - def _kernel_metadata_cls(fields: Tuple[str, ...]) -> Any: + def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any: return namedtuple("KernelMetadata", sorted(fields)) def __init__( self, kernel: CompiledKernel, config: Config, - compile_meta: Dict[str, Any], - inductor_meta: Dict[str, Any], + compile_meta: dict[str, Any], + inductor_meta: dict[str, Any], ) -> None: super().__init__() self.kernel = kernel @@ -904,7 +896,7 @@ class TritonCompileResult: self.compile_meta = compile_meta self.inductor_meta = inductor_meta - def __getstate__(self) -> Dict[str, Any]: + def __getstate__(self) -> dict[str, Any]: kernel = self.kernel # replace the fields that don't pickle nicely kernel_state = { @@ -916,7 +908,7 @@ class TritonCompileResult: } return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item] - def __setstate__(self, state: Dict[str, Any]) -> None: + def __setstate__(self, state: dict[str, Any]) -> None: # src = ASTSource.__new__(ASTSource) # src.__setstate__(state["kernel"]["src"]) # TODO(jansel): need to fixup src.fn which is now None @@ -1101,7 +1093,7 @@ def _find_names(obj): return obj_names -collected_calls: List[Any] = [] +collected_calls: list[Any] = [] def start_graph(): @@ -1220,7 +1212,7 @@ class DebugAutotuner(CachingAutotuner): collected_calls.append(self.cached) -def hash_configs(configs: List[Config]): +def hash_configs(configs: list[Config]): """ Hash used to check for changes in configurations """ @@ -1233,8 +1225,8 @@ def hash_configs(configs: List[Config]): def cached_autotune( - size_hints: Optional[List[int]], - configs: List[Config], + size_hints: Optional[list[int]], + configs: list[Config], triton_meta, heuristic_type, filename=None, @@ -1276,7 +1268,7 @@ def cached_autotune( if "restore_value" in triton_meta: mutated_arg_names += triton_meta.pop("restore_value") - reset_to_zero_arg_names: List[str] = [] + reset_to_zero_arg_names: list[str] = [] if "reset_to_zero" in triton_meta: reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero")) @@ -1331,7 +1323,7 @@ def cached_autotune( return decorator -def unique_configs(configs: List[Config]): +def unique_configs(configs: list[Config]): """Remove duplicate configurations""" seen: OrderedSet[Hashable] = OrderedSet() pruned_configs = [] @@ -1362,7 +1354,7 @@ def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None): ) -def check_max_block(cfg: Dict[str, int]): +def check_max_block(cfg: dict[str, int]): """ Check that block sizes are within the maximum allowed. """ @@ -1505,7 +1497,7 @@ def triton_config( return Config(cfg, num_warps=num_warps, num_stages=num_stages) -def _get_nd_reduction_numels(r: int, size_hints: Dict[str, int]) -> Dict[str, int]: +def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]: """ Converts a linear reduction numel to ND, in row major order. This order is often desirable as it presents opportunities to coalesce memory @@ -1596,7 +1588,7 @@ def triton_config_reduction( return Config(cfg, num_warps=num_warps, num_stages=num_stages) -def _get_config(numels: Dict[str, int]) -> Dict[str, int]: +def _get_config(numels: dict[str, int]) -> dict[str, int]: """ Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc. """ @@ -1729,8 +1721,8 @@ def pointwise( def _reduction_configs( - *, size_hints: Dict[str, int], inductor_meta: Dict[str, Any] -) -> List[Config]: + *, size_hints: dict[str, int], inductor_meta: dict[str, Any] +) -> list[Config]: reduction_hint = inductor_meta.get("reduction_hint", None) # Convert reductions to 1D, to simplify heuristics. @@ -1981,7 +1973,7 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No ) -def _pop_config_kwargs(config: Dict[str, Any]) -> Dict[str, Any]: +def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]: """Extract triton.Config options that should become kwargs""" popped = {} for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"): diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index 26baa5c2f8cf..f84e18822bb2 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -14,20 +14,8 @@ import pprint import textwrap import traceback import typing -from collections import defaultdict -from typing import ( - Any, - Callable, - Counter, - DefaultDict, - Dict, - Generic, - List, - Optional, - Sequence, - TypeVar, - Union, -) +from collections import Counter, defaultdict +from typing import Any, Callable, Generic, Optional, TypeVar, Union import sympy @@ -68,6 +56,10 @@ from .utils import ( from .virtualized import V +if typing.TYPE_CHECKING: + from collections.abc import Sequence + + log = logging.getLogger(__name__) fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering") @@ -78,7 +70,7 @@ class SchedulerBuffer: scheduler: Scheduler node: ir.Buffer defining_op: BaseSchedulerNode - users: List[NodeUser] = dataclasses.field(default_factory=list) + users: list[NodeUser] = dataclasses.field(default_factory=list) mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field( default_factory=MemoryPlanningInfoForBuffer ) @@ -154,9 +146,9 @@ class SchedulerBuffer: return False return True - def set_users(self, users: List[NodeUser]) -> None: + def set_users(self, users: list[NodeUser]) -> None: # deduplicate - result: Dict[int, NodeUser] = {} + result: dict[int, NodeUser] = {} for use in users: if id(use.node) in result: result[id(use.node)] = use.merge(result[id(use.node)]) @@ -194,7 +186,7 @@ class BaseSchedulerNode: def __init__(self, scheduler: Scheduler) -> None: self.scheduler: Scheduler = scheduler self.debug_device_str: Callable[ - [BaseSchedulerNode], List[str] + [BaseSchedulerNode], list[str] ] = lambda *args, **kwargs: [] def _init_from_node(self, node: ir.Operation) -> None: @@ -204,7 +196,7 @@ class BaseSchedulerNode: str ]() # buffers that won't be used after this kernel self.written = False - self.outputs: List[SchedulerBuffer] = [ + self.outputs: list[SchedulerBuffer] = [ SchedulerBuffer( scheduler=self.scheduler, node=output, @@ -212,7 +204,7 @@ class BaseSchedulerNode: ) for output in node.get_outputs() ] - self.outputs_by_name: Dict[str, SchedulerBuffer] = { + self.outputs_by_name: dict[str, SchedulerBuffer] = { buf.get_name(): buf for buf in self.outputs } @@ -247,7 +239,7 @@ class BaseSchedulerNode: def debug_str_extra(self) -> str: return "" - def _debug_str_for_device(self) -> List[str]: + def _debug_str_for_device(self) -> list[str]: return self.debug_device_str(self) def debug_str_short(self) -> str: @@ -278,7 +270,7 @@ class BaseSchedulerNode: ) -> None: return - def update_mutated_names(self, renames: Dict[str, str]) -> None: + def update_mutated_names(self, renames: dict[str, str]) -> None: self.set_read_writes(self.read_writes.rename(renames)) def add_fake_dep(self, dep: Dep) -> None: @@ -295,7 +287,7 @@ class BaseSchedulerNode: self.prune_deps() def set_last_usage( - self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] + self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str] ) -> None: used_buffers = self.used_or_aliased_buffer_names() used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers) @@ -352,7 +344,7 @@ class BaseSchedulerNode: self.set_read_writes(self.read_writes.remove_reads(to_remove)) def prune_redundant_deps( - self, name_to_fused_node: Dict[str, BaseSchedulerNode] + self, name_to_fused_node: dict[str, BaseSchedulerNode] ) -> None: _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) @@ -582,7 +574,7 @@ class BaseSchedulerNode: def get_read_write_buffer_accesses( self, include_reads: bool, include_writes: bool - ) -> Dict[str, int]: + ) -> dict[str, int]: """ Counting the number of bytes accessed for a kernel is surprisingly tricky. In particular, there is a differentiation @@ -658,7 +650,7 @@ class BaseSchedulerNode: writes = writes - removed_buffers reads = reads - removed_buffers - buf_byte_accesses: Dict[str, int] = {} + buf_byte_accesses: dict[str, int] = {} for buf_name in reads | writes: buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name]) @@ -811,8 +803,8 @@ class BaseSchedulerNode: @staticmethod def get_prologue_template_epilogue( - nodes: List[BaseSchedulerNode], - ) -> tuple[List[BaseSchedulerNode], BaseSchedulerNode, List[BaseSchedulerNode]]: + nodes: list[BaseSchedulerNode], + ) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]: """ For the list of nodes, get the prologue, template, and epilogue """ @@ -874,8 +866,8 @@ class OutputNode: def _prune_redundant_deps( node: BaseSchedulerNode, - name_to_fused_node: Dict[str, BaseSchedulerNode], - name_to_buf: Dict[str, SchedulerBuffer], + name_to_fused_node: dict[str, BaseSchedulerNode], + name_to_buf: dict[str, SchedulerBuffer], ) -> None: """ Prunes weakdeps intended for mutation ordering @@ -961,7 +953,7 @@ class SchedulerNode(BaseSchedulerNode): def _compute_attrs( self, - extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)) @@ -993,7 +985,7 @@ class SchedulerNode(BaseSchedulerNode): def recompute_size_and_body( self, - extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None, + extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None, ) -> None: self._compute_attrs( @@ -1120,7 +1112,7 @@ class SchedulerNode(BaseSchedulerNode): def ranges_from_index_vars( self, index_vars: Sequence[Sequence[sympy.Expr]] - ) -> Dict[sympy.Expr, sympy.Expr]: + ) -> dict[sympy.Expr, sympy.Expr]: sizes = self._sizes assert sum(map(len, sizes)) == sum(map(len, index_vars)) var_ranges = dict( @@ -1220,7 +1212,7 @@ def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None: def init_group_node( group_snode: BaseSchedulerNode, scheduler: Scheduler, - snodes: List[BaseSchedulerNode], + snodes: list[BaseSchedulerNode], ) -> None: assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode)) group_snode.snodes = snodes @@ -1246,7 +1238,7 @@ class FusedSchedulerNode(BaseSchedulerNode): its unmet dependencies as the union of its constituent nodes. """ - snodes: List[BaseSchedulerNode] + snodes: list[BaseSchedulerNode] @classmethod def fuse( @@ -1319,10 +1311,10 @@ class FusedSchedulerNode(BaseSchedulerNode): refresh_group_node_dependencies(self) - def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: + def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: super().__init__(scheduler) init_group_node(self, scheduler, snodes) - self.users: List[NodeUser] = [] + self.users: list[NodeUser] = [] self.group = max(snodes, key=lambda x: int(x.is_reduction())).group @cache_on_self @@ -1336,8 +1328,8 @@ class FusedSchedulerNode(BaseSchedulerNode): def get_buffer_names(self) -> OrderedSet[str]: return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) - def get_outputs(self) -> List[SchedulerBuffer]: - result: List[SchedulerBuffer] = [] + def get_outputs(self) -> list[SchedulerBuffer]: + result: list[SchedulerBuffer] = [] for node in self.snodes: result.extend(node.get_outputs()) return result @@ -1358,7 +1350,7 @@ class FusedSchedulerNode(BaseSchedulerNode): return f"{self}, snodes: {snodes_str}" def set_last_usage( - self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str] + self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str] ) -> None: # Set self.last_usage using the global information # This will be used for inter-kernel optimisations @@ -1414,7 +1406,7 @@ class FusedSchedulerNode(BaseSchedulerNode): # None of these need to be implemented, as a FusedSchedulerNode is just an # abstraction for scheduling purposes - def update_mutated_names(self, renames: Dict[str, str]) -> None: + def update_mutated_names(self, renames: dict[str, str]) -> None: raise NotImplementedError def add_fake_dep(self, name: Dep) -> None: @@ -1546,7 +1538,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): enable_autotune = consumer.enable_autotune prev_node_1 = None prev_node_2 = None - fused_nodes: List[BaseSchedulerNode] + fused_nodes: list[BaseSchedulerNode] if producer.is_foreach() and consumer.is_foreach(): producer = typing.cast(ForeachKernelSchedulerNode, producer) consumer = typing.cast(ForeachKernelSchedulerNode, consumer) @@ -1599,7 +1591,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): def __init__( self, scheduler: Scheduler, - snodes: List[BaseSchedulerNode], + snodes: list[BaseSchedulerNode], use_custom_partition_algo: bool, prev_node_1: Optional[BaseSchedulerNode] = None, prev_node_2: Optional[BaseSchedulerNode] = None, @@ -1621,7 +1613,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): self.scheduler = scheduler self.snodes = snodes self.node = None - self.users: List[NodeUser] = [] + self.users: list[NodeUser] = [] self.set_read_writes( dependencies.ReadWrites.merge_list( @@ -1666,8 +1658,8 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): @classmethod def combinable_nodes( - cls, nodes: List[BaseSchedulerNode] - ) -> List[BaseSchedulerNode]: + cls, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)] if extern: log.debug( @@ -1700,7 +1692,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): @staticmethod def _default_group_nodes_for_combo_kernels( scheduler: Scheduler, - ) -> List[List[BaseSchedulerNode]]: + ) -> list[list[BaseSchedulerNode]]: """ Returns a list of lists of nodes that are to be grouped together. """ @@ -1718,12 +1710,12 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): return grouped_nodes group_algorithm_for_combo_kernels: Callable[ - [Scheduler], List[List[BaseSchedulerNode]] + [Scheduler], list[list[BaseSchedulerNode]] ] = _default_group_nodes_for_combo_kernels @staticmethod def set_group_algorithm_for_combo_kernels( - custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]] + custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]] ) -> None: ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( custom_group_algorithm @@ -1732,7 +1724,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): @staticmethod def group_nodes_for_combo_kernels( scheduler: Scheduler, - ) -> List[List[BaseSchedulerNode]]: + ) -> list[list[BaseSchedulerNode]]: return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler) def mark_run(self) -> None: @@ -1744,7 +1736,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): def is_foreach(self) -> bool: return True - def get_subkernel_nodes(self) -> List[BaseSchedulerNode]: + def get_subkernel_nodes(self) -> list[BaseSchedulerNode]: """Returns a list of nodes which comprise the combo kernel. These nodes may be vertically fused.""" return list(self.snodes) @@ -1758,7 +1750,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode): return self.snodes[0].get_first_name() def prune_redundant_deps( - self, name_to_fused_node: Dict[str, BaseSchedulerNode] + self, name_to_fused_node: dict[str, BaseSchedulerNode] ) -> None: _prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf) @@ -1776,10 +1768,10 @@ class GroupedSchedulerNode(BaseSchedulerNode): At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node. """ - snodes: List[BaseSchedulerNode] + snodes: list[BaseSchedulerNode] @classmethod - def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode: + def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode: scheduler = snodes[0].scheduler assert all(node.scheduler is scheduler for node in snodes) grouped_snode = cls(scheduler, snodes) # type: ignore[arg-type] @@ -1788,11 +1780,11 @@ class GroupedSchedulerNode(BaseSchedulerNode): scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode return grouped_snode - def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None: + def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None: super().__init__(scheduler) init_group_node(self, scheduler, snodes) - def unpack(self) -> List[BaseSchedulerNode]: + def unpack(self) -> list[BaseSchedulerNode]: """ Do fusion among nodes within this GroupedSchedulerNode, and then unpack this GroupedSchedulerNode into regular nodes. @@ -1817,8 +1809,8 @@ class GroupedSchedulerNode(BaseSchedulerNode): def get_buffer_names(self) -> OrderedSet[str]: return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes]) - def get_outputs(self) -> List[SchedulerBuffer]: - result: List[SchedulerBuffer] = [] + def get_outputs(self) -> list[SchedulerBuffer]: + result: list[SchedulerBuffer] = [] for node in self.snodes: result.extend(node.get_outputs()) return result @@ -1833,10 +1825,10 @@ class GroupedSchedulerNode(BaseSchedulerNode): def pick_loop_order( - stride_lengths: List[List[int]], + stride_lengths: list[list[int]], sizes: Sequence[sympy.Expr], priority_idx: tuple[int, ...] = (), -) -> List[int]: +) -> list[int]: """ A heuristic to decide loop iteration orders. This has not been well tuned and may be something we should autotune. @@ -1914,17 +1906,17 @@ _post_grad_graph_counter = itertools.count() class Scheduler: - __dep_size_hint_cache: Dict[Dep, int] + __dep_size_hint_cache: dict[Dep, int] - def __init__(self, nodes: List[ir.Operation]) -> None: + def __init__(self, nodes: list[ir.Operation]) -> None: with dynamo_timed("Scheduler.__init__"): self._init(nodes) - def _init(self, nodes: List[ir.Operation]) -> None: + def _init(self, nodes: list[ir.Operation]) -> None: super().__init__() self.__dep_size_hint_cache = {} V.graph.scheduler = self - self.backends: Dict[torch.device, BaseScheduling] = {} + self.backends: dict[torch.device, BaseScheduling] = {} self.post_grad_graph_id = next(_post_grad_graph_counter) self.completed_operations = OrderedSet[str]() @@ -1943,23 +1935,23 @@ class Scheduler: for node in self.nodes: node.prune_deps() - self.name_to_donated_buffer: Dict[ + self.name_to_donated_buffer: dict[ str, SchedulerDonatedBuffer ] = self.get_donated_buffers() - self.name_to_node: Dict[str, BaseSchedulerNode] = { + self.name_to_node: dict[str, BaseSchedulerNode] = { n.get_name(): n for n in self.nodes } - self.name_to_buf: Dict[str, SchedulerBuffer] = { + self.name_to_buf: dict[str, SchedulerBuffer] = { buf.get_name(): buf for node in self.nodes for buf in node.get_outputs() } - self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy() + self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy() # mutation_real_name: Maps back to the original name for codegen # Example: # If you mutate buf0 inside of buf1's kernel, then: # mutation_real_name = {"buf0" : "buf1"} # all subsequent uses of buf0 become buf1's usage in dependency graph - self.mutation_real_name: Dict[str, str] = {} + self.mutation_real_name: dict[str, str] = {} # We handle mutation by renaming modified versions of the same # buffer in the dependency graph to prevent cycles. @@ -1969,7 +1961,7 @@ class Scheduler: # If you mutate buf0 inside of buf1's kernel, then: # mutation_renames = {"buf1" : "buf0"} # in codegen we only use buf0, never buf1 - self.mutation_renames: Dict[str, str] = {} + self.mutation_renames: dict[str, str] = {} # Must run first to correctly set dependencies, before all other passes that rely on # reading from .read_writes.reads or .unmet_dependencies @@ -2024,7 +2016,7 @@ class Scheduler: # fx graph node to the position it appears in the graph # for debug attribution - self.origin_to_index: Dict[torch.fx.Node, int] = {} + self.origin_to_index: dict[torch.fx.Node, int] = {} get_metric_table("graph_stats").add_row( lambda: { @@ -2034,7 +2026,7 @@ class Scheduler: } ) - def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]: + def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]: name_to_donated_buf = {} for name in V.graph.graph_inputs_original: if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer): @@ -2135,7 +2127,7 @@ class Scheduler: def __init__( self, - items: Optional[List[T]] = None, + items: Optional[list[T]] = None, membership: Optional[OrderedSet[T]] = None, ) -> None: self.items = items or [] @@ -2154,7 +2146,7 @@ class Scheduler: ] return DedupList(new_items, new_membership) - name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict( + name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict( DedupList ) @@ -2196,7 +2188,7 @@ class Scheduler: NodeUser(user_node, can_inplace, is_weak) ) - unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {} + unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {} # NB: None means that the dependency is on an input. Don't actually # generate a dependency because if we do, Inductor will start trying @@ -2367,14 +2359,14 @@ class Scheduler: node.prune_weak_deps() def topological_sort_schedule( - self, nodes: List[BaseSchedulerNode] - ) -> List[BaseSchedulerNode]: + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: """ Ensure nodes is in topologically sorted order """ seen = OrderedSet[BaseSchedulerNode]() - name_to_node: Dict[str, BaseSchedulerNode] = dict() - result: List[BaseSchedulerNode] = [] + name_to_node: dict[str, BaseSchedulerNode] = dict() + result: list[BaseSchedulerNode] = [] def visit(n: BaseSchedulerNode) -> None: if n not in seen: @@ -2393,7 +2385,7 @@ class Scheduler: visit(node) return result - def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]: + def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]: unmet_deps = OrderedSet[str]() if isinstance( snode, @@ -2415,13 +2407,13 @@ class Scheduler: OrderedSet(self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops) ) - def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]: + def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]: """ Sort nodes by their topological order, return a list of node lists. """ order = [] nodes = dict.fromkeys(self.nodes, 0) - children: Dict[Any, Any] = {} + children: dict[Any, Any] = {} for node in self.nodes: deps = self._get_unmet_dep_nodes(node) nodes[node] = len(deps) @@ -2446,7 +2438,7 @@ class Scheduler: Populate each node.ancestors """ # note self.nodes is topologically sorted - name_to_ancestors: Dict[str, OrderedSet[str]] = {} + name_to_ancestors: dict[str, OrderedSet[str]] = {} for node in self.nodes: ancestors = OrderedSet[str]() for dep in node.unmet_dependencies: @@ -2487,7 +2479,7 @@ class Scheduler: # FusedSchedulerNode having different merged loops. # Skip CPU backend for now. - def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: + def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]: """ Combine eligible nodes into FusedSchedulerNodes. """ @@ -2518,7 +2510,7 @@ class Scheduler: """ Unpack GroupedSchedulerNode into regular nodes. """ - new_nodes: List[BaseSchedulerNode] = [] + new_nodes: list[BaseSchedulerNode] = [] for node in self.nodes: new_nodes.extend( node.unpack() if isinstance(node, GroupedSchedulerNode) else [node] @@ -2802,8 +2794,8 @@ class Scheduler: return ms_fused < ms1 + ms2 def fuse_nodes_once( - self, nodes: List[BaseSchedulerNode] - ) -> List[BaseSchedulerNode]: + self, nodes: list[BaseSchedulerNode] + ) -> list[BaseSchedulerNode]: """ Combine eligible nodes into FusedSchedulerNodes. @@ -2890,20 +2882,20 @@ class Scheduler: ) self.prune_redundant_deps(self.nodes) - def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None: + def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None: for node in nodes: node.prune_redundant_deps(self.name_to_fused_node) def get_possible_fusions( - self, nodes: List[BaseSchedulerNode] - ) -> List[tuple[BaseSchedulerNode, BaseSchedulerNode]]: + self, nodes: list[BaseSchedulerNode] + ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]: """ Helper to find all legal fusion opportunities, sorted by self.score_fusion() """ possible_fusions = [] seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]() - def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None: + def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None: for node1_index, node1 in enumerate(nodes): for node2 in nodes[node1_index + 1 :]: key = (node1, node2) @@ -3015,7 +3007,7 @@ class Scheduler: def _find_single_user_inputs( node: BaseSchedulerNode, - ) -> List[ir.Buffer]: + ) -> list[ir.Buffer]: output = [] for rd in node.read_writes.reads: buf = self.name_to_buf.get(rd.name) @@ -3403,7 +3395,7 @@ class Scheduler: """ node1_buf_names = node1.get_buffer_names() why = WhyNoFuse(node1, node2) - remaining_deps_by_name: Dict[str, List[Dep]] = defaultdict(list) + remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list) for dep in node2.unmet_dependencies: name = self.mutation_renames.get(dep.name, dep.name) @@ -3562,14 +3554,14 @@ class Scheduler: return sum(self.dep_size_hint(dep) for dep in common_memory_deps) def get_possible_fusions_with_highest_priority( - self, possible_fusions: List[tuple[BaseSchedulerNode, BaseSchedulerNode]] - ) -> List[tuple[BaseSchedulerNode, BaseSchedulerNode]]: + self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]] + ) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]: # Group the possible fusions based on their priority from the backend. # Only return the group of possible fusions with highest priority. if len(possible_fusions) == 0: return possible_fusions - possible_fusions_group_by_priority: Dict[ - int, List[tuple[BaseSchedulerNode, BaseSchedulerNode]] + possible_fusions_group_by_priority: dict[ + int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]] ] = {} for node1, node2 in possible_fusions: @@ -3828,7 +3820,7 @@ class Scheduler: backend = self.get_backend(device) return backend.benchmark_combo_kernel(node_list) - def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool: + def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool: """ If config.benchmark_fusion is False, always return True. Otherwise, return True if fusion can brings speedup. diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 46d4e2e11736..57b1baa8f658 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -16,17 +16,8 @@ import textwrap import time from concurrent.futures import as_completed, ThreadPoolExecutor from io import StringIO -from typing import ( - Any, - Callable, - Dict, - List, - Optional, - Type, - TYPE_CHECKING, - TypeVar, - Union, -) +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import Self from unittest.mock import patch import sympy @@ -86,7 +77,7 @@ from .virtualized import V log = logging.getLogger(__name__) # correctness checks struggle with fp16/tf32 -VERIFY: Dict[str, Any] = {} +VERIFY: dict[str, Any] = {} PRINT_AUTOTUNE = True DEBUG = False @@ -104,14 +95,11 @@ class KernelNamespace: extern_kernels = KernelNamespace() -_T = TypeVar("_T", bound="AutotuneArgs") - - @dataclasses.dataclass class BenchmarkTensors: """Represents a set of inputs and outputs for autotuning with a template""" - input_tensors: List[torch.Tensor] + input_tensors: list[torch.Tensor] output_tensor: Optional[torch.Tensor] def unpack(self): @@ -139,13 +127,13 @@ class AutotuneArgs: @classmethod def from_choice_args( - cls: Type[_T], - example_inputs: List[torch.Tensor], - example_inputs_extern: List[torch.Tensor], + cls, + example_inputs: list[torch.Tensor], + example_inputs_extern: list[torch.Tensor], out: torch.Tensor, out_extern: torch.Tensor, expected: Optional[torch.Tensor] = None, - ) -> _T: + ) -> Self: """Factory method to create AutotuneInputs from separate inputs/outputs""" return cls( triton=BenchmarkTensors(example_inputs, out), @@ -207,7 +195,7 @@ class SubgraphInfo: ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined] # only copied over if not None - range_trees: Optional[List["IterationRangesRoot"]] = None + range_trees: Optional[list["IterationRangesRoot"]] = None numels = None # type: ignore[var-annotated] def __post_init__(self): @@ -226,7 +214,7 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined] self, kernel, subgraph_number: int, - fixed_inputs: Dict[str, Any], + fixed_inputs: dict[str, Any], mask: Optional[str], ): super().__init__(V.ops) @@ -290,7 +278,7 @@ class TritonTemplateKernel(TritonKernel): prefix_args=0, suffix_args=0, epilogue_fn=identity, - subgraphs: Optional[List[ir.ComputedBuffer]] = None, + subgraphs: Optional[list[ir.ComputedBuffer]] = None, workspace_arg: Optional[WorkspaceArg] = None, ) -> None: numel = sympy_product(output_node.get_size()) @@ -317,9 +305,9 @@ class TritonTemplateKernel(TritonKernel): self.suffix_args = suffix_args self.epilogue_fn = epilogue_fn self.render_hooks = {} # type: ignore[var-annotated] - self.triton_meta: Optional[Dict[str, object]] = None + self.triton_meta: Optional[dict[str, object]] = None # For Templated Attention this can be a list of ir.Subgraph - self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs + self.subgraphs: Optional[list[ir.ComputedBuffer]] = subgraphs # Some templates use extra global memory as a workspace self.workspace_arg = workspace_arg @@ -330,7 +318,7 @@ class TritonTemplateKernel(TritonKernel): # used for triton kernel codegen. # They are swapped onto the TritonTemplateKernel object by # `set_subgraph_body` - self.subgraph_bodies: Dict[str, SubgraphInfo] = {} + self.subgraph_bodies: dict[str, SubgraphInfo] = {} # input buffers which we are allowed to prologue fuse into self.prologue_supported_inputs: OrderedSet[str] = OrderedSet() @@ -420,7 +408,7 @@ class TritonTemplateKernel(TritonKernel): return "@triton.jit" argdefs, _, signature, _ = self.args.python_argdefs() - triton_meta: Dict[str, Any] = { + triton_meta: dict[str, Any] = { "signature": signature_to_meta( signature, size_dtype=self.index_dtype, argdefs=argdefs ), @@ -622,7 +610,7 @@ class TritonTemplateKernel(TritonKernel): ) with V.set_ops_handler(modification_handler): assert isinstance( - subgraph, (ir.ComputedBuffer, List) + subgraph, (ir.ComputedBuffer, list) ), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" # Handle scatter stores if isinstance(subgraph, list): @@ -651,7 +639,7 @@ class TritonTemplateKernel(TritonKernel): self, input_name: str, output_name: str, - indices: Union[List[Any], tuple[Any]], + indices: Union[list[Any], tuple[Any]], mask: Optional[str] = None, other: Optional[Union[float, int]] = 0.0, indent_width: int = 4, @@ -826,7 +814,7 @@ class TritonTemplateKernel(TritonKernel): def store_output( self, - indices: Union[List[Any], tuple[Any]], + indices: Union[list[Any], tuple[Any]], val: str, mask: Optional[str] = None, indent_width: int = 4, @@ -1032,7 +1020,7 @@ def _jinja2_env(): class TritonTemplate(KernelTemplate): index_counter = itertools.count() - all_templates: Dict[str, "TritonTemplate"] = {} + all_templates: dict[str, "TritonTemplate"] = {} def __init__(self, name: str, grid: Any, source: str, debug=False) -> None: super().__init__(name) @@ -1180,7 +1168,7 @@ class TritonTemplate(KernelTemplate): ), kwargs, ) - bmreq_cls: Type[TritonBenchmarkRequest] + bmreq_cls: type[TritonBenchmarkRequest] if layout.device.type == "cpu": bmreq_cls = TritonCPUBenchmarkRequest else: @@ -1292,7 +1280,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase): description, bmreq, log_info: Optional[ - Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]] + dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]] ] = None, mutated_inputs=None, workspace_arg: Optional[WorkspaceArg] = None, @@ -1303,7 +1291,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase): self.bmreq: TritonBenchmarkRequest = bmreq if log_info is None: log_info = {} - self.log_info: Dict[str, Any] = log_info + self.log_info: dict[str, Any] = log_info self.log_info.update( { "backend": "Triton", @@ -1351,7 +1339,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase): ) ) - def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: """Information returned here is logged to the autotune log file when that is enabled.""" return self.log_info @@ -1447,7 +1435,7 @@ class ExternKernelCaller(ChoiceCaller): return ir.TensorBox.create(inner) - def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: + def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]: """Information returned here is logged to the autotune log file when that is enabled.""" return { "backend": "extern", @@ -1589,7 +1577,7 @@ def create_inputs_key(input_nodes) -> str: def create_precompile_key( - name: str, inputs_key: str, choices: List[ChoiceCaller] + name: str, inputs_key: str, choices: list[ChoiceCaller] ) -> str: return ":".join( [ @@ -1609,18 +1597,18 @@ class AlgorithmSelectorCache(PersistentCache): # no guarantee that the first lowering for a given key will also be the # first to benchmark it. share a single precompilation function for all lowerings # of a particular key - self.precompile_cache: Dict[str, Callable[[], None]] = {} + self.precompile_cache: dict[str, Callable[[], None]] = {} # list of callbacks that are called after benchmarking - self.feedback_saver_fns: List[ + self.feedback_saver_fns: list[ Callable[ - [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None + [dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None ] ] = [] def __call__( self, name, - choices: List[ChoiceCaller], + choices: list[ChoiceCaller], input_nodes, layout, # optional dict mapping arg indices to the functions @@ -1628,7 +1616,7 @@ class AlgorithmSelectorCache(PersistentCache): # corresponding ir.Buffer. if passed for a given # arg, the function will be called instead of # generating a random torch.Tensor for benchmarking. - input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, + input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None, precompilation_timeout_seconds: int = 60 * 60, return_multi_template=False, ): @@ -1761,9 +1749,9 @@ class AlgorithmSelectorCache(PersistentCache): executor = ThreadPoolExecutor(max_workers=num_workers) async_compile = torch._inductor.async_compile.AsyncCompile() - futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {} - start_times: Dict[concurrent.futures.Future[Any], float] = {} - elapsed_times: Dict[concurrent.futures.Future[Any], float] = {} + futures: dict[concurrent.futures.Future[Any], ChoiceCaller] = {} + start_times: dict[concurrent.futures.Future[Any], float] = {} + elapsed_times: dict[concurrent.futures.Future[Any], float] = {} for c in choices: if hasattr(c, "precompile"): @@ -1925,7 +1913,7 @@ class AlgorithmSelectorCache(PersistentCache): input_gen_fns = {} def get_inputs( - choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]] + choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]] ) -> AutotuneArgs: # de-duplicate args unique_example_inputs = { @@ -1996,8 +1984,8 @@ class AlgorithmSelectorCache(PersistentCache): return result def benchmark_in_current_process( - choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]], - ) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]: + choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]], + ) -> dict[Union[ExternKernelCaller, TritonTemplateCaller], float]: inputs = get_inputs(choices) timings = {} for choice in choices: @@ -2045,7 +2033,7 @@ class AlgorithmSelectorCache(PersistentCache): return timings def benchmark_in_sub_process( - choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]] + choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]] ): from . import autotune_process @@ -2069,8 +2057,8 @@ class AlgorithmSelectorCache(PersistentCache): @staticmethod def log_results( name: str, - input_nodes: List[ir.IRNode], - timings: Dict[ChoiceCaller, float], + input_nodes: list[ir.IRNode], + timings: dict[ChoiceCaller, float], elapse: float, precompile_elapse: float, ): @@ -2227,7 +2215,7 @@ class AlgorithmSelectorCache(PersistentCache): def add_feedback_saver( self, fn: Callable[ - [Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None + [dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None ], ): self.feedback_saver_fns.append(fn) @@ -2250,7 +2238,7 @@ def autotune_select_algorithm(*args, **kwargs): def add_feedback_saver( - fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None] + fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None] ): global _ALGORITHM_SELECTOR_CACHE if _ALGORITHM_SELECTOR_CACHE is None: diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index 4b14e20038f6..cdb944559394 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -2,7 +2,8 @@ import functools import itertools import logging -from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union +from collections.abc import Iterable, Sequence +from typing import Any, Callable, cast, Optional, Union import sympy from sympy import Expr @@ -60,7 +61,7 @@ class SizeVarAllocator: shape_env = ShapeEnv() self.shape_env = shape_env self.var_to_val = self.shape_env.var_to_val - self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements + self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. # The basic idea is if we have some complicated sympy expression # f(s0), we may choose to precompute it on the host and then replace @@ -71,8 +72,8 @@ class SizeVarAllocator: # which potentially could have already had a precomputed replacement # on it, we are obligated to invert the precomputed replacements # (inv_precomputed_replacements). - self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {} - self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {} + self.precomputed_replacements: dict[Expr, sympy.Symbol] = {} + self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {} self.stride_vars = self.make_stride_vars_cache() self.simplify_with_ranges = self.make_simplify_with_ranges_cache() self._simplify_loops = self.make_simplify_loops_cache() @@ -84,7 +85,7 @@ class SizeVarAllocator: """ self._simplify_with_ranges() can be expensive, cache its results """ - cache: Dict[tuple[Any, ...], Expr] = {} + cache: dict[tuple[Any, ...], Expr] = {} replacement_count = len(self.replacements) def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: @@ -106,7 +107,7 @@ class SizeVarAllocator: """ self._simplify_with_ranges() can be expensive, cache its results """ - cache: Dict[tuple[Any, ...], Any] = {} + cache: dict[tuple[Any, ...], Any] = {} replacement_count = len(self.replacements) def simplify_loops(index_vars, sizes, index_formulas): @@ -221,7 +222,7 @@ class SizeVarAllocator: return expr def _simplify_loops_impl( - self, index_vars: List[sympy.Symbol], sizes, index_formulas + self, index_vars: list[sympy.Symbol], sizes, index_formulas ): """ Try to remove as many axis from loop iterations as possible, by: @@ -337,7 +338,7 @@ class SizeVarAllocator: return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type] # See Note - [On Statically Known] - def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: + def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool: """ Returns a bool indicating if it is sound to optimize as if left and right lists are equal. """ @@ -501,7 +502,7 @@ class SizeVarAllocator: self.guard_equals(left, sympy.Integer(right)) return int(right) - def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]: + def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]: return [self.evaluate_static_shape(x) for x in left] def remove_precomputed_replacements(self, expr: Expr) -> Expr: @@ -582,7 +583,7 @@ class SizeVarAllocator: index: Expr, vars: Sequence[sympy.Symbol], support_vars: Optional[Sequence[sympy.Symbol]] = None, - ) -> List[Expr]: + ) -> list[Expr]: if not support_vars: support_vars = vars return cache(index, tuple(vars), tuple(support_vars)) @@ -594,7 +595,7 @@ class SizeVarAllocator: index: Expr, vars: Sequence[sympy.Symbol], support_vars: Sequence[sympy.Symbol], - ) -> List[Expr]: + ) -> list[Expr]: """Convert an indexing expression back into strides NOTE: This is only valid if the index is a standard strided offset @@ -647,7 +648,7 @@ class SizeVarAllocator: } return expr.subs(size_dict) - def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: + def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr: """Extract offset part of an indexing expression""" index = self.simplify(index) return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0}) @@ -657,7 +658,7 @@ class SizeVarAllocator: index: Expr, vars: Sequence[sympy.Symbol], support_vars: Optional[Sequence[sympy.Symbol]] = None, - ) -> List[int]: + ) -> list[int]: for v in index.free_symbols: if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined] index = sympy_subs(index, {v: 0}) # type: ignore[dict-item] @@ -669,7 +670,7 @@ class SizeVarAllocator: result.append(0) return result - def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: + def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]: strides = tuple(map(abs, self.stride_hints(index, vars))) order = list(range(len(strides))) order.sort(key=lambda x: (strides[x] == 0, strides[x])) diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index 9a8684e1564a..ce35959c5322 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -2,9 +2,10 @@ import functools import operator +from collections.abc import Generator from contextlib import contextmanager from dataclasses import dataclass -from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar, Union +from typing import Any, Callable, Optional, TypeVar, Union from typing_extensions import ParamSpec import torch @@ -20,7 +21,7 @@ T = TypeVar("T") _P = ParamSpec("_P") OpOverload = torch._ops.OpOverload -LoweringDict = Dict[Union[OpOverload, str], Callable[..., Any]] +LoweringDict = dict[Union[OpOverload, str], Callable[..., Any]] TargetType = Union[Callable[..., Any], str] @@ -30,13 +31,13 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): lowering object. Errors if buffers are created unexpectedly """ - graph_outputs: Optional[List[ir.IRNode]] + graph_outputs: Optional[list[ir.IRNode]] root_graph: torch._inductor.graph.GraphLowering _current_op: Optional[TargetType] # For backwards of buffer_grads with scatters we allow mutations allowed_mutations: Optional[OrderedSet[OpOverload]] additional_lowerings: Optional[LoweringDict] - buffers: List[ir.Buffer] + buffers: list[ir.Buffer] mutated_buffers: OrderedSet[str] def __init__( @@ -102,7 +103,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): self, target: TargetType, args: Any, - kwargs: Dict[str, Any], + kwargs: dict[str, Any], ) -> Any: from .lowering import lowerings @@ -123,7 +124,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter): return lowerings[target](*args, **kwargs) - def output(self, target: str, args: tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override] + def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override] assert len(args) == 1 self.graph_outputs = args[0] @@ -155,7 +156,7 @@ class TracingOpsHandler(WrapperHandler[T]): def lower_pointwise_subgraph( - subgraph: ir.Subgraph, inputs: List[InputDescriptor] + subgraph: ir.Subgraph, inputs: list[InputDescriptor] ) -> Callable[_P, Any]: # Lower subgraph to ir.Pointwise nodes def fake_inner_fn( diff --git a/torch/_inductor/triton_bundler.py b/torch/_inductor/triton_bundler.py index ed5d26e68061..b181eb2a0edf 100644 --- a/torch/_inductor/triton_bundler.py +++ b/torch/_inductor/triton_bundler.py @@ -3,7 +3,7 @@ import logging import os import uuid from pathlib import Path -from typing import List, Optional +from typing import Optional from torch._dynamo.utils import counters, dynamo_timed, set_feature_use from torch._utils_internal import justknobs_check @@ -48,7 +48,7 @@ class TritonKernelArtifacts: kernel_hash: str device: int - artifacts: List[TritonKernelArtifact] + artifacts: list[TritonKernelArtifact] @dataclasses.dataclass(frozen=True) @@ -57,7 +57,7 @@ class TritonBundlerMetadata: Metadata used for instrumentation """ - cached_kernel_names: List[str] + cached_kernel_names: list[str] class TritonBundler: @@ -76,7 +76,7 @@ class TritonBundler: - TritonBundler.read_and_emit is called when a cache entry is read """ - _entries: Optional[List[TritonBundleEntry]] = None + _entries: Optional[list[TritonBundleEntry]] = None # __grp__kernel_name.json contains metadata with source code paths # we use this as sentinal value for search and replace @@ -134,7 +134,7 @@ class TritonBundler: @classmethod def collect( cls, - ) -> tuple[List[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]: + ) -> tuple[list[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]: """ This is the main function called when a cache write happens. This function converts all the previously remembered kernels into bundled format so that @@ -150,10 +150,10 @@ class TritonBundler: with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True): entries = cls._entries if entries is not None: - result: List[TritonKernelArtifacts] = [] - kernel_names: List[str] = [] + result: list[TritonKernelArtifacts] = [] + kernel_names: list[str] = [] for entry in entries: - artifacts: List[TritonKernelArtifact] = [] + artifacts: list[TritonKernelArtifact] = [] path = os.path.join(entry.directory, entry.kernel_hash) if not os.path.exists(path): continue @@ -203,7 +203,7 @@ class TritonBundler: @staticmethod def read_and_emit( - bundle: List[TritonKernelArtifacts], + bundle: list[TritonKernelArtifacts], ) -> Optional[TritonBundlerMetadata]: """ This is the main function called when a cache read happens. This function @@ -223,7 +223,7 @@ class TritonBundler: with dynamo_timed( key="TritonBundler.read_and_emit", log_pt2_compile_event=True ): - kernel_names: List[str] = [] + kernel_names: list[str] = [] for artifacts in bundle: basedir = triton_cache_dir(artifacts.device) diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 45d7bce84e75..394e0d7fa1e8 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -26,18 +26,13 @@ from io import StringIO from typing import ( Any, Callable, - Dict, Generic, - Iterable, - List, NamedTuple, Optional, Protocol, - Sequence, TYPE_CHECKING, TypeVar, Union, - ValuesView, ) from typing_extensions import Concatenate, dataclass_transform, ParamSpec, TypeGuard from unittest import mock @@ -49,6 +44,8 @@ from torch._inductor.runtime.hints import DeviceProperties if TYPE_CHECKING: + from collections.abc import Iterable, Sequence, ValuesView + from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND from .codegen.common import WorkspaceArg @@ -94,7 +91,7 @@ _IS_WINDOWS = sys.platform == "win32" log = logging.getLogger(__name__) _T = TypeVar("_T") -VarRanges = Dict[sympy.Expr, sympy.Expr] +VarRanges = dict[sympy.Expr, sympy.Expr] InputType = Optional[Union[torch.Tensor, int, torch.SymInt]] GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"} @@ -308,7 +305,7 @@ def _type_of(key): def convert_shape_to_inductor( lst: Iterable[Union[int, torch.SymInt]] -) -> List[sympy.Expr]: +) -> list[sympy.Expr]: """ Gets the shape and stride of a tensor. For non-symbolic tensors, this is trivial. But for symbolic tensors, we need to map from SymIntNode into @@ -319,7 +316,7 @@ def convert_shape_to_inductor( def convert_shape_to_symint( lst: Iterable[Union[int, sympy.Expr]] -) -> List[Union[int, torch.SymInt]]: +) -> list[Union[int, torch.SymInt]]: """ Takes a list of shapes from Inductor and converts them into symints (or just ints if all shapes are static). @@ -433,7 +430,7 @@ def precompute_method(obj: Any, method: str): setattr(obj, method, lambda: result) -def precompute_methods(obj: Any, methods: List[str]): +def precompute_methods(obj: Any, methods: list[str]): """Replace methods with new methods that returns a precomputed constants.""" for method in methods: precompute_method(obj, method) @@ -451,7 +448,7 @@ def pad_listlike(x, size): # Used to ensure that iterating over a set is deterministic -def tuple_sorted(x: tuple[_T, ...]) -> List[_T]: +def tuple_sorted(x: tuple[_T, ...]) -> list[_T]: if len(x) == 0: return [] @@ -716,7 +713,7 @@ def sympy_index_symbol(name: str) -> sympy.Symbol: return sympy.Symbol(name, integer=True, nonnegative=True) -def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr: +def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr: """ When the passed replacement symbol v is a string, it is converted to a symbol with name v that have the same replaced expression integer and nonnegative properties. @@ -804,7 +801,7 @@ def output_node(gm: torch.fx.GraphModule): return last_node -_registered_caches: List[Any] = [] +_registered_caches: list[Any] = [] def clear_on_fresh_inductor_cache(obj: Any): @@ -871,7 +868,7 @@ def fresh_inductor_cache(cache_entries=None, dir=None, delete=True): clear_inductor_caches() -def argsort(seq) -> List[int]: +def argsort(seq) -> list[int]: # preserve original order for equal strides getter = seq.__getitem__ a_r = range(len(seq)) @@ -880,7 +877,7 @@ def argsort(seq) -> List[int]: def argsort_sym( shape_env, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]] -) -> List[int]: +) -> list[int]: def cmp(a, b): a_idx, a_val = a b_idx, b_val = b @@ -1180,7 +1177,7 @@ def use_max_autotune() -> bool: ) -def _use_template_for_gpu(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool: +def _use_template_for_gpu(layout, allowed_layout_dtypes: list[torch.dtype]) -> bool: return ( is_gpu(layout.device.type) and layout.dtype in allowed_layout_dtypes @@ -1462,10 +1459,10 @@ class DebugDirManager: torch._dynamo.config.debug_dir_root = self.prev_debug_name -def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, List[str]]: +def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]: from .graph import GraphLowering - source_codes: List[str] = [] + source_codes: list[str] = [] def save_output_code(code: str): source_codes.append(code) @@ -1476,7 +1473,7 @@ def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, List[str]]: return result, source_codes -def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, List[str]]: +def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, list[str]]: result, source_codes = run_and_get_code(fn, *args, **kwargs) kernels = [] for code in source_codes: @@ -1497,7 +1494,7 @@ def get_code(fn, *args, **kwargs): """Get the inductor-generated code, but skip any actual compilation or running.""" from .graph import GraphLowering - source_codes: List[str] = [] + source_codes: list[str] = [] def save_output_code(code: str): source_codes.append(code) @@ -2217,13 +2214,13 @@ def shape_env_from_inputs(inputs: Sequence[InputType]): def align_inputs_from_check_idxs( - model: Callable[[List[InputType]], Any], + model: Callable[[list[InputType]], Any], inputs_to_check: Sequence[int], -) -> Callable[[List[InputType]], Any]: +) -> Callable[[list[InputType]], Any]: if len(inputs_to_check) == 0: return model - def run(new_inputs: List[InputType]): + def run(new_inputs: list[InputType]): copy_misaligned_inputs(new_inputs, inputs_to_check) return model(new_inputs) @@ -2243,7 +2240,7 @@ def clone_preserve_strides(x: torch.Tensor): def copy_misaligned_inputs( - new_inputs: List[InputType], check_inputs_idxs: Sequence[int] + new_inputs: list[InputType], check_inputs_idxs: Sequence[int] ) -> None: for i in check_inputs_idxs: _inp = new_inputs[i] @@ -2408,7 +2405,7 @@ class OpDtypeRule: override_return_dtype: Optional[torch.dtype] -op_dtype_propagation_rules: Dict[str, OpDtypeRule] = {} +op_dtype_propagation_rules: dict[str, OpDtypeRule] = {} def register_op_dtype_propagation_rules( @@ -2445,7 +2442,7 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True): return wrap(cls) -def get_donated_idxs() -> Optional[List[int]]: +def get_donated_idxs() -> Optional[list[int]]: tracing_context = torch._guards.TracingContext.try_get() if tracing_context is not None and tracing_context.fw_metadata: return tracing_context.fw_metadata.bw_donated_idxs diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index fa5f6e7ba1eb..4de418461662 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -59,7 +59,7 @@ from __future__ import annotations from contextlib import AbstractContextManager, contextmanager from threading import local -from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union from torch.utils._ordered_set import OrderedSet @@ -108,7 +108,7 @@ class Virtualized(Generic[T]): store other things, like booleans. """ - def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]): + def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]): self._key: str = f"__torchinductor_{vname}" self._default = default @@ -156,7 +156,7 @@ class NullKernelHandler(NullHandler): _ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) -_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler) +_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler) _kernel: Virtualized[NullKernelHandler] = Virtualized( "kernel", NullKernelHandler