mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
PEP585 update - torch/_inductor (#145198)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145198 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
2f9d378f7b
commit
bac62341eb
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -113,8 +113,8 @@ def register_jagged_ops():
|
||||
@register_lowering(torch.ops.aten._jagged_to_padded_dense_forward.default)
|
||||
def _jagged_to_padded_dense_forward(
|
||||
jagged_values: TensorBox,
|
||||
jagged_offsets: List[TensorBox],
|
||||
max_lengths: List[int], # list of ints/SymInts
|
||||
jagged_offsets: list[TensorBox],
|
||||
max_lengths: list[int], # list of ints/SymInts
|
||||
padding_value: float = 0.0,
|
||||
) -> TensorBox:
|
||||
device = jagged_values.get_device_or_error()
|
||||
@ -184,7 +184,7 @@ def register_jagged_ops():
|
||||
def _dense_to_jagged_forward_impl(
|
||||
fallback_op, # pyre-ignore[2]
|
||||
dense: TensorBox,
|
||||
jagged_offsets: List[TensorBox],
|
||||
jagged_offsets: list[TensorBox],
|
||||
jagged_len: Optional[int] = None,
|
||||
) -> TensorBox:
|
||||
device = dense.get_device_or_error()
|
||||
@ -257,7 +257,7 @@ def register_jagged_ops():
|
||||
@register_lowering(torch.ops.aten._padded_dense_to_jagged_forward)
|
||||
def _dense_to_jagged_forward(
|
||||
dense: TensorBox,
|
||||
jagged_offsets: List[TensorBox],
|
||||
jagged_offsets: list[TensorBox],
|
||||
jagged_len: Optional[int] = None,
|
||||
) -> TensorBox:
|
||||
return _dense_to_jagged_forward_impl(
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import cast, Optional, Sequence, TYPE_CHECKING, TypedDict
|
||||
from typing import cast, Optional, TYPE_CHECKING, TypedDict
|
||||
|
||||
import torch
|
||||
from torch._inductor.codegen.rocm.ck_conv_template import CKGroupedConvFwdTemplate
|
||||
@ -33,6 +33,8 @@ from .mm_common import build_rocm_gemm_configs, filtered_configs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from ..ir import TensorBox
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
@ -3,9 +3,10 @@
|
||||
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from enum import auto, Enum
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -90,7 +91,7 @@ def create_placeholder(
|
||||
return TensorBox.create(input_buffer)
|
||||
|
||||
|
||||
def maybe_realize(args: List[Optional[IRNode]]):
|
||||
def maybe_realize(args: list[Optional[IRNode]]):
|
||||
"""Accepts a list of optional IRNodes and returns a list of realized IRNodes"""
|
||||
return tree_map(
|
||||
lambda x: (
|
||||
@ -109,7 +110,7 @@ def get_float32_precision():
|
||||
return "'tf32'"
|
||||
|
||||
|
||||
def zeros_and_scatter_lowering(shape: List[int], indices, values):
|
||||
def zeros_and_scatter_lowering(shape: list[int], indices, values):
|
||||
# Always accumulate into fp32 then cast
|
||||
grad = _full(0, values.get_device(), torch.float32, shape)
|
||||
assert isinstance(grad, TensorBox)
|
||||
@ -153,10 +154,10 @@ def zeros_and_scatter_lowering(shape: List[int], indices, values):
|
||||
return buffer
|
||||
|
||||
|
||||
SubgraphResults = Union[List[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
|
||||
SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]]
|
||||
|
||||
|
||||
def build_subgraph_buffer(args: List[TensorBox], subgraph: Subgraph) -> SubgraphResults:
|
||||
def build_subgraph_buffer(args: list[TensorBox], subgraph: Subgraph) -> SubgraphResults:
|
||||
"""This function's goal is to take in the required args and produce the subgraph buffer
|
||||
The subgraph buffer is a ComputedBuffer that will be inlined into the triton template
|
||||
|
||||
@ -870,7 +871,7 @@ def lower_cpu(
|
||||
"torch.compile on current platform is not supported for CPU."
|
||||
)
|
||||
|
||||
fake_buffers: List[Buffer] = [] # noqa: F821
|
||||
fake_buffers: list[Buffer] = [] # noqa: F821
|
||||
placeholder_inps = [
|
||||
create_placeholder(name, dtype, query.get_device())
|
||||
for name, dtype in [
|
||||
@ -968,7 +969,7 @@ def lower_cpu(
|
||||
[B, Hq, seq_len_q, v_head_dim],
|
||||
stride=[sympy.sympify(s) for s in out_strides],
|
||||
)
|
||||
_choices: List[Any] = []
|
||||
_choices: list[Any] = []
|
||||
input_nodes = [query, key, value, kv_num_blocks, kv_indices]
|
||||
if not full_kv_num_blocks:
|
||||
no_full_kv_block = True
|
||||
@ -1214,8 +1215,8 @@ def flex_attention(
|
||||
"V_HEAD_DIM", V.graph.sizevars.evaluate_static_shape(v_head_dim)
|
||||
)
|
||||
|
||||
choices: List[Any] = []
|
||||
configs: List[tuple[int, int, int, int]] = []
|
||||
choices: list[Any] = []
|
||||
configs: list[tuple[int, int, int, int]] = []
|
||||
configs.append(_get_default_config_fwd(query))
|
||||
if config.max_autotune:
|
||||
configs += [
|
||||
@ -2071,9 +2072,9 @@ class JointOutputResult:
|
||||
"""Results from processing joint outputs."""
|
||||
|
||||
grad_input: ComputedBuffer
|
||||
captured_grads_compute: List[ComputedBuffer]
|
||||
captured_grads: List[Optional[TensorBox]]
|
||||
mutated_grads: List[TensorBox]
|
||||
captured_grads_compute: list[ComputedBuffer]
|
||||
captured_grads: list[Optional[TensorBox]]
|
||||
mutated_grads: list[TensorBox]
|
||||
|
||||
|
||||
def process_joint_outputs(
|
||||
@ -2088,7 +2089,7 @@ def process_joint_outputs(
|
||||
Returns:
|
||||
JointOutputResult containing processed buffers and gradients
|
||||
"""
|
||||
assert isinstance(all_joint_outputs, List)
|
||||
assert isinstance(all_joint_outputs, list)
|
||||
assert (
|
||||
all_joint_outputs[0] is not None
|
||||
), "joint_subgraph_buffer is None this is a bug!"
|
||||
@ -2307,8 +2308,8 @@ def flex_attention_backward(*args, **kwargs):
|
||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
SPARSE_KV_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
choices: List[Any] = []
|
||||
configs: List[tuple[int, int, int, int]] = []
|
||||
choices: list[Any] = []
|
||||
configs: list[tuple[int, int, int, int]] = []
|
||||
configs.append(_get_default_config_bwd(query))
|
||||
if config.max_autotune:
|
||||
num_stages_list = [1, 3, 4, 5] if torch.version.hip is None else [1]
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
||||
from typing import Any, List
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
||||
@ -415,8 +415,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
||||
score_mod_other_buffers = maybe_realize(score_mod_other_buffers)
|
||||
mask_mod_other_buffers = maybe_realize(mask_mod_other_buffers)
|
||||
|
||||
choices: List[Any] = []
|
||||
configs: List[tuple[int, int, int]] = []
|
||||
choices: list[Any] = []
|
||||
configs: list[tuple[int, int, int]] = []
|
||||
configs.append(_get_decoding_default_config(key))
|
||||
# Note: max_autotune is not supported yet. Causes error in lowering the dynamic shape in reduction ops.
|
||||
if config.max_autotune:
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.autoheuristic.autoheuristic import AutoHeuristicSelectAlgorithm
|
||||
@ -933,7 +933,7 @@ def tuned_fused_int_mm_mul(mat1, mat2, mat3, out_dtype, *, layout=None):
|
||||
def mul_epilogue(v1, v2):
|
||||
return V.ops.mul(v1, v2)
|
||||
|
||||
choices: List[Dict[Any, Any]] = []
|
||||
choices: list[dict[Any, Any]] = []
|
||||
for config in int8_mm_configs(
|
||||
m, n, k, **mm_config_kwargs(ir.get_device_type(mat1))
|
||||
):
|
||||
|
@ -2,7 +2,8 @@
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, cast, Dict, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
import sympy
|
||||
|
||||
@ -438,7 +439,7 @@ def mm_grid(m, n, meta):
|
||||
return (cdiv(m, meta["BLOCK_M"]) * cdiv(n, meta["BLOCK_N"]), 1, 1)
|
||||
|
||||
|
||||
def persistent_mm_grid(M: int, N: int, meta: Dict[str, Any]):
|
||||
def persistent_mm_grid(M: int, N: int, meta: dict[str, Any]):
|
||||
"""Defines the grid for persistent kernels."""
|
||||
return (
|
||||
min(meta["NUM_SMS"], cdiv(M, meta["BLOCK_M"]) * cdiv(N, meta["BLOCK_N"])),
|
||||
|
@ -1,5 +1,6 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Sequence
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
@ -428,7 +429,7 @@ def scaled_mm_options_device_tma( # type: ignore[no-untyped-def]
|
||||
scale_b: StorageBox,
|
||||
use_fast_accum: bool,
|
||||
b_prologue_cast_type: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
even_k_symbolic = (
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||
)
|
||||
@ -464,7 +465,7 @@ def scaled_mm_options( # type: ignore[no-untyped-def]
|
||||
scale_b: StorageBox,
|
||||
use_fast_accum: bool,
|
||||
b_prologue_cast_type: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
even_k_symbolic = (
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||
)
|
||||
@ -533,7 +534,7 @@ def tuned_scaled_mm(
|
||||
input_nodes, layout, out_dtype=out_dtype, use_fast_accum=use_fast_accum
|
||||
)
|
||||
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_aten_gemm_kernels():
|
||||
choices.append(aten_choice)
|
||||
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from typing import List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ..select_algorithm import autotune_select_algorithm, TritonTemplate
|
||||
from .mm_common import mm_args, mm_configs, mm_grid, mm_options
|
||||
@ -75,7 +75,7 @@ uint4x2_mixed_mm_template = TritonTemplate(
|
||||
|
||||
def tuned_uint4x2_mixed_mm(mat1, mat2, mat2_mm_shape, mat2_dtype):
|
||||
m, n, k, layout, mat1, mat2 = mm_args(mat1, mat2, layout=None, use_4x2_dim=True)
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
b_prologue_cast_type = f"tl.{mat2_dtype}".replace("torch.", "")
|
||||
for config in mm_configs(m, n, k):
|
||||
uint4x2_mixed_mm_template.maybe_append_choice(
|
||||
|
@ -6,7 +6,7 @@ import functools
|
||||
import itertools
|
||||
import re
|
||||
from enum import auto, Enum
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, TypeVar
|
||||
from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, TypeVar
|
||||
|
||||
import sympy
|
||||
|
||||
@ -21,6 +21,10 @@ from .utils import cache_on_self, sympy_index_symbol_with_prefix, sympy_subs
|
||||
from .virtualized import ops, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
@ -83,14 +87,14 @@ class LoopBody:
|
||||
indexing simplifications and makes it easier to analyze loop bodies.
|
||||
"""
|
||||
|
||||
indexing_exprs: Dict[str, sympy.Expr]
|
||||
indexing_exprs_name: Dict[sympy.Expr, str]
|
||||
submodules: Dict[str, Any]
|
||||
subblocks: Dict[str, LoopBodyBlock]
|
||||
indirect_vars: List[sympy.Symbol]
|
||||
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr]
|
||||
indexing_exprs: dict[str, sympy.Expr]
|
||||
indexing_exprs_name: dict[sympy.Expr, str]
|
||||
submodules: dict[str, Any]
|
||||
subblocks: dict[str, LoopBodyBlock]
|
||||
indirect_vars: list[sympy.Symbol]
|
||||
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr]
|
||||
root_block: LoopBodyBlock
|
||||
memory_usage: Dict[MemoryUsageType, List[MemoryEntry]]
|
||||
memory_usage: dict[MemoryUsageType, list[MemoryEntry]]
|
||||
op_counts: collections.Counter[str]
|
||||
|
||||
def __init__(self, fn, args, var_ranges, iter_vars, reduce_vars):
|
||||
@ -120,7 +124,7 @@ class LoopBody:
|
||||
self.submodules = {"get_index": self.get_index}
|
||||
self.subblocks = {}
|
||||
self.indirect_vars = []
|
||||
self.indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr] = {}
|
||||
self.indirect_var_ranges: dict[sympy.Symbol, sympy.Expr] = {}
|
||||
self.memory_usage = {t: [] for t in MemoryUsageType}
|
||||
self.op_counts = collections.Counter()
|
||||
self.root_block = LoopBodyBlock(self, fn, args) # traces
|
||||
@ -433,7 +437,7 @@ class LoopBodyBlock:
|
||||
operations will manifest as an extra LoopBodyBlock.
|
||||
"""
|
||||
|
||||
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: List[Any]):
|
||||
def __init__(self, body: LoopBody, fn: Callable[..., Any], args: list[Any]):
|
||||
self.body = body
|
||||
|
||||
def add_index(expr: sympy.Expr, mtype: MemoryUsageType, **kwargs):
|
||||
|
@ -8,8 +8,8 @@ import operator
|
||||
import os
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -90,9 +90,9 @@ FALLBACK_ALLOW_LIST = OrderedSet(
|
||||
)
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
lowerings: Dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
|
||||
lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
|
||||
# Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
|
||||
_maybe_layout_constraints: Dict[
|
||||
_maybe_layout_constraints: dict[
|
||||
torch._ops.OpOverload, Optional[Callable[..., Any]]
|
||||
] = {}
|
||||
fallbacks = OrderedSet[torch._ops.OpOverload]()
|
||||
@ -106,7 +106,7 @@ foreach_ops = OrderedSet[torch._ops.OpOverload](
|
||||
# TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload
|
||||
# so why is it in foreach_ops?
|
||||
inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]()
|
||||
inplaceable_foreach_ops: Dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
|
||||
inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
|
||||
quantized_decomposed = torch.ops.quantized_decomposed
|
||||
|
||||
|
||||
@ -313,12 +313,12 @@ def in_namespace(op, namespace):
|
||||
|
||||
|
||||
def transform_args(
|
||||
args: List[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
args: list[Any],
|
||||
kwargs: dict[str, Any],
|
||||
broadcast: bool,
|
||||
type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
|
||||
convert_input_to_bool: bool,
|
||||
) -> tuple[List[Any], Dict[str, Any]]:
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
|
||||
kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
|
||||
# check that there's something to transform
|
||||
@ -428,8 +428,8 @@ def _register_lowering(
|
||||
|
||||
@functools.wraps(decomp_fn)
|
||||
def wrapped(*args, **kwargs):
|
||||
args: List[Any] = list(args)
|
||||
kwargs: Dict[str, Any] = dict(kwargs)
|
||||
args: list[Any] = list(args)
|
||||
kwargs: dict[str, Any] = dict(kwargs)
|
||||
unpacked = False
|
||||
# TODO maybe we need to use pytrees here
|
||||
if len(args) == 1 and isinstance(args[0], (list, tuple)):
|
||||
@ -654,7 +654,7 @@ def make_pointwise(
|
||||
|
||||
|
||||
def make_foreach_pointwise(pw_fn, allow_alpha=False):
|
||||
def inner(*inputs: List[List[TensorBox]], alpha=1):
|
||||
def inner(*inputs: list[list[TensorBox]], alpha=1):
|
||||
realize_outputs = (
|
||||
len(V.graph.current_node.users) == 0
|
||||
or V.graph.current_node.target in inplace_foreach_ops
|
||||
@ -682,7 +682,7 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
|
||||
|
||||
outputs = [None] * len(a_list_input)
|
||||
for (device, use_foreach), group in groups.items():
|
||||
operation_list: List[str] = []
|
||||
operation_list: list[str] = []
|
||||
for (
|
||||
output_ind,
|
||||
args,
|
||||
@ -749,7 +749,7 @@ def _foreach_map(subgraph, *args, **kwargs):
|
||||
|
||||
outputs = [None] * len(sub_outputs)
|
||||
for (device, use_foreach), group in groups.items():
|
||||
operation_list: List[str] = []
|
||||
operation_list: list[str] = []
|
||||
for (
|
||||
output_ind,
|
||||
output,
|
||||
@ -949,7 +949,7 @@ def where(cond, a, b):
|
||||
def broadcast_tensors(*inputs):
|
||||
if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
|
||||
return broadcast_tensors(*inputs[0])
|
||||
target: List[sympy.Expr] = functools.reduce(
|
||||
target: list[sympy.Expr] = functools.reduce(
|
||||
broadcast_symbolic_shapes, [x.get_size() for x in inputs], []
|
||||
)
|
||||
outputs = []
|
||||
@ -1231,7 +1231,7 @@ def as_strided_copy(x, size, stride, storage_offset=None):
|
||||
|
||||
def pointwise_cat(inputs, dim=0):
|
||||
# (inclusive, exclusive)
|
||||
inputs_ranges: List[tuple[sympy.Expr, sympy.Expr]] = []
|
||||
inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = []
|
||||
prev_end = 0
|
||||
for inp in inputs:
|
||||
inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
|
||||
@ -2173,7 +2173,7 @@ def inductor_lookup_seed(seeds, index):
|
||||
|
||||
|
||||
@register_lowering(inductor_prims.random, type_promotion_kind=None)
|
||||
def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int = 0):
|
||||
def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0):
|
||||
assert not config.fallback_random
|
||||
assert mode in ("rand", "randn")
|
||||
size = [*size]
|
||||
@ -2202,7 +2202,7 @@ def inductor_random(size: List[int], seed: TensorBox, mode: str, *, offset: int
|
||||
|
||||
@register_lowering(inductor_prims.randint, type_promotion_kind=None)
|
||||
def inductor_randint(
|
||||
low: int, high: int, size: List[int], seed: TensorBox, *, offset: int = 0
|
||||
low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0
|
||||
):
|
||||
assert not config.fallback_random
|
||||
size = [*size]
|
||||
@ -2916,7 +2916,7 @@ def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
|
||||
else:
|
||||
dtype = dtype or torch.get_default_dtype()
|
||||
|
||||
ranges: List[sympy.Expr] = []
|
||||
ranges: list[sympy.Expr] = []
|
||||
|
||||
if isinstance(data, sympy.Basic):
|
||||
|
||||
@ -4041,7 +4041,7 @@ def constant_pad_nd(x, padding, fill_value=0):
|
||||
n = len(sizes) - len(bounds)
|
||||
|
||||
# if padding is a complicated expression, hoist it
|
||||
bounds_precomp: List[tuple[sympy.Symbol, Any]] = []
|
||||
bounds_precomp: list[tuple[sympy.Symbol, Any]] = []
|
||||
for l, h in bounds:
|
||||
bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
|
||||
|
||||
|
@ -4,7 +4,7 @@ import collections
|
||||
import dataclasses
|
||||
import heapq
|
||||
import logging
|
||||
from typing import Callable, Dict, List, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Callable, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
from torch._utils_internal import signpost_event
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -61,9 +61,9 @@ class FreeableInputBuffer:
|
||||
|
||||
|
||||
def get_freeable_input_buf(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
graph_inputs: OrderedSet[str],
|
||||
) -> Dict[str, FreeableInputBuffer]:
|
||||
) -> dict[str, FreeableInputBuffer]:
|
||||
"""
|
||||
Create and keep track of all input buffers that can be freed during the program
|
||||
|
||||
@ -87,10 +87,10 @@ def get_freeable_input_buf(
|
||||
|
||||
# get freeable input buffers' successor nodes and their sizes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
dep_name_to_succ_nodes: Dict[
|
||||
dep_name_to_succ_nodes: dict[
|
||||
str, OrderedSet[BaseSchedulerNode]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
dep_name_to_size: Dict[str, int] = dict()
|
||||
dep_name_to_size: dict[str, int] = dict()
|
||||
for node in nodes:
|
||||
for dep in node.read_writes.reads:
|
||||
if dep.name in graph_inputs and not dep.name.startswith(
|
||||
@ -100,7 +100,7 @@ def get_freeable_input_buf(
|
||||
dep_name_to_size[dep.name] = _dep_size_hint(dep)
|
||||
|
||||
# create FreeableInputBuffer objects and add them to the returned dictionary
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = dict()
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = dict()
|
||||
for dep_name, succ_nodes in dep_name_to_succ_nodes.items():
|
||||
name_to_freeable_input_buf[dep_name] = FreeableInputBuffer(
|
||||
dep_name,
|
||||
@ -112,8 +112,8 @@ def get_freeable_input_buf(
|
||||
|
||||
|
||||
def compute_size_for_scheduler_buffer(
|
||||
name_to_buf: Dict[str, SchedulerBuffer]
|
||||
) -> Dict[str, tuple[int, int]]:
|
||||
name_to_buf: dict[str, SchedulerBuffer]
|
||||
) -> dict[str, tuple[int, int]]:
|
||||
"""
|
||||
Compute the size of each scheduler buffer, including (1) memory allocated when
|
||||
it is created and (2) memory deallocated when it is freed.
|
||||
@ -134,7 +134,7 @@ def compute_size_for_scheduler_buffer(
|
||||
from .ir import MultiOutput
|
||||
from .scheduler import OutputNode
|
||||
|
||||
sched_buf_to_size: Dict[str, tuple[int, int]] = dict()
|
||||
sched_buf_to_size: dict[str, tuple[int, int]] = dict()
|
||||
|
||||
def _compute_and_update_buf_size(
|
||||
sched_buf: SchedulerBuffer, user_of_MultiOutputLayout: bool = False
|
||||
@ -175,8 +175,8 @@ def compute_size_for_scheduler_buffer(
|
||||
|
||||
|
||||
def assign_memory_planning_info_for_scheduler_buffers(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
) -> None:
|
||||
"""
|
||||
For each SchedulerBuffer, assign its size info and successor nodes.
|
||||
@ -187,7 +187,7 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
||||
|
||||
# get buffer's successor nodes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
dep_name_to_succ_nodes: Dict[
|
||||
dep_name_to_succ_nodes: dict[
|
||||
str, OrderedSet[BaseSchedulerNode]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
for node in nodes:
|
||||
@ -205,10 +205,10 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
||||
|
||||
|
||||
def assign_memory_planning_info_for_scheduler_nodes(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
) -> None:
|
||||
"""
|
||||
Assign to each scheduler node its predecessor and successor nodes.
|
||||
@ -243,10 +243,10 @@ def assign_memory_planning_info_for_scheduler_nodes(
|
||||
|
||||
|
||||
def estimate_peak_memory(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> tuple[int, List[int]]:
|
||||
) -> tuple[int, list[int]]:
|
||||
"""
|
||||
Given a list of nodes in their execution order, estimate the peak memory, by
|
||||
keeping track of the liveliness of SchedulerBuffers and FreeableInputBuffers.
|
||||
@ -267,12 +267,12 @@ def estimate_peak_memory(
|
||||
|
||||
# get the execution step of each node, this will be used to determine
|
||||
# the end_step of buffers
|
||||
node_to_step: Dict[BaseSchedulerNode, int] = dict()
|
||||
node_to_step: dict[BaseSchedulerNode, int] = dict()
|
||||
for step, node in enumerate(nodes):
|
||||
node_to_step[node] = step
|
||||
|
||||
# get buffers' size and liveliness information
|
||||
buf_info_list: List[BufferInfo] = []
|
||||
buf_info_list: list[BufferInfo] = []
|
||||
# 1. for freeable input buffers
|
||||
for buf_name, input_buf in name_to_freeable_input_buf.items():
|
||||
end_step = (
|
||||
@ -340,11 +340,11 @@ def estimate_peak_memory(
|
||||
|
||||
|
||||
def topological_sort_lpmf(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
graph_outputs: OrderedSet[str],
|
||||
) -> List[BaseSchedulerNode]:
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
A bfs-based greedy topological order. LPMF stands for "Least Peak Memory First".
|
||||
|
||||
@ -372,8 +372,8 @@ def topological_sort_lpmf(
|
||||
class BufferInfo(TypedDict):
|
||||
outdegree: int
|
||||
|
||||
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
buf_info: Dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
|
||||
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
buf_info: dict[Union[SchedulerBuffer, FreeableInputBuffer], BufferInfo] = dict()
|
||||
|
||||
# compute nodes' number of unmet dependencies (for schedulability)
|
||||
# initialize the list of nodes ready to be scheduled
|
||||
@ -422,7 +422,7 @@ def topological_sort_lpmf(
|
||||
node_info[node]["memory_to_free"] += buf.mpi_buffer.size_free
|
||||
|
||||
# schedule nodes one at a time
|
||||
schedule: List[BaseSchedulerNode] = []
|
||||
schedule: list[BaseSchedulerNode] = []
|
||||
num_iters: int = 0
|
||||
while num_iters < len(nodes) and nodes_to_schedule:
|
||||
# select a node to schedule:
|
||||
@ -464,7 +464,7 @@ def topological_sort_lpmf(
|
||||
return schedule
|
||||
|
||||
|
||||
def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def topological_sort_bfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
A BFS topological sort that selects nodes whose dependencies are executed the
|
||||
earliest. This follows a FIFO idea. Specifically, at every iteration, for each node
|
||||
@ -478,11 +478,11 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
indegree: int
|
||||
order: int
|
||||
|
||||
node_info: Dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
node_info: dict[BaseSchedulerNode, NodeInfo] = dict()
|
||||
|
||||
@dataclasses.dataclass
|
||||
class NodeWithPriority:
|
||||
priority: List[int]
|
||||
priority: list[int]
|
||||
node: BaseSchedulerNode
|
||||
|
||||
def __lt__(self, other: NodeWithPriority) -> bool:
|
||||
@ -490,7 +490,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
return self.node.mpi_node.index < other.node.mpi_node.index
|
||||
return self.priority < other.priority
|
||||
|
||||
def _node_priority(node: BaseSchedulerNode) -> List[int]:
|
||||
def _node_priority(node: BaseSchedulerNode) -> list[int]:
|
||||
# priority is the order in which predecessor nodes are executed
|
||||
assert node_info[node]["indegree"] == 0
|
||||
exec_orders = sorted(
|
||||
@ -502,7 +502,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
|
||||
# compute nodes' number of unmet dependencies (for schedulability)
|
||||
# initialize the list of nodes ready to be scheduled
|
||||
nodes_to_schedule: List[NodeWithPriority] = []
|
||||
nodes_to_schedule: list[NodeWithPriority] = []
|
||||
for node in nodes:
|
||||
node_info[node] = {"indegree": len(node.mpi_node.pred_nodes), "order": -1}
|
||||
if node_info[node]["indegree"] == 0:
|
||||
@ -511,7 +511,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
)
|
||||
|
||||
# schedule nodes one at a time
|
||||
schedule: List[BaseSchedulerNode] = []
|
||||
schedule: list[BaseSchedulerNode] = []
|
||||
num_iters: int = 0
|
||||
while num_iters < len(nodes) and nodes_to_schedule:
|
||||
# select a node to schedule
|
||||
@ -536,7 +536,7 @@ def topological_sort_bfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
return schedule
|
||||
|
||||
|
||||
def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def topological_sort_dfs(nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
This is a DFS topological sort. The setup is similar to `topological_sort_schedule`
|
||||
in scheduler.py. The difference is the order nodes are visited in the outer loop.
|
||||
@ -546,9 +546,9 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
the nodes in ascending order of this priority.
|
||||
"""
|
||||
seen: OrderedSet[BaseSchedulerNode] = OrderedSet()
|
||||
name_to_node: Dict[str, BaseSchedulerNode] = dict()
|
||||
result: List[BaseSchedulerNode] = []
|
||||
size_with_reads: Dict[BaseSchedulerNode, int] = dict()
|
||||
name_to_node: dict[str, BaseSchedulerNode] = dict()
|
||||
result: list[BaseSchedulerNode] = []
|
||||
size_with_reads: dict[BaseSchedulerNode, int] = dict()
|
||||
|
||||
def visit(n: BaseSchedulerNode) -> None:
|
||||
if n not in seen:
|
||||
@ -579,17 +579,17 @@ def topological_sort_dfs(nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNo
|
||||
|
||||
|
||||
def reorder_for_peak_memory(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
graph_inputs: OrderedSet[str],
|
||||
graph_outputs: OrderedSet[str],
|
||||
methods: List[Callable[..., List[BaseSchedulerNode]]] = [ # noqa: B006
|
||||
methods: list[Callable[..., list[BaseSchedulerNode]]] = [ # noqa: B006
|
||||
topological_sort_lpmf,
|
||||
topological_sort_bfs,
|
||||
topological_sort_dfs,
|
||||
],
|
||||
) -> List[BaseSchedulerNode]:
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Try a few heuristics based topological sort algorithms, and pick the one whose
|
||||
resulting topological order has the lowest peak memory estimation.
|
||||
@ -599,13 +599,13 @@ def reorder_for_peak_memory(
|
||||
|
||||
@dataclasses.dataclass
|
||||
class PeakMemoryResult:
|
||||
order: List[BaseSchedulerNode]
|
||||
order: list[BaseSchedulerNode]
|
||||
peak_memory: int
|
||||
method: str
|
||||
|
||||
# preparation -- as nodes are scheduled one at a time, these help
|
||||
# keep track of when a buffer can be freed, and when a node can be scheduled
|
||||
name_to_freeable_input_buf: Dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
name_to_freeable_input_buf: dict[str, FreeableInputBuffer] = get_freeable_input_buf(
|
||||
nodes, graph_inputs
|
||||
)
|
||||
assign_memory_planning_info_for_scheduler_buffers(nodes, name_to_buf)
|
||||
@ -614,7 +614,7 @@ def reorder_for_peak_memory(
|
||||
)
|
||||
|
||||
# keep track of the peak memory estimates of different methods
|
||||
peak_memory_diff_methods: List[PeakMemoryResult] = []
|
||||
peak_memory_diff_methods: list[PeakMemoryResult] = []
|
||||
|
||||
# the default
|
||||
estimated_peak_memory, _ = estimate_peak_memory(
|
||||
|
@ -8,7 +8,7 @@ import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.utils import get_benchmark_name
|
||||
@ -23,13 +23,13 @@ if TYPE_CHECKING:
|
||||
generated_kernel_count = 0
|
||||
generated_cpp_vec_kernel_count = 0
|
||||
num_bytes_accessed = 0
|
||||
nodes_num_elem: List[
|
||||
nodes_num_elem: list[
|
||||
tuple[
|
||||
BaseSchedulerNode,
|
||||
int,
|
||||
]
|
||||
] = []
|
||||
node_runtimes: List[tuple[BaseSchedulerNode, float]] = []
|
||||
node_runtimes: list[tuple[BaseSchedulerNode, float]] = []
|
||||
|
||||
# counters for tracking fusions
|
||||
ir_nodes_pre_fusion = 0
|
||||
@ -45,7 +45,7 @@ class CppOuterLoopFusedCount:
|
||||
|
||||
|
||||
# The length counts the number of outer loop fusions.
|
||||
cpp_outer_loop_fused_inner_counts: List[CppOuterLoopFusedCount] = []
|
||||
cpp_outer_loop_fused_inner_counts: list[CppOuterLoopFusedCount] = []
|
||||
|
||||
num_comprehensive_padding = 0
|
||||
num_matches_for_scatter_upon_const_tensor = 0
|
||||
@ -122,13 +122,13 @@ class CachedMetricsHelper:
|
||||
globals()[metric] += getattr(delta, metric)
|
||||
|
||||
|
||||
REGISTERED_METRIC_TABLES: Dict[str, MetricTable] = {}
|
||||
REGISTERED_METRIC_TABLES: dict[str, MetricTable] = {}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MetricTable:
|
||||
table_name: str
|
||||
column_names: List[str]
|
||||
column_names: list[str]
|
||||
|
||||
num_rows_added: int = 0
|
||||
|
||||
|
@ -1,5 +1,5 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
@ -31,13 +31,13 @@ def _prepare_convolution_fusion_create(
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding: List[int],
|
||||
stride: List[int],
|
||||
dilation: List[int],
|
||||
padding: list[int],
|
||||
stride: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
transposed: bool = False,
|
||||
output_padding: Optional[List[int]] = None,
|
||||
quantize_args: Optional[List["TensorBox"]] = None,
|
||||
output_padding: Optional[list[int]] = None,
|
||||
quantize_args: Optional[list["TensorBox"]] = None,
|
||||
other: Optional["TensorBox"] = None,
|
||||
):
|
||||
"""
|
||||
@ -204,7 +204,7 @@ def _prepare_linear_fusion_create(
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
quantize_args: Optional[List["TensorBox"]] = None,
|
||||
quantize_args: Optional[list["TensorBox"]] = None,
|
||||
other: Optional["TensorBox"] = None,
|
||||
binary_sum: bool = False,
|
||||
):
|
||||
@ -252,7 +252,7 @@ def _prepare_linear_fusion_create(
|
||||
output_size,
|
||||
output_stride,
|
||||
)
|
||||
constant_args: List[Any] = []
|
||||
constant_args: list[Any] = []
|
||||
|
||||
if bias is not None:
|
||||
inputs.append(bias)
|
||||
@ -298,12 +298,12 @@ class ConvolutionUnary(ExternKernelAlloc):
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding_: List[int],
|
||||
stride_: List[int],
|
||||
dilation_: List[int],
|
||||
padding_: list[int],
|
||||
stride_: list[int],
|
||||
dilation_: list[int],
|
||||
groups: int,
|
||||
attr,
|
||||
scalars: Optional[List[Any]],
|
||||
scalars: Optional[list[Any]],
|
||||
algorithm,
|
||||
):
|
||||
(
|
||||
@ -357,14 +357,14 @@ class ConvolutionBinary(ExternKernelAlloc):
|
||||
other: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding_: List[int],
|
||||
stride_: List[int],
|
||||
dilation_: List[int],
|
||||
padding_: list[int],
|
||||
stride_: list[int],
|
||||
dilation_: list[int],
|
||||
groups: int,
|
||||
binary_attr: str,
|
||||
binary_alpha: Optional[float],
|
||||
unary_attr: Optional[str],
|
||||
unary_scalars: Optional[List[Any]],
|
||||
unary_scalars: Optional[list[Any]],
|
||||
unary_algorithm: Optional[str],
|
||||
):
|
||||
(
|
||||
@ -431,14 +431,14 @@ class ConvolutionBinaryInplace(ExternKernelAlloc):
|
||||
other: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding_: List[int],
|
||||
stride_: List[int],
|
||||
dilation_: List[int],
|
||||
padding_: list[int],
|
||||
stride_: list[int],
|
||||
dilation_: list[int],
|
||||
groups: int,
|
||||
binary_attr: str,
|
||||
binary_alpha: Optional[float],
|
||||
unary_attr: Optional[str],
|
||||
unary_scalars: Optional[List[Any]],
|
||||
unary_scalars: Optional[list[Any]],
|
||||
unary_algorithm: Optional[str],
|
||||
):
|
||||
(
|
||||
@ -496,13 +496,13 @@ class ConvolutionTransposeUnary(ExternKernelAlloc):
|
||||
x: "TensorBox",
|
||||
weight: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
padding_: List[int],
|
||||
output_padding_: List[int],
|
||||
stride_: List[int],
|
||||
dilation_: List[int],
|
||||
padding_: list[int],
|
||||
output_padding_: list[int],
|
||||
stride_: list[int],
|
||||
dilation_: list[int],
|
||||
groups_: int,
|
||||
attr,
|
||||
scalars: Optional[List[Any]],
|
||||
scalars: Optional[list[Any]],
|
||||
algorithm,
|
||||
):
|
||||
transposed = True
|
||||
@ -580,9 +580,9 @@ class QConvPointWisePT2E(ExternKernelAlloc):
|
||||
w_scale: "TensorBox",
|
||||
w_zero_point: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
output_scale: float,
|
||||
output_zero_point: int,
|
||||
@ -692,9 +692,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
||||
w_zero_point,
|
||||
qaccum: "TensorBox",
|
||||
bias: "TensorBox",
|
||||
stride: List[int],
|
||||
padding: List[int],
|
||||
dilation: List[int],
|
||||
stride: list[int],
|
||||
padding: list[int],
|
||||
dilation: list[int],
|
||||
groups: int,
|
||||
output_scale: "TensorBox",
|
||||
output_zero_point: "TensorBox",
|
||||
@ -1139,7 +1139,7 @@ class MkldnnRnnLayer(ExternKernelAlloc):
|
||||
hx: "TensorBox",
|
||||
cx: "TensorBox",
|
||||
reverse: bool,
|
||||
batch_sizes: List[int],
|
||||
batch_sizes: list[int],
|
||||
mode: int,
|
||||
hidden_size: int,
|
||||
num_layers: int,
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -31,8 +31,8 @@ from .virtualized import ops, V
|
||||
|
||||
def grouped_gemm_lowering(
|
||||
x: TensorBox,
|
||||
w: List[TensorBox],
|
||||
b: List[TensorBox],
|
||||
w: list[TensorBox],
|
||||
b: list[TensorBox],
|
||||
attr=None,
|
||||
scalars=None,
|
||||
algorithm=None,
|
||||
@ -47,7 +47,7 @@ def grouped_gemm_lowering(
|
||||
assert use_max_autotune()
|
||||
b = [bias if bias is None else ir.ExternKernel.realize_input(bias) for bias in b]
|
||||
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
||||
|
||||
kwargs = dict(
|
||||
@ -245,7 +245,7 @@ def register_onednn_fusion_ops():
|
||||
x = view(x, [-1, x_size[-1]])
|
||||
if b is not None:
|
||||
b = ir.ExternKernel.realize_input(b)
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_max_autotune():
|
||||
transposed_w = permute(w, [1, 0])
|
||||
*_, layout, x, transposed_w = mm_args(x, transposed_w, layout=layout)
|
||||
@ -308,7 +308,7 @@ def register_onednn_fusion_ops():
|
||||
y = view(y, [-1, y_size[-1]])
|
||||
if b is not None:
|
||||
b = ir.ExternKernel.realize_input(b)
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_max_autotune():
|
||||
transposed_w = permute(w, [1, 0])
|
||||
*_, layout, x, transposed_w, y = mm_args(
|
||||
@ -397,7 +397,7 @@ def register_onednn_fusion_ops():
|
||||
hx: TensorBox,
|
||||
cx: TensorBox,
|
||||
reverse: bool,
|
||||
batch_sizes: List[int],
|
||||
batch_sizes: list[int],
|
||||
mode: int,
|
||||
hidden_size: int,
|
||||
num_layers: int,
|
||||
@ -611,7 +611,7 @@ def register_onednn_fusion_ops():
|
||||
|
||||
bias_dtype = None if bias is None else bias.get_dtype()
|
||||
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_max_autotune():
|
||||
*_, layout, x, packed_weight = mm_args(
|
||||
x, packed_weight, layout=layout, out_dtype=output_dtype
|
||||
@ -888,7 +888,7 @@ def register_onednn_fusion_ops():
|
||||
), "dtype of accum for qlinear post op sum should be the same as output"
|
||||
x2_dtype = x2.get_dtype()
|
||||
bias_dtype = bias.get_dtype() if bias is not None else None
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
if (
|
||||
use_max_autotune() and binary_attr == "add"
|
||||
): # <TODO> Support inplace sum fusion
|
||||
@ -1131,7 +1131,7 @@ def register_onednn_fusion_ops():
|
||||
*,
|
||||
layout=None,
|
||||
):
|
||||
choices: List[ChoiceCaller] = []
|
||||
choices: list[ChoiceCaller] = []
|
||||
if use_max_autotune():
|
||||
transposed_w = permute(orig_w, [1, 0])
|
||||
*_, layout, x, transposed_w = mm_args(
|
||||
|
@ -6,7 +6,7 @@ import contextlib
|
||||
import dataclasses
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, Optional, Type, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing_extensions import override, Self
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -56,7 +56,7 @@ class Stats:
|
||||
|
||||
|
||||
class _GlobalItemStats(Stats):
|
||||
cache: Dict[str, object]
|
||||
cache: dict[str, object]
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -266,7 +266,7 @@ class PatchCaches(contextlib.AbstractContextManager):
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
|
@ -1,17 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import itertools
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
@ -961,7 +950,7 @@ def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]:
|
||||
class OpCountResult(NamedTuple):
|
||||
num_ops: int
|
||||
used_ops: OrderedSet[str]
|
||||
read_buffers: List[str]
|
||||
read_buffers: list[str]
|
||||
nontrivial_read_count: int
|
||||
|
||||
|
||||
@ -974,7 +963,7 @@ class OpCounterCSE:
|
||||
self.op_count = 0
|
||||
self.var_names = {}
|
||||
self._used_ops = OrderedSet[str]()
|
||||
self._read_names: List[str] = []
|
||||
self._read_names: list[str] = []
|
||||
self._nontrivial_read_count = 0
|
||||
|
||||
def __getattr__(self, name):
|
||||
@ -1076,7 +1065,7 @@ class SimpleCSEHandler(WrapperHandler[T]):
|
||||
|
||||
def __init__(self, inner: OpsHandler[T]):
|
||||
super().__init__(inner)
|
||||
self.cse_cache: Dict[str, Union[T, tuple[T, ...]]] = {}
|
||||
self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {}
|
||||
self.mock = MockHandler()
|
||||
|
||||
def indirect_indexing(self, *args, **kwargs) -> sympy.Expr:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import math
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
||||
@ -40,10 +40,10 @@ def range_expressable_in_32_bits(range: ValueRanges[sympy.Expr]) -> bool:
|
||||
|
||||
def try_to_reduce_precision(
|
||||
node: Any,
|
||||
bounds: Dict[Any, Any],
|
||||
indirect_vars: List[Any],
|
||||
indices: Dict[Any, sympy.Expr],
|
||||
replacement_vals: Dict[Any, ValueRanges[sympy.Expr]],
|
||||
bounds: dict[Any, Any],
|
||||
indirect_vars: list[Any],
|
||||
indices: dict[Any, sympy.Expr],
|
||||
replacement_vals: dict[Any, ValueRanges[sympy.Expr]],
|
||||
) -> None:
|
||||
# if a downstream use of a node explicitly converts to int32, or float16/float32/float64,
|
||||
# then it's precision is set for that chain of uses, and we don't need to consider those
|
||||
|
@ -27,17 +27,7 @@ import logging
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Counter,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
@ -62,6 +52,9 @@ from .runtime.autotune_cache import AutotuneCacheBundler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch._inductor import metrics
|
||||
from torch._inductor.graph import GraphLowering
|
||||
|
||||
@ -108,13 +101,13 @@ def has_frozen_params(gm: torch.fx.GraphModule) -> bool:
|
||||
# for expanded dimensions (a dimension which used to have size 1 -> ?)
|
||||
# we can select one element from that dimension and write to it
|
||||
# to achieve writing to all values of that dimension of the input tensor
|
||||
def get_expanded_dims(t: torch.Tensor) -> List[int]:
|
||||
def get_expanded_dims(t: torch.Tensor) -> list[int]:
|
||||
if not isinstance(t, torch.Tensor):
|
||||
return None
|
||||
return [i for i in range(t.ndim) if t.stride(i) == 0 and t.size(i) != 1]
|
||||
|
||||
|
||||
def index_expanded_dims(t: torch.Tensor, expanded_dims: List[int]) -> torch.Tensor:
|
||||
def index_expanded_dims(t: torch.Tensor, expanded_dims: list[int]) -> torch.Tensor:
|
||||
for expanded_dim in expanded_dims:
|
||||
t = torch.ops.aten.slice(t, expanded_dim, 0, 1)
|
||||
return t
|
||||
@ -146,7 +139,7 @@ def cudagraph_post_compile(
|
||||
example_inputs: Sequence[InputType],
|
||||
compiled_graph: CompiledFxGraph,
|
||||
cudagraphs: BoxedBool,
|
||||
constants: Dict[str, torch.Tensor],
|
||||
constants: dict[str, torch.Tensor],
|
||||
) -> None:
|
||||
"""
|
||||
Checks for any reasons not to run cudagraphs and then
|
||||
@ -213,7 +206,7 @@ def cudagraph_post_compile(
|
||||
# should already exist from forward
|
||||
assert manager is not None
|
||||
|
||||
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
|
||||
def compiled_artifact(new_inputs: list[Any]) -> Callable[..., Any]:
|
||||
manager.set_to_running_backward() # type: ignore[union-attr]
|
||||
return compiled_graph_callable(new_inputs)
|
||||
|
||||
@ -270,7 +263,7 @@ class CompiledFxGraphConstants:
|
||||
the value of constants directly off of the original saved object.
|
||||
"""
|
||||
|
||||
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
|
||||
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
|
||||
assert g.constants is not None
|
||||
return g.constants
|
||||
|
||||
@ -287,7 +280,7 @@ class CompiledFxGraphConstantsWithGm(CompiledFxGraphConstants):
|
||||
def __init__(self, gm: torch.fx.GraphModule) -> None:
|
||||
self.gm = gm
|
||||
|
||||
def unwrap(self, g: CompiledFxGraph) -> Dict[str, torch.Tensor]:
|
||||
def unwrap(self, g: CompiledFxGraph) -> dict[str, torch.Tensor]:
|
||||
if g.allocated_constant_name is not None:
|
||||
return {
|
||||
name: getattr(self.gm, name)
|
||||
@ -308,7 +301,7 @@ class CompiledFxGraph(OutputCode):
|
||||
current_callable: Optional[Callable[..., Any]]
|
||||
cache_key: str
|
||||
source_code: str = dataclasses.field(repr=False) # Do not display source_code
|
||||
cache_linemap: Optional[List[tuple[int, str]]]
|
||||
cache_linemap: Optional[list[tuple[int, str]]]
|
||||
device_types: OrderedSet[str]
|
||||
device_idxs: OrderedSet[int]
|
||||
mutated_inputs: OrderedSet[str]
|
||||
@ -320,10 +313,10 @@ class CompiledFxGraph(OutputCode):
|
||||
# original name of the attribute in the GraphModule. When we create the module from
|
||||
# the cache entry, we then look up the constants from the current GraphModule. This
|
||||
# scheme allows us to support caching with freezing.
|
||||
allocated_constant_name: Optional[Dict[str, str]]
|
||||
constants: Optional[Dict[str, torch.Tensor]]
|
||||
torchbind_constants: Dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[List[Optional[tuple[_StrideExprStr, ...]]]]
|
||||
allocated_constant_name: Optional[dict[str, str]]
|
||||
constants: Optional[dict[str, torch.Tensor]]
|
||||
torchbind_constants: dict[str, torch._C.ScriptObject]
|
||||
output_strides: Optional[list[Optional[tuple[_StrideExprStr, ...]]]]
|
||||
disabled_cudagraphs_reason: Optional[str]
|
||||
metrics_deltas: metrics.CachedMetricsDeltas
|
||||
counter_deltas: Counter[str]
|
||||
@ -340,14 +333,14 @@ class CompiledFxGraph(OutputCode):
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||
|
||||
_boxed_call: Optional[bool] = None
|
||||
_triton_bundle: Optional[List[TritonKernelArtifacts]] = None
|
||||
_triton_bundle: Optional[list[TritonKernelArtifacts]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
current_callable: Optional[Callable[..., Any]],
|
||||
graph: GraphLowering,
|
||||
gm: torch.fx.GraphModule,
|
||||
output_strides: List[Optional[tuple[_StrideExprStr, ...]]],
|
||||
output_strides: list[Optional[tuple[_StrideExprStr, ...]]],
|
||||
disabled_cudagraphs_reason: Optional[str],
|
||||
metrics_deltas: metrics.CachedMetricsDeltas,
|
||||
counter_deltas: Counter[str],
|
||||
@ -583,7 +576,7 @@ class CompiledAOTI(OutputCode):
|
||||
Class holding an AOTInductor compiled so.
|
||||
"""
|
||||
|
||||
filename: Union[str, List[str]]
|
||||
filename: Union[str, list[str]]
|
||||
|
||||
def __call__(self, inputs: Sequence[Any]) -> Any:
|
||||
raise NotImplementedError("NYI")
|
||||
|
@ -7,7 +7,7 @@ import subprocess
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
@ -85,7 +85,7 @@ class PT2ArchiveReader:
|
||||
assert self.archive_file is not None
|
||||
self.archive_file.extractall(path)
|
||||
|
||||
def get_file_names(self) -> List[str]:
|
||||
def get_file_names(self) -> list[str]:
|
||||
assert self.archive_file is not None
|
||||
return self.archive_file.namelist()
|
||||
|
||||
@ -98,7 +98,7 @@ def _run_command_and_check(cmd: str) -> None:
|
||||
raise exc.CppCompileError(cmd, e.output) from e
|
||||
|
||||
|
||||
def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
|
||||
def compile_so(aoti_dir: str, aoti_files: list[str], so_path: str) -> str:
|
||||
def get_aoti_file_with_suffix(suffix: str) -> str:
|
||||
for file in aoti_files:
|
||||
if file.endswith(suffix):
|
||||
@ -159,7 +159,7 @@ def compile_so(aoti_dir: str, aoti_files: List[str], so_path: str) -> str:
|
||||
|
||||
def package_aoti(
|
||||
archive_file: Union[str, io.BytesIO],
|
||||
aoti_files: Union[List[str], Dict[str, List[str]]],
|
||||
aoti_files: Union[list[str], dict[str, list[str]]],
|
||||
) -> Union[str, io.BytesIO]:
|
||||
"""
|
||||
Saves the AOTInductor generated files to the PT2Archive format.
|
||||
@ -244,12 +244,12 @@ class AOTICompiledModel:
|
||||
flat_outputs = self.loader.boxed_run(flat_inputs) # type: ignore[attr-defined]
|
||||
return pytree.tree_unflatten(flat_outputs, out_spec)
|
||||
|
||||
def get_metadata(self) -> Dict[str, str]:
|
||||
def get_metadata(self) -> dict[str, str]:
|
||||
return self.loader.get_metadata() # type: ignore[attr-defined]
|
||||
|
||||
def load_constants(
|
||||
self,
|
||||
constants_map: Dict[str, torch.Tensor],
|
||||
constants_map: dict[str, torch.Tensor],
|
||||
*,
|
||||
check_full_update: bool,
|
||||
) -> None:
|
||||
@ -265,7 +265,7 @@ class AOTICompiledModel:
|
||||
"""
|
||||
self.loader.load_constants(constants_map, False, check_full_update) # type: ignore[attr-defined]
|
||||
|
||||
def get_constant_fqns(self) -> List[str]:
|
||||
def get_constant_fqns(self) -> list[str]:
|
||||
return self.loader.get_constant_fqns() # type: ignore[attr-defined]
|
||||
|
||||
|
||||
|
@ -50,24 +50,9 @@ import textwrap
|
||||
import typing
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, NoReturn, Optional, Protocol, TypeVar, Union
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
import torch
|
||||
@ -139,7 +124,7 @@ MULTIPLE = Multiple()
|
||||
|
||||
|
||||
def _transfer_meta(
|
||||
new_meta: Dict[str, Any], old_node: torch.fx.Node, pass_name: str = ""
|
||||
new_meta: dict[str, Any], old_node: torch.fx.Node, pass_name: str = ""
|
||||
) -> None:
|
||||
from torch.fx.traceback import NodeSource, NodeSourceAction
|
||||
|
||||
@ -177,10 +162,10 @@ class Match:
|
||||
"""
|
||||
|
||||
pattern: PatternExpr
|
||||
args: List[Any]
|
||||
kwargs: Dict[str, Any]
|
||||
nodes: List[torch.fx.Node]
|
||||
targets: Dict[_TargetExpr, torch.fx.node.Target]
|
||||
args: list[Any]
|
||||
kwargs: dict[str, Any]
|
||||
nodes: list[torch.fx.Node]
|
||||
targets: dict[_TargetExpr, torch.fx.node.Target]
|
||||
ctx: MatchContext
|
||||
replacement_graph: Optional[torch.fx.GraphModule]
|
||||
|
||||
@ -189,7 +174,7 @@ class Match:
|
||||
ctx: MatchContext,
|
||||
pattern: PatternExpr,
|
||||
args: Optional[Sequence[Any]] = None,
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.pattern = pattern
|
||||
@ -231,7 +216,7 @@ class Match:
|
||||
if not n._erased and not n.users:
|
||||
graph.erase_node(n)
|
||||
|
||||
def output_nodes(self) -> List[Optional[torch.fx.Node]]:
|
||||
def output_nodes(self) -> list[Optional[torch.fx.Node]]:
|
||||
return [
|
||||
(self.ctx.pattern_to_node[p] if p is not None else None)
|
||||
for p in self.ctx.outputs
|
||||
@ -338,15 +323,15 @@ class MatchContext:
|
||||
Internal state needed while running PatternExpr._match().
|
||||
"""
|
||||
|
||||
outputs: List[Optional[PatternExpr]]
|
||||
pattern_to_node: Dict[PatternExpr, Optional[torch.fx.Node]]
|
||||
outputs: list[Optional[PatternExpr]]
|
||||
pattern_to_node: dict[PatternExpr, Optional[torch.fx.Node]]
|
||||
graph: torch.fx.Graph
|
||||
exclusive_node_set: List[NodeOrConstant]
|
||||
exclusive_node_set: list[NodeOrConstant]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
outputs: List[Optional[PatternExpr]],
|
||||
pattern_to_node: Optional[Dict[PatternExpr, torch.fx.Node]] = None,
|
||||
outputs: list[Optional[PatternExpr]],
|
||||
pattern_to_node: Optional[dict[PatternExpr, torch.fx.Node]] = None,
|
||||
*,
|
||||
graph: torch.fx.Graph,
|
||||
) -> None:
|
||||
@ -367,7 +352,7 @@ class MatchContext:
|
||||
self.pattern_to_node[pattern] = node if m else None
|
||||
return m
|
||||
|
||||
def filter_multi_user_patterns(self) -> Dict[PatternExpr, torch.fx.Node]:
|
||||
def filter_multi_user_patterns(self) -> dict[PatternExpr, torch.fx.Node]:
|
||||
return {
|
||||
pattern: node
|
||||
for pattern, node in self.pattern_to_node.items()
|
||||
@ -487,7 +472,7 @@ class _TargetExpr(PatternExpr):
|
||||
Base class for filtering match by node.target
|
||||
"""
|
||||
|
||||
fns: List[FnsType]
|
||||
fns: list[FnsType]
|
||||
fns_set: OrderedSet[FnsType]
|
||||
|
||||
def __init__(
|
||||
@ -806,7 +791,7 @@ class ListOf(PatternExpr):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.pattern})"
|
||||
|
||||
def _match(self, node: List[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
|
||||
def _match(self, node: list[torch.fx.Node], ctx: MatchContext) -> MatchResult: # type: ignore[override]
|
||||
if not isinstance(node, (list, tuple)) or len(node) == 0:
|
||||
return FailedMatch("non_list")
|
||||
m = Match(ctx, self)
|
||||
@ -840,7 +825,7 @@ class ListOf(PatternExpr):
|
||||
|
||||
|
||||
class MultiOutputPattern(PatternExpr):
|
||||
outputs: List[Optional[PatternExpr]]
|
||||
outputs: list[Optional[PatternExpr]]
|
||||
|
||||
def __init__(self, outputs: Sequence[Optional[PatternExpr]]) -> None:
|
||||
super().__init__()
|
||||
@ -959,8 +944,8 @@ class PatternPrettyPrinter:
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.namespace = torch.fx.graph._Namespace()
|
||||
self.memoized_objs_names: Dict[PatternExpr, str] = {}
|
||||
self.memoized_objs_pp: Dict[PatternExpr, str] = {}
|
||||
self.memoized_objs_names: dict[PatternExpr, str] = {}
|
||||
self.memoized_objs_pp: dict[PatternExpr, str] = {}
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
@ -1006,7 +991,7 @@ class PatternPrettyPrinter:
|
||||
|
||||
|
||||
class _PassDictsType(Protocol):
|
||||
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
|
||||
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
|
||||
...
|
||||
|
||||
|
||||
@ -1069,7 +1054,7 @@ class GraphPatternEntry(PatternEntry):
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ReplacementPatternEntry(PatternEntry):
|
||||
normalize_args: Callable[..., List[Any]]
|
||||
normalize_args: Callable[..., list[Any]]
|
||||
|
||||
@staticmethod
|
||||
def replace_with_graph(
|
||||
@ -1253,7 +1238,7 @@ def log_trace_failure(search_fn: Callable[..., Any], e: RuntimeError) -> None:
|
||||
def check_and_add_duplicate_pattern(
|
||||
pattern: PatternExpr,
|
||||
graph: Optional[torch.fx.Graph],
|
||||
seen_patterns: Dict[str, List[Optional[str]]],
|
||||
seen_patterns: dict[str, list[Optional[str]]],
|
||||
skip_duplicates: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
@ -1299,7 +1284,7 @@ def register_replacement(
|
||||
trace_fn: TraceFn,
|
||||
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
|
||||
extra_check: Callable[[Match], bool] = _return_true,
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
||||
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
|
||||
exclusive_arg_names: Sequence[str] = (),
|
||||
search_fn_pattern: Union[PatternExpr, None] = None,
|
||||
skip_duplicates: bool = False,
|
||||
@ -1339,7 +1324,7 @@ def register_replacement(
|
||||
[match.kwargs[name] for name in argnames], lambda n: n.meta["val"]
|
||||
)
|
||||
)
|
||||
sym_args: List[torch.SymInt] = []
|
||||
sym_args: list[torch.SymInt] = []
|
||||
with torch._dynamo.utils.detect_fake_mode(args):
|
||||
for i, grad in enumerate(requires_grad):
|
||||
if isinstance(args[i], torch.Tensor):
|
||||
@ -1432,7 +1417,7 @@ def register_replacement(
|
||||
return True
|
||||
return False
|
||||
|
||||
def normalize_args(**kwargs: Any) -> List[Any]:
|
||||
def normalize_args(**kwargs: Any) -> list[Any]:
|
||||
args = [kwargs.pop(name) for name in argnames_static]
|
||||
for i in range(1, len(kwargs) + 1):
|
||||
if f"tangents_{i}" not in kwargs:
|
||||
@ -1449,7 +1434,7 @@ def register_replacement(
|
||||
|
||||
# TODO: Revisit the functionalize_rng_ops for lowmem dropout
|
||||
with functorch_config.patch(functionalize_rng_ops=False):
|
||||
requires_grad: List[bool] = [
|
||||
requires_grad: list[bool] = [
|
||||
isinstance(x, torch.Tensor) and x.requires_grad for x in example_inputs
|
||||
]
|
||||
if search_fn_pattern is None:
|
||||
@ -1493,7 +1478,7 @@ def _serialize_pattern(
|
||||
search_fn: SearchFn,
|
||||
example_inputs: Sequence[Any],
|
||||
trace_fn: TraceFn,
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None],
|
||||
scalar_workaround: Union[dict[str, Union[float, int]], None],
|
||||
) -> PatternExpr:
|
||||
def get_file_template() -> str:
|
||||
auto_generated_msg = textwrap.dedent(
|
||||
@ -1566,7 +1551,7 @@ SERIALIZED_PATTERN_PATH = Path(__file__).parent / "fx_passes" / "serialized_patt
|
||||
# This is the set of serialized patterns that we've registered. Used by
|
||||
# test_serialized_patterns_up_to_date() to ensure the patterns are up
|
||||
# to date.
|
||||
_known_precompiled_patterns: List[
|
||||
_known_precompiled_patterns: list[
|
||||
tuple[
|
||||
Any,
|
||||
Iterable[Any],
|
||||
@ -1585,7 +1570,7 @@ def gen_register_replacement(
|
||||
trace_fn: TraceFn,
|
||||
pass_dicts: Union[_PassDictsType, Sequence[_PassDictsType]],
|
||||
extra_check: Callable[[Match], bool] = _return_true,
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
||||
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
|
||||
exclusive_arg_names: Sequence[str] = (),
|
||||
skip_duplicates: bool = False,
|
||||
) -> None:
|
||||
@ -1638,7 +1623,7 @@ def gen_pattern_and_search_gm(
|
||||
search_fn: SearchFn,
|
||||
example_inputs: Sequence[Any],
|
||||
trace_fn: TraceFn,
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
||||
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
|
||||
exclusive_arg_names: Sequence[str] = (),
|
||||
) -> tuple[PatternExpr, torch.fx.GraphModule]:
|
||||
argnames = [*inspect.signature(search_fn).parameters.keys()]
|
||||
@ -1672,7 +1657,7 @@ def gen_pattern(
|
||||
search_fn: SearchFn,
|
||||
example_inputs: Sequence[Any],
|
||||
trace_fn: TraceFn,
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
||||
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
|
||||
exclusive_arg_names: Sequence[str] = (),
|
||||
) -> PatternExpr:
|
||||
return gen_pattern_and_search_gm(
|
||||
@ -1803,8 +1788,8 @@ class PatternMatcherPass:
|
||||
pass_name: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.patterns: DefaultDict[
|
||||
tuple[str, torch.fx.node.Target], List[PatternEntry]
|
||||
self.patterns: defaultdict[
|
||||
tuple[str, torch.fx.node.Target], list[PatternEntry]
|
||||
] = defaultdict(list)
|
||||
self.pass_name = pass_name
|
||||
|
||||
@ -1812,9 +1797,9 @@ class PatternMatcherPass:
|
||||
# of the graph used to generate them. Because we ignore certain patterns
|
||||
# in searching, but not in matching, use the graph to distinguish if two equivalent
|
||||
# searches are actually different.
|
||||
self.seen_patterns: Dict[str, List[Optional[str]]] = defaultdict(list)
|
||||
self.seen_patterns: dict[str, list[Optional[str]]] = defaultdict(list)
|
||||
|
||||
def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> List[PatternEntry]:
|
||||
def __getitem__(self, item: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
|
||||
return self.patterns[item]
|
||||
|
||||
def apply(self, gm: Union[torch.fx.GraphModule, torch.fx.Graph]) -> int:
|
||||
@ -1888,9 +1873,9 @@ def _not_implemented(*args: Any, **kwargs: Any) -> NoReturn:
|
||||
|
||||
def fx_to_pattern(
|
||||
gm: Union[torch.fx.GraphModule, torch.fx.Graph],
|
||||
ignore_types: Sequence[Type[Any]] = (),
|
||||
ignore_types: Sequence[type[Any]] = (),
|
||||
argnames: Sequence[str] = (),
|
||||
scalar_workaround: Union[Dict[str, Union[float, int]], None] = None,
|
||||
scalar_workaround: Union[dict[str, Union[float, int]], None] = None,
|
||||
exclusive_arg_names: Sequence[str] = (),
|
||||
) -> PatternExpr:
|
||||
"""
|
||||
@ -1904,7 +1889,7 @@ def fx_to_pattern(
|
||||
assert len(inv_scalar_workaround) == len(scalar_workaround)
|
||||
|
||||
def process_arg(
|
||||
x: T, ignore_types_override: Optional[Sequence[Type[Any]]] = None
|
||||
x: T, ignore_types_override: Optional[Sequence[type[Any]]] = None
|
||||
) -> Union[T, KeywordArg, Ignored]:
|
||||
current_ignore_types = (
|
||||
ignore_types_override if ignore_types_override is not None else ignore_types
|
||||
@ -1950,7 +1935,7 @@ def fx_to_pattern(
|
||||
|
||||
def process_arg_fn_impl(
|
||||
x: T,
|
||||
ignore_types_override: Optional[Sequence[Type[Any]]] = tuple(
|
||||
ignore_types_override: Optional[Sequence[type[Any]]] = tuple(
|
||||
t for t in ignore_types if t is not int
|
||||
),
|
||||
) -> Union[T, KeywordArg, Ignored]:
|
||||
@ -2054,8 +2039,8 @@ def joint_fwd_bwd(fn: Callable[..., Any], args: Sequence[Any]) -> torch.fx.Graph
|
||||
return gm
|
||||
|
||||
|
||||
def _args(n: torch.fx.Node) -> List[torch.fx.node.Argument]:
|
||||
args: List[torch.fx.node.Argument] = []
|
||||
def _args(n: torch.fx.Node) -> list[torch.fx.node.Argument]:
|
||||
args: list[torch.fx.node.Argument] = []
|
||||
torch.fx.map_arg((n.args, n.kwargs), args.append)
|
||||
return args
|
||||
|
||||
@ -2152,7 +2137,7 @@ def get_arg_value(
|
||||
)
|
||||
|
||||
|
||||
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> List[torch.fx.Node]:
|
||||
def filter_nodes(nodes: Iterable[torch.fx.Node], fn: Any) -> list[torch.fx.Node]:
|
||||
fns = [fn]
|
||||
if isinstance(fn, torch._ops.OpOverloadPacket):
|
||||
fns.extend([getattr(fn, overload) for overload in fn.overloads()])
|
||||
|
@ -10,7 +10,7 @@ import os
|
||||
import sys
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from typing import Any, Callable, Dict, Generic, List, Optional, Type, TypeVar, Union
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union
|
||||
from typing_extensions import override, TypeAlias
|
||||
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
@ -34,7 +34,7 @@ if config.is_fbcode():
|
||||
|
||||
Sample: TypeAlias = Sample_
|
||||
else:
|
||||
Sample: TypeAlias = Type[object] # type: ignore[misc,no-redef]
|
||||
Sample: TypeAlias = type[object] # type: ignore[misc,no-redef]
|
||||
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -106,7 +106,7 @@ class RemoteCacheSerde(Generic[_T, _U]):
|
||||
|
||||
|
||||
JsonDataTy = Optional[
|
||||
Union[int, float, str, bool, Dict[str, "JsonDataTy"], List["JsonDataTy"]]
|
||||
Union[int, float, str, bool, dict[str, "JsonDataTy"], list["JsonDataTy"]]
|
||||
]
|
||||
|
||||
|
||||
@ -371,7 +371,7 @@ class _CacheStat:
|
||||
|
||||
|
||||
class _CacheStats:
|
||||
_stats: Dict[str, _CacheStat]
|
||||
_stats: dict[str, _CacheStat]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._stats = collections.defaultdict(_CacheStat)
|
||||
|
@ -6,7 +6,7 @@ import logging
|
||||
import os
|
||||
import os.path
|
||||
import re
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_InductorMetaTy = Dict[str, object]
|
||||
_InductorMetaTy = dict[str, object]
|
||||
|
||||
|
||||
def inductor_meta_from_config() -> _InductorMetaTy:
|
||||
@ -88,7 +88,7 @@ class AutotuneCache:
|
||||
return hashlib.sha256(key.encode("utf-8")).hexdigest()
|
||||
|
||||
# Read the best config options from the most local cache and return it.
|
||||
def _read(self) -> Optional[Dict[str, JsonDataTy]]:
|
||||
def _read(self) -> Optional[dict[str, JsonDataTy]]:
|
||||
if local_cache := self.local_cache:
|
||||
cache, key = local_cache
|
||||
if best_config := cache.get(key):
|
||||
@ -106,7 +106,7 @@ class AutotuneCache:
|
||||
# Read the best config options from the most local cache and figure out
|
||||
# which `configs` represents that option.
|
||||
def read_best(
|
||||
self, inductor_meta: _InductorMetaTy, configs: List[Config]
|
||||
self, inductor_meta: _InductorMetaTy, configs: list[Config]
|
||||
) -> Optional[Config]:
|
||||
if best := self._read():
|
||||
return _load_cached_autotuning(
|
||||
@ -196,7 +196,7 @@ class _AutotuneCacheBundlerImpl:
|
||||
_cache: RemoteCache[JsonDataTy]
|
||||
|
||||
# All known entries from LocalAutotuneCache.put()
|
||||
_entries: Dict[str, JsonDataTy]
|
||||
_entries: dict[str, JsonDataTy]
|
||||
|
||||
def end_compile(self) -> None:
|
||||
# TODO: Do we need to compute time_taken_ms and encode that somehow?
|
||||
@ -407,9 +407,9 @@ def _should_use_remote_autotune_cache(inductor_meta: _InductorMetaTy) -> bool:
|
||||
|
||||
|
||||
def _load_cached_autotuning(
|
||||
best_config: Dict[str, JsonDataTy],
|
||||
best_config: dict[str, JsonDataTy],
|
||||
configs_hash: str,
|
||||
configs: List[Config],
|
||||
configs: list[Config],
|
||||
inductor_meta: _InductorMetaTy,
|
||||
) -> Optional[Config]:
|
||||
if best_config is None:
|
||||
|
@ -3,7 +3,7 @@ import time
|
||||
from functools import cached_property, wraps
|
||||
from itertools import chain
|
||||
from statistics import median
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
from typing import Any, Callable
|
||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
||||
|
||||
import torch
|
||||
@ -49,8 +49,8 @@ class Benchmarker:
|
||||
def benchmark(
|
||||
self: Self,
|
||||
fn: Callable[..., Any],
|
||||
fn_args: Tuple[Any, ...],
|
||||
fn_kwargs: Dict[str, Any],
|
||||
fn_args: tuple[Any, ...],
|
||||
fn_kwargs: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> float:
|
||||
"""Benchmark `fn(*fn_args, *fn_kwargs)` and return the runtime, in milliseconds (the
|
||||
@ -114,7 +114,7 @@ class Benchmarker:
|
||||
- The median runtime of `_callable`, in milliseconds.
|
||||
"""
|
||||
|
||||
def run_for(ms: int) -> List[float]:
|
||||
def run_for(ms: int) -> list[float]:
|
||||
timings = []
|
||||
run_start_t = time.perf_counter()
|
||||
while True:
|
||||
@ -183,7 +183,7 @@ class InductorBenchmarker(TritonBenchmarker):
|
||||
|
||||
def get_event_pairs(
|
||||
self: Self, iters: int
|
||||
) -> List[Tuple[torch.cuda.Event, torch.cuda.Event]]:
|
||||
) -> list[tuple[torch.cuda.Event, torch.cuda.Event]]:
|
||||
"""Get `iters` pairs of CUDA events."""
|
||||
return [
|
||||
(
|
||||
@ -194,7 +194,7 @@ class InductorBenchmarker(TritonBenchmarker):
|
||||
]
|
||||
|
||||
def get_event_pairs_min_timing(
|
||||
self: Self, event_pairs: List[Tuple[torch.cuda.Event, torch.cuda.Event]]
|
||||
self: Self, event_pairs: list[tuple[torch.cuda.Event, torch.cuda.Event]]
|
||||
) -> float:
|
||||
"""Get the minimum timing, in milliseconds, for a group of CUDA event pairs."""
|
||||
return min(
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Callable, Dict, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -67,7 +67,7 @@ def _set_triton_ptxas_path() -> None:
|
||||
|
||||
|
||||
def _worker_compile_triton(
|
||||
load_kernel: Callable[[], CachingAutotuner], extra_env: Dict[str, str]
|
||||
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
|
||||
) -> None:
|
||||
_set_triton_ptxas_path()
|
||||
os.environ.update(extra_env)
|
||||
|
@ -5,7 +5,7 @@ import collections
|
||||
import functools
|
||||
import typing
|
||||
from enum import auto, Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
# The following maximums only apply to runtime autotuning, when using FixedTritonConfig one may see larger values
|
||||
@ -167,8 +167,8 @@ class DeviceProperties(typing.NamedTuple):
|
||||
class HalideInputSpec(typing.NamedTuple):
|
||||
ctype: str
|
||||
name: str
|
||||
shape: Optional[List[str]] = None
|
||||
stride: Optional[List[str]] = None
|
||||
shape: Optional[list[str]] = None
|
||||
stride: Optional[list[str]] = None
|
||||
offset: Optional[str] = None
|
||||
alias_of: Optional[str] = None
|
||||
|
||||
@ -192,13 +192,13 @@ class HalideInputSpec(typing.NamedTuple):
|
||||
|
||||
|
||||
class HalideMeta(typing.NamedTuple):
|
||||
argtypes: List[HalideInputSpec]
|
||||
argtypes: list[HalideInputSpec]
|
||||
target: str
|
||||
scheduler: Optional[str] = None
|
||||
scheduler_flags: Optional[Dict[str, Union[int, str]]] = None
|
||||
scheduler_flags: Optional[dict[str, Union[int, str]]] = None
|
||||
cuda_device: Optional[int] = None
|
||||
|
||||
def args(self) -> List[str]:
|
||||
def args(self) -> list[str]:
|
||||
"""Command line args to pass to halide generator"""
|
||||
args = [f"target={self.target}"]
|
||||
if self.scheduler:
|
||||
|
@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import operator
|
||||
from typing import Any, Hashable, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
|
||||
@ -13,6 +13,8 @@ from torch._inductor.runtime.cache_dir_utils import ( # noqa: F401
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Hashable
|
||||
|
||||
from .triton_compat import Config
|
||||
|
||||
|
||||
|
@ -16,17 +16,7 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Container,
|
||||
Dict,
|
||||
Hashable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -76,13 +66,15 @@ from .triton_compat import (
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Hashable
|
||||
|
||||
LauncherType = Any
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_total_reduction_numel(numels: Dict[str, int]) -> int:
|
||||
def get_total_reduction_numel(numels: dict[str, int]) -> int:
|
||||
return conditional_product(
|
||||
*[numel for prefix, numel in numels.items() if prefix_is_reduction(prefix)]
|
||||
)
|
||||
@ -93,7 +85,7 @@ def autotune_hints_to_configs(
|
||||
size_hints,
|
||||
block_size: int,
|
||||
device_props: DeviceProperties,
|
||||
) -> List[Config]:
|
||||
) -> list[Config]:
|
||||
"""
|
||||
AutotuneHints can be attached to the metadata of triton kernels for providing
|
||||
suggestions about what to try for autotuning. One reason to do this is if there are
|
||||
@ -104,7 +96,7 @@ def autotune_hints_to_configs(
|
||||
configs to try.
|
||||
"""
|
||||
xyz_options: tuple[tuple[int, Optional[int], Optional[int]], ...]
|
||||
configs: List[Config] = []
|
||||
configs: list[Config] = []
|
||||
for hint in hints:
|
||||
if hint == AutotuneHint.ONE_ELEMENT_PER_THREAD:
|
||||
if len(size_hints) == 1:
|
||||
@ -180,14 +172,14 @@ class CachingAutotuner(KernelInterface):
|
||||
triton_meta, # passed directly to triton
|
||||
configs,
|
||||
save_cache_hook,
|
||||
mutated_arg_names: List[str], # see [Note: clone mutated buffers]
|
||||
mutated_arg_names: list[str], # see [Note: clone mutated buffers]
|
||||
optimize_mem,
|
||||
heuristic_type,
|
||||
size_hints=None,
|
||||
inductor_meta=None, # metadata not relevant to triton
|
||||
custom_kernel=False, # whether the kernel is inductor-generated or custom
|
||||
filename: Optional[str] = None,
|
||||
reset_to_zero_arg_names: Optional[List[str]] = None,
|
||||
reset_to_zero_arg_names: Optional[list[str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@ -223,8 +215,8 @@ class CachingAutotuner(KernelInterface):
|
||||
for c in self.configs:
|
||||
log.debug(c)
|
||||
|
||||
self.compile_results: List[TritonCompileResult] = []
|
||||
self.launchers: List[LauncherType] = []
|
||||
self.compile_results: list[TritonCompileResult] = []
|
||||
self.launchers: list[LauncherType] = []
|
||||
self.lock = threading.Lock()
|
||||
if os.getenv("TRITON_CACHE_DIR") is None:
|
||||
os.environ["TRITON_CACHE_DIR"] = triton_cache_dir(
|
||||
@ -430,7 +422,7 @@ class CachingAutotuner(KernelInterface):
|
||||
self.fn.repr = _ConstRepr(self.fn.repr(self.fn))
|
||||
self.launchers = []
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
assert (
|
||||
not self.launchers
|
||||
), "pickle should not be called with after make_launchers()"
|
||||
@ -439,7 +431,7 @@ class CachingAutotuner(KernelInterface):
|
||||
"lock": None,
|
||||
}
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
self.__dict__.update(state)
|
||||
self.lock = threading.Lock()
|
||||
|
||||
@ -636,7 +628,7 @@ class CachingAutotuner(KernelInterface):
|
||||
|
||||
def maybe_clone_args(
|
||||
self, exclude: Container[str], *args, **kwargs
|
||||
) -> tuple[List[Any], Dict[str, Any]]:
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""
|
||||
Prepare new args and kwargs by cloning any in-place buffers
|
||||
(that are not in the provided exclusion list), to avoid autotune
|
||||
@ -659,7 +651,7 @@ class CachingAutotuner(KernelInterface):
|
||||
|
||||
return cloned_args, cloned_kwargs
|
||||
|
||||
def clone_args(self, *args, **kwargs) -> tuple[List[Any], Dict[str, Any]]:
|
||||
def clone_args(self, *args, **kwargs) -> tuple[list[Any], dict[str, Any]]:
|
||||
return self.maybe_clone_args(OrderedSet(), *args, **kwargs)
|
||||
|
||||
def benchmark_all_configs(self, *args, **kwargs):
|
||||
@ -888,15 +880,15 @@ class TritonCompileResult:
|
||||
|
||||
@staticmethod
|
||||
@functools.lru_cache(32)
|
||||
def _kernel_metadata_cls(fields: Tuple[str, ...]) -> Any:
|
||||
def _kernel_metadata_cls(fields: tuple[str, ...]) -> Any:
|
||||
return namedtuple("KernelMetadata", sorted(fields))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel: CompiledKernel,
|
||||
config: Config,
|
||||
compile_meta: Dict[str, Any],
|
||||
inductor_meta: Dict[str, Any],
|
||||
compile_meta: dict[str, Any],
|
||||
inductor_meta: dict[str, Any],
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.kernel = kernel
|
||||
@ -904,7 +896,7 @@ class TritonCompileResult:
|
||||
self.compile_meta = compile_meta
|
||||
self.inductor_meta = inductor_meta
|
||||
|
||||
def __getstate__(self) -> Dict[str, Any]:
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
kernel = self.kernel
|
||||
# replace the fields that don't pickle nicely
|
||||
kernel_state = {
|
||||
@ -916,7 +908,7 @@ class TritonCompileResult:
|
||||
}
|
||||
return {**self.__dict__, "kernel": kernel_state} # type: ignore[dict-item]
|
||||
|
||||
def __setstate__(self, state: Dict[str, Any]) -> None:
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
# src = ASTSource.__new__(ASTSource)
|
||||
# src.__setstate__(state["kernel"]["src"])
|
||||
# TODO(jansel): need to fixup src.fn which is now None
|
||||
@ -1101,7 +1093,7 @@ def _find_names(obj):
|
||||
return obj_names
|
||||
|
||||
|
||||
collected_calls: List[Any] = []
|
||||
collected_calls: list[Any] = []
|
||||
|
||||
|
||||
def start_graph():
|
||||
@ -1220,7 +1212,7 @@ class DebugAutotuner(CachingAutotuner):
|
||||
collected_calls.append(self.cached)
|
||||
|
||||
|
||||
def hash_configs(configs: List[Config]):
|
||||
def hash_configs(configs: list[Config]):
|
||||
"""
|
||||
Hash used to check for changes in configurations
|
||||
"""
|
||||
@ -1233,8 +1225,8 @@ def hash_configs(configs: List[Config]):
|
||||
|
||||
|
||||
def cached_autotune(
|
||||
size_hints: Optional[List[int]],
|
||||
configs: List[Config],
|
||||
size_hints: Optional[list[int]],
|
||||
configs: list[Config],
|
||||
triton_meta,
|
||||
heuristic_type,
|
||||
filename=None,
|
||||
@ -1276,7 +1268,7 @@ def cached_autotune(
|
||||
if "restore_value" in triton_meta:
|
||||
mutated_arg_names += triton_meta.pop("restore_value")
|
||||
|
||||
reset_to_zero_arg_names: List[str] = []
|
||||
reset_to_zero_arg_names: list[str] = []
|
||||
if "reset_to_zero" in triton_meta:
|
||||
reset_to_zero_arg_names.extend(triton_meta.pop("reset_to_zero"))
|
||||
|
||||
@ -1331,7 +1323,7 @@ def cached_autotune(
|
||||
return decorator
|
||||
|
||||
|
||||
def unique_configs(configs: List[Config]):
|
||||
def unique_configs(configs: list[Config]):
|
||||
"""Remove duplicate configurations"""
|
||||
seen: OrderedSet[Hashable] = OrderedSet()
|
||||
pruned_configs = []
|
||||
@ -1362,7 +1354,7 @@ def check_config(cfg, *, xnumel=None, ynumel=None, znumel=None):
|
||||
)
|
||||
|
||||
|
||||
def check_max_block(cfg: Dict[str, int]):
|
||||
def check_max_block(cfg: dict[str, int]):
|
||||
"""
|
||||
Check that block sizes are within the maximum allowed.
|
||||
"""
|
||||
@ -1505,7 +1497,7 @@ def triton_config(
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
def _get_nd_reduction_numels(r: int, size_hints: Dict[str, int]) -> Dict[str, int]:
|
||||
def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, int]:
|
||||
"""
|
||||
Converts a linear reduction numel to ND, in row major order.
|
||||
This order is often desirable as it presents opportunities to coalesce memory
|
||||
@ -1596,7 +1588,7 @@ def triton_config_reduction(
|
||||
return Config(cfg, num_warps=num_warps, num_stages=num_stages)
|
||||
|
||||
|
||||
def _get_config(numels: Dict[str, int]) -> Dict[str, int]:
|
||||
def _get_config(numels: dict[str, int]) -> dict[str, int]:
|
||||
"""
|
||||
Convert numels ("x", "r0_", etc.) to block sizes ("XBLOCK", "R0_BLOCK"), etc.
|
||||
"""
|
||||
@ -1729,8 +1721,8 @@ def pointwise(
|
||||
|
||||
|
||||
def _reduction_configs(
|
||||
*, size_hints: Dict[str, int], inductor_meta: Dict[str, Any]
|
||||
) -> List[Config]:
|
||||
*, size_hints: dict[str, int], inductor_meta: dict[str, Any]
|
||||
) -> list[Config]:
|
||||
reduction_hint = inductor_meta.get("reduction_hint", None)
|
||||
|
||||
# Convert reductions to 1D, to simplify heuristics.
|
||||
@ -1981,7 +1973,7 @@ def template(num_stages, num_warps, triton_meta, filename=None, inductor_meta=No
|
||||
)
|
||||
|
||||
|
||||
def _pop_config_kwargs(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def _pop_config_kwargs(config: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract triton.Config options that should become kwargs"""
|
||||
popped = {}
|
||||
for key in ("num_warps", "num_stages", "num_ctas", "maxnreg"):
|
||||
|
@ -14,20 +14,8 @@ import pprint
|
||||
import textwrap
|
||||
import traceback
|
||||
import typing
|
||||
from collections import defaultdict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Counter,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Callable, Generic, Optional, TypeVar, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -68,6 +56,10 @@ from .utils import (
|
||||
from .virtualized import V
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
|
||||
loop_ordering_log = torch._logging.getArtifactLogger(__name__, "loop_ordering")
|
||||
@ -78,7 +70,7 @@ class SchedulerBuffer:
|
||||
scheduler: Scheduler
|
||||
node: ir.Buffer
|
||||
defining_op: BaseSchedulerNode
|
||||
users: List[NodeUser] = dataclasses.field(default_factory=list)
|
||||
users: list[NodeUser] = dataclasses.field(default_factory=list)
|
||||
mpi_buffer: MemoryPlanningInfoForBuffer = dataclasses.field(
|
||||
default_factory=MemoryPlanningInfoForBuffer
|
||||
)
|
||||
@ -154,9 +146,9 @@ class SchedulerBuffer:
|
||||
return False
|
||||
return True
|
||||
|
||||
def set_users(self, users: List[NodeUser]) -> None:
|
||||
def set_users(self, users: list[NodeUser]) -> None:
|
||||
# deduplicate
|
||||
result: Dict[int, NodeUser] = {}
|
||||
result: dict[int, NodeUser] = {}
|
||||
for use in users:
|
||||
if id(use.node) in result:
|
||||
result[id(use.node)] = use.merge(result[id(use.node)])
|
||||
@ -194,7 +186,7 @@ class BaseSchedulerNode:
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[
|
||||
[BaseSchedulerNode], List[str]
|
||||
[BaseSchedulerNode], list[str]
|
||||
] = lambda *args, **kwargs: []
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
@ -204,7 +196,7 @@ class BaseSchedulerNode:
|
||||
str
|
||||
]() # buffers that won't be used after this kernel
|
||||
self.written = False
|
||||
self.outputs: List[SchedulerBuffer] = [
|
||||
self.outputs: list[SchedulerBuffer] = [
|
||||
SchedulerBuffer(
|
||||
scheduler=self.scheduler,
|
||||
node=output,
|
||||
@ -212,7 +204,7 @@ class BaseSchedulerNode:
|
||||
)
|
||||
for output in node.get_outputs()
|
||||
]
|
||||
self.outputs_by_name: Dict[str, SchedulerBuffer] = {
|
||||
self.outputs_by_name: dict[str, SchedulerBuffer] = {
|
||||
buf.get_name(): buf for buf in self.outputs
|
||||
}
|
||||
|
||||
@ -247,7 +239,7 @@ class BaseSchedulerNode:
|
||||
def debug_str_extra(self) -> str:
|
||||
return ""
|
||||
|
||||
def _debug_str_for_device(self) -> List[str]:
|
||||
def _debug_str_for_device(self) -> list[str]:
|
||||
return self.debug_device_str(self)
|
||||
|
||||
def debug_str_short(self) -> str:
|
||||
@ -278,7 +270,7 @@ class BaseSchedulerNode:
|
||||
) -> None:
|
||||
return
|
||||
|
||||
def update_mutated_names(self, renames: Dict[str, str]) -> None:
|
||||
def update_mutated_names(self, renames: dict[str, str]) -> None:
|
||||
self.set_read_writes(self.read_writes.rename(renames))
|
||||
|
||||
def add_fake_dep(self, dep: Dep) -> None:
|
||||
@ -295,7 +287,7 @@ class BaseSchedulerNode:
|
||||
self.prune_deps()
|
||||
|
||||
def set_last_usage(
|
||||
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
|
||||
self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
|
||||
) -> None:
|
||||
used_buffers = self.used_or_aliased_buffer_names()
|
||||
used_buffers = OrderedSet(mutation_real_name.get(k, k) for k in used_buffers)
|
||||
@ -352,7 +344,7 @@ class BaseSchedulerNode:
|
||||
self.set_read_writes(self.read_writes.remove_reads(to_remove))
|
||||
|
||||
def prune_redundant_deps(
|
||||
self, name_to_fused_node: Dict[str, BaseSchedulerNode]
|
||||
self, name_to_fused_node: dict[str, BaseSchedulerNode]
|
||||
) -> None:
|
||||
_prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
|
||||
|
||||
@ -582,7 +574,7 @@ class BaseSchedulerNode:
|
||||
|
||||
def get_read_write_buffer_accesses(
|
||||
self, include_reads: bool, include_writes: bool
|
||||
) -> Dict[str, int]:
|
||||
) -> dict[str, int]:
|
||||
"""
|
||||
Counting the number of bytes accessed for a kernel is
|
||||
surprisingly tricky. In particular, there is a differentiation
|
||||
@ -658,7 +650,7 @@ class BaseSchedulerNode:
|
||||
writes = writes - removed_buffers
|
||||
reads = reads - removed_buffers
|
||||
|
||||
buf_byte_accesses: Dict[str, int] = {}
|
||||
buf_byte_accesses: dict[str, int] = {}
|
||||
|
||||
for buf_name in reads | writes:
|
||||
buf_accessed_elems = sum(node_numel for dep in buf_accesses[buf_name])
|
||||
@ -811,8 +803,8 @@ class BaseSchedulerNode:
|
||||
|
||||
@staticmethod
|
||||
def get_prologue_template_epilogue(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
) -> tuple[List[BaseSchedulerNode], BaseSchedulerNode, List[BaseSchedulerNode]]:
|
||||
nodes: list[BaseSchedulerNode],
|
||||
) -> tuple[list[BaseSchedulerNode], BaseSchedulerNode, list[BaseSchedulerNode]]:
|
||||
"""
|
||||
For the list of nodes, get the prologue, template, and epilogue
|
||||
"""
|
||||
@ -874,8 +866,8 @@ class OutputNode:
|
||||
|
||||
def _prune_redundant_deps(
|
||||
node: BaseSchedulerNode,
|
||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, SchedulerBuffer],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
) -> None:
|
||||
"""
|
||||
Prunes weakdeps intended for mutation ordering
|
||||
@ -961,7 +953,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def _compute_attrs(
|
||||
self,
|
||||
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
|
||||
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
|
||||
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
||||
) -> None:
|
||||
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer))
|
||||
@ -993,7 +985,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def recompute_size_and_body(
|
||||
self,
|
||||
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
|
||||
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
|
||||
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
||||
) -> None:
|
||||
self._compute_attrs(
|
||||
@ -1120,7 +1112,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||
|
||||
def ranges_from_index_vars(
|
||||
self, index_vars: Sequence[Sequence[sympy.Expr]]
|
||||
) -> Dict[sympy.Expr, sympy.Expr]:
|
||||
) -> dict[sympy.Expr, sympy.Expr]:
|
||||
sizes = self._sizes
|
||||
assert sum(map(len, sizes)) == sum(map(len, index_vars))
|
||||
var_ranges = dict(
|
||||
@ -1220,7 +1212,7 @@ def refresh_group_node_dependencies(group_snode: BaseSchedulerNode) -> None:
|
||||
def init_group_node(
|
||||
group_snode: BaseSchedulerNode,
|
||||
scheduler: Scheduler,
|
||||
snodes: List[BaseSchedulerNode],
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> None:
|
||||
assert isinstance(group_snode, (FusedSchedulerNode, GroupedSchedulerNode))
|
||||
group_snode.snodes = snodes
|
||||
@ -1246,7 +1238,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
its unmet dependencies as the union of its constituent nodes.
|
||||
"""
|
||||
|
||||
snodes: List[BaseSchedulerNode]
|
||||
snodes: list[BaseSchedulerNode]
|
||||
|
||||
@classmethod
|
||||
def fuse(
|
||||
@ -1319,10 +1311,10 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
refresh_group_node_dependencies(self)
|
||||
|
||||
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
|
||||
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
|
||||
super().__init__(scheduler)
|
||||
init_group_node(self, scheduler, snodes)
|
||||
self.users: List[NodeUser] = []
|
||||
self.users: list[NodeUser] = []
|
||||
self.group = max(snodes, key=lambda x: int(x.is_reduction())).group
|
||||
|
||||
@cache_on_self
|
||||
@ -1336,8 +1328,8 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
def get_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
|
||||
|
||||
def get_outputs(self) -> List[SchedulerBuffer]:
|
||||
result: List[SchedulerBuffer] = []
|
||||
def get_outputs(self) -> list[SchedulerBuffer]:
|
||||
result: list[SchedulerBuffer] = []
|
||||
for node in self.snodes:
|
||||
result.extend(node.get_outputs())
|
||||
return result
|
||||
@ -1358,7 +1350,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
return f"{self}, snodes: {snodes_str}"
|
||||
|
||||
def set_last_usage(
|
||||
self, future_used_buffers: OrderedSet[str], mutation_real_name: Dict[str, str]
|
||||
self, future_used_buffers: OrderedSet[str], mutation_real_name: dict[str, str]
|
||||
) -> None:
|
||||
# Set self.last_usage using the global information
|
||||
# This will be used for inter-kernel optimisations
|
||||
@ -1414,7 +1406,7 @@ class FusedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
# None of these need to be implemented, as a FusedSchedulerNode is just an
|
||||
# abstraction for scheduling purposes
|
||||
def update_mutated_names(self, renames: Dict[str, str]) -> None:
|
||||
def update_mutated_names(self, renames: dict[str, str]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_fake_dep(self, name: Dep) -> None:
|
||||
@ -1546,7 +1538,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
enable_autotune = consumer.enable_autotune
|
||||
prev_node_1 = None
|
||||
prev_node_2 = None
|
||||
fused_nodes: List[BaseSchedulerNode]
|
||||
fused_nodes: list[BaseSchedulerNode]
|
||||
if producer.is_foreach() and consumer.is_foreach():
|
||||
producer = typing.cast(ForeachKernelSchedulerNode, producer)
|
||||
consumer = typing.cast(ForeachKernelSchedulerNode, consumer)
|
||||
@ -1599,7 +1591,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
def __init__(
|
||||
self,
|
||||
scheduler: Scheduler,
|
||||
snodes: List[BaseSchedulerNode],
|
||||
snodes: list[BaseSchedulerNode],
|
||||
use_custom_partition_algo: bool,
|
||||
prev_node_1: Optional[BaseSchedulerNode] = None,
|
||||
prev_node_2: Optional[BaseSchedulerNode] = None,
|
||||
@ -1621,7 +1613,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
self.scheduler = scheduler
|
||||
self.snodes = snodes
|
||||
self.node = None
|
||||
self.users: List[NodeUser] = []
|
||||
self.users: list[NodeUser] = []
|
||||
|
||||
self.set_read_writes(
|
||||
dependencies.ReadWrites.merge_list(
|
||||
@ -1666,8 +1658,8 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
|
||||
@classmethod
|
||||
def combinable_nodes(
|
||||
cls, nodes: List[BaseSchedulerNode]
|
||||
) -> List[BaseSchedulerNode]:
|
||||
cls, nodes: list[BaseSchedulerNode]
|
||||
) -> list[BaseSchedulerNode]:
|
||||
extern = [x for x in nodes if isinstance(x, ExternKernelSchedulerNode)]
|
||||
if extern:
|
||||
log.debug(
|
||||
@ -1700,7 +1692,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
@staticmethod
|
||||
def _default_group_nodes_for_combo_kernels(
|
||||
scheduler: Scheduler,
|
||||
) -> List[List[BaseSchedulerNode]]:
|
||||
) -> list[list[BaseSchedulerNode]]:
|
||||
"""
|
||||
Returns a list of lists of nodes that are to be grouped together.
|
||||
"""
|
||||
@ -1718,12 +1710,12 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
return grouped_nodes
|
||||
|
||||
group_algorithm_for_combo_kernels: Callable[
|
||||
[Scheduler], List[List[BaseSchedulerNode]]
|
||||
[Scheduler], list[list[BaseSchedulerNode]]
|
||||
] = _default_group_nodes_for_combo_kernels
|
||||
|
||||
@staticmethod
|
||||
def set_group_algorithm_for_combo_kernels(
|
||||
custom_group_algorithm: Callable[[Scheduler], List[List[BaseSchedulerNode]]]
|
||||
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]]
|
||||
) -> None:
|
||||
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
|
||||
custom_group_algorithm
|
||||
@ -1732,7 +1724,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
@staticmethod
|
||||
def group_nodes_for_combo_kernels(
|
||||
scheduler: Scheduler,
|
||||
) -> List[List[BaseSchedulerNode]]:
|
||||
) -> list[list[BaseSchedulerNode]]:
|
||||
return ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels(scheduler)
|
||||
|
||||
def mark_run(self) -> None:
|
||||
@ -1744,7 +1736,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
def is_foreach(self) -> bool:
|
||||
return True
|
||||
|
||||
def get_subkernel_nodes(self) -> List[BaseSchedulerNode]:
|
||||
def get_subkernel_nodes(self) -> list[BaseSchedulerNode]:
|
||||
"""Returns a list of nodes which comprise the combo kernel.
|
||||
These nodes may be vertically fused."""
|
||||
return list(self.snodes)
|
||||
@ -1758,7 +1750,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||
return self.snodes[0].get_first_name()
|
||||
|
||||
def prune_redundant_deps(
|
||||
self, name_to_fused_node: Dict[str, BaseSchedulerNode]
|
||||
self, name_to_fused_node: dict[str, BaseSchedulerNode]
|
||||
) -> None:
|
||||
_prune_redundant_deps(self, name_to_fused_node, self.scheduler.name_to_buf)
|
||||
|
||||
@ -1776,10 +1768,10 @@ class GroupedSchedulerNode(BaseSchedulerNode):
|
||||
At codegen time, this scheduler node will be unpacked and codegen is called on each constituent node.
|
||||
"""
|
||||
|
||||
snodes: List[BaseSchedulerNode]
|
||||
snodes: list[BaseSchedulerNode]
|
||||
|
||||
@classmethod
|
||||
def create(cls, snodes: List[BaseSchedulerNode]) -> GroupedSchedulerNode:
|
||||
def create(cls, snodes: list[BaseSchedulerNode]) -> GroupedSchedulerNode:
|
||||
scheduler = snodes[0].scheduler
|
||||
assert all(node.scheduler is scheduler for node in snodes)
|
||||
grouped_snode = cls(scheduler, snodes) # type: ignore[arg-type]
|
||||
@ -1788,11 +1780,11 @@ class GroupedSchedulerNode(BaseSchedulerNode):
|
||||
scheduler.name_to_fused_node[grouped_snode.get_name()] = grouped_snode
|
||||
return grouped_snode
|
||||
|
||||
def __init__(self, scheduler: Scheduler, snodes: List[BaseSchedulerNode]) -> None:
|
||||
def __init__(self, scheduler: Scheduler, snodes: list[BaseSchedulerNode]) -> None:
|
||||
super().__init__(scheduler)
|
||||
init_group_node(self, scheduler, snodes)
|
||||
|
||||
def unpack(self) -> List[BaseSchedulerNode]:
|
||||
def unpack(self) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Do fusion among nodes within this GroupedSchedulerNode,
|
||||
and then unpack this GroupedSchedulerNode into regular nodes.
|
||||
@ -1817,8 +1809,8 @@ class GroupedSchedulerNode(BaseSchedulerNode):
|
||||
def get_buffer_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet.union(*[x.get_buffer_names() for x in self.snodes])
|
||||
|
||||
def get_outputs(self) -> List[SchedulerBuffer]:
|
||||
result: List[SchedulerBuffer] = []
|
||||
def get_outputs(self) -> list[SchedulerBuffer]:
|
||||
result: list[SchedulerBuffer] = []
|
||||
for node in self.snodes:
|
||||
result.extend(node.get_outputs())
|
||||
return result
|
||||
@ -1833,10 +1825,10 @@ class GroupedSchedulerNode(BaseSchedulerNode):
|
||||
|
||||
|
||||
def pick_loop_order(
|
||||
stride_lengths: List[List[int]],
|
||||
stride_lengths: list[list[int]],
|
||||
sizes: Sequence[sympy.Expr],
|
||||
priority_idx: tuple[int, ...] = (),
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""
|
||||
A heuristic to decide loop iteration orders. This has not been well
|
||||
tuned and may be something we should autotune.
|
||||
@ -1914,17 +1906,17 @@ _post_grad_graph_counter = itertools.count()
|
||||
|
||||
|
||||
class Scheduler:
|
||||
__dep_size_hint_cache: Dict[Dep, int]
|
||||
__dep_size_hint_cache: dict[Dep, int]
|
||||
|
||||
def __init__(self, nodes: List[ir.Operation]) -> None:
|
||||
def __init__(self, nodes: list[ir.Operation]) -> None:
|
||||
with dynamo_timed("Scheduler.__init__"):
|
||||
self._init(nodes)
|
||||
|
||||
def _init(self, nodes: List[ir.Operation]) -> None:
|
||||
def _init(self, nodes: list[ir.Operation]) -> None:
|
||||
super().__init__()
|
||||
self.__dep_size_hint_cache = {}
|
||||
V.graph.scheduler = self
|
||||
self.backends: Dict[torch.device, BaseScheduling] = {}
|
||||
self.backends: dict[torch.device, BaseScheduling] = {}
|
||||
self.post_grad_graph_id = next(_post_grad_graph_counter)
|
||||
|
||||
self.completed_operations = OrderedSet[str]()
|
||||
@ -1943,23 +1935,23 @@ class Scheduler:
|
||||
for node in self.nodes:
|
||||
node.prune_deps()
|
||||
|
||||
self.name_to_donated_buffer: Dict[
|
||||
self.name_to_donated_buffer: dict[
|
||||
str, SchedulerDonatedBuffer
|
||||
] = self.get_donated_buffers()
|
||||
self.name_to_node: Dict[str, BaseSchedulerNode] = {
|
||||
self.name_to_node: dict[str, BaseSchedulerNode] = {
|
||||
n.get_name(): n for n in self.nodes
|
||||
}
|
||||
self.name_to_buf: Dict[str, SchedulerBuffer] = {
|
||||
self.name_to_buf: dict[str, SchedulerBuffer] = {
|
||||
buf.get_name(): buf for node in self.nodes for buf in node.get_outputs()
|
||||
}
|
||||
self.name_to_fused_node: Dict[str, BaseSchedulerNode] = self.name_to_node.copy()
|
||||
self.name_to_fused_node: dict[str, BaseSchedulerNode] = self.name_to_node.copy()
|
||||
|
||||
# mutation_real_name: Maps back to the original name for codegen
|
||||
# Example:
|
||||
# If you mutate buf0 inside of buf1's kernel, then:
|
||||
# mutation_real_name = {"buf0" : "buf1"}
|
||||
# all subsequent uses of buf0 become buf1's usage in dependency graph
|
||||
self.mutation_real_name: Dict[str, str] = {}
|
||||
self.mutation_real_name: dict[str, str] = {}
|
||||
|
||||
# We handle mutation by renaming modified versions of the same
|
||||
# buffer in the dependency graph to prevent cycles.
|
||||
@ -1969,7 +1961,7 @@ class Scheduler:
|
||||
# If you mutate buf0 inside of buf1's kernel, then:
|
||||
# mutation_renames = {"buf1" : "buf0"}
|
||||
# in codegen we only use buf0, never buf1
|
||||
self.mutation_renames: Dict[str, str] = {}
|
||||
self.mutation_renames: dict[str, str] = {}
|
||||
|
||||
# Must run first to correctly set dependencies, before all other passes that rely on
|
||||
# reading from .read_writes.reads or .unmet_dependencies
|
||||
@ -2024,7 +2016,7 @@ class Scheduler:
|
||||
|
||||
# fx graph node to the position it appears in the graph
|
||||
# for debug attribution
|
||||
self.origin_to_index: Dict[torch.fx.Node, int] = {}
|
||||
self.origin_to_index: dict[torch.fx.Node, int] = {}
|
||||
|
||||
get_metric_table("graph_stats").add_row(
|
||||
lambda: {
|
||||
@ -2034,7 +2026,7 @@ class Scheduler:
|
||||
}
|
||||
)
|
||||
|
||||
def get_donated_buffers(self) -> Dict[str, SchedulerDonatedBuffer]:
|
||||
def get_donated_buffers(self) -> dict[str, SchedulerDonatedBuffer]:
|
||||
name_to_donated_buf = {}
|
||||
for name in V.graph.graph_inputs_original:
|
||||
if isinstance(V.graph.graph_inputs_original[name], ir.DonatedBuffer):
|
||||
@ -2135,7 +2127,7 @@ class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
items: Optional[List[T]] = None,
|
||||
items: Optional[list[T]] = None,
|
||||
membership: Optional[OrderedSet[T]] = None,
|
||||
) -> None:
|
||||
self.items = items or []
|
||||
@ -2154,7 +2146,7 @@ class Scheduler:
|
||||
]
|
||||
return DedupList(new_items, new_membership)
|
||||
|
||||
name_to_users: DefaultDict[str, DedupList[NodeUser]] = collections.defaultdict(
|
||||
name_to_users: defaultdict[str, DedupList[NodeUser]] = collections.defaultdict(
|
||||
DedupList
|
||||
)
|
||||
|
||||
@ -2196,7 +2188,7 @@ class Scheduler:
|
||||
NodeUser(user_node, can_inplace, is_weak)
|
||||
)
|
||||
|
||||
unbacked_symbol_to_origin_node: Dict[sympy.Symbol, Optional[str]] = {}
|
||||
unbacked_symbol_to_origin_node: dict[sympy.Symbol, Optional[str]] = {}
|
||||
|
||||
# NB: None means that the dependency is on an input. Don't actually
|
||||
# generate a dependency because if we do, Inductor will start trying
|
||||
@ -2367,14 +2359,14 @@ class Scheduler:
|
||||
node.prune_weak_deps()
|
||||
|
||||
def topological_sort_schedule(
|
||||
self, nodes: List[BaseSchedulerNode]
|
||||
) -> List[BaseSchedulerNode]:
|
||||
self, nodes: list[BaseSchedulerNode]
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Ensure nodes is in topologically sorted order
|
||||
"""
|
||||
seen = OrderedSet[BaseSchedulerNode]()
|
||||
name_to_node: Dict[str, BaseSchedulerNode] = dict()
|
||||
result: List[BaseSchedulerNode] = []
|
||||
name_to_node: dict[str, BaseSchedulerNode] = dict()
|
||||
result: list[BaseSchedulerNode] = []
|
||||
|
||||
def visit(n: BaseSchedulerNode) -> None:
|
||||
if n not in seen:
|
||||
@ -2393,7 +2385,7 @@ class Scheduler:
|
||||
visit(node)
|
||||
return result
|
||||
|
||||
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> List[BaseSchedulerNode]:
|
||||
def _get_unmet_dep_nodes(self, snode: BaseSchedulerNode) -> list[BaseSchedulerNode]:
|
||||
unmet_deps = OrderedSet[str]()
|
||||
if isinstance(
|
||||
snode,
|
||||
@ -2415,13 +2407,13 @@ class Scheduler:
|
||||
OrderedSet(self.name_to_fused_node[n.get_name()] for n in unmet_dep_ops)
|
||||
)
|
||||
|
||||
def _topological_sort_nodes(self) -> List[List[BaseSchedulerNode]]:
|
||||
def _topological_sort_nodes(self) -> list[list[BaseSchedulerNode]]:
|
||||
"""
|
||||
Sort nodes by their topological order, return a list of node lists.
|
||||
"""
|
||||
order = []
|
||||
nodes = dict.fromkeys(self.nodes, 0)
|
||||
children: Dict[Any, Any] = {}
|
||||
children: dict[Any, Any] = {}
|
||||
for node in self.nodes:
|
||||
deps = self._get_unmet_dep_nodes(node)
|
||||
nodes[node] = len(deps)
|
||||
@ -2446,7 +2438,7 @@ class Scheduler:
|
||||
Populate each node.ancestors
|
||||
"""
|
||||
# note self.nodes is topologically sorted
|
||||
name_to_ancestors: Dict[str, OrderedSet[str]] = {}
|
||||
name_to_ancestors: dict[str, OrderedSet[str]] = {}
|
||||
for node in self.nodes:
|
||||
ancestors = OrderedSet[str]()
|
||||
for dep in node.unmet_dependencies:
|
||||
@ -2487,7 +2479,7 @@ class Scheduler:
|
||||
# FusedSchedulerNode having different merged loops.
|
||||
# Skip CPU backend for now.
|
||||
|
||||
def fuse_nodes(self, nodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def fuse_nodes(self, nodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Combine eligible nodes into FusedSchedulerNodes.
|
||||
"""
|
||||
@ -2518,7 +2510,7 @@ class Scheduler:
|
||||
"""
|
||||
Unpack GroupedSchedulerNode into regular nodes.
|
||||
"""
|
||||
new_nodes: List[BaseSchedulerNode] = []
|
||||
new_nodes: list[BaseSchedulerNode] = []
|
||||
for node in self.nodes:
|
||||
new_nodes.extend(
|
||||
node.unpack() if isinstance(node, GroupedSchedulerNode) else [node]
|
||||
@ -2802,8 +2794,8 @@ class Scheduler:
|
||||
return ms_fused < ms1 + ms2
|
||||
|
||||
def fuse_nodes_once(
|
||||
self, nodes: List[BaseSchedulerNode]
|
||||
) -> List[BaseSchedulerNode]:
|
||||
self, nodes: list[BaseSchedulerNode]
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Combine eligible nodes into FusedSchedulerNodes.
|
||||
|
||||
@ -2890,20 +2882,20 @@ class Scheduler:
|
||||
)
|
||||
self.prune_redundant_deps(self.nodes)
|
||||
|
||||
def prune_redundant_deps(self, nodes: List[BaseSchedulerNode]) -> None:
|
||||
def prune_redundant_deps(self, nodes: list[BaseSchedulerNode]) -> None:
|
||||
for node in nodes:
|
||||
node.prune_redundant_deps(self.name_to_fused_node)
|
||||
|
||||
def get_possible_fusions(
|
||||
self, nodes: List[BaseSchedulerNode]
|
||||
) -> List[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
|
||||
self, nodes: list[BaseSchedulerNode]
|
||||
) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
|
||||
"""
|
||||
Helper to find all legal fusion opportunities, sorted by self.score_fusion()
|
||||
"""
|
||||
possible_fusions = []
|
||||
seen = OrderedSet[tuple[BaseSchedulerNode, BaseSchedulerNode]]()
|
||||
|
||||
def check_all_pairs(nodes: List[BaseSchedulerNode]) -> None:
|
||||
def check_all_pairs(nodes: list[BaseSchedulerNode]) -> None:
|
||||
for node1_index, node1 in enumerate(nodes):
|
||||
for node2 in nodes[node1_index + 1 :]:
|
||||
key = (node1, node2)
|
||||
@ -3015,7 +3007,7 @@ class Scheduler:
|
||||
|
||||
def _find_single_user_inputs(
|
||||
node: BaseSchedulerNode,
|
||||
) -> List[ir.Buffer]:
|
||||
) -> list[ir.Buffer]:
|
||||
output = []
|
||||
for rd in node.read_writes.reads:
|
||||
buf = self.name_to_buf.get(rd.name)
|
||||
@ -3403,7 +3395,7 @@ class Scheduler:
|
||||
"""
|
||||
node1_buf_names = node1.get_buffer_names()
|
||||
why = WhyNoFuse(node1, node2)
|
||||
remaining_deps_by_name: Dict[str, List[Dep]] = defaultdict(list)
|
||||
remaining_deps_by_name: dict[str, list[Dep]] = defaultdict(list)
|
||||
|
||||
for dep in node2.unmet_dependencies:
|
||||
name = self.mutation_renames.get(dep.name, dep.name)
|
||||
@ -3562,14 +3554,14 @@ class Scheduler:
|
||||
return sum(self.dep_size_hint(dep) for dep in common_memory_deps)
|
||||
|
||||
def get_possible_fusions_with_highest_priority(
|
||||
self, possible_fusions: List[tuple[BaseSchedulerNode, BaseSchedulerNode]]
|
||||
) -> List[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
|
||||
self, possible_fusions: list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
|
||||
) -> list[tuple[BaseSchedulerNode, BaseSchedulerNode]]:
|
||||
# Group the possible fusions based on their priority from the backend.
|
||||
# Only return the group of possible fusions with highest priority.
|
||||
if len(possible_fusions) == 0:
|
||||
return possible_fusions
|
||||
possible_fusions_group_by_priority: Dict[
|
||||
int, List[tuple[BaseSchedulerNode, BaseSchedulerNode]]
|
||||
possible_fusions_group_by_priority: dict[
|
||||
int, list[tuple[BaseSchedulerNode, BaseSchedulerNode]]
|
||||
] = {}
|
||||
|
||||
for node1, node2 in possible_fusions:
|
||||
@ -3828,7 +3820,7 @@ class Scheduler:
|
||||
backend = self.get_backend(device)
|
||||
return backend.benchmark_combo_kernel(node_list)
|
||||
|
||||
def speedup_by_combo_kernel(self, nodes: List[BaseSchedulerNode]) -> bool:
|
||||
def speedup_by_combo_kernel(self, nodes: list[BaseSchedulerNode]) -> bool:
|
||||
"""
|
||||
If config.benchmark_fusion is False, always return True.
|
||||
Otherwise, return True if fusion can brings speedup.
|
||||
|
@ -16,17 +16,8 @@ import textwrap
|
||||
import time
|
||||
from concurrent.futures import as_completed, ThreadPoolExecutor
|
||||
from io import StringIO
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import Self
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -86,7 +77,7 @@ from .virtualized import V
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# correctness checks struggle with fp16/tf32
|
||||
VERIFY: Dict[str, Any] = {}
|
||||
VERIFY: dict[str, Any] = {}
|
||||
PRINT_AUTOTUNE = True
|
||||
DEBUG = False
|
||||
|
||||
@ -104,14 +95,11 @@ class KernelNamespace:
|
||||
extern_kernels = KernelNamespace()
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound="AutotuneArgs")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class BenchmarkTensors:
|
||||
"""Represents a set of inputs and outputs for autotuning with a template"""
|
||||
|
||||
input_tensors: List[torch.Tensor]
|
||||
input_tensors: list[torch.Tensor]
|
||||
output_tensor: Optional[torch.Tensor]
|
||||
|
||||
def unpack(self):
|
||||
@ -139,13 +127,13 @@ class AutotuneArgs:
|
||||
|
||||
@classmethod
|
||||
def from_choice_args(
|
||||
cls: Type[_T],
|
||||
example_inputs: List[torch.Tensor],
|
||||
example_inputs_extern: List[torch.Tensor],
|
||||
cls,
|
||||
example_inputs: list[torch.Tensor],
|
||||
example_inputs_extern: list[torch.Tensor],
|
||||
out: torch.Tensor,
|
||||
out_extern: torch.Tensor,
|
||||
expected: Optional[torch.Tensor] = None,
|
||||
) -> _T:
|
||||
) -> Self:
|
||||
"""Factory method to create AutotuneInputs from separate inputs/outputs"""
|
||||
return cls(
|
||||
triton=BenchmarkTensors(example_inputs, out),
|
||||
@ -207,7 +195,7 @@ class SubgraphInfo:
|
||||
ops_handler: Optional[V.WrapperHandler] = None # type: ignore[name-defined]
|
||||
|
||||
# only copied over if not None
|
||||
range_trees: Optional[List["IterationRangesRoot"]] = None
|
||||
range_trees: Optional[list["IterationRangesRoot"]] = None
|
||||
numels = None # type: ignore[var-annotated]
|
||||
|
||||
def __post_init__(self):
|
||||
@ -226,7 +214,7 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self,
|
||||
kernel,
|
||||
subgraph_number: int,
|
||||
fixed_inputs: Dict[str, Any],
|
||||
fixed_inputs: dict[str, Any],
|
||||
mask: Optional[str],
|
||||
):
|
||||
super().__init__(V.ops)
|
||||
@ -290,7 +278,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
prefix_args=0,
|
||||
suffix_args=0,
|
||||
epilogue_fn=identity,
|
||||
subgraphs: Optional[List[ir.ComputedBuffer]] = None,
|
||||
subgraphs: Optional[list[ir.ComputedBuffer]] = None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
) -> None:
|
||||
numel = sympy_product(output_node.get_size())
|
||||
@ -317,9 +305,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self.suffix_args = suffix_args
|
||||
self.epilogue_fn = epilogue_fn
|
||||
self.render_hooks = {} # type: ignore[var-annotated]
|
||||
self.triton_meta: Optional[Dict[str, object]] = None
|
||||
self.triton_meta: Optional[dict[str, object]] = None
|
||||
# For Templated Attention this can be a list of ir.Subgraph
|
||||
self.subgraphs: Optional[List[ir.ComputedBuffer]] = subgraphs
|
||||
self.subgraphs: Optional[list[ir.ComputedBuffer]] = subgraphs
|
||||
|
||||
# Some templates use extra global memory as a workspace
|
||||
self.workspace_arg = workspace_arg
|
||||
@ -330,7 +318,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
# used for triton kernel codegen.
|
||||
# They are swapped onto the TritonTemplateKernel object by
|
||||
# `set_subgraph_body`
|
||||
self.subgraph_bodies: Dict[str, SubgraphInfo] = {}
|
||||
self.subgraph_bodies: dict[str, SubgraphInfo] = {}
|
||||
|
||||
# input buffers which we are allowed to prologue fuse into
|
||||
self.prologue_supported_inputs: OrderedSet[str] = OrderedSet()
|
||||
@ -420,7 +408,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
return "@triton.jit"
|
||||
|
||||
argdefs, _, signature, _ = self.args.python_argdefs()
|
||||
triton_meta: Dict[str, Any] = {
|
||||
triton_meta: dict[str, Any] = {
|
||||
"signature": signature_to_meta(
|
||||
signature, size_dtype=self.index_dtype, argdefs=argdefs
|
||||
),
|
||||
@ -622,7 +610,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
)
|
||||
with V.set_ops_handler(modification_handler):
|
||||
assert isinstance(
|
||||
subgraph, (ir.ComputedBuffer, List)
|
||||
subgraph, (ir.ComputedBuffer, list)
|
||||
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
|
||||
# Handle scatter stores
|
||||
if isinstance(subgraph, list):
|
||||
@ -651,7 +639,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self,
|
||||
input_name: str,
|
||||
output_name: str,
|
||||
indices: Union[List[Any], tuple[Any]],
|
||||
indices: Union[list[Any], tuple[Any]],
|
||||
mask: Optional[str] = None,
|
||||
other: Optional[Union[float, int]] = 0.0,
|
||||
indent_width: int = 4,
|
||||
@ -826,7 +814,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
|
||||
def store_output(
|
||||
self,
|
||||
indices: Union[List[Any], tuple[Any]],
|
||||
indices: Union[list[Any], tuple[Any]],
|
||||
val: str,
|
||||
mask: Optional[str] = None,
|
||||
indent_width: int = 4,
|
||||
@ -1032,7 +1020,7 @@ def _jinja2_env():
|
||||
|
||||
class TritonTemplate(KernelTemplate):
|
||||
index_counter = itertools.count()
|
||||
all_templates: Dict[str, "TritonTemplate"] = {}
|
||||
all_templates: dict[str, "TritonTemplate"] = {}
|
||||
|
||||
def __init__(self, name: str, grid: Any, source: str, debug=False) -> None:
|
||||
super().__init__(name)
|
||||
@ -1180,7 +1168,7 @@ class TritonTemplate(KernelTemplate):
|
||||
),
|
||||
kwargs,
|
||||
)
|
||||
bmreq_cls: Type[TritonBenchmarkRequest]
|
||||
bmreq_cls: type[TritonBenchmarkRequest]
|
||||
if layout.device.type == "cpu":
|
||||
bmreq_cls = TritonCPUBenchmarkRequest
|
||||
else:
|
||||
@ -1292,7 +1280,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
description,
|
||||
bmreq,
|
||||
log_info: Optional[
|
||||
Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]
|
||||
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
||||
] = None,
|
||||
mutated_inputs=None,
|
||||
workspace_arg: Optional[WorkspaceArg] = None,
|
||||
@ -1303,7 +1291,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
self.bmreq: TritonBenchmarkRequest = bmreq
|
||||
if log_info is None:
|
||||
log_info = {}
|
||||
self.log_info: Dict[str, Any] = log_info
|
||||
self.log_info: dict[str, Any] = log_info
|
||||
self.log_info.update(
|
||||
{
|
||||
"backend": "Triton",
|
||||
@ -1351,7 +1339,7 @@ class TritonTemplateCaller(ir.TritonTemplateCallerBase):
|
||||
)
|
||||
)
|
||||
|
||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
||||
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return self.log_info
|
||||
|
||||
@ -1447,7 +1435,7 @@ class ExternKernelCaller(ChoiceCaller):
|
||||
|
||||
return ir.TensorBox.create(inner)
|
||||
|
||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
||||
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return {
|
||||
"backend": "extern",
|
||||
@ -1589,7 +1577,7 @@ def create_inputs_key(input_nodes) -> str:
|
||||
|
||||
|
||||
def create_precompile_key(
|
||||
name: str, inputs_key: str, choices: List[ChoiceCaller]
|
||||
name: str, inputs_key: str, choices: list[ChoiceCaller]
|
||||
) -> str:
|
||||
return ":".join(
|
||||
[
|
||||
@ -1609,18 +1597,18 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# no guarantee that the first lowering for a given key will also be the
|
||||
# first to benchmark it. share a single precompilation function for all lowerings
|
||||
# of a particular key
|
||||
self.precompile_cache: Dict[str, Callable[[], None]] = {}
|
||||
self.precompile_cache: dict[str, Callable[[], None]] = {}
|
||||
# list of callbacks that are called after benchmarking
|
||||
self.feedback_saver_fns: List[
|
||||
self.feedback_saver_fns: list[
|
||||
Callable[
|
||||
[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
|
||||
[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None
|
||||
]
|
||||
] = []
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
name,
|
||||
choices: List[ChoiceCaller],
|
||||
choices: list[ChoiceCaller],
|
||||
input_nodes,
|
||||
layout,
|
||||
# optional dict mapping arg indices to the functions
|
||||
@ -1628,7 +1616,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
# corresponding ir.Buffer. if passed for a given
|
||||
# arg, the function will be called instead of
|
||||
# generating a random torch.Tensor for benchmarking.
|
||||
input_gen_fns: Optional[Dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
|
||||
input_gen_fns: Optional[dict[int, Callable[[ir.Buffer], torch.Tensor]]] = None,
|
||||
precompilation_timeout_seconds: int = 60 * 60,
|
||||
return_multi_template=False,
|
||||
):
|
||||
@ -1761,9 +1749,9 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
executor = ThreadPoolExecutor(max_workers=num_workers)
|
||||
async_compile = torch._inductor.async_compile.AsyncCompile()
|
||||
|
||||
futures: Dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
|
||||
start_times: Dict[concurrent.futures.Future[Any], float] = {}
|
||||
elapsed_times: Dict[concurrent.futures.Future[Any], float] = {}
|
||||
futures: dict[concurrent.futures.Future[Any], ChoiceCaller] = {}
|
||||
start_times: dict[concurrent.futures.Future[Any], float] = {}
|
||||
elapsed_times: dict[concurrent.futures.Future[Any], float] = {}
|
||||
|
||||
for c in choices:
|
||||
if hasattr(c, "precompile"):
|
||||
@ -1925,7 +1913,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
input_gen_fns = {}
|
||||
|
||||
def get_inputs(
|
||||
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]]
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
|
||||
) -> AutotuneArgs:
|
||||
# de-duplicate args
|
||||
unique_example_inputs = {
|
||||
@ -1996,8 +1984,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
return result
|
||||
|
||||
def benchmark_in_current_process(
|
||||
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]],
|
||||
) -> Dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
|
||||
) -> dict[Union[ExternKernelCaller, TritonTemplateCaller], float]:
|
||||
inputs = get_inputs(choices)
|
||||
timings = {}
|
||||
for choice in choices:
|
||||
@ -2045,7 +2033,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
return timings
|
||||
|
||||
def benchmark_in_sub_process(
|
||||
choices: Union[List[ExternKernelCaller], List[TritonTemplateCaller]]
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
|
||||
):
|
||||
from . import autotune_process
|
||||
|
||||
@ -2069,8 +2057,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
@staticmethod
|
||||
def log_results(
|
||||
name: str,
|
||||
input_nodes: List[ir.IRNode],
|
||||
timings: Dict[ChoiceCaller, float],
|
||||
input_nodes: list[ir.IRNode],
|
||||
timings: dict[ChoiceCaller, float],
|
||||
elapse: float,
|
||||
precompile_elapse: float,
|
||||
):
|
||||
@ -2227,7 +2215,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||
def add_feedback_saver(
|
||||
self,
|
||||
fn: Callable[
|
||||
[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None
|
||||
[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None
|
||||
],
|
||||
):
|
||||
self.feedback_saver_fns.append(fn)
|
||||
@ -2250,7 +2238,7 @@ def autotune_select_algorithm(*args, **kwargs):
|
||||
|
||||
|
||||
def add_feedback_saver(
|
||||
fn: Callable[[Dict[ChoiceCaller, float], str, List[Any], List[ChoiceCaller]], None]
|
||||
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None]
|
||||
):
|
||||
global _ALGORITHM_SELECTOR_CACHE
|
||||
if _ALGORITHM_SELECTOR_CACHE is None:
|
||||
|
@ -2,7 +2,8 @@
|
||||
import functools
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Callable, cast, Dict, Iterable, List, Optional, Sequence, Union
|
||||
from collections.abc import Iterable, Sequence
|
||||
from typing import Any, Callable, cast, Optional, Union
|
||||
|
||||
import sympy
|
||||
from sympy import Expr
|
||||
@ -60,7 +61,7 @@ class SizeVarAllocator:
|
||||
shape_env = ShapeEnv()
|
||||
self.shape_env = shape_env
|
||||
self.var_to_val = self.shape_env.var_to_val
|
||||
self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements
|
||||
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
|
||||
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
|
||||
# The basic idea is if we have some complicated sympy expression
|
||||
# f(s0), we may choose to precompute it on the host and then replace
|
||||
@ -71,8 +72,8 @@ class SizeVarAllocator:
|
||||
# which potentially could have already had a precomputed replacement
|
||||
# on it, we are obligated to invert the precomputed replacements
|
||||
# (inv_precomputed_replacements).
|
||||
self.precomputed_replacements: Dict[Expr, sympy.Symbol] = {}
|
||||
self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = {}
|
||||
self.precomputed_replacements: dict[Expr, sympy.Symbol] = {}
|
||||
self.inv_precomputed_replacements: dict[sympy.Symbol, Expr] = {}
|
||||
self.stride_vars = self.make_stride_vars_cache()
|
||||
self.simplify_with_ranges = self.make_simplify_with_ranges_cache()
|
||||
self._simplify_loops = self.make_simplify_loops_cache()
|
||||
@ -84,7 +85,7 @@ class SizeVarAllocator:
|
||||
"""
|
||||
self._simplify_with_ranges() can be expensive, cache its results
|
||||
"""
|
||||
cache: Dict[tuple[Any, ...], Expr] = {}
|
||||
cache: dict[tuple[Any, ...], Expr] = {}
|
||||
replacement_count = len(self.replacements)
|
||||
|
||||
def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr:
|
||||
@ -106,7 +107,7 @@ class SizeVarAllocator:
|
||||
"""
|
||||
self._simplify_with_ranges() can be expensive, cache its results
|
||||
"""
|
||||
cache: Dict[tuple[Any, ...], Any] = {}
|
||||
cache: dict[tuple[Any, ...], Any] = {}
|
||||
replacement_count = len(self.replacements)
|
||||
|
||||
def simplify_loops(index_vars, sizes, index_formulas):
|
||||
@ -221,7 +222,7 @@ class SizeVarAllocator:
|
||||
return expr
|
||||
|
||||
def _simplify_loops_impl(
|
||||
self, index_vars: List[sympy.Symbol], sizes, index_formulas
|
||||
self, index_vars: list[sympy.Symbol], sizes, index_formulas
|
||||
):
|
||||
"""
|
||||
Try to remove as many axis from loop iterations as possible, by:
|
||||
@ -337,7 +338,7 @@ class SizeVarAllocator:
|
||||
return self.is_expr_static_and_true(sympy.Eq(left, right)) # type: ignore[arg-type]
|
||||
|
||||
# See Note - [On Statically Known]
|
||||
def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool:
|
||||
def statically_known_list_equals(self, left: list[Expr], right: list[Expr]) -> bool:
|
||||
"""
|
||||
Returns a bool indicating if it is sound to optimize as if left and right lists are equal.
|
||||
"""
|
||||
@ -501,7 +502,7 @@ class SizeVarAllocator:
|
||||
self.guard_equals(left, sympy.Integer(right))
|
||||
return int(right)
|
||||
|
||||
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> List[int]:
|
||||
def evaluate_static_shapes(self, left: Sequence[Union[Expr, int]]) -> list[int]:
|
||||
return [self.evaluate_static_shape(x) for x in left]
|
||||
|
||||
def remove_precomputed_replacements(self, expr: Expr) -> Expr:
|
||||
@ -582,7 +583,7 @@ class SizeVarAllocator:
|
||||
index: Expr,
|
||||
vars: Sequence[sympy.Symbol],
|
||||
support_vars: Optional[Sequence[sympy.Symbol]] = None,
|
||||
) -> List[Expr]:
|
||||
) -> list[Expr]:
|
||||
if not support_vars:
|
||||
support_vars = vars
|
||||
return cache(index, tuple(vars), tuple(support_vars))
|
||||
@ -594,7 +595,7 @@ class SizeVarAllocator:
|
||||
index: Expr,
|
||||
vars: Sequence[sympy.Symbol],
|
||||
support_vars: Sequence[sympy.Symbol],
|
||||
) -> List[Expr]:
|
||||
) -> list[Expr]:
|
||||
"""Convert an indexing expression back into strides
|
||||
|
||||
NOTE: This is only valid if the index is a standard strided offset
|
||||
@ -647,7 +648,7 @@ class SizeVarAllocator:
|
||||
}
|
||||
return expr.subs(size_dict)
|
||||
|
||||
def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr:
|
||||
def offset_var(self, index: Expr, vars: list[sympy.Symbol]) -> Expr:
|
||||
"""Extract offset part of an indexing expression"""
|
||||
index = self.simplify(index)
|
||||
return sympy_subs(index, {v: sympy.S.Zero for v in vars if v != 0})
|
||||
@ -657,7 +658,7 @@ class SizeVarAllocator:
|
||||
index: Expr,
|
||||
vars: Sequence[sympy.Symbol],
|
||||
support_vars: Optional[Sequence[sympy.Symbol]] = None,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
for v in index.free_symbols:
|
||||
if symbol_is_type(v, SymT.INDIRECT): # type: ignore[attr-defined]
|
||||
index = sympy_subs(index, {v: 0}) # type: ignore[dict-item]
|
||||
@ -669,7 +670,7 @@ class SizeVarAllocator:
|
||||
result.append(0)
|
||||
return result
|
||||
|
||||
def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]:
|
||||
def stride_order(self, index: Expr, vars: list[sympy.Symbol]) -> list[int]:
|
||||
strides = tuple(map(abs, self.stride_hints(index, vars)))
|
||||
order = list(range(len(strides)))
|
||||
order.sort(key=lambda x: (strides[x] == 0, strides[x]))
|
||||
|
@ -2,9 +2,10 @@
|
||||
|
||||
import functools
|
||||
import operator
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -20,7 +21,7 @@ T = TypeVar("T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
OpOverload = torch._ops.OpOverload
|
||||
LoweringDict = Dict[Union[OpOverload, str], Callable[..., Any]]
|
||||
LoweringDict = dict[Union[OpOverload, str], Callable[..., Any]]
|
||||
TargetType = Union[Callable[..., Any], str]
|
||||
|
||||
|
||||
@ -30,13 +31,13 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
lowering object. Errors if buffers are created unexpectedly
|
||||
"""
|
||||
|
||||
graph_outputs: Optional[List[ir.IRNode]]
|
||||
graph_outputs: Optional[list[ir.IRNode]]
|
||||
root_graph: torch._inductor.graph.GraphLowering
|
||||
_current_op: Optional[TargetType]
|
||||
# For backwards of buffer_grads with scatters we allow mutations
|
||||
allowed_mutations: Optional[OrderedSet[OpOverload]]
|
||||
additional_lowerings: Optional[LoweringDict]
|
||||
buffers: List[ir.Buffer]
|
||||
buffers: list[ir.Buffer]
|
||||
mutated_buffers: OrderedSet[str]
|
||||
|
||||
def __init__(
|
||||
@ -102,7 +103,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
self,
|
||||
target: TargetType,
|
||||
args: Any,
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> Any:
|
||||
from .lowering import lowerings
|
||||
|
||||
@ -123,7 +124,7 @@ class PointwiseSubgraphLowering(torch.fx.Interpreter):
|
||||
|
||||
return lowerings[target](*args, **kwargs)
|
||||
|
||||
def output(self, target: str, args: tuple[Any], kwargs: Dict[str, Any]) -> None: # type: ignore[override]
|
||||
def output(self, target: str, args: tuple[Any], kwargs: dict[str, Any]) -> None: # type: ignore[override]
|
||||
assert len(args) == 1
|
||||
self.graph_outputs = args[0]
|
||||
|
||||
@ -155,7 +156,7 @@ class TracingOpsHandler(WrapperHandler[T]):
|
||||
|
||||
|
||||
def lower_pointwise_subgraph(
|
||||
subgraph: ir.Subgraph, inputs: List[InputDescriptor]
|
||||
subgraph: ir.Subgraph, inputs: list[InputDescriptor]
|
||||
) -> Callable[_P, Any]:
|
||||
# Lower subgraph to ir.Pointwise nodes
|
||||
def fake_inner_fn(
|
||||
|
@ -3,7 +3,7 @@ import logging
|
||||
import os
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
||||
from torch._utils_internal import justknobs_check
|
||||
@ -48,7 +48,7 @@ class TritonKernelArtifacts:
|
||||
|
||||
kernel_hash: str
|
||||
device: int
|
||||
artifacts: List[TritonKernelArtifact]
|
||||
artifacts: list[TritonKernelArtifact]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -57,7 +57,7 @@ class TritonBundlerMetadata:
|
||||
Metadata used for instrumentation
|
||||
"""
|
||||
|
||||
cached_kernel_names: List[str]
|
||||
cached_kernel_names: list[str]
|
||||
|
||||
|
||||
class TritonBundler:
|
||||
@ -76,7 +76,7 @@ class TritonBundler:
|
||||
- TritonBundler.read_and_emit is called when a cache entry is read
|
||||
"""
|
||||
|
||||
_entries: Optional[List[TritonBundleEntry]] = None
|
||||
_entries: Optional[list[TritonBundleEntry]] = None
|
||||
|
||||
# __grp__kernel_name.json contains metadata with source code paths
|
||||
# we use this as sentinal value for search and replace
|
||||
@ -134,7 +134,7 @@ class TritonBundler:
|
||||
@classmethod
|
||||
def collect(
|
||||
cls,
|
||||
) -> tuple[List[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
|
||||
) -> tuple[list[TritonKernelArtifacts], Optional[TritonBundlerMetadata]]:
|
||||
"""
|
||||
This is the main function called when a cache write happens. This function
|
||||
converts all the previously remembered kernels into bundled format so that
|
||||
@ -150,10 +150,10 @@ class TritonBundler:
|
||||
with dynamo_timed(key="TritonBundler.collect", log_pt2_compile_event=True):
|
||||
entries = cls._entries
|
||||
if entries is not None:
|
||||
result: List[TritonKernelArtifacts] = []
|
||||
kernel_names: List[str] = []
|
||||
result: list[TritonKernelArtifacts] = []
|
||||
kernel_names: list[str] = []
|
||||
for entry in entries:
|
||||
artifacts: List[TritonKernelArtifact] = []
|
||||
artifacts: list[TritonKernelArtifact] = []
|
||||
path = os.path.join(entry.directory, entry.kernel_hash)
|
||||
if not os.path.exists(path):
|
||||
continue
|
||||
@ -203,7 +203,7 @@ class TritonBundler:
|
||||
|
||||
@staticmethod
|
||||
def read_and_emit(
|
||||
bundle: List[TritonKernelArtifacts],
|
||||
bundle: list[TritonKernelArtifacts],
|
||||
) -> Optional[TritonBundlerMetadata]:
|
||||
"""
|
||||
This is the main function called when a cache read happens. This function
|
||||
@ -223,7 +223,7 @@ class TritonBundler:
|
||||
with dynamo_timed(
|
||||
key="TritonBundler.read_and_emit", log_pt2_compile_event=True
|
||||
):
|
||||
kernel_names: List[str] = []
|
||||
kernel_names: list[str] = []
|
||||
|
||||
for artifacts in bundle:
|
||||
basedir = triton_cache_dir(artifacts.device)
|
||||
|
@ -26,18 +26,13 @@ from io import StringIO
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
Iterable,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
ValuesView,
|
||||
)
|
||||
from typing_extensions import Concatenate, dataclass_transform, ParamSpec, TypeGuard
|
||||
from unittest import mock
|
||||
@ -49,6 +44,8 @@ from torch._inductor.runtime.hints import DeviceProperties
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence, ValuesView
|
||||
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from .codegen.common import WorkspaceArg
|
||||
|
||||
@ -94,7 +91,7 @@ _IS_WINDOWS = sys.platform == "win32"
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_T = TypeVar("_T")
|
||||
VarRanges = Dict[sympy.Expr, sympy.Expr]
|
||||
VarRanges = dict[sympy.Expr, sympy.Expr]
|
||||
InputType = Optional[Union[torch.Tensor, int, torch.SymInt]]
|
||||
|
||||
GPU_KERNEL_BIN_EXTS = {"cuda": ".cubin", "xpu": ".spv"}
|
||||
@ -308,7 +305,7 @@ def _type_of(key):
|
||||
|
||||
def convert_shape_to_inductor(
|
||||
lst: Iterable[Union[int, torch.SymInt]]
|
||||
) -> List[sympy.Expr]:
|
||||
) -> list[sympy.Expr]:
|
||||
"""
|
||||
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
|
||||
trivial. But for symbolic tensors, we need to map from SymIntNode into
|
||||
@ -319,7 +316,7 @@ def convert_shape_to_inductor(
|
||||
|
||||
def convert_shape_to_symint(
|
||||
lst: Iterable[Union[int, sympy.Expr]]
|
||||
) -> List[Union[int, torch.SymInt]]:
|
||||
) -> list[Union[int, torch.SymInt]]:
|
||||
"""
|
||||
Takes a list of shapes from Inductor and converts them into symints (or just
|
||||
ints if all shapes are static).
|
||||
@ -433,7 +430,7 @@ def precompute_method(obj: Any, method: str):
|
||||
setattr(obj, method, lambda: result)
|
||||
|
||||
|
||||
def precompute_methods(obj: Any, methods: List[str]):
|
||||
def precompute_methods(obj: Any, methods: list[str]):
|
||||
"""Replace methods with new methods that returns a precomputed constants."""
|
||||
for method in methods:
|
||||
precompute_method(obj, method)
|
||||
@ -451,7 +448,7 @@ def pad_listlike(x, size):
|
||||
|
||||
|
||||
# Used to ensure that iterating over a set is deterministic
|
||||
def tuple_sorted(x: tuple[_T, ...]) -> List[_T]:
|
||||
def tuple_sorted(x: tuple[_T, ...]) -> list[_T]:
|
||||
if len(x) == 0:
|
||||
return []
|
||||
|
||||
@ -716,7 +713,7 @@ def sympy_index_symbol(name: str) -> sympy.Symbol:
|
||||
return sympy.Symbol(name, integer=True, nonnegative=True)
|
||||
|
||||
|
||||
def sympy_subs(expr: sympy.Expr, replacements: Dict[sympy.Expr, Any]) -> sympy.Expr:
|
||||
def sympy_subs(expr: sympy.Expr, replacements: dict[sympy.Expr, Any]) -> sympy.Expr:
|
||||
"""
|
||||
When the passed replacement symbol v is a string, it is converted to a symbol with name v that
|
||||
have the same replaced expression integer and nonnegative properties.
|
||||
@ -804,7 +801,7 @@ def output_node(gm: torch.fx.GraphModule):
|
||||
return last_node
|
||||
|
||||
|
||||
_registered_caches: List[Any] = []
|
||||
_registered_caches: list[Any] = []
|
||||
|
||||
|
||||
def clear_on_fresh_inductor_cache(obj: Any):
|
||||
@ -871,7 +868,7 @@ def fresh_inductor_cache(cache_entries=None, dir=None, delete=True):
|
||||
clear_inductor_caches()
|
||||
|
||||
|
||||
def argsort(seq) -> List[int]:
|
||||
def argsort(seq) -> list[int]:
|
||||
# preserve original order for equal strides
|
||||
getter = seq.__getitem__
|
||||
a_r = range(len(seq))
|
||||
@ -880,7 +877,7 @@ def argsort(seq) -> List[int]:
|
||||
|
||||
def argsort_sym(
|
||||
shape_env, seq: Sequence[Union[int, torch.SymInt, sympy.Expr]]
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
def cmp(a, b):
|
||||
a_idx, a_val = a
|
||||
b_idx, b_val = b
|
||||
@ -1180,7 +1177,7 @@ def use_max_autotune() -> bool:
|
||||
)
|
||||
|
||||
|
||||
def _use_template_for_gpu(layout, allowed_layout_dtypes: List[torch.dtype]) -> bool:
|
||||
def _use_template_for_gpu(layout, allowed_layout_dtypes: list[torch.dtype]) -> bool:
|
||||
return (
|
||||
is_gpu(layout.device.type)
|
||||
and layout.dtype in allowed_layout_dtypes
|
||||
@ -1462,10 +1459,10 @@ class DebugDirManager:
|
||||
torch._dynamo.config.debug_dir_root = self.prev_debug_name
|
||||
|
||||
|
||||
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, List[str]]:
|
||||
def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, list[str]]:
|
||||
from .graph import GraphLowering
|
||||
|
||||
source_codes: List[str] = []
|
||||
source_codes: list[str] = []
|
||||
|
||||
def save_output_code(code: str):
|
||||
source_codes.append(code)
|
||||
@ -1476,7 +1473,7 @@ def run_and_get_code(fn, *args, **kwargs) -> tuple[Any, List[str]]:
|
||||
return result, source_codes
|
||||
|
||||
|
||||
def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, List[str]]:
|
||||
def run_and_get_kernels(fn, *args, **kwargs) -> tuple[Any, list[str]]:
|
||||
result, source_codes = run_and_get_code(fn, *args, **kwargs)
|
||||
kernels = []
|
||||
for code in source_codes:
|
||||
@ -1497,7 +1494,7 @@ def get_code(fn, *args, **kwargs):
|
||||
"""Get the inductor-generated code, but skip any actual compilation or running."""
|
||||
from .graph import GraphLowering
|
||||
|
||||
source_codes: List[str] = []
|
||||
source_codes: list[str] = []
|
||||
|
||||
def save_output_code(code: str):
|
||||
source_codes.append(code)
|
||||
@ -2217,13 +2214,13 @@ def shape_env_from_inputs(inputs: Sequence[InputType]):
|
||||
|
||||
|
||||
def align_inputs_from_check_idxs(
|
||||
model: Callable[[List[InputType]], Any],
|
||||
model: Callable[[list[InputType]], Any],
|
||||
inputs_to_check: Sequence[int],
|
||||
) -> Callable[[List[InputType]], Any]:
|
||||
) -> Callable[[list[InputType]], Any]:
|
||||
if len(inputs_to_check) == 0:
|
||||
return model
|
||||
|
||||
def run(new_inputs: List[InputType]):
|
||||
def run(new_inputs: list[InputType]):
|
||||
copy_misaligned_inputs(new_inputs, inputs_to_check)
|
||||
return model(new_inputs)
|
||||
|
||||
@ -2243,7 +2240,7 @@ def clone_preserve_strides(x: torch.Tensor):
|
||||
|
||||
|
||||
def copy_misaligned_inputs(
|
||||
new_inputs: List[InputType], check_inputs_idxs: Sequence[int]
|
||||
new_inputs: list[InputType], check_inputs_idxs: Sequence[int]
|
||||
) -> None:
|
||||
for i in check_inputs_idxs:
|
||||
_inp = new_inputs[i]
|
||||
@ -2408,7 +2405,7 @@ class OpDtypeRule:
|
||||
override_return_dtype: Optional[torch.dtype]
|
||||
|
||||
|
||||
op_dtype_propagation_rules: Dict[str, OpDtypeRule] = {}
|
||||
op_dtype_propagation_rules: dict[str, OpDtypeRule] = {}
|
||||
|
||||
|
||||
def register_op_dtype_propagation_rules(
|
||||
@ -2445,7 +2442,7 @@ def ir_dataclass(cls=None, /, *, frozen: bool = True):
|
||||
return wrap(cls)
|
||||
|
||||
|
||||
def get_donated_idxs() -> Optional[List[int]]:
|
||||
def get_donated_idxs() -> Optional[list[int]]:
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
if tracing_context is not None and tracing_context.fw_metadata:
|
||||
return tracing_context.fw_metadata.bw_donated_idxs
|
||||
|
@ -59,7 +59,7 @@ from __future__ import annotations
|
||||
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from threading import local
|
||||
from typing import Any, Callable, Generic, List, Type, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
@ -108,7 +108,7 @@ class Virtualized(Generic[T]):
|
||||
store other things, like booleans.
|
||||
"""
|
||||
|
||||
def __init__(self, vname: str, default: Union[Callable[[], T], Type[NullHandler]]):
|
||||
def __init__(self, vname: str, default: Union[Callable[[], T], type[NullHandler]]):
|
||||
self._key: str = f"__torchinductor_{vname}"
|
||||
self._default = default
|
||||
|
||||
@ -156,7 +156,7 @@ class NullKernelHandler(NullHandler):
|
||||
|
||||
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
|
||||
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
||||
_real_inputs: Virtualized[List[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
||||
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
||||
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
||||
_kernel: Virtualized[NullKernelHandler] = Virtualized(
|
||||
"kernel", NullKernelHandler
|
||||
|
Reference in New Issue
Block a user