PEP585 update - torch/_inductor (#145198)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-20 12:27:30 -08:00
committed by PyTorch MergeBot
parent 2f9d378f7b
commit bac62341eb
34 changed files with 494 additions and 545 deletions

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import List, Optional, Union
from typing import Optional, Union
import sympy
@ -113,8 +113,8 @@ def register_jagged_ops():
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
def _jagged_to_padded_dense_forward(
jagged_values: TensorBox,
jagged_offsets: List[TensorBox],
max_lengths: List[int], # list of ints/SymInts
jagged_offsets: list[TensorBox],
max_lengths: list[int], # list of ints/SymInts
padding_value: float = 0.0,
) -> TensorBox:
device = jagged_values.get_device_or_error()
@ -184,7 +184,7 @@ def register_jagged_ops():
def _dense_to_jagged_forward_impl(
fallback_op, # pyre-ignore[2]
dense: TensorBox,
jagged_offsets: List[TensorBox],
jagged_offsets: list[TensorBox],
jagged_len: Optional[int] = None,
) -> TensorBox:
device = dense.get_device_or_error()
@ -257,7 +257,7 @@ def register_jagged_ops():
@register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
def _dense_to_jagged_forward(
dense: TensorBox,
jagged_offsets: List[TensorBox],
jagged_offsets: list[TensorBox],
jagged_len: Optional[int] = None,
) -> TensorBox:
return _dense_to_jagged_forward_impl(

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import logging
from typing import cast, Optional, Sequence, TYPE_CHECKING, TypedDict
from typing import cast, Optional, TYPE_CHECKING, TypedDict
import torch
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
@ -33,6 +33,8 @@ from .mm_common import build_rocm_gemm_configs, filtered_configs
if TYPE_CHECKING:
from collections.abc import Sequence
from ..ir import TensorBox
log = logging.getLogger(__name__)

View File

@ -3,9 +3,10 @@
import logging
import math
from collections.abc import Sequence
from dataclasses import dataclass
from enum import auto, Enum
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Optional, Union
import sympy
@ -90,7 +91,7 @@ def create_placeholder(
return TensorBox.create(input_buffer)
def maybe_realize(args: List[Optional[IRNode]]):
def maybe_realize(args: list[Optional[IRNode]]):
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
return tree_map(
lambda x: (
@ -109,7 +110,7 @@ def get_float32_precision():
return "'tf32'"
def zeros_and_scatter_lowering(shape: List[int], indices, values):
def zeros_and_scatter_lowering(shape: list[int], indices, values):
# Always accumulate into fp32 then cast
grad = _full(0, values.get_device(), torch.float32, shape)
assert isinstance(grad, TensorBox)
@ -153,10 +154,10 @@ def zeros_and_scatter_lowering(shape: List[int], indices, values):
return buffer
SubgraphResults = Union[List[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
def build_subgraph_buffer(args: List[TensorBox], subgraph: Subgraph) -> SubgraphResults:
def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults:
"""This function's goal is to take in the required args and produce the subgraph buffer
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
@ -870,7 +871,7 @@ def lower_cpu(
"torch.compile on current platform is not supported for CPU."
)
fake_buffers: List[Buffer] = [] # noqa: F821
fake_buffers: list[Buffer] = [] # noqa: F821
placeholder_inps = [
create_placeholder(name, dtype, query.get_device())
for name, dtype in [
@ -968,7 +969,7 @@ def lower_cpu(
[B, Hq, seq_len_q, v_head_dim],
stride=[sympy.sympify(s) for s in out_strides],
)
_choices: List[Any] = []
_choices: list[Any] = []
input_nodes = [query, key, value, kv_num_blocks, kv_indices]
if not full_kv_num_blocks:
no_full_kv_block = True
@ -1214,8 +1215,8 @@ def flex_attention(
"V_HEAD_DIM", V.graph.sizevars.evaluate_static_shape(v_head_dim)
)
choices: List[Any] = []
configs: List[tuple[int, int, int, int]] = []
choices: list[Any] = []
configs: list[tuple[int, int, int, int]] = []
configs.append(_get_default_config_fwd(query))
if config.max_autotune:
configs += [
@ -2071,9 +2072,9 @@ class JointOutputResult:
"""Results from processing joint outputs."""
grad_input: ComputedBuffer
captured_grads_compute: List[ComputedBuffer]
captured_grads: List[Optional[TensorBox]]
mutated_grads: List[TensorBox]
captured_grads_compute: list[ComputedBuffer]
captured_grads: list[Optional[TensorBox]]
mutated_grads: list[TensorBox]
def process_joint_outputs(
@ -2088,7 +2089,7 @@ def process_joint_outputs(
Returns:
JointOutputResult containing processed buffers and gradients
"""
assert isinstance(all_joint_outputs, List)
assert isinstance(all_joint_outputs, list)
assert (
all_joint_outputs[0] is not None
), "joint_subgraph_buffer is None this is a bug!"
@ -2307,8 +2308,8 @@ def flex_attention_backward(*args, **kwargs):
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
choices: List[Any] = []
configs: List[tuple[int, int, int, int]] = []
choices: list[Any] = []
configs: list[tuple[int, int, int, int]] = []
configs.append(_get_default_config_bwd(query))
if config.max_autotune:
num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1]

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
from typing import Any, List
from typing import Any
import sympy
@ -415,8 +415,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
choices: List[Any] = []
configs: List[tuple[int, int, int]] = []
choices: list[Any] = []
configs: list[tuple[int, int, int]] = []
configs.append(_get_decoding_default_config(key))
# Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
if config.max_autotune:

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import functools
import logging
from typing import Any, Dict, List, Optional
from typing import Any, Optional
import torch
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
@ -933,7 +933,7 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
def mul_epilogue(v1, v2):
return V.ops.mul(v1, v2)
choices: List[Dict[Any, Any]] = []
choices: list[dict[Any, Any]] = []
for config in int8_mm_configs(
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
):

View File

@ -2,7 +2,8 @@
import functools
import itertools
import logging
from typing import Any, cast, Dict, Sequence
from collections.abc import Sequence
from typing import Any, cast
import sympy
@ -438,7 +439,7 @@ def mm_grid(m, n, meta):
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
def persistent_mm_grid(M: int, N: int, meta: Dict[str, Any]):
def persistent_mm_grid(M: int, N: int, meta: dict[str, Any]):
"""Defines the grid for persistent kernels."""
return (
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),

View File

@ -1,5 +1,6 @@
import logging
from typing import Any, Dict, List, Optional, Sequence
from collections.abc import Sequence
from typing import Any, Optional
import sympy
@ -428,7 +429,7 @@ def scaled_mm_options_device_tma( # type: ignore[no-untyped-def]
scale_b: StorageBox,
use_fast_accum: bool,
b_prologue_cast_type: Optional[str] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
even_k_symbolic = (
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
@ -464,7 +465,7 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
scale_b: StorageBox,
use_fast_accum: bool,
b_prologue_cast_type: Optional[str] = None,
) -> Dict[str, Any]:
) -> dict[str, Any]:
even_k_symbolic = (
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
@ -533,7 +534,7 @@ def tuned_scaled_mm(
input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
)
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
if use_aten_gemm_kernels():
choices.append(aten_choice)

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import logging
from typing import List, TYPE_CHECKING
from typing import TYPE_CHECKING
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
@ -75,7 +75,7 @@ uint4x2_mixed_mm_template = TritonTemplate(
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
for config in mm_configs(m, n, k):
uint4x2_mixed_mm_template.maybe_append_choice(

View File

@ -6,7 +6,7 @@ import functools
import itertools
import re
from enum import auto, Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, TypeVar
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, TypeVar
import sympy
@ -21,6 +21,10 @@ from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
from .virtualized import ops, V
if TYPE_CHECKING:
from collections.abc import Sequence
T = TypeVar("T")
@ -83,14 +87,14 @@ class LoopBody:
indexing simplifications and makes it easier to analyze loop bodies.
"""
indexing_exprs: Dict[str, sympy.Expr]
indexing_exprs_name: Dict[sympy.Expr, str]
submodules: Dict[str, Any]
subblocks: Dict[str, LoopBodyBlock]
indirect_vars: List[sympy.Symbol]
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr]
indexing_exprs: dict[str, sympy.Expr]
indexing_exprs_name: dict[sympy.Expr, str]
submodules: dict[str, Any]
subblocks: dict[str, LoopBodyBlock]
indirect_vars: list[sympy.Symbol]
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr]
root_block: LoopBodyBlock
memory_usage: Dict[MemoryUsageType, List[MemoryEntry]]
memory_usage: dict[MemoryUsageType, list[MemoryEntry]]
op_counts: collections.Counter[str]
def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars):
@ -120,7 +124,7 @@ class LoopBody:
self.submodules = {"get_index": self.get_index}
self.subblocks = {}
self.indirect_vars = []
self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
self.indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] = {}
self.memory_usage = {t: [] for t in MemoryUsageType}
self.op_counts = collections.Counter()
self.root_block = LoopBodyBlock(self, fn, args) # traces
@ -433,7 +437,7 @@ class LoopBodyBlock:
operations will manifest as an extra LoopBodyBlock.
"""
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
self.body = body
def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):

View File

@ -8,8 +8,8 @@ import operator
import os
import warnings
from collections import defaultdict
from collections.abc import Iterable
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from collections.abc import Iterable, Sequence
from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec
from unittest.mock import patch
@ -90,9 +90,9 @@ FALLBACK_ALLOW_LIST = OrderedSet(
)
log = logging.getLogger(__name__)
lowerings: Dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
_maybe_layout_constraints: Dict[
_maybe_layout_constraints: dict[
torch._ops.OpOverload, Optional[Callable[..., Any]]
] = {}
fallbacks = OrderedSet[torch._ops.OpOverload]()
@ -106,7 +106,7 @@ foreach_ops = OrderedSet[torch._ops.OpOverload](
# TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload
# so why is it in foreach_ops?
inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]()
inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
quantized_decomposed = torch.ops.quantized_decomposed
@ -313,12 +313,12 @@ def in_namespace(op, namespace):
def transform_args(
args: List[Any],
kwargs: Dict[str, Any],
args: list[Any],
kwargs: dict[str, Any],
broadcast: bool,
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
convert_input_to_bool: bool,
) -> tuple[List[Any], Dict[str, Any]]:
) -> tuple[list[Any], dict[str, Any]]:
args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
# check that there's something to transform
@ -428,8 +428,8 @@ def _register_lowering(
@functools.wraps(decomp_fn)
def wrapped(*args, **kwargs):
args: List[Any] = list(args)
kwargs: Dict[str, Any] = dict(kwargs)
args: list[Any] = list(args)
kwargs: dict[str, Any] = dict(kwargs)
unpacked = False
# TODO maybe we need to use pytrees here
if len(args) == 1 and isinstance(args[0], (list, tuple)):
@ -654,7 +654,7 @@ def make_pointwise(
def make_foreach_pointwise(pw_fn, allow_alpha=False):
def inner(*inputs: List[List[TensorBox]], alpha=1):
def inner(*inputs: list[list[TensorBox]], alpha=1):
realize_outputs = (
len(V.graph.current_node.users) == 0
or V.graph.current_node.target in inplace_foreach_ops
@ -682,7 +682,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
outputs = [None] * len(a_list_input)
for (device, use_foreach), group in groups.items():
operation_list: List[str] = []
operation_list: list[str] = []
for (
output_ind,
args,
@ -749,7 +749,7 @@ def _foreach_map(subgraph, *args, **kwargs):
outputs = [None] * len(sub_outputs)
for (device, use_foreach), group in groups.items():
operation_list: List[str] = []
operation_list: list[str] = []
for (
output_ind,
output,
@ -949,7 +949,7 @@ def where(cond, a, b):
def broadcast_tensors(*inputs):
if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
return broadcast_tensors(*inputs[0])
target: List[sympy.Expr] = functools.reduce(
target: list[sympy.Expr] = functools.reduce(
broadcast_symbolic_shapes, [x.get_size() for x in inputs], []
)
outputs = []
@ -1231,7 +1231,7 @@ def as_strided_copy(x, size, stride, storage_offset=None):
def pointwise_cat(inputs, dim=0):
# (inclusive, exclusive)
inputs_ranges: List[tuple[sympy.Expr, sympy.Expr]] = []
inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = []
prev_end = 0
for inp in inputs:
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
@ -2173,7 +2173,7 @@ def inductor_lookup_seed(seeds, index):
@register_lowering(inductor_prims.random, type_promotion_kind=None)
def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0):
def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0):
assert not config.fallback_random
assert mode in ("rand", "randn")
size = [*size]
@ -2202,7 +2202,7 @@ def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int
@register_lowering(inductor_prims.randint, type_promotion_kind=None)
def inductor_randint(
low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0
low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0
):
assert not config.fallback_random
size = [*size]
@ -2916,7 +2916,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
else:
dtype = dtype or torch.get_default_dtype()
ranges: List[sympy.Expr] = []
ranges: list[sympy.Expr] = []
if isinstance(data, sympy.Basic):
@ -4041,7 +4041,7 @@ def constant_pad_nd(x, padding, fill_value=0):
n = len(sizes) - len(bounds)
# if padding is a complicated expression, hoist it
bounds_precomp: List[tuple[sympy.Symbol, Any]] = []
bounds_precomp: list[tuple[sympy.Symbol, Any]] = []
for l, h in bounds:
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]

View File

@ -4,7 +4,7 @@ import collections
import dataclasses
import heapq
import logging
from typing import Callable, Dict, List, TYPE_CHECKING, TypedDict, Union
from typing import Callable, TYPE_CHECKING, TypedDict, Union
from torch._utils_internal import signpost_event
from torch.utils._ordered_set import OrderedSet
@ -61,9 +61,9 @@ class FreeableInputBuffer:
def get_freeable_input_buf(
nodes: List[BaseSchedulerNode],
nodes: list[BaseSchedulerNode],
graph_inputs: OrderedSet[str],
) -> Dict[str, FreeableInputBuffer]:
) -> dict[str, FreeableInputBuffer]:
"""
Create and keep track of all input buffers that can be freed during the program
@ -87,10 +87,10 @@ def get_freeable_input_buf(
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: Dict[
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
dep_name_to_size: Dict[str, int] = dict()
dep_name_to_size: dict[str, int] = dict()
for node in nodes:
for dep in node.read_writes.reads:
if dep.name in graph_inputs and not dep.name.startswith(
@ -100,7 +100,7 @@ def get_freeable_input_buf(
dep_name_to_size[dep.name] = _dep_size_hint(dep)
# create FreeableInputBuffer objects and add them to the returned dictionary
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = dict()
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict()
for dep_name, succ_nodes in dep_name_to_succ_nodes.items():
name_to_freeable_input_buf[dep_name] = FreeableInputBuffer(
dep_name,
@ -112,8 +112,8 @@ def get_freeable_input_buf(
def compute_size_for_scheduler_buffer(
name_to_buf: Dict[str, SchedulerBuffer]
) -> Dict[str, tuple[int, int]]:
name_to_buf: dict[str, SchedulerBuffer]
) -> dict[str, tuple[int, int]]:
"""
Compute the size of each scheduler buffer, including (1) memory allocated when
it is created and (2) memory deallocated when it is freed.
@ -134,7 +134,7 @@ def compute_size_for_scheduler_buffer(
from .ir import MultiOutput
from .scheduler import OutputNode
sched_buf_to_size: Dict[str, tuple[int, int]] = dict()
sched_buf_to_size: dict[str, tuple[int, int]] = dict()
def _compute_and_update_buf_size(
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
@ -175,8 +175,8 @@ def compute_size_for_scheduler_buffer(
def assign_memory_planning_info_for_scheduler_buffers(
nodes: List[BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
) -> None:
"""
For each SchedulerBuffer, assign its size info and successor nodes.
@ -187,7 +187,7 @@ def assign_memory_planning_info_for_scheduler_buffers(
# get buffer's successor nodes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: Dict[
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
for node in nodes:
@ -205,10 +205,10 @@ def assign_memory_planning_info_for_scheduler_buffers(
def assign_memory_planning_info_for_scheduler_nodes(
nodes: List[BaseSchedulerNode],
name_to_fused_node: Dict[str, BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
nodes: list[BaseSchedulerNode],
name_to_fused_node: dict[str, BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
) -> None:
"""
Assign to each scheduler node its predecessor and successor nodes.
@ -243,10 +243,10 @@ def assign_memory_planning_info_for_scheduler_nodes(
def estimate_peak_memory(
nodes: List[BaseSchedulerNode],
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
graph_outputs: OrderedSet[str],
) -> tuple[int, List[int]]:
) -> tuple[int, list[int]]:
"""
Given a list of nodes in their execution order, estimate the peak memory, by
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
@ -267,12 +267,12 @@ def estimate_peak_memory(
# get the execution step of each node, this will be used to determine
# the end_step of buffers
node_to_step: Dict[BaseSchedulerNode, int] = dict()
node_to_step: dict[BaseSchedulerNode, int] = dict()
for step, node in enumerate(nodes):
node_to_step[node] = step
# get buffers' size and liveliness information
buf_info_list: List[BufferInfo] = []
buf_info_list: list[BufferInfo] = []
# 1. for freeable input buffers
for buf_name, input_buf in name_to_freeable_input_buf.items():
end_step = (
@ -340,11 +340,11 @@ def estimate_peak_memory(
def topological_sort_lpmf(
nodes: List[BaseSchedulerNode],
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
name_to_buf: Dict[str, SchedulerBuffer],
nodes: list[BaseSchedulerNode],
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
name_to_buf: dict[str, SchedulerBuffer],
graph_outputs: OrderedSet[str],
) -> List[BaseSchedulerNode]:
) -> list[BaseSchedulerNode]:
"""
A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First".
@ -372,8 +372,8 @@ def topological_sort_lpmf(
class BufferInfo(TypedDict):
outdegree: int
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
buf_info: Dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
@ -422,7 +422,7 @@ def topological_sort_lpmf(
node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free
# schedule nodes one at a time
schedule: List[BaseSchedulerNode] = []
schedule: list[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule:
@ -464,7 +464,7 @@ def topological_sort_lpmf(
return schedule
def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
A BFS topological sort that selects nodes whose dependencies are executed the
earliest. This follows a FIFO idea. Specifically, at every iteration, for each node
@ -478,11 +478,11 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
indegree: int
order: int
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
@dataclasses.dataclass
class NodeWithPriority:
priority: List[int]
priority: list[int]
node: BaseSchedulerNode
def __lt__(self, other: NodeWithPriority) -> bool:
@ -490,7 +490,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
return self.node.mpi_node.index < other.node.mpi_node.index
return self.priority < other.priority
def _node_priority(node: BaseSchedulerNode) -> List[int]:
def _node_priority(node: BaseSchedulerNode) -> list[int]:
# priority is the order in which predecessor nodes are executed
assert node_info[node]["indegree"] == 0
exec_orders = sorted(
@ -502,7 +502,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
# compute nodes' number of unmet dependencies (for schedulability)
# initialize the list of nodes ready to be scheduled
nodes_to_schedule: List[NodeWithPriority] = []
nodes_to_schedule: list[NodeWithPriority] = []
for node in nodes:
node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1}
if node_info[node]["indegree"] == 0:
@ -511,7 +511,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
)
# schedule nodes one at a time
schedule: List[BaseSchedulerNode] = []
schedule: list[BaseSchedulerNode] = []
num_iters: int = 0
while num_iters < len(nodes) and nodes_to_schedule:
# select a node to schedule
@ -536,7 +536,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
return schedule
def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
This is a DFS topological sort. The setup is similar to `topological_sort_schedule`
in scheduler.py. The difference is the order nodes are visited in the outer loop.
@ -546,9 +546,9 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
the nodes in ascending order of this priority.
"""
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []
size_with_reads: Dict[BaseSchedulerNode, int] = dict()
name_to_node: dict[str, BaseSchedulerNode] = dict()
result: list[BaseSchedulerNode] = []
size_with_reads: dict[BaseSchedulerNode, int] = dict()
def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
@ -579,17 +579,17 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
def reorder_for_peak_memory(
nodes: List[BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
name_to_fused_node: Dict[str, BaseSchedulerNode],
nodes: list[BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
graph_inputs: OrderedSet[str],
graph_outputs: OrderedSet[str],
methods: List[Callable[..., List[BaseSchedulerNode]]] = [ # noqa: B006
methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006
topological_sort_lpmf,
topological_sort_bfs,
topological_sort_dfs,
],
) -> List[BaseSchedulerNode]:
) -> list[BaseSchedulerNode]:
"""
Try a few heuristics based topological sort algorithms, and pick the one whose
resulting topological order has the lowest peak memory estimation.
@ -599,13 +599,13 @@ def reorder_for_peak_memory(
@dataclasses.dataclass
class PeakMemoryResult:
order: List[BaseSchedulerNode]
order: list[BaseSchedulerNode]
peak_memory: int
method: str
# preparation -- as nodes are scheduled one at a time, these help
# keep track of when a buffer can be freed, and when a node can be scheduled
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf(
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
nodes, graph_inputs
)
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
@ -614,7 +614,7 @@ def reorder_for_peak_memory(
)
# keep track of the peak memory estimates of different methods
peak_memory_diff_methods: List[PeakMemoryResult] = []
peak_memory_diff_methods: list[PeakMemoryResult] = []
# the default
estimated_peak_memory, _ = estimate_peak_memory(

View File

@ -8,7 +8,7 @@ import os
import re
from dataclasses import dataclass
from functools import lru_cache
from typing import Dict, List, TYPE_CHECKING
from typing import TYPE_CHECKING
from torch._inductor import config
from torch._inductor.utils import get_benchmark_name
@ -23,13 +23,13 @@ if TYPE_CHECKING:
generated_kernel_count = 0
generated_cpp_vec_kernel_count = 0
num_bytes_accessed = 0
nodes_num_elem: List[
nodes_num_elem: list[
tuple[
BaseSchedulerNode,
int,
]
] = []
node_runtimes: List[tuple[BaseSchedulerNode, float]] = []
node_runtimes: list[tuple[BaseSchedulerNode, float]] = []
# counters for tracking fusions
ir_nodes_pre_fusion = 0
@ -45,7 +45,7 @@ class CppOuterLoopFusedCount:
# The length counts the number of outer loop fusions.
cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = []
num_comprehensive_padding = 0
num_matches_for_scatter_upon_const_tensor = 0
@ -122,13 +122,13 @@ class CachedMetricsHelper:
globals()[metric] += getattr(delta, metric)
REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {}
@dataclass
class MetricTable:
table_name: str
column_names: List[str]
column_names: list[str]
num_rows_added: int = 0

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
from typing import Any, List, Optional
from typing import Any, Optional
import sympy
@ -31,13 +31,13 @@ def _prepare_convolution_fusion_create(
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding: List[int],
stride: List[int],
dilation: List[int],
padding: list[int],
stride: list[int],
dilation: list[int],
groups: int,
transposed: bool = False,
output_padding: Optional[List[int]] = None,
quantize_args: Optional[List["TensorBox"]] = None,
output_padding: Optional[list[int]] = None,
quantize_args: Optional[list["TensorBox"]] = None,
other: Optional["TensorBox"] = None,
):
"""
@ -204,7 +204,7 @@ def _prepare_linear_fusion_create(
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
quantize_args: Optional[List["TensorBox"]] = None,
quantize_args: Optional[list["TensorBox"]] = None,
other: Optional["TensorBox"] = None,
binary_sum: bool = False,
):
@ -252,7 +252,7 @@ def _prepare_linear_fusion_create(
output_size,
output_stride,
)
constant_args: List[Any] = []
constant_args: list[Any] = []
if bias is not None:
inputs.append(bias)
@ -298,12 +298,12 @@ class ConvolutionUnary(ExternKernelAlloc):
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
padding_: list[int],
stride_: list[int],
dilation_: list[int],
groups: int,
attr,
scalars: Optional[List[Any]],
scalars: Optional[list[Any]],
algorithm,
):
(
@ -357,14 +357,14 @@ class ConvolutionBinary(ExternKernelAlloc):
other: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
padding_: list[int],
stride_: list[int],
dilation_: list[int],
groups: int,
binary_attr: str,
binary_alpha: Optional[float],
unary_attr: Optional[str],
unary_scalars: Optional[List[Any]],
unary_scalars: Optional[list[Any]],
unary_algorithm: Optional[str],
):
(
@ -431,14 +431,14 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
other: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
stride_: List[int],
dilation_: List[int],
padding_: list[int],
stride_: list[int],
dilation_: list[int],
groups: int,
binary_attr: str,
binary_alpha: Optional[float],
unary_attr: Optional[str],
unary_scalars: Optional[List[Any]],
unary_scalars: Optional[list[Any]],
unary_algorithm: Optional[str],
):
(
@ -496,13 +496,13 @@ class ConvolutionTransposeUnary(ExternKernelAlloc):
x: "TensorBox",
weight: "TensorBox",
bias: "TensorBox",
padding_: List[int],
output_padding_: List[int],
stride_: List[int],
dilation_: List[int],
padding_: list[int],
output_padding_: list[int],
stride_: list[int],
dilation_: list[int],
groups_: int,
attr,
scalars: Optional[List[Any]],
scalars: Optional[list[Any]],
algorithm,
):
transposed = True
@ -580,9 +580,9 @@ class QConvPointWisePT2E(ExternKernelAlloc):
w_scale: "TensorBox",
w_zero_point: "TensorBox",
bias: "TensorBox",
stride: List[int],
padding: List[int],
dilation: List[int],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
output_scale: float,
output_zero_point: int,
@ -692,9 +692,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
w_zero_point,
qaccum: "TensorBox",
bias: "TensorBox",
stride: List[int],
padding: List[int],
dilation: List[int],
stride: list[int],
padding: list[int],
dilation: list[int],
groups: int,
output_scale: "TensorBox",
output_zero_point: "TensorBox",
@ -1139,7 +1139,7 @@ class MkldnnRnnLayer(ExternKernelAlloc):
hx: "TensorBox",
cx: "TensorBox",
reverse: bool,
batch_sizes: List[int],
batch_sizes: list[int],
mode: int,
hidden_size: int,
num_layers: int,

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
import functools
from typing import List, Optional
from typing import Optional
import torch
import torch.utils._pytree as pytree
@ -31,8 +31,8 @@ from .virtualized import ops, V
def grouped_gemm_lowering(
x: TensorBox,
w: List[TensorBox],
b: List[TensorBox],
w: list[TensorBox],
b: list[TensorBox],
attr=None,
scalars=None,
algorithm=None,
@ -47,7 +47,7 @@ def grouped_gemm_lowering(
assert use_max_autotune()
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
kwargs = dict(
@ -245,7 +245,7 @@ def register_onednn_fusion_ops():
x = view(x, [-1, x_size[-1]])
if b is not None:
b = ir.ExternKernel.realize_input(b)
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
if use_max_autotune():
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
@ -308,7 +308,7 @@ def register_onednn_fusion_ops():
y = view(y, [-1, y_size[-1]])
if b is not None:
b = ir.ExternKernel.realize_input(b)
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
if use_max_autotune():
transposed_w = permute(w, [1, 0])
*_, layout, x, transposed_w, y = mm_args(
@ -397,7 +397,7 @@ def register_onednn_fusion_ops():
hx: TensorBox,
cx: TensorBox,
reverse: bool,
batch_sizes: List[int],
batch_sizes: list[int],
mode: int,
hidden_size: int,
num_layers: int,
@ -611,7 +611,7 @@ def register_onednn_fusion_ops():
bias_dtype = None if bias is None else bias.get_dtype()
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
if use_max_autotune():
*_, layout, x, packed_weight = mm_args(
x, packed_weight, layout=layout, out_dtype=output_dtype
@ -888,7 +888,7 @@ def register_onednn_fusion_ops():
), "dtype of accum for qlinear post op sum should be the same as output"
x2_dtype = x2.get_dtype()
bias_dtype = bias.get_dtype() if bias is not None else None
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
if (
use_max_autotune() and binary_attr == "add"
): # <TODO> Support inplace sum fusion
@ -1131,7 +1131,7 @@ def register_onednn_fusion_ops():
*,
layout=None,
):
choices: List[ChoiceCaller] = []
choices: list[ChoiceCaller] = []
if use_max_autotune():
transposed_w = permute(orig_w, [1, 0])
*_, layout, x, transposed_w = mm_args(

View File

@ -6,7 +6,7 @@ import contextlib
import dataclasses
import sys
import threading
from typing import Any, Callable, Dict, Optional, Type, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing_extensions import override, Self
from unittest.mock import patch
@ -56,7 +56,7 @@ class Stats:
class _GlobalItemStats(Stats):
cache: Dict[str, object]
cache: dict[str, object]
def __init__(self) -> None:
super().__init__()
@ -266,7 +266,7 @@ class PatchCaches(contextlib.AbstractContextManager):
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:

View File

@ -1,17 +1,6 @@
# mypy: allow-untyped-defs
import itertools
from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
NamedTuple,
Optional,
TypeVar,
Union,
)
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from typing_extensions import Protocol
from unittest.mock import patch
@ -961,7 +950,7 @@ def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]:
class OpCountResult(NamedTuple):
num_ops: int
used_ops: OrderedSet[str]
read_buffers: List[str]
read_buffers: list[str]
nontrivial_read_count: int
@ -974,7 +963,7 @@ class OpCounterCSE:
self.op_count = 0
self.var_names = {}
self._used_ops = OrderedSet[str]()
self._read_names: List[str] = []
self._read_names: list[str] = []
self._nontrivial_read_count = 0
def __getattr__(self, name):
@ -1076,7 +1065,7 @@ class SimpleCSEHandler(WrapperHandler[T]):
def __init__(self, inner: OpsHandler[T]):
super().__init__(inner)
self.cse_cache: Dict[str, Union[T, tuple[T, ...]]] = {}
self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {}
self.mock = MockHandler()
def indirect_indexing(self, *args, **kwargs) -> sympy.Expr:

View File

@ -1,5 +1,5 @@
import math
from typing import Any, Dict, List
from typing import Any
import sympy
@ -40,10 +40,10 @@ def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool:
def try_to_reduce_precision(
node: Any,
bounds: Dict[Any, Any],
indirect_vars: List[Any],
indices: Dict[Any, sympy.Expr],
replacement_vals: Dict[Any, ValueRanges[sympy.Expr]],
bounds: dict[Any, Any],
indirect_vars: list[Any],
indices: dict[Any, sympy.Expr],
replacement_vals: dict[Any, ValueRanges[sympy.Expr]],
) -> None:
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
# then it's precision is set for that chain of uses, and we don't need to consider those

View File

@ -27,17 +27,7 @@ import logging
import os
import re
from pathlib import Path
from typing import (
Any,
Callable,
Counter,
Dict,
List,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
import torch
@ -62,6 +52,9 @@ from .runtime.autotune_cache import AutotuneCacheBundler
if TYPE_CHECKING:
from collections import Counter
from collections.abc import Sequence
from torch._inductor import metrics
from torch._inductor.graph import GraphLowering
@ -108,13 +101,13 @@ def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
# for expanded dimensions (a dimension which used to have size 1 -> ?)
# we can select one element from that dimension and write to it
# to achieve writing to all values of that dimension of the input tensor
def get_expanded_dims(t: torch.Tensor) -> List[int]:
def get_expanded_dims(t: torch.Tensor) -> list[int]:
if not isinstance(t, torch.Tensor):
return None
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
def index_expanded_dims(t: torch.Tensor, expanded_dims: list[int]) -> torch.Tensor:
for expanded_dim in expanded_dims:
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
return t
@ -146,7 +139,7 @@ def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
constants: Dict[str, torch.Tensor],
constants: dict[str, torch.Tensor],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
@ -213,7 +206,7 @@ def cudagraph_post_compile(
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
@ -270,7 +263,7 @@ class CompiledFxGraphConstants:
the value of constants directly off of the original saved object.
"""
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
assert g.constants is not None
return g.constants
@ -287,7 +280,7 @@ class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants):
def __init__(self, gm: torch.fx.GraphModule) -> None:
self.gm = gm
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
if g.allocated_constant_name is not None:
return {
name: getattr(self.gm, name)
@ -308,7 +301,7 @@ class CompiledFxGraph(OutputCode):
current_callable: Optional[Callable[..., Any]]
cache_key: str
source_code: str = dataclasses.field(repr=False) # Do not display source_code
cache_linemap: Optional[List[tuple[int, str]]]
cache_linemap: Optional[list[tuple[int, str]]]
device_types: OrderedSet[str]
device_idxs: OrderedSet[int]
mutated_inputs: OrderedSet[str]
@ -320,10 +313,10 @@ class CompiledFxGraph(OutputCode):
# original name of the attribute in the GraphModule. When we create the module from
# the cache entry, we then look up the constants from the current GraphModule. This
# scheme allows us to support caching with freezing.
allocated_constant_name: Optional[Dict[str, str]]
constants: Optional[Dict[str, torch.Tensor]]
torchbind_constants: Dict[str, torch._C.ScriptObject]
output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]]
allocated_constant_name: Optional[dict[str, str]]
constants: Optional[dict[str, torch.Tensor]]
torchbind_constants: dict[str, torch._C.ScriptObject]
output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]]
disabled_cudagraphs_reason: Optional[str]
metrics_deltas: metrics.CachedMetricsDeltas
counter_deltas: Counter[str]
@ -340,14 +333,14 @@ class CompiledFxGraph(OutputCode):
boxed_forward_device_index: Optional[BoxedDeviceIndex]
_boxed_call: Optional[bool] = None
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
_triton_bundle: Optional[list[TritonKernelArtifacts]] = None
def __init__(
self,
current_callable: Optional[Callable[..., Any]],
graph: GraphLowering,
gm: torch.fx.GraphModule,
output_strides: List[Optional[tuple[_StrideExprStr, ...]]],
output_strides: list[Optional[tuple[_StrideExprStr, ...]]],
disabled_cudagraphs_reason: Optional[str],
metrics_deltas: metrics.CachedMetricsDeltas,
counter_deltas: Counter[str],
@ -583,7 +576,7 @@ class CompiledAOTI(OutputCode):
Class holding an AOTInductor compiled so.
"""
filename: Union[str, List[str]]
filename: Union[str, list[str]]
def __call__(self, inputs: Sequence[Any]) -> Any:
raise NotImplementedError("NYI")

View File

@ -7,7 +7,7 @@ import subprocess
import tempfile
import zipfile
from pathlib import Path
from typing import Dict, List, Optional, Union
from typing import Optional, Union
import torch
import torch._inductor
@ -85,7 +85,7 @@ class PT2ArchiveReader:
assert self.archive_file is not None
self.archive_file.extractall(path)
def get_file_names(self) -> List[str]:
def get_file_names(self) -> list[str]:
assert self.archive_file is not None
return self.archive_file.namelist()
@ -98,7 +98,7 @@ def _run_command_and_check(cmd: str) -> None:
raise exc.CppCompileError(cmd, e.output) from e
def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
def get_aoti_file_with_suffix(suffix: str) -> str:
for file in aoti_files:
if file.endswith(suffix):
@ -159,7 +159,7 @@ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
def package_aoti(
archive_file: Union[str, io.BytesIO],
aoti_files: Union[List[str], Dict[str, List[str]]],
aoti_files: Union[list[str], dict[str, list[str]]],
) -> Union[str, io.BytesIO]:
"""
Saves the AOTInductor generated files to the PT2Archive format.
@ -244,12 +244,12 @@ class AOTICompiledModel:
flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined]
return pytree.tree_unflatten(flat_outputs, out_spec)
def get_metadata(self) -> Dict[str, str]:
def get_metadata(self) -> dict[str, str]:
return self.loader.get_metadata() # type: ignore[attr-defined]
def load_constants(
self,
constants_map: Dict[str, torch.Tensor],
constants_map: dict[str, torch.Tensor],
*,
check_full_update: bool,
) -> None:
@ -265,7 +265,7 @@ class AOTICompiledModel:
"""
self.loader.load_constants(constants_map, False, check_full_update) # type: ignore[attr-defined]
def get_constant_fqns(self) -> List[str]:
def get_constant_fqns(self) -> list[str]:
return self.loader.get_constant_fqns() # type: ignore[attr-defined]

View File

@ -50,24 +50,9 @@ import textwrap
import typing
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Generator, Iterable, Mapping, Sequence
from pathlib import Path
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Generator,
Iterable,
List,
Mapping,
NoReturn,
Optional,
Protocol,
Sequence,
Type,
TypeVar,
Union,
)
from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union
from typing_extensions import Self, TypeIs
import torch
@ -139,7 +124,7 @@ MULTIPLE = Multiple()
def _transfer_meta(
new_meta: Dict[str, Any], old_node: torch.fx.Node, pass_name: str = ""
new_meta: dict[str, Any], old_node: torch.fx.Node, pass_name: str = ""
) -> None:
from torch.fx.traceback import NodeSource, NodeSourceAction
@ -177,10 +162,10 @@ class Match:
"""
pattern: PatternExpr
args: List[Any]
kwargs: Dict[str, Any]
nodes: List[torch.fx.Node]
targets: Dict[_TargetExpr, torch.fx.node.Target]
args: list[Any]
kwargs: dict[str, Any]
nodes: list[torch.fx.Node]
targets: dict[_TargetExpr, torch.fx.node.Target]
ctx: MatchContext
replacement_graph: Optional[torch.fx.GraphModule]
@ -189,7 +174,7 @@ class Match:
ctx: MatchContext,
pattern: PatternExpr,
args: Optional[Sequence[Any]] = None,
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
) -> None:
super().__init__()
self.pattern = pattern
@ -231,7 +216,7 @@ class Match:
if not n._erased and not n.users:
graph.erase_node(n)
def output_nodes(self) -> List[Optional[torch.fx.Node]]:
def output_nodes(self) -> list[Optional[torch.fx.Node]]:
return [
(self.ctx.pattern_to_node[p] if p is not None else None)
for p in self.ctx.outputs
@ -338,15 +323,15 @@ class MatchContext:
Internal state needed while running PatternExpr._match().
"""
outputs: List[Optional[PatternExpr]]
pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
outputs: list[Optional[PatternExpr]]
pattern_to_node: dict[PatternExpr, Optional[torch.fx.Node]]
graph: torch.fx.Graph
exclusive_node_set: List[NodeOrConstant]
exclusive_node_set: list[NodeOrConstant]
def __init__(
self,
outputs: List[Optional[PatternExpr]],
pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
outputs: list[Optional[PatternExpr]],
pattern_to_node: Optional[dict[PatternExpr, torch.fx.Node]] = None,
*,
graph: torch.fx.Graph,
) -> None:
@ -367,7 +352,7 @@ class MatchContext:
self.pattern_to_node[pattern] = node if m else None
return m
def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
def filter_multi_user_patterns(self) -> dict[PatternExpr, torch.fx.Node]:
return {
pattern: node
for pattern, node in self.pattern_to_node.items()
@ -487,7 +472,7 @@ class _TargetExpr(PatternExpr):
Base class for filtering match by node.target
"""
fns: List[FnsType]
fns: list[FnsType]
fns_set: OrderedSet[FnsType]
def __init__(
@ -806,7 +791,7 @@ class ListOf(PatternExpr):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.pattern})"
def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
def _match(self, node: list[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
if not isinstance(node, (list, tuple)) or len(node) == 0:
return FailedMatch("non_list")
m = Match(ctx, self)
@ -840,7 +825,7 @@ class ListOf(PatternExpr):
class MultiOutputPattern(PatternExpr):
outputs: List[Optional[PatternExpr]]
outputs: list[Optional[PatternExpr]]
def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
super().__init__()
@ -959,8 +944,8 @@ class PatternPrettyPrinter:
def __init__(self) -> None:
self.namespace = torch.fx.graph._Namespace()
self.memoized_objs_names: Dict[PatternExpr, str] = {}
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
self.memoized_objs_names: dict[PatternExpr, str] = {}
self.memoized_objs_pp: dict[PatternExpr, str] = {}
@staticmethod
@functools.lru_cache(None)
@ -1006,7 +991,7 @@ class PatternPrettyPrinter:
class _PassDictsType(Protocol):
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
...
@ -1069,7 +1054,7 @@ class GraphPatternEntry(PatternEntry):
@dataclasses.dataclass
class ReplacementPatternEntry(PatternEntry):
normalize_args: Callable[..., List[Any]]
normalize_args: Callable[..., list[Any]]
@staticmethod
def replace_with_graph(
@ -1253,7 +1238,7 @@ def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
def check_and_add_duplicate_pattern(
pattern: PatternExpr,
graph: Optional[torch.fx.Graph],
seen_patterns: Dict[str, List[Optional[str]]],
seen_patterns: dict[str, list[Optional[str]]],
skip_duplicates: bool = False,
) -> bool:
"""
@ -1299,7 +1284,7 @@ def register_replacement(
trace_fn: TraceFn,
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
extra_check: Callable[[Match], bool] = _return_true,
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
exclusive_arg_names: Sequence[str] = (),
search_fn_pattern: Union[PatternExpr, None] = None,
skip_duplicates: bool = False,
@ -1339,7 +1324,7 @@ def register_replacement(
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
)
)
sym_args: List[torch.SymInt] = []
sym_args: list[torch.SymInt] = []
with torch._dynamo.utils.detect_fake_mode(args):
for i, grad in enumerate(requires_grad):
if isinstance(args[i], torch.Tensor):
@ -1432,7 +1417,7 @@ def register_replacement(
return True
return False
def normalize_args(**kwargs: Any) -> List[Any]:
def normalize_args(**kwargs: Any) -> list[Any]:
args = [kwargs.pop(name) for name in argnames_static]
for i in range(1, len(kwargs) + 1):
if f"tangents_{i}" not in kwargs:
@ -1449,7 +1434,7 @@ def register_replacement(
# TODO: Revisit the functionalize_rng_ops for lowmem dropout
with functorch_config.patch(functionalize_rng_ops=False):
requires_grad: List[bool] = [
requires_grad: list[bool] = [
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
]
if search_fn_pattern is None:
@ -1493,7 +1478,7 @@ def _serialize_pattern(
search_fn: SearchFn,
example_inputs: Sequence[Any],
trace_fn: TraceFn,
scalar_workaround: Union[Dict[str, Union[float, int]], None],
scalar_workaround: Union[dict[str, Union[float, int]], None],
) -> PatternExpr:
def get_file_template() -> str:
auto_generated_msg = textwrap.dedent(
@ -1566,7 +1551,7 @@ SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patt
# This is the set of serialized patterns that we've registered. Used by
# test_serialized_patterns_up_to_date() to ensure the patterns are up
# to date.
_known_precompiled_patterns: List[
_known_precompiled_patterns: list[
tuple[
Any,
Iterable[Any],
@ -1585,7 +1570,7 @@ def gen_register_replacement(
trace_fn: TraceFn,
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
extra_check: Callable[[Match], bool] = _return_true,
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
exclusive_arg_names: Sequence[str] = (),
skip_duplicates: bool = False,
) -> None:
@ -1638,7 +1623,7 @@ def gen_pattern_and_search_gm(
search_fn: SearchFn,
example_inputs: Sequence[Any],
trace_fn: TraceFn,
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
exclusive_arg_names: Sequence[str] = (),
) -> tuple[PatternExpr, torch.fx.GraphModule]:
argnames = [*inspect.signature(search_fn).parameters.keys()]
@ -1672,7 +1657,7 @@ def gen_pattern(
search_fn: SearchFn,
example_inputs: Sequence[Any],
trace_fn: TraceFn,
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
exclusive_arg_names: Sequence[str] = (),
) -> PatternExpr:
return gen_pattern_and_search_gm(
@ -1803,8 +1788,8 @@ class PatternMatcherPass:
pass_name: Optional[str] = None,
) -> None:
super().__init__()
self.patterns: DefaultDict[
tuple[str, torch.fx.node.Target], List[PatternEntry]
self.patterns: defaultdict[
tuple[str, torch.fx.node.Target], list[PatternEntry]
] = defaultdict(list)
self.pass_name = pass_name
@ -1812,9 +1797,9 @@ class PatternMatcherPass:
# of the graph used to generate them. Because we ignore certain patterns
# in searching, but not in matching, use the graph to distinguish if two equivalent
# searches are actually different.
self.seen_patterns: Dict[str, List[Optional[str]]] = defaultdict(list)
self.seen_patterns: dict[str, list[Optional[str]]] = defaultdict(list)
def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
return self.patterns[item]
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
@ -1888,9 +1873,9 @@ def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
def fx_to_pattern(
gm: Union[torch.fx.GraphModule, torch.fx.Graph],
ignore_types: Sequence[Type[Any]] = (),
ignore_types: Sequence[type[Any]] = (),
argnames: Sequence[str] = (),
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
exclusive_arg_names: Sequence[str] = (),
) -> PatternExpr:
"""
@ -1904,7 +1889,7 @@ def fx_to_pattern(
assert len(inv_scalar_workaround) == len(scalar_workaround)
def process_arg(
x: T, ignore_types_override: Optional[Sequence[Type[Any]]] = None
x: T, ignore_types_override: Optional[Sequence[type[Any]]] = None
) -> Union[T, KeywordArg, Ignored]:
current_ignore_types = (
ignore_types_override if ignore_types_override is not None else ignore_types
@ -1950,7 +1935,7 @@ def fx_to_pattern(
def process_arg_fn_impl(
x: T,
ignore_types_override: Optional[Sequence[Type[Any]]] = tuple(
ignore_types_override: Optional[Sequence[type[Any]]] = tuple(
t for t in ignore_types if t is not int
),
) -> Union[T, KeywordArg, Ignored]:
@ -2054,8 +2039,8 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
return gm
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
args: List[torch.fx.node.Argument] = []
def _args(n: torch.fx.Node) -> list[torch.fx.node.Argument]:
args: list[torch.fx.node.Argument] = []
torch.fx.map_arg((n.args, n.kwargs), args.append)
return args
@ -2152,7 +2137,7 @@ def get_arg_value(
)
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]:
fns = [fn]
if isinstance(fn, torch._ops.OpOverloadPacket):
fns.extend([getattr(fn, overload) for overload in fn.overloads()])

View File

@ -10,7 +10,7 @@ import os
import sys
import typing
from abc import abstractmethod
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
from typing import Any, Callable, Generic, Optional, TypeVar, Union
from typing_extensions import override, TypeAlias
from torch._dynamo.utils import dynamo_timed
@ -34,7 +34,7 @@ if config.is_fbcode():
Sample: TypeAlias = Sample_
else:
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
Sample: TypeAlias = type[object] # type: ignore[misc,no-redef]
_T = TypeVar("_T")
@ -106,7 +106,7 @@ class RemoteCacheSerde(Generic[_T, _U]):
JsonDataTy = Optional[
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
Union[int, float, str, bool, dict[str, "JsonDataTy"], list["JsonDataTy"]]
]
@ -371,7 +371,7 @@ class _CacheStat:
class _CacheStats:
_stats: Dict[str, _CacheStat]
_stats: dict[str, _CacheStat]
def __init__(self) -> None:
self._stats = collections.defaultdict(_CacheStat)

View File

@ -6,7 +6,7 @@ import logging
import os
import os.path
import re
from typing import Dict, List, Optional, TYPE_CHECKING
from typing import Optional, TYPE_CHECKING
from typing_extensions import override
import torch
@ -29,7 +29,7 @@ if TYPE_CHECKING:
log = logging.getLogger(__name__)
_InductorMetaTy = Dict[str, object]
_InductorMetaTy = dict[str, object]
def inductor_meta_from_config() -> _InductorMetaTy:
@ -88,7 +88,7 @@ class AutotuneCache:
return hashlib.sha256(key.encode("utf-8")).hexdigest()
# Read the best config options from the most local cache and return it.
def _read(self) -> Optional[Dict[str, JsonDataTy]]:
def _read(self) -> Optional[dict[str, JsonDataTy]]:
if local_cache := self.local_cache:
cache, key = local_cache
if best_config := cache.get(key):
@ -106,7 +106,7 @@ class AutotuneCache:
# Read the best config options from the most local cache and figure out
# which `configs` represents that option.
def read_best(
self, inductor_meta: _InductorMetaTy, configs: List[Config]
self, inductor_meta: _InductorMetaTy, configs: list[Config]
) -> Optional[Config]:
if best := self._read():
return _load_cached_autotuning(
@ -196,7 +196,7 @@ class _AutotuneCacheBundlerImpl:
_cache: RemoteCache[JsonDataTy]
# All known entries from LocalAutotuneCache.put()
_entries: Dict[str, JsonDataTy]
_entries: dict[str, JsonDataTy]
def end_compile(self) -> None:
# TODO: Do we need to compute time_taken_ms and encode that somehow?
@ -407,9 +407,9 @@ def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool:
def _load_cached_autotuning(
best_config: Dict[str, JsonDataTy],
best_config: dict[str, JsonDataTy],
configs_hash: str,
configs: List[Config],
configs: list[Config],
inductor_meta: _InductorMetaTy,
) -> Optional[Config]:
if best_config is None:

View File

@ -3,7 +3,7 @@ import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable, Dict, List, Tuple
from typing import Any, Callable
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
@ -49,8 +49,8 @@ class Benchmarker:
def benchmark(
self: Self,
fn: Callable[..., Any],
fn_args: Tuple[Any, ...],
fn_kwargs: Dict[str, Any],
fn_args: tuple[Any, ...],
fn_kwargs: dict[str, Any],
**kwargs: Any,
) -> float:
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
@ -114,7 +114,7 @@ class Benchmarker:
- The median runtime of `_callable`, in milliseconds.
"""
def run_for(ms: int) -> List[float]:
def run_for(ms: int) -> list[float]:
timings = []
run_start_t = time.perf_counter()
while True:
@ -183,7 +183,7 @@ class InductorBenchmarker(TritonBenchmarker):
def get_event_pairs(
self: Self, iters: int
) -> List[Tuple[torch.cuda.Event, torch.cuda.Event]]:
) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]:
"""Get `iters` pairs of CUDA events."""
return [
(
@ -194,7 +194,7 @@ class InductorBenchmarker(TritonBenchmarker):
]
def get_event_pairs_min_timing(
self: Self, event_pairs: List[Tuple[torch.cuda.Event, torch.cuda.Event]]
self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]]
) -> float:
"""Get the minimum timing, in milliseconds, for a group of CUDA event pairs."""
return min(

View File

@ -6,7 +6,7 @@ import sys
import warnings
from pathlib import Path
from types import ModuleType
from typing import Callable, Dict, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING:
@ -67,7 +67,7 @@ def _set_triton_ptxas_path() -> None:
def _worker_compile_triton(
load_kernel: Callable[[], CachingAutotuner], extra_env: Dict[str, str]
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
) -> None:
_set_triton_ptxas_path()
os.environ.update(extra_env)

View File

@ -5,7 +5,7 @@ import collections
import functools
import typing
from enum import auto, Enum
from typing import Dict, List, Optional, Union
from typing import Optional, Union
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
@ -167,8 +167,8 @@ class DeviceProperties(typing.NamedTuple):
class HalideInputSpec(typing.NamedTuple):
ctype: str
name: str
shape: Optional[List[str]] = None
stride: Optional[List[str]] = None
shape: Optional[list[str]] = None
stride: Optional[list[str]] = None
offset: Optional[str] = None
alias_of: Optional[str] = None
@ -192,13 +192,13 @@ class HalideInputSpec(typing.NamedTuple):
class HalideMeta(typing.NamedTuple):
argtypes: List[HalideInputSpec]
argtypes: list[HalideInputSpec]
target: str
scheduler: Optional[str] = None
scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
scheduler_flags: Optional[dict[str, Union[int, str]]] = None
cuda_device: Optional[int] = None
def args(self) -> List[str]:
def args(self) -> list[str]:
"""Command line args to pass to halide generator"""
args = [f"target={self.target}"]
if self.scheduler:

View File

@ -2,7 +2,7 @@ from __future__ import annotations
import functools
import operator
from typing import Any, Hashable, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import torch
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
@ -13,6 +13,8 @@ from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
if TYPE_CHECKING:
from collections.abc import Hashable
from .triton_compat import Config

View File

@ -16,17 +16,7 @@ import sys
import threading
import time
from collections import namedtuple
from typing import (
Any,
Callable,
Container,
Dict,
Hashable,
List,
Optional,
Tuple,
TYPE_CHECKING,
)
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
from torch.utils._ordered_set import OrderedSet
@ -76,13 +66,15 @@ from .triton_compat import (
if TYPE_CHECKING:
from collections.abc import Container, Hashable
LauncherType = Any
log = logging.getLogger(__name__)
def get_total_reduction_numel(numels: Dict[str, int]) -> int:
def get_total_reduction_numel(numels: dict[str, int]) -> int:
return conditional_product(
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
)
@ -93,7 +85,7 @@ def autotune_hints_to_configs(
size_hints,
block_size: int,
device_props: DeviceProperties,
) -> List[Config]:
) -> list[Config]:
"""
AutotuneHints can be attached to the metadata of triton kernels for providing
suggestions about what to try for autotuning. One reason to do this is if there are
@ -104,7 +96,7 @@ def autotune_hints_to_configs(
configs to try.
"""
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
configs: List[Config] = []
configs: list[Config] = []
for hint in hints:
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
if len(size_hints) == 1:
@ -180,14 +172,14 @@ class CachingAutotuner(KernelInterface):
triton_meta, # passed directly to triton
configs,
save_cache_hook,
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
mutated_arg_names: list[str], # see [Note: clone mutated buffers]
optimize_mem,
heuristic_type,
size_hints=None,
inductor_meta=None, # metadata not relevant to triton
custom_kernel=False, # whether the kernel is inductor-generated or custom
filename: Optional[str] = None,
reset_to_zero_arg_names: Optional[List[str]] = None,
reset_to_zero_arg_names: Optional[list[str]] = None,
):
super().__init__()
@ -223,8 +215,8 @@ class CachingAutotuner(KernelInterface):
for c in self.configs:
log.debug(c)
self.compile_results: List[TritonCompileResult] = []
self.launchers: List[LauncherType] = []
self.compile_results: list[TritonCompileResult] = []
self.launchers: list[LauncherType] = []
self.lock = threading.Lock()
if os.getenv("TRITON_CACHE_DIR") is None:
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
@ -430,7 +422,7 @@ class CachingAutotuner(KernelInterface):
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
self.launchers = []
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
assert (
not self.launchers
), "pickle should not be called with after make_launchers()"
@ -439,7 +431,7 @@ class CachingAutotuner(KernelInterface):
"lock": None,
}
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
self.__dict__.update(state)
self.lock = threading.Lock()
@ -636,7 +628,7 @@ class CachingAutotuner(KernelInterface):
def maybe_clone_args(
self, exclude: Container[str], *args, **kwargs
) -> tuple[List[Any], Dict[str, Any]]:
) -> tuple[list[Any], dict[str, Any]]:
"""
Prepare new args and kwargs by cloning any in-place buffers
(that are not in the provided exclusion list), to avoid autotune
@ -659,7 +651,7 @@ class CachingAutotuner(KernelInterface):
return cloned_args, cloned_kwargs
def clone_args(self, *args, **kwargs) -> tuple[List[Any], Dict[str, Any]]:
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
def benchmark_all_configs(self, *args, **kwargs):
@ -888,15 +880,15 @@ class TritonCompileResult:
@staticmethod
@functools.lru_cache(32)
def _kernel_metadata_cls(fields: Tuple[str, ...]) -> Any:
def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any:
return namedtuple("KernelMetadata", sorted(fields))
def __init__(
self,
kernel: CompiledKernel,
config: Config,
compile_meta: Dict[str, Any],
inductor_meta: Dict[str, Any],
compile_meta: dict[str, Any],
inductor_meta: dict[str, Any],
) -> None:
super().__init__()
self.kernel = kernel
@ -904,7 +896,7 @@ class TritonCompileResult:
self.compile_meta = compile_meta
self.inductor_meta = inductor_meta
def __getstate__(self) -> Dict[str, Any]:
def __getstate__(self) -> dict[str, Any]:
kernel = self.kernel
# replace the fields that don't pickle nicely
kernel_state = {
@ -916,7 +908,7 @@ class TritonCompileResult:
}
return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item]
def __setstate__(self, state: Dict[str, Any]) -> None:
def __setstate__(self, state: dict[str, Any]) -> None:
# src = ASTSource.__new__(ASTSource)
# src.__setstate__(state["kernel"]["src"])
# TODO(jansel): need to fixup src.fn which is now None
@ -1101,7 +1093,7 @@ def _find_names(obj):
return obj_names
collected_calls: List[Any] = []
collected_calls: list[Any] = []
def start_graph():
@ -1220,7 +1212,7 @@ class DebugAutotuner(CachingAutotuner):
collected_calls.append(self.cached)
def hash_configs(configs: List[Config]):
def hash_configs(configs: list[Config]):
"""
Hash used to check for changes in configurations
"""
@ -1233,8 +1225,8 @@ def hash_configs(configs: List[Config]):
def cached_autotune(
size_hints: Optional[List[int]],
configs: List[Config],
size_hints: Optional[list[int]],
configs: list[Config],
triton_meta,
heuristic_type,
filename=None,
@ -1276,7 +1268,7 @@ def cached_autotune(
if "restore_value" in triton_meta:
mutated_arg_names += triton_meta.pop("restore_value")
reset_to_zero_arg_names: List[str] = []
reset_to_zero_arg_names: list[str] = []
if "reset_to_zero" in triton_meta:
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
@ -1331,7 +1323,7 @@ def cached_autotune(
return decorator
def unique_configs(configs: List[Config]):
def unique_configs(configs: list[Config]):
"""Remove duplicate configurations"""
seen: OrderedSet[Hashable] = OrderedSet()
pruned_configs = []
@ -1362,7 +1354,7 @@ def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
)
def check_max_block(cfg: Dict[str, int]):
def check_max_block(cfg: dict[str, int]):
"""
Check that block sizes are within the maximum allowed.
"""
@ -1505,7 +1497,7 @@ def triton_config(
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def _get_nd_reduction_numels(r: int, size_hints: Dict[str, int]) -> Dict[str, int]:
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
"""
Converts a linear reduction numel to ND, in row major order.
This order is often desirable as it presents opportunities to coalesce memory
@ -1596,7 +1588,7 @@ def triton_config_reduction(
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
def _get_config(numels: Dict[str, int]) -> Dict[str, int]:
def _get_config(numels: dict[str, int]) -> dict[str, int]:
"""
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
"""
@ -1729,8 +1721,8 @@ def pointwise(
def _reduction_configs(
*, size_hints: Dict[str, int], inductor_meta: Dict[str, Any]
) -> List[Config]:
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
) -> list[Config]:
reduction_hint = inductor_meta.get("reduction_hint", None)
# Convert reductions to 1D, to simplify heuristics.
@ -1981,7 +1973,7 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
)
def _pop_config_kwargs(config: Dict[str, Any]) -> Dict[str, Any]:
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
"""Extract triton.Config options that should become kwargs"""
popped = {}
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):

View File

@ -14,20 +14,8 @@ import pprint
import textwrap
import traceback
import typing
from collections import defaultdict
from typing import (
Any,
Callable,
Counter,
DefaultDict,
Dict,
Generic,
List,
Optional,
Sequence,
TypeVar,
Union,
)
from collections import Counter, defaultdict
from typing import Any, Callable, Generic, Optional, TypeVar, Union
import sympy
@ -68,6 +56,10 @@ from .utils import (
from .virtualized import V
if typing.TYPE_CHECKING:
from collections.abc import Sequence
log = logging.getLogger(__name__)
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
@ -78,7 +70,7 @@ class SchedulerBuffer:
scheduler: Scheduler
node: ir.Buffer
defining_op: BaseSchedulerNode
users: List[NodeUser] = dataclasses.field(default_factory=list)
users: list[NodeUser] = dataclasses.field(default_factory=list)
mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
default_factory=MemoryPlanningInfoForBuffer
)
@ -154,9 +146,9 @@ class SchedulerBuffer:
return False
return True
def set_users(self, users: List[NodeUser]) -> None:
def set_users(self, users: list[NodeUser]) -> None:
# deduplicate
result: Dict[int, NodeUser] = {}
result: dict[int, NodeUser] = {}
for use in users:
if id(use.node) in result:
result[id(use.node)] = use.merge(result[id(use.node)])
@ -194,7 +186,7 @@ class BaseSchedulerNode:
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[
[BaseSchedulerNode], List[str]
[BaseSchedulerNode], list[str]
] = lambda *args, **kwargs: []
def _init_from_node(self, node: ir.Operation) -> None:
@ -204,7 +196,7 @@ class BaseSchedulerNode:
str
]() # buffers that won't be used after this kernel
self.written = False
self.outputs: List[SchedulerBuffer] = [
self.outputs: list[SchedulerBuffer] = [
SchedulerBuffer(
scheduler=self.scheduler,
node=output,
@ -212,7 +204,7 @@ class BaseSchedulerNode:
)
for output in node.get_outputs()
]
self.outputs_by_name: Dict[str, SchedulerBuffer] = {
self.outputs_by_name: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for buf in self.outputs
}
@ -247,7 +239,7 @@ class BaseSchedulerNode:
def debug_str_extra(self) -> str:
return ""
def _debug_str_for_device(self) -> List[str]:
def _debug_str_for_device(self) -> list[str]:
return self.debug_device_str(self)
def debug_str_short(self) -> str:
@ -278,7 +270,7 @@ class BaseSchedulerNode:
) -> None:
return
def update_mutated_names(self, renames: Dict[str, str]) -> None:
def update_mutated_names(self, renames: dict[str, str]) -> None:
self.set_read_writes(self.read_writes.rename(renames))
def add_fake_dep(self, dep: Dep) -> None:
@ -295,7 +287,7 @@ class BaseSchedulerNode:
self.prune_deps()
def set_last_usage(
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
) -> None:
used_buffers = self.used_or_aliased_buffer_names()
used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers)
@ -352,7 +344,7 @@ class BaseSchedulerNode:
self.set_read_writes(self.read_writes.remove_reads(to_remove))
def prune_redundant_deps(
self, name_to_fused_node: Dict[str, BaseSchedulerNode]
self, name_to_fused_node: dict[str, BaseSchedulerNode]
) -> None:
_prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
@ -582,7 +574,7 @@ class BaseSchedulerNode:
def get_read_write_buffer_accesses(
self, include_reads: bool, include_writes: bool
) -> Dict[str, int]:
) -> dict[str, int]:
"""
Counting the number of bytes accessed for a kernel is
surprisingly tricky. In particular, there is a differentiation
@ -658,7 +650,7 @@ class BaseSchedulerNode:
writes = writes - removed_buffers
reads = reads - removed_buffers
buf_byte_accesses: Dict[str, int] = {}
buf_byte_accesses: dict[str, int] = {}
for buf_name in reads | writes:
buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name])
@ -811,8 +803,8 @@ class BaseSchedulerNode:
@staticmethod
def get_prologue_template_epilogue(
nodes: List[BaseSchedulerNode],
) -> tuple[List[BaseSchedulerNode], BaseSchedulerNode, List[BaseSchedulerNode]]:
nodes: list[BaseSchedulerNode],
) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]:
"""
For the list of nodes, get the prologue, template, and epilogue
"""
@ -874,8 +866,8 @@ class OutputNode:
def _prune_redundant_deps(
node: BaseSchedulerNode,
name_to_fused_node: Dict[str, BaseSchedulerNode],
name_to_buf: Dict[str, SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
name_to_buf: dict[str, SchedulerBuffer],
) -> None:
"""
Prunes weakdeps intended for mutation ordering
@ -961,7 +953,7 @@ class SchedulerNode(BaseSchedulerNode):
def _compute_attrs(
self,
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
) -> None:
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
@ -993,7 +985,7 @@ class SchedulerNode(BaseSchedulerNode):
def recompute_size_and_body(
self,
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
) -> None:
self._compute_attrs(
@ -1120,7 +1112,7 @@ class SchedulerNode(BaseSchedulerNode):
def ranges_from_index_vars(
self, index_vars: Sequence[Sequence[sympy.Expr]]
) -> Dict[sympy.Expr, sympy.Expr]:
) -> dict[sympy.Expr, sympy.Expr]:
sizes = self._sizes
assert sum(map(len, sizes)) == sum(map(len, index_vars))
var_ranges = dict(
@ -1220,7 +1212,7 @@ def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None:
def init_group_node(
group_snode: BaseSchedulerNode,
scheduler: Scheduler,
snodes: List[BaseSchedulerNode],
snodes: list[BaseSchedulerNode],
) -> None:
assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode))
group_snode.snodes = snodes
@ -1246,7 +1238,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
its unmet dependencies as the union of its constituent nodes.
"""
snodes: List[BaseSchedulerNode]
snodes: list[BaseSchedulerNode]
@classmethod
def fuse(
@ -1319,10 +1311,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
refresh_group_node_dependencies(self)
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
super().__init__(scheduler)
init_group_node(self, scheduler, snodes)
self.users: List[NodeUser] = []
self.users: list[NodeUser] = []
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
@cache_on_self
@ -1336,8 +1328,8 @@ class FusedSchedulerNode(BaseSchedulerNode):
def get_buffer_names(self) -> OrderedSet[str]:
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
def get_outputs(self) -> List[SchedulerBuffer]:
result: List[SchedulerBuffer] = []
def get_outputs(self) -> list[SchedulerBuffer]:
result: list[SchedulerBuffer] = []
for node in self.snodes:
result.extend(node.get_outputs())
return result
@ -1358,7 +1350,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
return f"{self}, snodes: {snodes_str}"
def set_last_usage(
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
) -> None:
# Set self.last_usage using the global information
# This will be used for inter-kernel optimisations
@ -1414,7 +1406,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
# None of these need to be implemented, as a FusedSchedulerNode is just an
# abstraction for scheduling purposes
def update_mutated_names(self, renames: Dict[str, str]) -> None:
def update_mutated_names(self, renames: dict[str, str]) -> None:
raise NotImplementedError
def add_fake_dep(self, name: Dep) -> None:
@ -1546,7 +1538,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
enable_autotune = consumer.enable_autotune
prev_node_1 = None
prev_node_2 = None
fused_nodes: List[BaseSchedulerNode]
fused_nodes: list[BaseSchedulerNode]
if producer.is_foreach() and consumer.is_foreach():
producer = typing.cast(ForeachKernelSchedulerNode, producer)
consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
@ -1599,7 +1591,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
def __init__(
self,
scheduler: Scheduler,
snodes: List[BaseSchedulerNode],
snodes: list[BaseSchedulerNode],
use_custom_partition_algo: bool,
prev_node_1: Optional[BaseSchedulerNode] = None,
prev_node_2: Optional[BaseSchedulerNode] = None,
@ -1621,7 +1613,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
self.scheduler = scheduler
self.snodes = snodes
self.node = None
self.users: List[NodeUser] = []
self.users: list[NodeUser] = []
self.set_read_writes(
dependencies.ReadWrites.merge_list(
@ -1666,8 +1658,8 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@classmethod
def combinable_nodes(
cls, nodes: List[BaseSchedulerNode]
) -> List[BaseSchedulerNode]:
cls, nodes: list[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
if extern:
log.debug(
@ -1700,7 +1692,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@staticmethod
def _default_group_nodes_for_combo_kernels(
scheduler: Scheduler,
) -> List[List[BaseSchedulerNode]]:
) -> list[list[BaseSchedulerNode]]:
"""
Returns a list of lists of nodes that are to be grouped together.
"""
@ -1718,12 +1710,12 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
return grouped_nodes
group_algorithm_for_combo_kernels: Callable[
[Scheduler], List[List[BaseSchedulerNode]]
[Scheduler], list[list[BaseSchedulerNode]]
] = _default_group_nodes_for_combo_kernels
@staticmethod
def set_group_algorithm_for_combo_kernels(
custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]]
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]]
) -> None:
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
custom_group_algorithm
@ -1732,7 +1724,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@staticmethod
def group_nodes_for_combo_kernels(
scheduler: Scheduler,
) -> List[List[BaseSchedulerNode]]:
) -> list[list[BaseSchedulerNode]]:
return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler)
def mark_run(self) -> None:
@ -1744,7 +1736,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
def is_foreach(self) -> bool:
return True
def get_subkernel_nodes(self) -> List[BaseSchedulerNode]:
def get_subkernel_nodes(self) -> list[BaseSchedulerNode]:
"""Returns a list of nodes which comprise the combo kernel.
These nodes may be vertically fused."""
return list(self.snodes)
@ -1758,7 +1750,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
return self.snodes[0].get_first_name()
def prune_redundant_deps(
self, name_to_fused_node: Dict[str, BaseSchedulerNode]
self, name_to_fused_node: dict[str, BaseSchedulerNode]
) -> None:
_prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
@ -1776,10 +1768,10 @@ class GroupedSchedulerNode(BaseSchedulerNode):
At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node.
"""
snodes: List[BaseSchedulerNode]
snodes: list[BaseSchedulerNode]
@classmethod
def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode:
def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode:
scheduler = snodes[0].scheduler
assert all(node.scheduler is scheduler for node in snodes)
grouped_snode = cls(scheduler, snodes) # type: ignore[arg-type]
@ -1788,11 +1780,11 @@ class GroupedSchedulerNode(BaseSchedulerNode):
scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode
return grouped_snode
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
super().__init__(scheduler)
init_group_node(self, scheduler, snodes)
def unpack(self) -> List[BaseSchedulerNode]:
def unpack(self) -> list[BaseSchedulerNode]:
"""
Do fusion among nodes within this GroupedSchedulerNode,
and then unpack this GroupedSchedulerNode into regular nodes.
@ -1817,8 +1809,8 @@ class GroupedSchedulerNode(BaseSchedulerNode):
def get_buffer_names(self) -> OrderedSet[str]:
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
def get_outputs(self) -> List[SchedulerBuffer]:
result: List[SchedulerBuffer] = []
def get_outputs(self) -> list[SchedulerBuffer]:
result: list[SchedulerBuffer] = []
for node in self.snodes:
result.extend(node.get_outputs())
return result
@ -1833,10 +1825,10 @@ class GroupedSchedulerNode(BaseSchedulerNode):
def pick_loop_order(
stride_lengths: List[List[int]],
stride_lengths: list[list[int]],
sizes: Sequence[sympy.Expr],
priority_idx: tuple[int, ...] = (),
) -> List[int]:
) -> list[int]:
"""
A heuristic to decide loop iteration orders. This has not been well
tuned and may be something we should autotune.
@ -1914,17 +1906,17 @@ _post_grad_graph_counter = itertools.count()
class Scheduler:
__dep_size_hint_cache: Dict[Dep, int]
__dep_size_hint_cache: dict[Dep, int]
def __init__(self, nodes: List[ir.Operation]) -> None:
def __init__(self, nodes: list[ir.Operation]) -> None:
with dynamo_timed("Scheduler.__init__"):
self._init(nodes)
def _init(self, nodes: List[ir.Operation]) -> None:
def _init(self, nodes: list[ir.Operation]) -> None:
super().__init__()
self.__dep_size_hint_cache = {}
V.graph.scheduler = self
self.backends: Dict[torch.device, BaseScheduling] = {}
self.backends: dict[torch.device, BaseScheduling] = {}
self.post_grad_graph_id = next(_post_grad_graph_counter)
self.completed_operations = OrderedSet[str]()
@ -1943,23 +1935,23 @@ class Scheduler:
for node in self.nodes:
node.prune_deps()
self.name_to_donated_buffer: Dict[
self.name_to_donated_buffer: dict[
str, SchedulerDonatedBuffer
] = self.get_donated_buffers()
self.name_to_node: Dict[str, BaseSchedulerNode] = {
self.name_to_node: dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes
}
self.name_to_buf: Dict[str, SchedulerBuffer] = {
self.name_to_buf: dict[str, SchedulerBuffer] = {
buf.get_name(): buf for node in self.nodes for buf in node.get_outputs()
}
self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy()
self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy()
# mutation_real_name: Maps back to the original name for codegen
# Example:
# If you mutate buf0 inside of buf1's kernel, then:
# mutation_real_name = {"buf0" : "buf1"}
# all subsequent uses of buf0 become buf1's usage in dependency graph
self.mutation_real_name: Dict[str, str] = {}
self.mutation_real_name: dict[str, str] = {}
# We handle mutation by renaming modified versions of the same
# buffer in the dependency graph to prevent cycles.
@ -1969,7 +1961,7 @@ class Scheduler:
# If you mutate buf0 inside of buf1's kernel, then:
# mutation_renames = {"buf1" : "buf0"}
# in codegen we only use buf0, never buf1
self.mutation_renames: Dict[str, str] = {}
self.mutation_renames: dict[str, str] = {}
# Must run first to correctly set dependencies, before all other passes that rely on
# reading from .read_writes.reads or .unmet_dependencies
@ -2024,7 +2016,7 @@ class Scheduler:
# fx graph node to the position it appears in the graph
# for debug attribution
self.origin_to_index: Dict[torch.fx.Node, int] = {}
self.origin_to_index: dict[torch.fx.Node, int] = {}
get_metric_table("graph_stats").add_row(
lambda: {
@ -2034,7 +2026,7 @@ class Scheduler:
}
)
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]:
name_to_donated_buf = {}
for name in V.graph.graph_inputs_original:
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
@ -2135,7 +2127,7 @@ class Scheduler:
def __init__(
self,
items: Optional[List[T]] = None,
items: Optional[list[T]] = None,
membership: Optional[OrderedSet[T]] = None,
) -> None:
self.items = items or []
@ -2154,7 +2146,7 @@ class Scheduler:
]
return DedupList(new_items, new_membership)
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict(
DedupList
)
@ -2196,7 +2188,7 @@ class Scheduler:
NodeUser(user_node, can_inplace, is_weak)
)
unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {}
unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {}
# NB: None means that the dependency is on an input. Don't actually
# generate a dependency because if we do, Inductor will start trying
@ -2367,14 +2359,14 @@ class Scheduler:
node.prune_weak_deps()
def topological_sort_schedule(
self, nodes: List[BaseSchedulerNode]
) -> List[BaseSchedulerNode]:
self, nodes: list[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
"""
Ensure nodes is in topologically sorted order
"""
seen = OrderedSet[BaseSchedulerNode]()
name_to_node: Dict[str, BaseSchedulerNode] = dict()
result: List[BaseSchedulerNode] = []
name_to_node: dict[str, BaseSchedulerNode] = dict()
result: list[BaseSchedulerNode] = []
def visit(n: BaseSchedulerNode) -> None:
if n not in seen:
@ -2393,7 +2385,7 @@ class Scheduler:
visit(node)
return result
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]:
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]:
unmet_deps = OrderedSet[str]()
if isinstance(
snode,
@ -2415,13 +2407,13 @@ class Scheduler:
OrderedSet(self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops)
)
def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]:
def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]:
"""
Sort nodes by their topological order, return a list of node lists.
"""
order = []
nodes = dict.fromkeys(self.nodes, 0)
children: Dict[Any, Any] = {}
children: dict[Any, Any] = {}
for node in self.nodes:
deps = self._get_unmet_dep_nodes(node)
nodes[node] = len(deps)
@ -2446,7 +2438,7 @@ class Scheduler:
Populate each node.ancestors
"""
# note self.nodes is topologically sorted
name_to_ancestors: Dict[str, OrderedSet[str]] = {}
name_to_ancestors: dict[str, OrderedSet[str]] = {}
for node in self.nodes:
ancestors = OrderedSet[str]()
for dep in node.unmet_dependencies:
@ -2487,7 +2479,7 @@ class Scheduler:
# FusedSchedulerNode having different merged loops.
# Skip CPU backend for now.
def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
Combine eligible nodes into FusedSchedulerNodes.
"""
@ -2518,7 +2510,7 @@ class Scheduler:
"""
Unpack GroupedSchedulerNode into regular nodes.
"""
new_nodes: List[BaseSchedulerNode] = []
new_nodes: list[BaseSchedulerNode] = []
for node in self.nodes:
new_nodes.extend(
node.unpack() if isinstance(node, GroupedSchedulerNode) else [node]
@ -2802,8 +2794,8 @@ class Scheduler:
return ms_fused < ms1 + ms2
def fuse_nodes_once(
self, nodes: List[BaseSchedulerNode]
) -> List[BaseSchedulerNode]:
self, nodes: list[BaseSchedulerNode]
) -> list[BaseSchedulerNode]:
"""
Combine eligible nodes into FusedSchedulerNodes.
@ -2890,20 +2882,20 @@ class Scheduler:
)
self.prune_redundant_deps(self.nodes)
def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None:
def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None:
for node in nodes:
node.prune_redundant_deps(self.name_to_fused_node)
def get_possible_fusions(
self, nodes: List[BaseSchedulerNode]
) -> List[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
self, nodes: list[BaseSchedulerNode]
) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
"""
Helper to find all legal fusion opportunities, sorted by self.score_fusion()
"""
possible_fusions = []
seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]()
def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None:
def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None:
for node1_index, node1 in enumerate(nodes):
for node2 in nodes[node1_index + 1 :]:
key = (node1, node2)
@ -3015,7 +3007,7 @@ class Scheduler:
def _find_single_user_inputs(
node: BaseSchedulerNode,
) -> List[ir.Buffer]:
) -> list[ir.Buffer]:
output = []
for rd in node.read_writes.reads:
buf = self.name_to_buf.get(rd.name)
@ -3403,7 +3395,7 @@ class Scheduler:
"""
node1_buf_names = node1.get_buffer_names()
why = WhyNoFuse(node1, node2)
remaining_deps_by_name: Dict[str, List[Dep]] = defaultdict(list)
remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list)
for dep in node2.unmet_dependencies:
name = self.mutation_renames.get(dep.name, dep.name)
@ -3562,14 +3554,14 @@ class Scheduler:
return sum(self.dep_size_hint(dep) for dep in common_memory_deps)
def get_possible_fusions_with_highest_priority(
self, possible_fusions: List[tuple[BaseSchedulerNode, BaseSchedulerNode]]
) -> List[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
# Group the possible fusions based on their priority from the backend.
# Only return the group of possible fusions with highest priority.
if len(possible_fusions) == 0:
return possible_fusions
possible_fusions_group_by_priority: Dict[
int, List[tuple[BaseSchedulerNode, BaseSchedulerNode]]
possible_fusions_group_by_priority: dict[
int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
] = {}
for node1, node2 in possible_fusions:
@ -3828,7 +3820,7 @@ class Scheduler:
backend = self.get_backend(device)
return backend.benchmark_combo_kernel(node_list)
def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool:
def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool:
"""
If config.benchmark_fusion is False, always return True.
Otherwise, return True if fusion can brings speedup.

View File

@ -16,17 +16,8 @@ import textwrap
import time
from concurrent.futures import as_completed, ThreadPoolExecutor
from io import StringIO
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing_extensions import Self
from unittest.mock import patch
import sympy
@ -86,7 +77,7 @@ from .virtualized import V
log = logging.getLogger(__name__)
# correctness checks struggle with fp16/tf32
VERIFY: Dict[str, Any] = {}
VERIFY: dict[str, Any] = {}
PRINT_AUTOTUNE = True
DEBUG = False
@ -104,14 +95,11 @@ class KernelNamespace:
extern_kernels = KernelNamespace()
_T = TypeVar("_T", bound="AutotuneArgs")
@dataclasses.dataclass
class BenchmarkTensors:
"""Represents a set of inputs and outputs for autotuning with a template"""
input_tensors: List[torch.Tensor]
input_tensors: list[torch.Tensor]
output_tensor: Optional[torch.Tensor]
def unpack(self):
@ -139,13 +127,13 @@ class AutotuneArgs:
@classmethod
def from_choice_args(
cls: Type[_T],
example_inputs: List[torch.Tensor],
example_inputs_extern: List[torch.Tensor],
cls,
example_inputs: list[torch.Tensor],
example_inputs_extern: list[torch.Tensor],
out: torch.Tensor,
out_extern: torch.Tensor,
expected: Optional[torch.Tensor] = None,
) -> _T:
) -> Self:
"""Factory method to create AutotuneInputs from separate inputs/outputs"""
return cls(
triton=BenchmarkTensors(example_inputs, out),
@ -207,7 +195,7 @@ class SubgraphInfo:
ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined]
# only copied over if not None
range_trees: Optional[List["IterationRangesRoot"]] = None
range_trees: Optional[list["IterationRangesRoot"]] = None
numels = None # type: ignore[var-annotated]
def __post_init__(self):
@ -226,7 +214,7 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
self,
kernel,
subgraph_number: int,
fixed_inputs: Dict[str, Any],
fixed_inputs: dict[str, Any],
mask: Optional[str],
):
super().__init__(V.ops)
@ -290,7 +278,7 @@ class TritonTemplateKernel(TritonKernel):
prefix_args=0,
suffix_args=0,
epilogue_fn=identity,
subgraphs: Optional[List[ir.ComputedBuffer]] = None,
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
workspace_arg: Optional[WorkspaceArg] = None,
) -> None:
numel = sympy_product(output_node.get_size())
@ -317,9 +305,9 @@ class TritonTemplateKernel(TritonKernel):
self.suffix_args = suffix_args
self.epilogue_fn = epilogue_fn
self.render_hooks = {} # type: ignore[var-annotated]
self.triton_meta: Optional[Dict[str, object]] = None
self.triton_meta: Optional[dict[str, object]] = None
# For Templated Attention this can be a list of ir.Subgraph
self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
self.subgraphs: Optional[list[ir.ComputedBuffer]] = subgraphs
# Some templates use extra global memory as a workspace
self.workspace_arg = workspace_arg
@ -330,7 +318,7 @@ class TritonTemplateKernel(TritonKernel):
# used for triton kernel codegen.
# They are swapped onto the TritonTemplateKernel object by
# `set_subgraph_body`
self.subgraph_bodies: Dict[str, SubgraphInfo] = {}
self.subgraph_bodies: dict[str, SubgraphInfo] = {}
# input buffers which we are allowed to prologue fuse into
self.prologue_supported_inputs: OrderedSet[str] = OrderedSet()
@ -420,7 +408,7 @@ class TritonTemplateKernel(TritonKernel):
return "@triton.jit"
argdefs, _, signature, _ = self.args.python_argdefs()
triton_meta: Dict[str, Any] = {
triton_meta: dict[str, Any] = {
"signature": signature_to_meta(
signature, size_dtype=self.index_dtype, argdefs=argdefs
),
@ -622,7 +610,7 @@ class TritonTemplateKernel(TritonKernel):
)
with V.set_ops_handler(modification_handler):
assert isinstance(
subgraph, (ir.ComputedBuffer, List)
subgraph, (ir.ComputedBuffer, list)
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
# Handle scatter stores
if isinstance(subgraph, list):
@ -651,7 +639,7 @@ class TritonTemplateKernel(TritonKernel):
self,
input_name: str,
output_name: str,
indices: Union[List[Any], tuple[Any]],
indices: Union[list[Any], tuple[Any]],
mask: Optional[str] = None,
other: Optional[Union[float, int]] = 0.0,
indent_width: int = 4,
@ -826,7 +814,7 @@ class TritonTemplateKernel(TritonKernel):
def store_output(
self,
indices: Union[List[Any], tuple[Any]],
indices: Union[list[Any], tuple[Any]],
val: str,
mask: Optional[str] = None,
indent_width: int = 4,
@ -1032,7 +1020,7 @@ def _jinja2_env():
class TritonTemplate(KernelTemplate):
index_counter = itertools.count()
all_templates: Dict[str, "TritonTemplate"] = {}
all_templates: dict[str, "TritonTemplate"] = {}
def __init__(self, name: str, grid: Any, source: str, debug=False) -> None:
super().__init__(name)
@ -1180,7 +1168,7 @@ class TritonTemplate(KernelTemplate):
),
kwargs,
)
bmreq_cls: Type[TritonBenchmarkRequest]
bmreq_cls: type[TritonBenchmarkRequest]
if layout.device.type == "cpu":
bmreq_cls = TritonCPUBenchmarkRequest
else:
@ -1292,7 +1280,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
description,
bmreq,
log_info: Optional[
Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
] = None,
mutated_inputs=None,
workspace_arg: Optional[WorkspaceArg] = None,
@ -1303,7 +1291,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
self.bmreq: TritonBenchmarkRequest = bmreq
if log_info is None:
log_info = {}
self.log_info: Dict[str, Any] = log_info
self.log_info: dict[str, Any] = log_info
self.log_info.update(
{
"backend": "Triton",
@ -1351,7 +1339,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
)
)
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return self.log_info
@ -1447,7 +1435,7 @@ class ExternKernelCaller(ChoiceCaller):
return ir.TensorBox.create(inner)
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {
"backend": "extern",
@ -1589,7 +1577,7 @@ def create_inputs_key(input_nodes) -> str:
def create_precompile_key(
name: str, inputs_key: str, choices: List[ChoiceCaller]
name: str, inputs_key: str, choices: list[ChoiceCaller]
) -> str:
return ":".join(
[
@ -1609,18 +1597,18 @@ class AlgorithmSelectorCache(PersistentCache):
# no guarantee that the first lowering for a given key will also be the
# first to benchmark it. share a single precompilation function for all lowerings
# of a particular key
self.precompile_cache: Dict[str, Callable[[], None]] = {}
self.precompile_cache: dict[str, Callable[[], None]] = {}
# list of callbacks that are called after benchmarking
self.feedback_saver_fns: List[
self.feedback_saver_fns: list[
Callable[
[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None
]
] = []
def __call__(
self,
name,
choices: List[ChoiceCaller],
choices: list[ChoiceCaller],
input_nodes,
layout,
# optional dict mapping arg indices to the functions
@ -1628,7 +1616,7 @@ class AlgorithmSelectorCache(PersistentCache):
# corresponding ir.Buffer. if passed for a given
# arg, the function will be called instead of
# generating a random torch.Tensor for benchmarking.
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
precompilation_timeout_seconds: int = 60 * 60,
return_multi_template=False,
):
@ -1761,9 +1749,9 @@ class AlgorithmSelectorCache(PersistentCache):
executor = ThreadPoolExecutor(max_workers=num_workers)
async_compile = torch._inductor.async_compile.AsyncCompile()
futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
start_times: Dict[concurrent.futures.Future[Any], float] = {}
elapsed_times: Dict[concurrent.futures.Future[Any], float] = {}
futures: dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
start_times: dict[concurrent.futures.Future[Any], float] = {}
elapsed_times: dict[concurrent.futures.Future[Any], float] = {}
for c in choices:
if hasattr(c, "precompile"):
@ -1925,7 +1913,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_gen_fns = {}
def get_inputs(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]]
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
) -> AutotuneArgs:
# de-duplicate args
unique_example_inputs = {
@ -1996,8 +1984,8 @@ class AlgorithmSelectorCache(PersistentCache):
return result
def benchmark_in_current_process(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
) -> dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
inputs = get_inputs(choices)
timings = {}
for choice in choices:
@ -2045,7 +2033,7 @@ class AlgorithmSelectorCache(PersistentCache):
return timings
def benchmark_in_sub_process(
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]]
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
):
from . import autotune_process
@ -2069,8 +2057,8 @@ class AlgorithmSelectorCache(PersistentCache):
@staticmethod
def log_results(
name: str,
input_nodes: List[ir.IRNode],
timings: Dict[ChoiceCaller, float],
input_nodes: list[ir.IRNode],
timings: dict[ChoiceCaller, float],
elapse: float,
precompile_elapse: float,
):
@ -2227,7 +2215,7 @@ class AlgorithmSelectorCache(PersistentCache):
def add_feedback_saver(
self,
fn: Callable[
[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None
],
):
self.feedback_saver_fns.append(fn)
@ -2250,7 +2238,7 @@ def autotune_select_algorithm(*args, **kwargs):
def add_feedback_saver(
fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None]
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None]
):
global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None:

View File

@ -2,7 +2,8 @@
import functools
import itertools
import logging
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
from collections.abc import Iterable, Sequence
from typing import Any, Callable, cast, Optional, Union
import sympy
from sympy import Expr
@ -60,7 +61,7 @@ class SizeVarAllocator:
shape_env = ShapeEnv()
self.shape_env = shape_env
self.var_to_val = self.shape_env.var_to_val
self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
# The basic idea is if we have some complicated sympy expression
# f(s0), we may choose to precompute it on the host and then replace
@ -71,8 +72,8 @@ class SizeVarAllocator:
# which potentially could have already had a precomputed replacement
# on it, we are obligated to invert the precomputed replacements
# (inv_precomputed_replacements).
self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {}
self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {}
self.precomputed_replacements: dict[Expr, sympy.Symbol] = {}
self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {}
self.stride_vars = self.make_stride_vars_cache()
self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
self._simplify_loops = self.make_simplify_loops_cache()
@ -84,7 +85,7 @@ class SizeVarAllocator:
"""
self._simplify_with_ranges() can be expensive, cache its results
"""
cache: Dict[tuple[Any, ...], Expr] = {}
cache: dict[tuple[Any, ...], Expr] = {}
replacement_count = len(self.replacements)
def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
@ -106,7 +107,7 @@ class SizeVarAllocator:
"""
self._simplify_with_ranges() can be expensive, cache its results
"""
cache: Dict[tuple[Any, ...], Any] = {}
cache: dict[tuple[Any, ...], Any] = {}
replacement_count = len(self.replacements)
def simplify_loops(index_vars, sizes, index_formulas):
@ -221,7 +222,7 @@ class SizeVarAllocator:
return expr
def _simplify_loops_impl(
self, index_vars: List[sympy.Symbol], sizes, index_formulas
self, index_vars: list[sympy.Symbol], sizes, index_formulas
):
"""
Try to remove as many axis from loop iterations as possible, by:
@ -337,7 +338,7 @@ class SizeVarAllocator:
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
# See Note - [On Statically Known]
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool:
"""
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
"""
@ -501,7 +502,7 @@ class SizeVarAllocator:
self.guard_equals(left, sympy.Integer(right))
return int(right)
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]:
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
return [self.evaluate_static_shape(x) for x in left]
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
@ -582,7 +583,7 @@ class SizeVarAllocator:
index: Expr,
vars: Sequence[sympy.Symbol],
support_vars: Optional[Sequence[sympy.Symbol]] = None,
) -> List[Expr]:
) -> list[Expr]:
if not support_vars:
support_vars = vars
return cache(index, tuple(vars), tuple(support_vars))
@ -594,7 +595,7 @@ class SizeVarAllocator:
index: Expr,
vars: Sequence[sympy.Symbol],
support_vars: Sequence[sympy.Symbol],
) -> List[Expr]:
) -> list[Expr]:
"""Convert an indexing expression back into strides
NOTE: This is only valid if the index is a standard strided offset
@ -647,7 +648,7 @@ class SizeVarAllocator:
}
return expr.subs(size_dict)
def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr:
"""Extract offset part of an indexing expression"""
index = self.simplify(index)
return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0})
@ -657,7 +658,7 @@ class SizeVarAllocator:
index: Expr,
vars: Sequence[sympy.Symbol],
support_vars: Optional[Sequence[sympy.Symbol]] = None,
) -> List[int]:
) -> list[int]:
for v in index.free_symbols:
if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined]
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
@ -669,7 +670,7 @@ class SizeVarAllocator:
result.append(0)
return result
def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]:
strides = tuple(map(abs, self.stride_hints(index, vars)))
order = list(range(len(strides)))
order.sort(key=lambda x: (strides[x] == 0, strides[x]))

View File

@ -2,9 +2,10 @@
import functools
import operator
from collections.abc import Generator
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -20,7 +21,7 @@ T = TypeVar("T")
_P = ParamSpec("_P")
OpOverload = torch._ops.OpOverload
LoweringDict = Dict[Union[OpOverload, str], Callable[..., Any]]
LoweringDict = dict[Union[OpOverload, str], Callable[..., Any]]
TargetType = Union[Callable[..., Any], str]
@ -30,13 +31,13 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
lowering object. Errors if buffers are created unexpectedly
"""
graph_outputs: Optional[List[ir.IRNode]]
graph_outputs: Optional[list[ir.IRNode]]
root_graph: torch._inductor.graph.GraphLowering
_current_op: Optional[TargetType]
# For backwards of buffer_grads with scatters we allow mutations
allowed_mutations: Optional[OrderedSet[OpOverload]]
additional_lowerings: Optional[LoweringDict]
buffers: List[ir.Buffer]
buffers: list[ir.Buffer]
mutated_buffers: OrderedSet[str]
def __init__(
@ -102,7 +103,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
self,
target: TargetType,
args: Any,
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> Any:
from .lowering import lowerings
@ -123,7 +124,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
return lowerings[target](*args, **kwargs)
def output(self, target: str, args: tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]
def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override]
assert len(args) == 1
self.graph_outputs = args[0]
@ -155,7 +156,7 @@ class TracingOpsHandler(WrapperHandler[T]):
def lower_pointwise_subgraph(
subgraph: ir.Subgraph, inputs: List[InputDescriptor]
subgraph: ir.Subgraph, inputs: list[InputDescriptor]
) -> Callable[_P, Any]:
# Lower subgraph to ir.Pointwise nodes
def fake_inner_fn(

View File

@ -3,7 +3,7 @@ import logging
import os
import uuid
from pathlib import Path
from typing import List, Optional
from typing import Optional
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
from torch._utils_internal import justknobs_check
@ -48,7 +48,7 @@ class TritonKernelArtifacts:
kernel_hash: str
device: int
artifacts: List[TritonKernelArtifact]
artifacts: list[TritonKernelArtifact]
@dataclasses.dataclass(frozen=True)
@ -57,7 +57,7 @@ class TritonBundlerMetadata:
Metadata used for instrumentation
"""
cached_kernel_names: List[str]
cached_kernel_names: list[str]
class TritonBundler:
@ -76,7 +76,7 @@ class TritonBundler:
- TritonBundler.read_and_emit is called when a cache entry is read
"""
_entries: Optional[List[TritonBundleEntry]] = None
_entries: Optional[list[TritonBundleEntry]] = None
# __grp__kernel_name.json contains metadata with source code paths
# we use this as sentinal value for search and replace
@ -134,7 +134,7 @@ class TritonBundler:
@classmethod
def collect(
cls,
) -> tuple[List[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
) -> tuple[list[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
"""
This is the main function called when a cache write happens. This function
converts all the previously remembered kernels into bundled format so that
@ -150,10 +150,10 @@ class TritonBundler:
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
entries = cls._entries
if entries is not None:
result: List[TritonKernelArtifacts] = []
kernel_names: List[str] = []
result: list[TritonKernelArtifacts] = []
kernel_names: list[str] = []
for entry in entries:
artifacts: List[TritonKernelArtifact] = []
artifacts: list[TritonKernelArtifact] = []
path = os.path.join(entry.directory, entry.kernel_hash)
if not os.path.exists(path):
continue
@ -203,7 +203,7 @@ class TritonBundler:
@staticmethod
def read_and_emit(
bundle: List[TritonKernelArtifacts],
bundle: list[TritonKernelArtifacts],
) -> Optional[TritonBundlerMetadata]:
"""
This is the main function called when a cache read happens. This function
@ -223,7 +223,7 @@ class TritonBundler:
with dynamo_timed(
key="TritonBundler.read_and_emit", log_pt2_compile_event=True
):
kernel_names: List[str] = []
kernel_names: list[str] = []
for artifacts in bundle:
basedir = triton_cache_dir(artifacts.device)

View File

@ -26,18 +26,13 @@ from io import StringIO
from typing import (
Any,
Callable,
Dict,
Generic,
Iterable,
List,
NamedTuple,
Optional,
Protocol,
Sequence,
TYPE_CHECKING,
TypeVar,
Union,
ValuesView,
)
from typing_extensions import Concatenate, dataclass_transform, ParamSpec, TypeGuard
from unittest import mock
@ -49,6 +44,8 @@ from torch._inductor.runtime.hints import DeviceProperties
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence, ValuesView
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
from .codegen.common import WorkspaceArg
@ -94,7 +91,7 @@ _IS_WINDOWS = sys.platform == "win32"
log = logging.getLogger(__name__)
_T = TypeVar("_T")
VarRanges = Dict[sympy.Expr, sympy.Expr]
VarRanges = dict[sympy.Expr, sympy.Expr]
InputType = Optional[Union[torch.Tensor, int, torch.SymInt]]
GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"}
@ -308,7 +305,7 @@ def _type_of(key):
def convert_shape_to_inductor(
lst: Iterable[Union[int, torch.SymInt]]
) -> List[sympy.Expr]:
) -> list[sympy.Expr]:
"""
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
trivial. But for symbolic tensors, we need to map from SymIntNode into
@ -319,7 +316,7 @@ def convert_shape_to_inductor(
def convert_shape_to_symint(
lst: Iterable[Union[int, sympy.Expr]]
) -> List[Union[int, torch.SymInt]]:
) -> list[Union[int, torch.SymInt]]:
"""
Takes a list of shapes from Inductor and converts them into symints (or just
ints if all shapes are static).
@ -433,7 +430,7 @@ def precompute_method(obj: Any, method: str):
setattr(obj, method, lambda: result)
def precompute_methods(obj: Any, methods: List[str]):
def precompute_methods(obj: Any, methods: list[str]):
"""Replace methods with new methods that returns a precomputed constants."""
for method in methods:
precompute_method(obj, method)
@ -451,7 +448,7 @@ def pad_listlike(x, size):
# Used to ensure that iterating over a set is deterministic
def tuple_sorted(x: tuple[_T, ...]) -> List[_T]:
def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
if len(x) == 0:
return []
@ -716,7 +713,7 @@ def sympy_index_symbol(name: str) -> sympy.Symbol:
return sympy.Symbol(name, integer=True, nonnegative=True)
def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr:
"""
When the passed replacement symbol v is a string, it is converted to a symbol with name v that
have the same replaced expression integer and nonnegative properties.
@ -804,7 +801,7 @@ def output_node(gm: torch.fx.GraphModule):
return last_node
_registered_caches: List[Any] = []
_registered_caches: list[Any] = []
def clear_on_fresh_inductor_cache(obj: Any):
@ -871,7 +868,7 @@ def fresh_inductor_cache(cache_entries=None, dir=None, delete=True):
clear_inductor_caches()
def argsort(seq) -> List[int]:
def argsort(seq) -> list[int]:
# preserve original order for equal strides
getter = seq.__getitem__
a_r = range(len(seq))
@ -880,7 +877,7 @@ def argsort(seq) -> List[int]:
def argsort_sym(
shape_env, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
) -> List[int]:
) -> list[int]:
def cmp(a, b):
a_idx, a_val = a
b_idx, b_val = b
@ -1180,7 +1177,7 @@ def use_max_autotune() -> bool:
)
def _use_template_for_gpu(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
def _use_template_for_gpu(layout, allowed_layout_dtypes: list[torch.dtype]) -> bool:
return (
is_gpu(layout.device.type)
and layout.dtype in allowed_layout_dtypes
@ -1462,10 +1459,10 @@ class DebugDirManager:
torch._dynamo.config.debug_dir_root = self.prev_debug_name
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, List[str]]:
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]:
from .graph import GraphLowering
source_codes: List[str] = []
source_codes: list[str] = []
def save_output_code(code: str):
source_codes.append(code)
@ -1476,7 +1473,7 @@ def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, List[str]]:
return result, source_codes
def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, List[str]]:
def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, list[str]]:
result, source_codes = run_and_get_code(fn, *args, **kwargs)
kernels = []
for code in source_codes:
@ -1497,7 +1494,7 @@ def get_code(fn, *args, **kwargs):
"""Get the inductor-generated code, but skip any actual compilation or running."""
from .graph import GraphLowering
source_codes: List[str] = []
source_codes: list[str] = []
def save_output_code(code: str):
source_codes.append(code)
@ -2217,13 +2214,13 @@ def shape_env_from_inputs(inputs: Sequence[InputType]):
def align_inputs_from_check_idxs(
model: Callable[[List[InputType]], Any],
model: Callable[[list[InputType]], Any],
inputs_to_check: Sequence[int],
) -> Callable[[List[InputType]], Any]:
) -> Callable[[list[InputType]], Any]:
if len(inputs_to_check) == 0:
return model
def run(new_inputs: List[InputType]):
def run(new_inputs: list[InputType]):
copy_misaligned_inputs(new_inputs, inputs_to_check)
return model(new_inputs)
@ -2243,7 +2240,7 @@ def clone_preserve_strides(x: torch.Tensor):
def copy_misaligned_inputs(
new_inputs: List[InputType], check_inputs_idxs: Sequence[int]
new_inputs: list[InputType], check_inputs_idxs: Sequence[int]
) -> None:
for i in check_inputs_idxs:
_inp = new_inputs[i]
@ -2408,7 +2405,7 @@ class OpDtypeRule:
override_return_dtype: Optional[torch.dtype]
op_dtype_propagation_rules: Dict[str, OpDtypeRule] = {}
op_dtype_propagation_rules: dict[str, OpDtypeRule] = {}
def register_op_dtype_propagation_rules(
@ -2445,7 +2442,7 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True):
return wrap(cls)
def get_donated_idxs() -> Optional[List[int]]:
def get_donated_idxs() -> Optional[list[int]]:
tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None and tracing_context.fw_metadata:
return tracing_context.fw_metadata.bw_donated_idxs

View File

@ -59,7 +59,7 @@ from __future__ import annotations
from contextlib import AbstractContextManager, contextmanager
from threading import local
from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union
from torch.utils._ordered_set import OrderedSet
@ -108,7 +108,7 @@ class Virtualized(Generic[T]):
store other things, like booleans.
"""
def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]):
self._key: str = f"__torchinductor_{vname}"
self._default = default
@ -156,7 +156,7 @@ class NullKernelHandler(NullHandler):
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
_kernel: Virtualized[NullKernelHandler] = Virtualized(
"kernel", NullKernelHandler