PEP585 update - torch/_inductor/[_-i]* (#145137)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-18 09:28:55 -08:00
committed by PyTorch MergeBot
parent cede43e06b
commit 893ca1dfe1
36 changed files with 727 additions and 765 deletions

View File

@ -28,8 +28,8 @@ log = logging.getLogger(__name__)
def compile(
gm: torch.fx.GraphModule,
example_inputs: List[InputType],
options: Optional[Dict[str, Any]] = None,
example_inputs: list[InputType],
options: Optional[dict[str, Any]] = None,
):
"""
Compile a given FX graph with TorchInductor. This allows compiling
@ -54,7 +54,7 @@ def aoti_compile_and_package(
_deprecated_unused_kwargs=None,
*,
package_path: Optional[Union[str, io.BytesIO]] = None,
inductor_configs: Optional[Dict[str, Any]] = None,
inductor_configs: Optional[dict[str, Any]] = None,
) -> str:
"""
Compiles the exported program with AOTInductor, and packages it into a .pt2
@ -131,11 +131,11 @@ def _aoti_compile_and_package_inner(
gm: torch.nn.Module,
# flat_example_inputs: List[Any],
args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
load_and_run: bool = False,
package_path: Optional[Union[str, io.BytesIO]] = None,
inductor_configs: Optional[Dict[str, Any]] = None,
inductor_configs: Optional[dict[str, Any]] = None,
):
"""
See docstring for aoti_compile_and_package.
@ -199,10 +199,10 @@ def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type
def aot_compile(
gm: torch.fx.GraphModule,
args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None,
kwargs: Optional[dict[str, Any]] = None,
*,
options: Optional[Dict[str, Any]] = None,
) -> Union[str, List[str]]:
options: Optional[dict[str, Any]] = None,
) -> Union[str, list[str]]:
"""
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
@ -232,7 +232,7 @@ def aot_compile(
def list_mode_options(
mode: Optional[str] = None, dynamic: Optional[bool] = None
) -> Dict[str, Any]:
) -> dict[str, Any]:
r"""Returns a dictionary describing the optimizations that each of the available
modes passed to `torch.compile()` performs.
@ -245,7 +245,7 @@ def list_mode_options(
>>> torch._inductor.list_mode_options()
"""
mode_options: Dict[str, Dict[str, bool]] = {
mode_options: dict[str, dict[str, bool]] = {
"default": {},
# enable cudagraphs
"reduce-overhead": {
@ -267,7 +267,7 @@ def list_mode_options(
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
def list_options() -> List[str]:
def list_options() -> list[str]:
r"""Returns a dictionary describing the optimizations and debug configurations
that are available to `torch.compile()`.
@ -280,7 +280,7 @@ def list_options() -> List[str]:
from torch._inductor import config
current_config: Dict[str, Any] = config.get_config_copy()
current_config: dict[str, Any] = config.get_config_copy()
return list(current_config.keys())

View File

@ -2,7 +2,7 @@ import json
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
from unittest import mock
import torch
@ -31,7 +31,7 @@ def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
def load_aoti_eager_cache(
ns: str, op_func_name_with_overload: str, device_type: str
) -> List[Optional[Dict[str, Any]]]:
) -> list[Optional[dict[str, Any]]]:
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
if not op_conf.exists():
@ -81,7 +81,7 @@ def load_aoti_eager_cache(
return []
def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
def supported_builtin_dtype_torch_dtype() -> dict[type, torch.dtype]:
return {int: torch.int32, float: torch.float, bool: torch.bool}
@ -90,8 +90,8 @@ def supported_scalar_types() -> tuple[type, ...]:
return tuple(type_to_torch_dtype.keys())
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
metadata: Dict[str, Any] = {}
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> dict[str, Any]:
metadata: dict[str, Any] = {}
metadata["is_dynamic"] = dynamic
assert isinstance(input, torch.Tensor)
@ -110,21 +110,21 @@ def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any
def extract_tensor_list_metadata(
dynamic: bool,
input: List[torch.Tensor],
) -> Dict[str, Any]:
input: list[torch.Tensor],
) -> dict[str, Any]:
metadata_list = []
for item in input:
assert isinstance(item, torch.Tensor)
metadata_list.append(extract_tensor_metadata(dynamic, item))
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata["tensor_list"] = metadata_list
return metadata
def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
def extract_scalar_metadata(device_type: str, input: Any) -> dict[str, Any]:
assert isinstance(input, supported_scalar_types())
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata["is_dynamic"] = False
# Scalar tensor
metadata["device_type"] = device_type
@ -135,31 +135,31 @@ def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
return metadata
def extract_string_metadata(input: str) -> Dict[str, Any]:
def extract_string_metadata(input: str) -> dict[str, Any]:
assert isinstance(input, str)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata["string_value"] = input
return metadata
def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]:
def extract_dtype_metadata(input: torch.dtype) -> dict[str, Any]:
assert isinstance(input, torch.dtype)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata["dtype_value"] = f"{input}"
return metadata
def extract_device_metadata(input: torch.device) -> Dict[str, Any]:
def extract_device_metadata(input: torch.device) -> dict[str, Any]:
assert isinstance(input, torch.device)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata["device_type_value"] = f"{input.type}"
metadata["device_index_value"] = input.index
return metadata
def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]:
def extract_layout_metadata(input: torch.layout) -> dict[str, Any]:
assert isinstance(input, torch.layout)
metadata: Dict[str, Any] = {}
metadata: dict[str, Any] = {}
metadata["layout_value"] = f"{input}"
return metadata
@ -171,10 +171,10 @@ def aoti_compile_with_persistent_cache(
dynamic: bool,
f: Callable[..., Any],
args: tuple[Any],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
*,
dynamic_shapes: Optional[Dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None,
dynamic_shapes: Optional[dict[str, Any]] = None,
options: Optional[dict[str, Any]] = None,
remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False,
) -> str:
@ -261,7 +261,7 @@ def aoti_compile_with_persistent_cache(
metadata["arg_order"] = idx
kernel_metadata_items.append(metadata)
kernel_meta_info: Dict[str, Any] = {}
kernel_meta_info: dict[str, Any] = {}
kernel_meta_info["meta_info"] = kernel_metadata_items
kernel_meta_info["kernel_path"] = (
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()

View File

@ -11,7 +11,7 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from functools import partial
from time import time
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from typing import Any, Callable, Optional, TYPE_CHECKING
import torch
from torch._dynamo.device_interface import get_registered_device_interfaces
@ -250,7 +250,7 @@ class AsyncCompile:
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
return LambdaFuture(lambda: get_result().kernel)
def cpp_pybinding(self, argtypes: List[str], source_code: str):
def cpp_pybinding(self, argtypes: list[str], source_code: str):
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
if get_compile_threads() <= 1:
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
@ -299,7 +299,7 @@ class AsyncCompile:
)
return LambdaFuture(get_result)
def wait(self, scope: Dict[str, Any]) -> None:
def wait(self, scope: dict[str, Any]) -> None:
with dynamo_timed(
"async_compile.wait",
log_pt2_compile_event=True,

View File

@ -1,7 +1,7 @@
import json
import os
from functools import partial
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
from torch._inductor.autoheuristic.autoheuristic_utils import (
@ -50,16 +50,16 @@ class AutoHeuristic:
a heuristic (see torchgen/autoheuristic/).
"""
collected_feedback: Dict[Choice, Feedback]
collected_feedback: dict[Choice, Feedback]
def __init__(
self,
fallback: Callable[[], Choice],
choices: List[Choice],
choices: list[Choice],
feedback: Optional[LocalFeedback],
context: AHContext,
name: str,
augment_context: Optional[List[AHOperation]] = None,
augment_context: Optional[list[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None:
"""
@ -135,8 +135,8 @@ class AutoHeuristic:
return self.fallback()
def get_top_k_choices(
self, top_k: int, always_included: Optional[List[str]] = None
) -> Optional[List[Choice]]:
self, top_k: int, always_included: Optional[list[str]] = None
) -> Optional[list[Choice]]:
if not self.satisfies_precondition():
return None
if torch._inductor.config.use_autoheuristic(self.name):
@ -223,11 +223,11 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
def __init__(
self,
fallback: Callable[[], Optional[ChoiceCaller]],
choices: List[ChoiceCaller],
input_nodes: List[Any],
choices: list[ChoiceCaller],
input_nodes: list[Any],
context: AHContext,
name: str,
augment_context: Optional[List[AHOperation]] = None,
augment_context: Optional[list[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None:
"""
@ -237,7 +237,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
have to be used here.
"""
self.input_nodes = input_nodes
self.choicestr2choice: Dict[str, ChoiceCaller] = {}
self.choicestr2choice: dict[str, ChoiceCaller] = {}
for choice in choices:
self.choicestr2choice[choice.autoheuristic_id()] = choice
choices_str = list(self.choicestr2choice.keys())
@ -266,7 +266,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
self.register_global_feedback(input_nodes, choices)
def register_global_feedback(
self, input_nodes: List[Any], choices: List[ChoiceCaller]
self, input_nodes: list[Any], choices: list[ChoiceCaller]
) -> None:
"""
Registers a callback in select_algorithm, which is called with the timing of each choice.
@ -281,10 +281,10 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
def store_global_feedback(
ah_inputs_key: str,
ah_precompile_key: str,
timings: Dict[ChoiceCaller, float],
timings: dict[ChoiceCaller, float],
name: str,
input_nodes: List[Any],
choices: List[ChoiceCaller],
input_nodes: list[Any],
choices: list[ChoiceCaller],
) -> None:
current_inputs_key = create_inputs_key(input_nodes)
if current_inputs_key != ah_inputs_key:
@ -307,8 +307,8 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
return self.choicestr2choice.get(choice, None)
def get_top_k_choices_caller(
self, top_k: int, always_included: Optional[List[str]] = None
) -> Optional[List[ChoiceCaller]]:
self, top_k: int, always_included: Optional[list[str]] = None
) -> Optional[list[ChoiceCaller]]:
choices = self.get_top_k_choices(top_k, always_included)
if choices is None:
return None

View File

@ -1,5 +1,5 @@
import functools
from typing import Any, Callable, Dict, List
from typing import Any, Callable
import torch
@ -51,8 +51,8 @@ class AHContext:
information that will help to learn a heuristic.
"""
features: List[AHFeature]
context_dict: Dict[str, Value]
features: list[AHFeature]
context_dict: dict[str, Value]
def __init__(self) -> None:
self.features = []
@ -64,7 +64,7 @@ class AHContext:
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
self.context_dict[name] = value
def get_numerical_and_categorical_features(self) -> tuple[List[str], List[str]]:
def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
numerical_features = []
categorical_features = []
for feature in self.features:
@ -84,7 +84,7 @@ class AHContext:
def get_value(self, name: str) -> Value:
return self.context_dict[name]
def apply_operations(self, operations: List[AHOperation]) -> None:
def apply_operations(self, operations: list[AHOperation]) -> None:
for op in operations:
op.apply_operation(self.context_dict)
@ -94,7 +94,7 @@ class AHMetadata:
self,
shared_memory: Any,
device_capa: tuple[int, int],
choices: List[Choice],
choices: list[Choice],
name: str,
) -> None:
# use amount of shared_memory and device_capability to identify GPU
@ -104,7 +104,7 @@ class AHMetadata:
self.choices = choices
self.name = name
def to_dict(self) -> Dict[str, Value]:
def to_dict(self) -> dict[str, Value]:
return {
"shared_memory": self.shared_memory,
"device_capa": self.device_capa,
@ -147,7 +147,7 @@ def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
return mat1_iscontig and not mat2_iscontig
def get_mult_dims_ops() -> List[AHOperation]:
def get_mult_dims_ops() -> list[AHOperation]:
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
@ -163,7 +163,7 @@ def get_arith_intensity(data: Any) -> float:
return m * k * n / (m * k + k * n + m * n)
def pad_mm_operations() -> List[AHOperation]:
def pad_mm_operations() -> list[AHOperation]:
mult_dims_ops = get_mult_dims_ops()
k_div_m_times_n_op = AHOperation(
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
@ -200,7 +200,7 @@ def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
return data[dim] >= lower and data[dim] <= upper
def between_ops() -> List[AHOperation]:
def between_ops() -> list[AHOperation]:
dims = ["m", "k", "n"]
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
ah_operations = []
@ -221,13 +221,13 @@ def pow2_op(data: Any, dim: str, exponent: int) -> bool:
return data[dim] == 2**exponent
def mm_operations() -> List[AHOperation]:
def mm_operations() -> list[AHOperation]:
mult_dims_ops = get_mult_dims_ops()
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
return mult_dims_ops + [arith_intensity_op]
def mixed_mm_operations() -> List[AHOperation]:
def mixed_mm_operations() -> list[AHOperation]:
return mm_operations() + between_ops()
@ -235,7 +235,7 @@ def is_multiple(data: Any, dim: str, mult: int) -> bool:
return data[dim] % mult == 0
def get_dims_multiple_ops() -> List[AHOperation]:
def get_dims_multiple_ops() -> list[AHOperation]:
multiples = [2, 4, 8, 16, 32]
dims = ["m", "k", "n"]
dims_multiple_ops = []
@ -249,7 +249,7 @@ def get_dims_multiple_ops() -> List[AHOperation]:
return dims_multiple_ops
def get_dims_need_padding_ops() -> List[AHOperation]:
def get_dims_need_padding_ops() -> list[AHOperation]:
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
mat1_stride_0 = data["mat1_stride_0"]
mat1_stride_1 = data["mat1_stride_1"]
@ -303,7 +303,7 @@ def get_dims_need_padding_ops() -> List[AHOperation]:
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
def get_is_contig_ops() -> List[AHOperation]:
def get_is_contig_ops() -> list[AHOperation]:
def mat1_is_contig_fn(data: Any) -> bool:
stride_0 = data["mat1_stride_0"]
stride_1 = data["mat1_stride_1"]

View File

@ -2,7 +2,7 @@ import importlib
import inspect
import pkgutil
from collections import defaultdict
from typing import Any, Dict, List, Optional
from typing import Any, Optional
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
@ -14,7 +14,7 @@ from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeur
def find_and_instantiate_subclasses(
package_name: str, base_class: Any
) -> List[LearnedHeuristic]:
) -> list[LearnedHeuristic]:
instances = []
package = importlib.import_module(package_name)
@ -49,7 +49,7 @@ class LearnedHeuristicController:
a way to get the decision of a learned heuristic.
"""
existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list)
existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
"""
A dictionary that stores all the learned heuristics for each optimization.
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
@ -69,7 +69,7 @@ class LearnedHeuristicController:
self.metadata = metadata
self.context = context
def get_heuristics(self, name: str) -> List[LearnedHeuristic]:
def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
"""
Returns a list of learned heuristics for the given optimization name.
"""
@ -105,7 +105,7 @@ class LearnedHeuristicController:
return heuristic.get_decision(self.context, self.metadata.choices)
return None
def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
heuristics = self.get_heuristics(self.metadata.name)
for heuristic in heuristics:
if heuristic.check_precondition(self.metadata, self.context):

View File

@ -1,4 +1,4 @@
from typing import List, Optional
from typing import Optional
from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext,
@ -23,7 +23,7 @@ class LearnedHeuristic:
return True
def get_decision(
self, context: AHContext, choices: List[Choice]
self, context: AHContext, choices: list[Choice]
) -> Optional[Choice]:
return None
@ -33,7 +33,7 @@ class LearnedHeuristic:
def get_name(self) -> str:
return ""
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
return None
@ -45,7 +45,7 @@ class LearnedHeuristicRegression(LearnedHeuristic):
return 1.0
def get_decision(
self, context: AHContext, choices: List[Choice]
self, context: AHContext, choices: list[Choice]
) -> Optional[Choice]:
choice2feedback = {}
for choice in choices:
@ -68,7 +68,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
return None
def get_decision(
self, context: AHContext, choices: List[Choice]
self, context: AHContext, choices: list[Choice]
) -> Optional[Choice]:
best_choices = self.get_best_choices(context)
if not best_choices:
@ -78,7 +78,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
return None
return self.get_choice(best_choice_idx)
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
feedback_idx_list = self.get_best_choices(context)
if feedback_idx_list is None:
return None
@ -88,5 +88,5 @@ class LearnedHeuristicDecision(LearnedHeuristic):
choices = [choice for choice in choices if choice is not None]
return choices
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
return []

View File

@ -10,19 +10,10 @@ import os
import queue
import time
import warnings
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p, CDLL
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
@ -396,8 +387,8 @@ class TuningProcessPool:
def benchmark(
self,
choices: List[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]:
choices: list[TritonTemplateCaller],
) -> dict[TritonTemplateCaller, float]:
"""
Benchmark each choice in a separate process.
"""
@ -432,9 +423,9 @@ class TensorMeta:
@classmethod
def from_irnodes(
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
) -> Union[TensorMeta, List[TensorMeta]]:
) -> Union[TensorMeta, list[TensorMeta]]:
if isinstance(irnodes, Sequence):
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
result: list[Any] = [cls.from_irnodes(x) for x in irnodes]
assert all(isinstance(x, TensorMeta) for x in result)
return result
@ -488,8 +479,8 @@ class BenchmarkRequest:
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any],
) -> None:
# the kernel name defined in the module
@ -640,12 +631,12 @@ class TritonBenchmarkRequest(BenchmarkRequest):
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any],
module_path: str, # the path of the module defining the triton kernel
module_cache_key: str,
grid: List[int],
grid: list[int],
num_stages: int,
num_warps: int,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
@ -770,8 +761,8 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
) -> None:
@ -889,8 +880,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
def __init__(
self,
kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any],
source_code: str,
) -> None:
@ -946,8 +937,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
def benchmark_in_sub_process(
choices: List[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]:
choices: list[TritonTemplateCaller],
) -> dict[TritonTemplateCaller, float]:
"""
Do benchmarking in a subprocess and return the perf number (latency).
"""

View File

@ -1,7 +1,7 @@
import logging
import operator
from functools import partial
from typing import Any, Callable, Dict, Union
from typing import Any, Callable, Union
from sympy import Expr
@ -43,7 +43,7 @@ class BoundVars:
or "masked_subblock" in node.target
)
# To access this variable call `get_bounds()`
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
def __repr__(self) -> str:
return (
@ -55,7 +55,7 @@ class BoundVars:
)
@cache_on_self
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
submodules = self.swap_submodules(self.loop_body.submodules)
# Initialize the environment with the unbounded variables
@ -74,9 +74,9 @@ class BoundVars:
return self._bounds
def swap_submodules(
self, submodules: Dict[str, Callable[..., Any]]
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
self, submodules: dict[str, Callable[..., Any]]
) -> dict[str, Callable[..., ValueRanges[Expr]]]:
result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
for key in submodules.keys():
if key == "get_index":
result[key] = self.get_index
@ -111,10 +111,10 @@ class BoundVars:
def masked_subblock(
self,
subblock: LoopBodyBlock,
env: Dict[torch.fx.Node, ValueRanges[Expr]],
env: dict[torch.fx.Node, ValueRanges[Expr]],
mask: Any,
value: Any,
submodules: Dict[str, Callable[..., Any]],
submodules: dict[str, Callable[..., Any]],
) -> ValueRanges[Expr]:
interp = InterpreterShim(subblock.graph, submodules)
interp.run(V.get_ops_handler(), initial_env=env)

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import typing
from typing import Any, Dict, List, Type, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import sympy
@ -42,11 +42,11 @@ class InductorChoices:
def triton_kernel_kwargs(
self,
kernel_cls: Type[TritonKernel],
kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures,
groups: List[sympy.Expr],
kernel_kwargs: Dict[str, Any],
) -> Dict[str, Any]:
groups: list[sympy.Expr],
kernel_kwargs: dict[str, Any],
) -> dict[str, Any]:
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
return kernel_kwargs

View File

@ -36,13 +36,8 @@ from typing import (
Any,
Callable,
cast,
Dict,
Generator,
List,
NoReturn,
Optional,
Sequence,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
@ -127,7 +122,7 @@ else:
if TYPE_CHECKING:
from collections.abc import KeysView
from collections.abc import Generator, KeysView, Sequence
from concurrent.futures import Future
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
@ -168,7 +163,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
class CacheBase:
@staticmethod
@functools.lru_cache(None)
def get_system() -> Dict[str, Any]:
def get_system() -> dict[str, Any]:
try:
from triton.compiler.compiler import triton_key
@ -179,7 +174,7 @@ class CacheBase:
triton_version = None
try:
system: Dict[str, Any] = {
system: dict[str, Any] = {
"device": {"name": None},
"version": {
"triton": triton_version,
@ -217,7 +212,7 @@ class CacheBase:
def __init__(self) -> None:
self.system = CacheBase.get_system()
def get_local_cache(self) -> Dict[str, Any]:
def get_local_cache(self) -> dict[str, Any]:
local_cache_path = self.get_local_cache_path()
if not local_cache_path.is_file():
return {}
@ -225,7 +220,7 @@ class CacheBase:
local_cache = json.load(local_cache_fp)
return local_cache["cache"]
def update_local_cache(self, local_cache: Dict[str, Any]) -> None:
def update_local_cache(self, local_cache: dict[str, Any]) -> None:
local_cache_path = self.get_local_cache_path()
write_atomic(
str(local_cache_path),
@ -235,7 +230,7 @@ class CacheBase:
class LocalCache(CacheBase):
def lookup(self, *keys: str) -> Optional[Dict[str, Any]]:
def lookup(self, *keys: str) -> Optional[dict[str, Any]]:
cache = self.get_local_cache()
sub_cache = cache
@ -261,7 +256,7 @@ class LocalCache(CacheBase):
class PersistentCache(CacheBase):
@functools.lru_cache(None) # noqa: B019
def get_global_cache(self) -> Dict[str, Any]:
def get_global_cache(self) -> dict[str, Any]:
global_cache_path = self.get_global_cache_path()
if global_cache_path is None or not global_cache_path.is_file():
return {}
@ -271,11 +266,11 @@ class PersistentCache(CacheBase):
def lookup(
self,
choices: List[ChoiceCaller],
choices: list[ChoiceCaller],
op: str,
inputs: str,
benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]],
) -> Dict[ChoiceCaller, float]:
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
) -> dict[ChoiceCaller, float]:
"""
Check to see if we have benchmarked the given choice callers. For each
choice caller:
@ -296,7 +291,7 @@ class PersistentCache(CacheBase):
)
timings = {}
def check_cache(cache: Dict[str, Any], callback: Any = None) -> bool:
def check_cache(cache: dict[str, Any], callback: Any = None) -> bool:
"""Check if `cache` contains data for all the choices"""
hit = True
for choice in choices:
@ -456,7 +451,7 @@ class TensorMetadataAndValues:
"""
tensor_metadata: TensorMetadata
values: List[Any]
values: list[Any]
def _ident(x: T) -> T:
@ -584,7 +579,7 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_graph_module(
self, gm: torch.fx.GraphModule
) -> tuple[Any, tuple[Dict[str, Any], str]]:
) -> tuple[Any, tuple[dict[str, Any], str]]:
"""
Custom reducer for graph module to handle irrelevant data for user
defined triton kernels
@ -624,7 +619,7 @@ class FxGraphCachePickler(pickle.Pickler):
serialized_data = self.dumps(obj)
return sha256_hash(serialized_data)
def debug_lines(self, inp: FxGraphHashDetails) -> List[str]:
def debug_lines(self, inp: FxGraphHashDetails) -> list[str]:
"""
Get a printable string describing in more detail all the attributes
comprising an object. Useful for debugging when one graph hashes
@ -659,7 +654,7 @@ class FxGraphCachePickler(pickle.Pickler):
def build_code_hash(
roots: List[str] | None, prefix: str, hasher: hashlib._Hash
roots: list[str] | None, prefix: str, hasher: hashlib._Hash
) -> None:
for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
spec = lib.module_finder.find_spec(lib.name, None)
@ -721,7 +716,7 @@ class OrderedSetHolder:
of set kwargs.
"""
items: List[Any]
items: list[Any]
class BypassFxGraphCache(Exception):
@ -753,7 +748,7 @@ class FxGraphHashDetails:
# Order kwargs so hashing is stable to changes in kwarg order. Although
# it's technically a _CompileFxKwargs we don't actually need it typed as
# such since we're just using it to generate a hash.
self.fx_kwargs: Dict[str, object] = {}
self.fx_kwargs: dict[str, object] = {}
for k, v in sorted(fx_kwargs.items()):
if k not in self.EXCLUDED_KWARGS:
if type(v) in (set, OrderedSet): # noqa: set_linter
@ -774,7 +769,7 @@ class FxGraphHashDetails:
# Node meta will not be part of gm's reduce function, so lets remember
# the kernel source code separately
self.user_defined_triton_source: List[Any] = []
self.user_defined_triton_source: list[Any] = []
if gm is not None:
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
@ -856,7 +851,7 @@ def compiled_fx_graph_hash(
example_inputs: Sequence[InputType],
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
) -> tuple[str, List[str]]:
) -> tuple[str, list[str]]:
"""
Generate a unique hash of the FX graph for caching.
"""
@ -952,7 +947,7 @@ class FxGraphCache:
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
@staticmethod
def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]:
def _filter_backed_symints(inputs: Sequence[InputType]) -> list[torch.SymInt]:
"""
Get the backed SymInt objects from the input list. Note that we can never
have guards that depend on unbacked symint.
@ -976,7 +971,7 @@ class FxGraphCache:
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
constants: CompiledFxGraphConstants,
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
"""
Lookup a compiled graph in the cache by key. On a hit, return the
deserialized CompiledFxGraph object. On a miss, return None.
@ -988,7 +983,7 @@ class FxGraphCache:
hints = [hint_int(s) for s in symints]
def iterate_over_candidates() -> (
Generator[Tuple[CompiledFxGraph, bytes], None, None]
Generator[tuple[CompiledFxGraph, bytes], None, None]
):
if local:
subdir = FxGraphCache._get_tmp_dir_for_key(key)
@ -1021,7 +1016,7 @@ class FxGraphCache:
# their guards to determine whether there's a hit.
graph = None
pickled_content = None
cache_info: Dict[str, Any] = dict()
cache_info: dict[str, Any] = dict()
for candidate, pickled_content in iterate_over_candidates():
if not candidate.guards_expr:
@ -1234,7 +1229,7 @@ class FxGraphCache:
fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int],
remote: bool,
) -> tuple[Optional[tuple[str, List[str]]], Dict[str, Any]]:
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
"""
Checks that the inductor input is cacheable, then computes
and returns the cache key for the input.
@ -1280,13 +1275,13 @@ class FxGraphCache:
@staticmethod
def load_with_key(
key: str,
debug_lines: List[str],
debug_lines: list[str],
example_inputs: Sequence[InputType],
local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]],
is_backward: bool,
constants: CompiledFxGraphConstants,
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
"""
Lookup the graph with the given key, and return results and metadata.
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@ -1373,11 +1368,11 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
@clear_on_fresh_inductor_cache
class CudaKernelParamCache:
cache: Dict[str, Dict[str, str]] = {}
cache: dict[str, dict[str, str]] = {}
cache_clear = staticmethod(cache.clear)
@classmethod
def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> None:
def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
_, path = write(
cubin,
bin_type,
@ -1391,7 +1386,7 @@ class CudaKernelParamCache:
cls.cache[key] = params
@classmethod
def get(cls, key: str) -> Optional[Dict[str, str]]:
def get(cls, key: str) -> Optional[dict[str, str]]:
return cls.cache.get(key, None)
@classmethod
@ -1407,8 +1402,8 @@ class AotCodeCompiler:
source_code: str,
serialized_extern_kernel_nodes: Optional[str],
device_type: str,
additional_files: List[str],
) -> Union[List[str], str]:
additional_files: list[str],
) -> Union[list[str], str]:
"""
Returns the .so path, or returns a list of files that were generated if
config.aot_inductor.package=True.
@ -1842,14 +1837,14 @@ def cpp_prefix() -> str:
# Given a path to an input cpp file and an output path,
# Attempts to compile the file, storing the output in "output_path"
def compile_file(
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
input_path: Union[str, list[str]], output_path: str, cmd: list[str]
) -> None:
with dynamo_timed("compile_file"):
return _compile_file(input_path, output_path, cmd)
def _compile_file(
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
input_path: Union[str, list[str]], output_path: str, cmd: list[str]
) -> None:
input_paths = [input_path] if isinstance(input_path, str) else input_path
input_files = [
@ -1948,9 +1943,9 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p]:
@clear_on_fresh_inductor_cache
class CppCodeCache:
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags: Dict[str, Any] = {}
cpp_compile_command_flags: dict[str, Any] = {}
@staticmethod
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
@ -2093,7 +2088,7 @@ def _worker_compile_cpp(
# Customized Python binding for cpp kernels
@clear_on_fresh_inductor_cache
class CppPythonBindingsCodeCache(CppCodeCache):
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = {
# kernels have no dependency on libtorch
@ -2212,7 +2207,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
@classmethod
def load_pybinding_async(
cls,
argtypes: List[str],
argtypes: list[str],
source_code: str,
device_type: str = "cpu",
num_outputs: int = -1,
@ -2269,7 +2264,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
@clear_on_fresh_inductor_cache
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = {
"include_pytorch": True,
@ -2335,7 +2330,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
@clear_on_fresh_inductor_cache
class HalideCodeCache(CppPythonBindingsCodeCache):
cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache_clear = staticmethod(cache.clear)
_standalone_runtime_path: Optional[str] = None
prefix = textwrap.dedent(
@ -2412,7 +2407,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
)
@classmethod
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> List[str]:
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]:
assert arg.shape is not None
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
assert arg.offset is not None
@ -2573,7 +2568,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
donefile = str(dirpath / "done")
lockfile = str(dirpath / "lock")
need_compile = not os.path.exists(donefile)
jobs: List[Any] = []
jobs: list[Any] = []
if need_compile:
write_atomic(genfile, source_code)
cmd = [
@ -2685,7 +2680,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
return sofile
def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None:
def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
from torch.utils._filelock import FileLock
try:
@ -2733,8 +2728,8 @@ class PyCodeCache:
# clearing the cache. Note also that we may load the same path more
# than once, but attach different attributes, i.e., due to different
# constant values.
modules: List[ModuleType] = []
linemaps: Dict[str, List[tuple[Any, ...]]] = {}
modules: list[ModuleType] = []
linemaps: dict[str, list[tuple[Any, ...]]] = {}
@classmethod
def write(cls, source_code: str, extra: str = "") -> tuple[str, str]:
@ -2745,8 +2740,8 @@ class PyCodeCache:
cls,
source_code: str,
extra: str = "",
linemap: Optional[List[tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None,
linemap: Optional[list[tuple[int, str]]] = None,
attrs: Optional[dict[str, Any]] = None,
) -> ModuleType:
key, path = write(source_code, "py", extra=extra)
return cls.load_by_key_path(key, path, linemap, attrs)
@ -2756,8 +2751,8 @@ class PyCodeCache:
cls,
key: str,
path: str,
linemap: Optional[List[tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None,
linemap: Optional[list[tuple[int, str]]] = None,
attrs: Optional[dict[str, Any]] = None,
) -> ModuleType:
if linemap is None:
linemap = []
@ -2798,7 +2793,7 @@ class PyCodeCache:
@functools.lru_cache(None)
def stack_frames_for_code(
cls, path: str, lineno: int
) -> Optional[List[Dict[str, Any]]]:
) -> Optional[list[dict[str, Any]]]:
if path not in cls.linemaps:
return None
# [(starting_line, <fx node>), ...]
@ -2810,7 +2805,7 @@ class PyCodeCache:
if not entry:
return None
def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
# ideally fx stores stack traces as data rather than a string
# but this is not along a performance critical path
regex = r'File "(.+)", line (\d+), in (.+)\n'
@ -2841,7 +2836,7 @@ def _cuda_compiler() -> Optional[str]:
return "nvcc"
def _cutlass_include_paths() -> List[str]:
def _cutlass_include_paths() -> list[str]:
if config.is_fbcode():
from libfb.py import parutil
@ -2857,14 +2852,14 @@ def _cutlass_include_paths() -> List[str]:
]
def _cuda_lib_options() -> List[str]:
def _cuda_lib_options() -> list[str]:
_set_gpu_runtime_env() # cpp_extension consults the env
from torch.utils import cpp_extension
lpaths = cpp_extension.library_paths(device_type="cuda") + [
sysconfig.get_config_var("LIBDIR")
]
extra_ldflags: List[str] = []
extra_ldflags: list[str] = []
if is_linux():
_transform_cuda_paths(lpaths)
for path in lpaths:
@ -2880,7 +2875,7 @@ def _cuda_lib_options() -> List[str]:
return extra_ldflags
def _nvcc_host_compiler_options() -> List[str]:
def _nvcc_host_compiler_options() -> list[str]:
return [
"-fPIC",
"-fno-strict-aliasing",
@ -2889,7 +2884,7 @@ def _nvcc_host_compiler_options() -> List[str]:
]
def _nvcc_compiler_options() -> List[str]:
def _nvcc_compiler_options() -> list[str]:
arch = cuda_env.get_cuda_arch()
if arch == "90":
# Required by cutlass compilation.
@ -2934,10 +2929,10 @@ def _nvcc_compiler_options() -> List[str]:
def cuda_compile_command(
src_files: List[str],
src_files: list[str],
dst_file: str,
dst_file_ext: str,
extra_args: Optional[List[str]] = None,
extra_args: Optional[list[str]] = None,
) -> str:
if extra_args is None:
extra_args = []
@ -3052,7 +3047,7 @@ class CUDACodeCache:
input_path: str
output_path: str
cache: Dict[str, CacheEntry] = {}
cache: dict[str, CacheEntry] = {}
cache_clear = staticmethod(cache.clear)
_SOURCE_CODE_SUFFIX = "cu"
@ -3073,7 +3068,7 @@ class CUDACodeCache:
@classmethod
def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
) -> tuple[str, str, str]:
"""
Compiles CUDA source_code into a file with dst_file_ext extension.
@ -3137,7 +3132,7 @@ class ROCmCodeCache:
input_path: str
output_path: str
cache: Dict[str, CacheEntry] = {}
cache: dict[str, CacheEntry] = {}
cache_clear = staticmethod(cache.clear)
_SOURCE_CODE_SUFFIX = "cpp"
_logged_compiler_version = False
@ -3159,7 +3154,7 @@ class ROCmCodeCache:
@classmethod
def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
) -> tuple[str, str, str]:
"""
Compiles source_code into a file with dst_file_ext extension,

View File

@ -7,7 +7,7 @@ import logging
import operator
import sys
from collections import defaultdict
from typing import Any, Dict, List, TYPE_CHECKING
from typing import Any, TYPE_CHECKING
import torch
from torch.multiprocessing.reductions import StorageWeakRef
@ -33,7 +33,7 @@ if TYPE_CHECKING:
from .scheduler import BaseSchedulerNode
def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
Greedily schedules waits as late as possible.
"""
@ -42,7 +42,7 @@ def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
)
def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
"""
Greedily schedules comms as early as possible.
"""
@ -52,8 +52,8 @@ def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def reorder_compute_for_overlap(
snodes: List[BaseSchedulerNode],
) -> List[BaseSchedulerNode]:
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
"""
This achieves the following overall scheduling procedure:
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
@ -71,11 +71,11 @@ def reorder_compute_for_overlap(
def _schedule_for_comm(
snodes: List[BaseSchedulerNode],
snodes: list[BaseSchedulerNode],
raise_comms: bool,
sink_waits: bool,
reorder_for_overlap: bool,
) -> List[BaseSchedulerNode]:
) -> list[BaseSchedulerNode]:
"""
Schedule `snodes` for various comm optimization objectives.
@ -149,13 +149,13 @@ def _schedule_for_comm(
def __lt__(self, other):
return self.score < other.score
unmet_deps: Dict[BaseSchedulerNode, OrderedSet[str]] = {
unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
for snode in snodes
}
ready: List[Runnable] = []
buffer_users: Dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
ready: list[Runnable] = []
buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
for snode, deps in unmet_deps.items():
@ -226,8 +226,8 @@ def _schedule_for_comm(
def decide_global_ordering_of_comms(
nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
) -> List[BaseSchedulerNode]:
nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
) -> list[BaseSchedulerNode]:
"""
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
(might not be the same ordering as the eager mode program).
@ -303,8 +303,8 @@ def visualize_overlap(order):
def reorder_compute_and_comm_for_overlap(
snodes: List[BaseSchedulerNode],
) -> List[BaseSchedulerNode]:
snodes: list[BaseSchedulerNode],
) -> list[BaseSchedulerNode]:
order = snodes
for p in config.reorder_for_compute_comm_overlap_passes:
@ -653,10 +653,10 @@ def get_op_idx(snode):
def enforce_comm_ordering_for_fsdp(
snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
name_to_fused_node: Dict[str, BaseSchedulerNode],
) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
name_to_fused_node: dict[str, BaseSchedulerNode],
) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
from . import scheduler
new_order: list[BaseSchedulerNode] = []

View File

@ -16,11 +16,7 @@ from typing import (
Any,
Callable,
ContextManager,
Dict,
Generator,
List,
Optional,
Sequence,
TYPE_CHECKING,
TypeVar,
Union,
@ -121,6 +117,8 @@ from .virtualized import V
if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload
@ -157,7 +155,7 @@ static_inputs_log = torch._logging.getArtifactLogger(
)
def get_static_input_idxs(num_fixed: int) -> List[int]:
def get_static_input_idxs(num_fixed: int) -> list[int]:
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
# of cudagraphs. Rather than copying these into cudagraph-owned memory
# like we do for normal inputs on each run, we will re-record a cudagraph if these
@ -208,7 +206,7 @@ def _unlift_graph(
) -> GraphModule:
from torch.export.unflatten import _assign_attr, _AttrKind
state_dict: Dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
for name, param in mod.named_parameters(remove_duplicate=False):
state_dict[name] = param
_assign_attr(
@ -227,7 +225,7 @@ def _unlift_graph(
)
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
lifted_inputs: List[Optional[FQN]] = []
lifted_inputs: list[Optional[FQN]] = []
# In AOTI, module parameters and buffers are not lifted as graph inputs.
# As a result, mutation to buffers has side effect which makes their initial
@ -343,9 +341,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
def split_const_gm(
gm: GraphModule,
skip_constructor: bool = True,
lifted_constant_names: Optional[List[str]] = None,
lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> tuple[GraphModule, Dict[str, int]]:
) -> tuple[GraphModule, dict[str, int]]:
"""
This function takes an GraphModule input "gm".
The gm will be split into 2 components,
@ -488,8 +486,8 @@ def fake_tensor_prop(
# pass config dict back to user
def get_patched_config_dict(
config_patches: Optional[Union[str, Dict[str, Any]]] = None
) -> Dict[str, Any]:
config_patches: Optional[Union[str, dict[str, Any]]] = None
) -> dict[str, Any]:
with config.patch(config_patches):
return config.get_config_copy()
@ -515,7 +513,7 @@ class _CompileFxKwargs(TypedDict, total=False):
aot_mode: bool
is_inference: bool
layout_opt: Optional[bool]
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]]
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
boxed_forward_device_index: Optional[BoxedDeviceIndex]
@ -822,7 +820,7 @@ class _InProcessFxCompile(FxCompile):
aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[
Callable[[List[ExternKernelNode]], Any]
Callable[[list[ExternKernelNode]], Any]
] = graph_kwargs.get("extern_node_serializer", None)
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
"boxed_forward_device_index", None
@ -997,7 +995,7 @@ class _InProcessFxCompile(FxCompile):
metrics_helper = metrics.CachedMetricsHelper()
with V.set_graph_handler(graph):
graph.run(*example_inputs)
output_strides: List[Optional[tuple[_StrideExprStr, ...]]] = []
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None:
# We'll put the output strides in the compiled graph so we
# can later return them to the caller via TracingContext
@ -1189,7 +1187,7 @@ def cudagraphify(
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
stack_traces: List[Optional[str]],
stack_traces: list[Optional[str]],
is_backward: bool,
is_inference: bool,
constants: tuple[torch.Tensor, ...] = (),
@ -1240,7 +1238,7 @@ def static_input(x: torch.Tensor) -> torch.Tensor:
def index_expanded_dims_and_copy_(
dst: torch.Tensor,
src: torch.Tensor,
expanded_dims: List[int],
expanded_dims: list[int],
) -> None:
"Index into expanded dimensions of both dst and src then copy_"
dst = index_expanded_dims(dst, expanded_dims)
@ -1250,9 +1248,9 @@ def index_expanded_dims_and_copy_(
def cudagraphify_impl(
model: Callable[..., Any],
inputs: List[torch.Tensor],
inputs: list[torch.Tensor],
static_input_idxs: Sequence[int] = (),
) -> Callable[[List[InputType]], Any]:
) -> Callable[[list[InputType]], Any]:
"""
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
@ -1304,7 +1302,7 @@ def cudagraphify_impl(
if config.size_asserts:
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims)
@ -1328,7 +1326,7 @@ def cudagraphify_impl(
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
]
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
for idx in copy_indices:
expanded_dims = inps_expanded_dims[idx]
src = new_inputs[idx]
@ -1343,16 +1341,16 @@ def cudagraphify_impl(
def compile_fx_aot(
model_: GraphModule,
example_inputs_: List[InputType],
example_inputs_: list[InputType],
inner_compile: _CompileFxCallable = compile_fx_inner,
config_patches: Optional[Dict[str, str]] = None,
) -> Union[List[str], str]:
config_patches: Optional[dict[str, str]] = None,
) -> Union[list[str], str]:
assert isinstance(model_, GraphModule), model_
# [See NOTE] Unwrapping subclasses AOT
unwrap_tensor_subclass_parameters(model_)
config_patches: Dict[str, Any] = (
config_patches: dict[str, Any] = (
{"cpp_wrapper": True}
if config_patches is None
else {**config_patches, "cpp_wrapper": True}
@ -1409,7 +1407,7 @@ def fw_compiler_freezing(
cudagraphs: BoxedBool,
graph_id: int,
forward_device: BoxedDeviceIndex,
) -> Callable[[List[object]], Sequence[torch.Tensor]]:
) -> Callable[[list[object]], Sequence[torch.Tensor]]:
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
# partition_fn won't be called
@ -1492,7 +1490,7 @@ def fw_compiler_freezing(
if V.aot_compilation:
return optimized_function
def wrapper(args: List[object]) -> Sequence[torch.Tensor]:
def wrapper(args: list[object]) -> Sequence[torch.Tensor]:
args_new = [
args[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
for i in preserved_arg_indices
@ -1505,7 +1503,7 @@ def fw_compiler_freezing(
return wrapper
def get_cpp_wrapper_config() -> Dict[str, object]:
def get_cpp_wrapper_config() -> dict[str, object]:
return {
# Set autotune_at_compile_time to True as default if the option is not explicitly set
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
@ -1551,9 +1549,9 @@ def compile_fx(
model_: GraphModule,
example_inputs_: Sequence[InputType],
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]:
config_patches: Optional[dict[str, Any]] = None,
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
"""
Main entry point for compiling given FX graph. Despite the fact that this
lives in :mod:`torch._inductor`, this function is responsible for calling
@ -2017,11 +2015,11 @@ def _check_triton_bf16_support(graph: GraphLowering) -> None:
def _aoti_flatten_inputs(
gm: torch.fx.GraphModule,
args: Union[List[Any], tuple[Any, ...]],
kwargs: Optional[Dict[str, Any]] = None,
args: Union[list[Any], tuple[Any, ...]],
kwargs: Optional[dict[str, Any]] = None,
*,
options: Optional[Dict[str, Any]] = None,
) -> tuple[List[Any], Dict[str, Any]]:
options: Optional[dict[str, Any]] = None,
) -> tuple[list[Any], dict[str, Any]]:
"""
Flatten the inputs to the graph module and return the flat inputs and options.
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.

View File

@ -5,7 +5,7 @@ import importlib
import logging
import os
import sys
from typing import Type, TypeVar
from typing import TypeVar
from torch._inductor.async_compile import pre_fork_setup
from torch._inductor.compile_worker.subproc_pool import (
@ -32,7 +32,7 @@ except ImportError:
pass
def _lookup_and_create_type(base: Type[_T], qname: str) -> _T:
def _lookup_and_create_type(base: type[_T], qname: str) -> _T:
"""
Given a base type and qualified name: import & lookup that name, check
that it's of the given type and then instantiate it.

View File

@ -13,7 +13,7 @@ import typing
from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from enum import Enum
from typing import Any, BinaryIO, Callable, Dict, Optional, TypeVar
from typing import Any, BinaryIO, Callable, Optional, TypeVar
from typing_extensions import Never, ParamSpec
# _thread_safe_fork is needed because the subprocesses in the pool can read
@ -158,7 +158,7 @@ class SubprocPool:
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
self.futures_lock = threading.Lock()
self.pending_futures: Dict[int, Future[Any]] = {}
self.pending_futures: dict[int, Future[Any]] = {}
self.job_id_count = itertools.count()
self.running = True

View File

@ -7,7 +7,7 @@ import shutil
import sys
import tempfile
from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional
from typing import Callable, Optional
from torch._inductor.runtime.cache_dir_utils import cache_dir
@ -43,7 +43,7 @@ class ConfigChange(BinarySubsystem):
# Dictionary of backend -> subsystems
BACKENDS: Dict[str, List[Subsystem]] = {
BACKENDS: dict[str, list[Subsystem]] = {
# run dynamo without aot_autograd
"eager": [],
# run dynamo with aot_autograd, but no partitioner or decomps
@ -68,8 +68,8 @@ BACKENDS: Dict[str, List[Subsystem]] = {
], # TODO - add more - fusions ?
}
subsystem_call_counter: Dict[str, int] = collections.Counter()
call_counter_debug_info: Dict[int, str] = {}
subsystem_call_counter: dict[str, int] = collections.Counter()
call_counter_debug_info: dict[int, str] = {}
def reset_counters() -> None:
@ -123,13 +123,13 @@ class CompilerBisector:
return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}"
@classmethod
def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None:
def write_lines_to_file(cls, file_path: str, lines: list[str]) -> None:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as file:
file.writelines(lines)
@classmethod
def read_lines_from_file(cls, file_path: str) -> List[str]:
def read_lines_from_file(cls, file_path: str) -> list[str]:
if os.path.exists(file_path):
with open(file_path) as file:
return file.readlines()
@ -154,7 +154,7 @@ class CompilerBisector:
@classmethod
def set_config_values(
cls, backend: str, subsystem: str, config_data: Dict[str, object]
cls, backend: str, subsystem: str, config_data: dict[str, object]
) -> None:
file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt")
lines = [f"{k}={v}\n" for k, v in config_data.items()]
@ -267,7 +267,7 @@ class CompilerBisector:
cls.write_lines_to_file(file_path, lines)
@classmethod
def get_config_change(cls, config_name: str) -> Optional[Dict[str, object]]:
def get_config_change(cls, config_name: str) -> Optional[dict[str, object]]:
backend = cls.get_backend()
subsystem = cls.get_subsystem()

View File

@ -1,16 +1,6 @@
import os # noqa: C101
import sys
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
import torch
import torch._inductor.custom_graph_pass
@ -193,8 +183,8 @@ pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
# hence custom IR passes built on top of it might break in the future.
_pre_fusion_custom_pass: Optional[
Callable[
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
List["torch._inductor.scheduler.BaseSchedulerNode"],
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
list["torch._inductor.scheduler.BaseSchedulerNode"],
]
] = None
@ -231,11 +221,11 @@ batch_fusion = True
# merge_splits_pass
# mutate_cat_pass
# split_cat_pass
pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
pre_grad_fusion_options: dict[str, dict[str, Any]] = {}
# Post grad fusion and options, set to empty dict to disable fusion.
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
post_grad_fusion_options: dict[str, dict[str, Any]] = {}
# enable reordering pass for improving memory locality
reorder_for_locality = True
@ -257,7 +247,7 @@ use_mixed_mm = True
# floating point numbers,about 16 decimal digits for double precision floating point numbers)
# according to PyTorch documentation.
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
fx_passes_numeric_check: Dict[str, Any] = {
fx_passes_numeric_check: dict[str, Any] = {
"pre_grad": False,
"precision": 1e-4,
"num_iterations": 1,
@ -287,12 +277,12 @@ reorder_for_compute_comm_overlap = False
# for built-in passes, use string name; for user-defined passes, pass in the function handle
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
# hence custom IR passes built on top of it might break in the future.
reorder_for_compute_comm_overlap_passes: List[
reorder_for_compute_comm_overlap_passes: list[
Union[
str,
Callable[
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
List["torch._inductor.scheduler.BaseSchedulerNode"],
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
list["torch._inductor.scheduler.BaseSchedulerNode"],
],
]
] = [
@ -618,7 +608,7 @@ _fuse_ddp_bucket_size = 25
# overlapping. At this moment, this pass performs better than
# reorder_for_compute_comm_overlap_passes but we will add the logic of
# "schedule_comm_wait" in the future and remove the one here.
_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
_fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
"fuse_ddp_with_concat_op",
"schedule_comm_wait",
]
@ -852,7 +842,7 @@ class cpp:
simdlen: Optional[int] = None
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
cxx: Tuple[None, str] = (
cxx: tuple[Literal[None], str] = (
None, # download gcc12 from conda-forge if conda is installed
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
) # type: ignore[assignment]
@ -1157,7 +1147,7 @@ class aot_inductor:
# Dictionary of metadata users might want to save to pass to the runtime.
# TODO: Move this somewhere else, since it's no longer really a config
metadata: Dict[str, str] = {}
metadata: dict[str, str] = {}
# fbcode only. Whether to raise error if C++ codegen is too big to optimize
raise_error_on_ignored_optimization: bool = (
@ -1168,7 +1158,7 @@ class aot_inductor:
dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
# Dictionary of presets that can be passed in
presets: Dict[str, Any] = {}
presets: dict[str, Any] = {}
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
# should be run with this flag both on and off to make sure we have coverage.
@ -1265,11 +1255,11 @@ class cuda:
class rocm:
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
# If empty, the `native` arch is used
arch: List[str] = []
arch: list[str] = []
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
ck_supported_arch: list[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
# Optimization level, use to balance compilation speed and runtime performance.
# The type will not necessarily be comprehensive and won't be enforced at runtime.
@ -1415,7 +1405,7 @@ class trace:
log_inductor_triton_kernel_to_post_grad_node_info: bool = True
_save_config_ignore: List[str] = [
_save_config_ignore: list[str] = [
# workaround: "Can't pickle <function ...>"
"trace.upload_tar",
"joint_custom_pre_pass",
@ -1423,7 +1413,7 @@ _save_config_ignore: List[str] = [
"pre_grad_custom_pass",
]
_cache_config_ignore_prefix: List[str] = [
_cache_config_ignore_prefix: list[str] = [
# trace functions are not relevant to config caching
"trace",
# uses absolute path
@ -1439,7 +1429,7 @@ _cache_config_ignore_prefix: List[str] = [
]
# External callable for matmul tuning candidates
external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
class test_configs:

View File

@ -1,5 +1,5 @@
import collections
from typing import Any, Callable, Dict, List, Optional
from typing import Any, Callable, Optional
import torch
import torch.utils._pytree as pytree
@ -57,7 +57,7 @@ def replace_node_with_constant(
def is_const_source(
node: torch.fx.Node, lifted_constant_names: Optional[List[str]]
node: torch.fx.Node, lifted_constant_names: Optional[list[str]]
) -> bool:
return node.op == "get_attr" or node.name in (lifted_constant_names or ())
@ -67,12 +67,12 @@ class ConstantFolder(torch.fx.Interpreter):
self,
gm: torch.fx.GraphModule,
skip_constructors: bool = False,
lifted_constant_names: Optional[List[str]] = None,
lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
self.node_replacements: dict[torch.fx.Node, Any] = {}
self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object()
self.skip_constructors: bool = skip_constructors
@ -141,7 +141,7 @@ class ConstantFolder(torch.fx.Interpreter):
return True
return False
def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list)
seen_uses = OrderedSet[torch.fx.Node]()
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
@ -264,11 +264,11 @@ class ConstantFolder(torch.fx.Interpreter):
self.node_replacements[node] = tensor
def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {}
env: dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env)
return super().run(initial_env=env)
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
env[n] = self.unknown_value # type: ignore[assignment]
if self.lifted_constant_names is None:
@ -309,7 +309,7 @@ def constant_fold(
def constant_graph_tag(
gm: torch.fx.GraphModule,
skip_constructors: bool = True,
lifted_constant_names: Optional[List[str]] = None,
lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None:
with torch.utils._python_dispatch._disable_current_modes():
@ -337,7 +337,7 @@ def constant_graph_tag(
def run_and_get_constant_graph(
gm: torch.fx.GraphModule,
skip_constructors: bool = True,
lifted_constant_names: Optional[List[str]] = None,
lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> torch.fx.GraphModule:
"""
@ -367,7 +367,7 @@ def run_and_get_constant_graph(
new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = []
for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG:

View File

@ -16,10 +16,11 @@ import sys
import sysconfig
import textwrap
import warnings
from collections.abc import Sequence
from ctypes import cdll
from ctypes.util import find_library
from pathlib import Path
from typing import Any, List, Optional, Sequence, Union
from typing import Any, Optional, Union
import torch
from torch._dynamo.utils import dynamo_timed
@ -285,12 +286,12 @@ def get_compiler_version_info(compiler: str) -> str:
# =============================== cpp builder ===============================
def _append_list(dest_list: List[str], src_list: List[str]) -> None:
def _append_list(dest_list: list[str], src_list: list[str]) -> None:
dest_list.extend(copy.deepcopy(item) for item in src_list)
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
new_list: List[str] = []
def _remove_duplication_in_list(orig_list: list[str]) -> list[str]:
new_list: list[str] = []
for item in orig_list:
if item not in new_list:
new_list.append(item)
@ -362,26 +363,26 @@ class BuildOptionsBase:
def __init__(
self,
compiler: str = "",
definitions: Optional[List[str]] = None,
include_dirs: Optional[List[str]] = None,
cflags: Optional[List[str]] = None,
ldflags: Optional[List[str]] = None,
libraries_dirs: Optional[List[str]] = None,
libraries: Optional[List[str]] = None,
passthrough_args: Optional[List[str]] = None,
definitions: Optional[list[str]] = None,
include_dirs: Optional[list[str]] = None,
cflags: Optional[list[str]] = None,
ldflags: Optional[list[str]] = None,
libraries_dirs: Optional[list[str]] = None,
libraries: Optional[list[str]] = None,
passthrough_args: Optional[list[str]] = None,
aot_mode: bool = False,
use_absolute_path: bool = False,
compile_only: bool = False,
) -> None:
self._compiler = compiler
self._definations: List[str] = definitions or []
self._include_dirs: List[str] = include_dirs or []
self._cflags: List[str] = cflags or []
self._ldflags: List[str] = ldflags or []
self._libraries_dirs: List[str] = libraries_dirs or []
self._libraries: List[str] = libraries or []
self._definations: list[str] = definitions or []
self._include_dirs: list[str] = include_dirs or []
self._cflags: list[str] = cflags or []
self._ldflags: list[str] = ldflags or []
self._libraries_dirs: list[str] = libraries_dirs or []
self._libraries: list[str] = libraries or []
# Some args is hard to abstract to OS compatable, passthrough it directly.
self._passthrough_args: List[str] = passthrough_args or []
self._passthrough_args: list[str] = passthrough_args or []
self._aot_mode: bool = aot_mode
self._use_absolute_path: bool = use_absolute_path
@ -408,25 +409,25 @@ class BuildOptionsBase:
def get_compiler(self) -> str:
return self._compiler
def get_definations(self) -> List[str]:
def get_definations(self) -> list[str]:
return self._definations
def get_include_dirs(self) -> List[str]:
def get_include_dirs(self) -> list[str]:
return self._include_dirs
def get_cflags(self) -> List[str]:
def get_cflags(self) -> list[str]:
return self._cflags
def get_ldflags(self) -> List[str]:
def get_ldflags(self) -> list[str]:
return self._ldflags
def get_libraries_dirs(self) -> List[str]:
def get_libraries_dirs(self) -> list[str]:
return self._libraries_dirs
def get_libraries(self) -> List[str]:
def get_libraries(self) -> list[str]:
return self._libraries
def get_passthrough_args(self) -> List[str]:
def get_passthrough_args(self) -> list[str]:
return self._passthrough_args
def get_aot_mode(self) -> bool:
@ -457,14 +458,14 @@ class BuildOptionsBase:
json.dump(attrs, f)
def _get_warning_all_cflag(warning_all: bool = True) -> List[str]:
def _get_warning_all_cflag(warning_all: bool = True) -> list[str]:
if not _IS_WINDOWS:
return ["Wall"] if warning_all else []
else:
return []
def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
if _IS_WINDOWS:
"""
On Windows, only c++20 can support `std::enable_if_t`.
@ -479,7 +480,7 @@ def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
return [f"std={std_num}"]
def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
if _IS_WINDOWS:
cflags = [
"wd4819",
@ -506,7 +507,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
return cflags
def _get_ffast_math_flags() -> List[str]:
def _get_ffast_math_flags() -> list[str]:
# ffast-math is equivalent to these flags as in
# https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468
# however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have
@ -527,7 +528,7 @@ def _get_ffast_math_flags() -> List[str]:
return flags
def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
def _get_optimization_cflags(cpp_compiler: str) -> list[str]:
if _IS_WINDOWS:
return ["O2"]
else:
@ -554,7 +555,7 @@ def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
return cflags
def _get_shared_cflag(compile_only: bool) -> List[str]:
def _get_shared_cflag(compile_only: bool) -> list[str]:
if _IS_WINDOWS:
"""
MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
@ -578,14 +579,14 @@ def get_cpp_options(
compile_only: bool,
warning_all: bool = True,
extra_flags: Sequence[str] = (),
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
definations: List[str] = []
include_dirs: List[str] = []
cflags: List[str] = []
ldflags: List[str] = []
libraries_dirs: List[str] = []
libraries: List[str] = []
passthrough_args: List[str] = []
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
definations: list[str] = []
include_dirs: list[str] = []
cflags: list[str] = []
ldflags: list[str] = []
libraries_dirs: list[str] = []
libraries: list[str] = []
passthrough_args: list[str] = []
cflags = (
_get_shared_cflag(compile_only)
@ -657,22 +658,22 @@ class CppOptions(BuildOptionsBase):
self._finalize_options()
def _get_glibcxx_abi_build_flags() -> List[str]:
def _get_glibcxx_abi_build_flags() -> list[str]:
if not _IS_WINDOWS:
return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
else:
return []
def _get_torch_cpp_wrapper_defination() -> List[str]:
def _get_torch_cpp_wrapper_defination() -> list[str]:
return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"]
def _use_custom_generated_macros() -> List[str]:
def _use_custom_generated_macros() -> list[str]:
return [" C10_USING_CUSTOM_GENERATED_MACROS"]
def _use_fb_internal_macros() -> List[str]:
def _use_fb_internal_macros() -> list[str]:
if not _IS_WINDOWS:
if config.is_fbcode():
fb_internal_macros = [
@ -697,12 +698,12 @@ def _setup_standard_sys_libs(
cpp_compiler: str,
aot_mode: bool,
use_absolute_path: bool,
) -> tuple[List[str], List[str], List[str]]:
) -> tuple[list[str], list[str], list[str]]:
from torch._inductor.codecache import _LINKER_SCRIPT
cflags: List[str] = []
include_dirs: List[str] = []
passthrough_args: List[str] = []
cflags: list[str] = []
include_dirs: list[str] = []
passthrough_args: list[str] = []
if _IS_WINDOWS:
return cflags, include_dirs, passthrough_args
@ -737,9 +738,9 @@ def _setup_standard_sys_libs(
return cflags, include_dirs, passthrough_args
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]]:
macros: List[str] = []
build_flags: List[str] = []
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]:
macros: list[str] = []
build_flags: list[str] = []
if vec_isa != invalid_vec_isa:
# Add Windows support later.
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
@ -759,7 +760,7 @@ def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]
def _get_torch_related_args(
include_pytorch: bool, aot_mode: bool
) -> tuple[List[str], List[str], List[str]]:
) -> tuple[list[str], list[str], list[str]]:
from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH
include_dirs = [
@ -783,7 +784,7 @@ def _get_torch_related_args(
return include_dirs, libraries_dirs, libraries
def _get_python_include_dirs() -> List[str]:
def _get_python_include_dirs() -> list[str]:
include_dir = Path(sysconfig.get_path("include"))
# On Darwin Python executable from a framework can return
# non-existing /Library/Python/... include path, in which case
@ -796,7 +797,7 @@ def _get_python_include_dirs() -> List[str]:
return [str(include_dir)]
def _get_python_related_args() -> tuple[List[str], List[str]]:
def _get_python_related_args() -> tuple[list[str], list[str]]:
python_include_dirs = _get_python_include_dirs()
python_include_path = sysconfig.get_path(
"include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
@ -893,13 +894,13 @@ def perload_icx_libomp_win(cpp_compiler: str) -> None:
def _get_openmp_args(
cpp_compiler: str,
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str]]:
cflags: List[str] = []
ldflags: List[str] = []
include_dir_paths: List[str] = []
lib_dir_paths: List[str] = []
libs: List[str] = []
passthrough_args: List[str] = []
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str]]:
cflags: list[str] = []
ldflags: list[str] = []
include_dir_paths: list[str] = []
lib_dir_paths: list[str] = []
libs: list[str] = []
passthrough_args: list[str] = []
if _IS_MACOS:
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
cflags.append("Xclang")
@ -998,7 +999,7 @@ def _get_openmp_args(
return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args
def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]:
def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]:
macros = []
if use_mmap_weights:
macros.append(" USE_MMAP_SELF")
@ -1013,14 +1014,14 @@ def get_cpp_torch_options(
compile_only: bool,
use_absolute_path: bool,
use_mmap_weights: bool,
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
definations: List[str] = []
include_dirs: List[str] = []
cflags: List[str] = []
ldflags: List[str] = []
libraries_dirs: List[str] = []
libraries: List[str] = []
passthrough_args: List[str] = []
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
definations: list[str] = []
include_dirs: list[str] = []
cflags: list[str] = []
ldflags: list[str] = []
libraries_dirs: list[str] = []
libraries: list[str] = []
passthrough_args: list[str] = []
torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination()
use_custom_generated_macros_definations = _use_custom_generated_macros()
@ -1163,7 +1164,7 @@ def _set_gpu_runtime_env() -> None:
os.environ["CUDA_HOME"] = build_paths.sdk_home
def _transform_cuda_paths(lpaths: List[str]) -> None:
def _transform_cuda_paths(lpaths: list[str]) -> None:
# This handles two cases:
# 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs
# 2. Linux machines may have CUDA installed under either lib64/ or lib/
@ -1186,14 +1187,14 @@ def get_cpp_torch_device_options(
device_type: str,
aot_mode: bool = False,
compile_only: bool = False,
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
definations: List[str] = []
include_dirs: List[str] = []
cflags: List[str] = []
ldflags: List[str] = []
libraries_dirs: List[str] = []
libraries: List[str] = []
passthrough_args: List[str] = []
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
definations: list[str] = []
include_dirs: list[str] = []
cflags: list[str] = []
ldflags: list[str] = []
libraries_dirs: list[str] = []
libraries: list[str] = []
passthrough_args: list[str] = []
if (
config.is_fbcode()
and "CUDA_HOME" not in os.environ
@ -1287,13 +1288,13 @@ class CppTorchDeviceOptions(CppTorchOptions):
extra_flags=extra_flags,
)
device_definations: List[str] = []
device_include_dirs: List[str] = []
device_cflags: List[str] = []
device_ldflags: List[str] = []
device_libraries_dirs: List[str] = []
device_libraries: List[str] = []
device_passthrough_args: List[str] = []
device_definations: list[str] = []
device_include_dirs: list[str] = []
device_cflags: list[str] = []
device_ldflags: list[str] = []
device_libraries_dirs: list[str] = []
device_libraries: list[str] = []
device_passthrough_args: list[str] = []
(
device_definations,
@ -1379,7 +1380,7 @@ class CppBuilder:
def __init__(
self,
name: str,
sources: Union[str, List[str]],
sources: Union[str, list[str]],
BuildOption: BuildOptionsBase,
output_dir: str = "",
) -> None:

View File

@ -7,7 +7,7 @@ import re
import subprocess
import sys
import warnings
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Union
import torch
from torch._inductor import config
@ -33,9 +33,9 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
class VecISA:
_bit_width: int
_macro: List[str]
_macro: list[str]
_arch_flags: str
_dtype_nelements: Dict[torch.dtype, int]
_dtype_nelements: dict[torch.dtype, int]
# Note [Checking for Vectorized Support in Inductor]
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
@ -79,7 +79,7 @@ cdll.LoadLibrary("__lib_path__")
def nelements(self, dtype: torch.dtype = torch.float) -> int:
return self._dtype_nelements[dtype]
def build_macro(self) -> List[str]:
def build_macro(self) -> list[str]:
return self._macro
def build_arch_flags(self) -> str:
@ -300,11 +300,11 @@ class InvalidVecISA(VecISA):
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
def x86_isa_checker() -> List[str]:
supported_isa: List[str] = []
def x86_isa_checker() -> list[str]:
supported_isa: list[str] = []
def _check_and_append_supported_isa(
dest: List[str], isa_supported: bool, isa_name: str
dest: list[str], isa_supported: bool, isa_name: str
) -> None:
if isa_supported:
dest.append(isa_name)
@ -333,7 +333,7 @@ supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
def get_isa_from_cpu_capability(
capability: Union[str, None],
vec_isa_list: List[VecISA],
vec_isa_list: list[VecISA],
invalid_vec_isa: InvalidVecISA,
):
# AMX setting is not supported in eager
@ -364,8 +364,8 @@ def get_isa_from_cpu_capability(
# might have too much redundant content that is useless for ISA check. Hence,
# we only cache some key isa information.
@functools.lru_cache(None)
def valid_vec_isa_list() -> List[VecISA]:
isa_list: List[VecISA] = []
def valid_vec_isa_list() -> list[VecISA]:
isa_list: list[VecISA] = []
if sys.platform == "darwin" and platform.processor() == "arm":
isa_list.append(VecNEON())
@ -411,7 +411,7 @@ def pick_vec_isa() -> VecISA:
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
return VecAVX2()
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
_valid_vec_isa_list: list[VecISA] = valid_vec_isa_list()
if not _valid_vec_isa_list:
return invalid_vec_isa

View File

@ -54,13 +54,7 @@ from typing import (
Callable,
cast,
ContextManager,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Type,
TYPE_CHECKING,
TypeVar,
Union,
@ -99,6 +93,8 @@ from torch.utils.weak import TensorWeakRef
if TYPE_CHECKING:
from collections.abc import Generator, Iterator, Sequence
from torch._inductor.utils import InputType
from torch.types import _bool
@ -357,12 +353,12 @@ def get_manager(
def cudagraphify_impl(
model: ModelType,
inputs: List[InputType],
inputs: list[InputType],
static_input_idxs: Sequence[int],
*args: Any,
**kwargs: Any,
) -> ModelType:
fn_cache: Dict[tuple[int, ...], Callable[..., Any]] = {}
fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {}
# Detect int inputs: we need to index on these
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
@ -372,7 +368,7 @@ def cudagraphify_impl(
del inputs
def deferred_cudagraphify(inputs: List[InputType]) -> OutputType:
def deferred_cudagraphify(inputs: list[InputType]) -> OutputType:
nonlocal has_warn
int_key = get_ints(inputs)
@ -405,7 +401,7 @@ def cudagraphify_impl(
def cudagraphify(
model: ModelType,
inputs: List[InputType],
inputs: list[InputType],
static_input_idxs: Sequence[int] = (),
*,
device_index: int,
@ -466,7 +462,7 @@ class StorageWeakRefWrapper:
@classmethod
def from_weakref_and_data_ptr(
cls: Type[S],
cls: type[S],
cdata: Any,
data_ptr: int,
extra_ref_check: Optional[Callable[[], bool]] = None,
@ -561,9 +557,9 @@ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
PathOutputIndex = tuple[int, int]
# For each node in the path, for each output, is the output alive
PathLiveness = List[List[bool]]
PathLiveness = list[list[bool]]
StackTraces = List[Optional[str]]
StackTraces = list[Optional[str]]
class CUDAWarmupNode:
@ -600,8 +596,8 @@ class CUDAWarmupNode:
self.wrapped_function = wrapped_function
self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
self.cuda_graphs_pool = cuda_graphs_pool
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
self.outputs_weakrefs: list[Optional[StorageWeakRefWrapper]] = []
self.tensor_weakrefs: list[Optional[TensorWeakRef]] = []
self.existing_cuda_graph = existing_cuda_graph
self.has_run = False
self.device_index = device_index
@ -619,7 +615,7 @@ class CUDAWarmupNode:
[t.data_ptr() for t in self.path_live_weakrefs() if t()]
)
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]:
non_cudagraph_inps = [
weakref.ref(t.untyped_storage())
for t in itertools.chain(new_inputs, self.wrapped_function.constants)
@ -707,9 +703,9 @@ class CUDAWarmupNode:
# Aliases for List that say what the indices denote
InputList = List # input indexes
OutputList = List # output indexes
LevelList = List # levels (distance from root of tree)
InputList = list # input indexes
OutputList = list # output indexes
LevelList = list # levels (distance from root of tree)
class OutputAliasInfo:
@ -772,7 +768,7 @@ class CUDAGraphNode:
wrapped_function: WrappedFunction,
id: GraphID,
parent: Optional[CUDAGraphNode],
inputs: List[InputType],
inputs: list[InputType],
cuda_graphs_pool: tuple[int, int],
device_index: int,
stack_traces: Optional[StackTraces],
@ -800,7 +796,7 @@ class CUDAGraphNode:
# A single wrapped function may be recorded multiple times if memory patterns or
# invariants change from one execution to the next
self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
self.children: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
# StorageWeakRef maintains whether the Storage C++ object remains allocated,
# not whether the corresponding memory has been deallocated. In order
@ -825,7 +821,7 @@ class CUDAGraphNode:
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
# tensors which are outputs of previous graphs in the tree
self.cudagraph_managed_idxs: List[int] = [
self.cudagraph_managed_idxs: list[int] = [
idx
for idx, t in enumerate(inputs)
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
@ -845,7 +841,7 @@ class CUDAGraphNode:
# and also aliases an output of the current CUDAGraphNode
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
self.static_input_idxs: List[int] = list(
self.static_input_idxs: list[int] = list(
OrderedSet(wrapped_function.static_input_idxs)
| OrderedSet(self.cudagraph_managed_idxs)
)
@ -866,8 +862,8 @@ class CUDAGraphNode:
def maybe_get_static_data_ptr(
idx: int,
inputs: List[InputType],
static_input_idxs: List[int],
inputs: list[InputType],
static_input_idxs: list[int],
) -> Optional[int]:
inp = inputs[idx]
if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
@ -888,7 +884,7 @@ class CUDAGraphNode:
# fresh allocations.
# precompute expanded dims to avoid computing in the hot path
self.expanded_dims: List[List[int]] = [
self.expanded_dims: list[list[int]] = [
get_expanded_dims(x)
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
else []
@ -903,11 +899,11 @@ class CUDAGraphNode:
# List of Tuples of (depth, output_index) that index into node at depth
# number of nodes from root and output_index of outputs. Will index into
# path_weakrefs.
self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
self.expected_dead_indices_before_graph: list[PathOutputIndex] = []
self.expected_dead_indices_after_graph: list[PathOutputIndex] = []
# all live indices after graph recording
self.live_indices_after_graph: List[PathOutputIndex] = []
self.live_indices_after_graph: list[PathOutputIndex] = []
if self.parent is not None:
previous_liveness = self.parent.recorded_liveness_after_graph
@ -934,7 +930,7 @@ class CUDAGraphNode:
# we reconstruct tensors at the correct data pointers of our inputs which are
# non owning and do not prevent deallocation. On subsequent executions, input values
# will be copied over to these tensors.
self.reconstructed_inputs: List[InputType] = [
self.reconstructed_inputs: list[InputType] = [
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
if isinstance(x, torch.Tensor)
else x
@ -983,7 +979,7 @@ class CUDAGraphNode:
self.recording_outputs: Optional[OutputType] = self._record(
wrapped_function.model, recording_inputs
)
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
self.outputs_metadata: OutputList[Union[dict[str, Any], int, None]] = []
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
@ -1001,7 +997,7 @@ class CUDAGraphNode:
self.graph.replay()
def _copy_inputs_and_remove_from_src(
self, dsts: List[InputType], srcs: List[InputType]
self, dsts: list[InputType], srcs: list[InputType]
) -> None:
dst_tensors = []
src_tensors = []
@ -1016,7 +1012,7 @@ class CUDAGraphNode:
if dst_tensors:
torch._foreach_copy_(dst_tensors, src_tensors)
def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None:
def check_static_inputs_are_stable(self, new_inputs: list[InputType]) -> None:
# avoid checking managed tensor static points since we already checked those in check_invariants
if (
not self.rerecord_if_static_inputs_change
@ -1036,7 +1032,7 @@ class CUDAGraphNode:
)
torch._check(False, lambda: error_msg)
def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType:
def run_first_inputs(self, new_inputs: list[InputType]) -> OutputType:
if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_before_invocation()
@ -1048,7 +1044,7 @@ class CUDAGraphNode:
assert outputs is not None
return outputs
def run(self, new_inputs: List[InputType]) -> OutputType:
def run(self, new_inputs: list[InputType]) -> OutputType:
self.check_static_inputs_are_stable(new_inputs)
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
@ -1130,7 +1126,7 @@ class CUDAGraphNode:
def prepare_alias_info_for_tensor_construction(
self,
out_alias_info: Optional[OutputAliasInfo],
metadata: Union[Dict[str, Any], int, None],
metadata: Union[dict[str, Any], int, None],
) -> Union[UntypedStorage, None, int]:
if (
isinstance(metadata, (int, type(None)))
@ -1149,7 +1145,7 @@ class CUDAGraphNode:
def prepare_storages_for_construction(
self,
) -> List[Union[UntypedStorage, None, int]]:
) -> list[Union[UntypedStorage, None, int]]:
output_storages = []
for output_storage_alias, metadata in zip(
self.output_storage_alias, self.outputs_metadata
@ -1173,7 +1169,7 @@ class CUDAGraphNode:
return False
return True
def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType:
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
"Record the model"
def static_input_iter() -> Generator[torch.Tensor, None, None]:
@ -1185,7 +1181,7 @@ class CUDAGraphNode:
yield _inp
# see: output_is_alias_of_persistent_static_inputs above
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper] = {
inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
for inp in itertools.chain(
static_input_iter(), self.wrapped_function.constants
@ -1229,7 +1225,7 @@ class CUDAGraphNode:
def _add_first_outputs(
self,
outputs: OutputType,
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper],
) -> None:
"Add the outputs from the first invocation of the node and set up metadata"
@ -1243,7 +1239,7 @@ class CUDAGraphNode:
assert len(self.outputs_weakrefs) == 0
# index from data pointer to index in outputs
output_new_storages_index: Dict[StorageDataPtr, int] = {}
output_new_storages_index: dict[StorageDataPtr, int] = {}
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
self.static_output_tensors = [None for _ in range(len(outputs))]
@ -1431,8 +1427,8 @@ class CUDAGraphNode:
@staticmethod
def _check_liveness(
indices: List[PathOutputIndex],
output_refs: List[List[Optional[StorageWeakRefWrapper]]],
indices: list[PathOutputIndex],
output_refs: list[list[Optional[StorageWeakRefWrapper]]],
) -> bool:
"Check that all of the indices specified are dead references"
for depth, output_index in indices:
@ -1448,8 +1444,8 @@ class CUDAGraphNode:
@staticmethod
def _get_different_indices(
prev: List[List[bool]], curr: List[List[bool]]
) -> List[PathOutputIndex]:
prev: list[list[bool]], curr: list[list[bool]]
) -> list[PathOutputIndex]:
"Find indices where the two lists differ."
dead_indices = []
assert len(prev) <= len(curr)
@ -1463,8 +1459,8 @@ class CUDAGraphNode:
@staticmethod
def _get_liveness(
weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
) -> List[List[bool]]:
weakrefs: list[list[Optional[StorageWeakRefWrapper]]],
) -> list[list[bool]]:
"Maps weakrefs to true if the reference is alive and false otherwise"
if len(weakrefs) == 0:
return []
@ -1472,7 +1468,7 @@ class CUDAGraphNode:
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
def debug_assert_invariants(
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
self, expected_liveness: list[list[bool]], newly_dead: list[PathOutputIndex]
) -> None:
if not config.triton.fast_path_cudagraph_asserts:
return
@ -1520,7 +1516,7 @@ class CUDAGraphNode:
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
)
def data_ptrs_dead_since_invocation(self) -> List[int]:
def data_ptrs_dead_since_invocation(self) -> list[int]:
"""
Since this node was invoked, return data ptrs of all tensor outputs that have died
in the current executing tree path.
@ -1568,7 +1564,7 @@ class CUDAGraphNode:
@staticmethod
def _tensor_metadata(
x: torch.Tensor, ignore_storage_offset: bool = True
) -> Dict[str, Any]:
) -> dict[str, Any]:
assert isinstance(x, torch.Tensor)
# We ignore the storage offset for inputs, but not for outputs
# TODO: - should we make the storage resizable ?
@ -1583,19 +1579,19 @@ class CUDAGraphNode:
}
def _reconstruct_from_tensor_metadata(
self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None
self, metadata: dict[str, Any], storage: Optional[UntypedStorage] = None
) -> Tensor:
s = self.create_storage(metadata) if storage is None else storage
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type]
def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage:
def create_storage(self, metadata: dict[str, Any]) -> torch.types.Storage:
return torch._C._construct_storage_from_data_pointer(
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
)
def _allocate_and_copy_recording_inputs(
self, inputs: List[InputType]
) -> List[InputType]:
self, inputs: list[InputType]
) -> list[InputType]:
"""
Allocate inputs for non static, non cudagraph managed tensors in the memory pool
and copy over the tensor values.
@ -1603,7 +1599,7 @@ class CUDAGraphNode:
torch.cuda.synchronize()
self.stream.wait_stream(torch.cuda.current_stream())
recording_inputs: List[InputType] = []
recording_inputs: list[InputType] = []
with warnings.catch_warnings(record=True), torch.cuda.device(
self.device
@ -1627,7 +1623,7 @@ class CUDAGraphNode:
return recording_inputs
def check_invariants(
self, inputs: List[InputType]
self, inputs: list[InputType]
) -> tuple[CheckInvariantStatus, Callable[..., str]]:
"""
Checks if this node can be run. The same pattern of tensor liveness, static inputs,
@ -1714,7 +1710,7 @@ def get_cudagraph_segments(pool_id: tuple[int, int]) -> Any:
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[int]:
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> list[int]:
blocks = []
for segment in get_cudagraph_segments(pool_id):
@ -1728,7 +1724,7 @@ def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[in
return blocks
def format_tb(frames: List[Any]) -> str:
def format_tb(frames: list[Any]) -> str:
formatted_traceback = [
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
for entry in frames
@ -1740,7 +1736,7 @@ def format_tb(frames: List[Any]) -> str:
def check_memory_pool(
device: int,
pool_id: tuple[int, int],
live_storages_ptrs: List[StorageWeakRefWrapper],
live_storages_ptrs: list[StorageWeakRefWrapper],
) -> None:
assert all(
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
@ -1839,12 +1835,12 @@ class CUDAGraphTreeManager:
# when they are first invoked, none of their inputs are outputs are outputs
# of another node, nor are there any live outputs of another node whose
# liveness would create a dependency.
self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
self.roots: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
# mapping from function id to wrapped function
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
self.ids_to_funcs: dict[FunctionID, WrappedFunction] = {}
self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {}
self.ids_to_stack_traces: dict[FunctionID, Optional[StackTraces]] = {}
self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet()
# if we fail to increment generation, and are stuck warming up,
@ -1883,14 +1879,14 @@ class CUDAGraphTreeManager:
# mapping from graph_id to (function id to mutation type hint) since we are
# specializing on a particular combination of Parent Node -> Function ID.
self.non_cudagraph_managed_mutation_hint: Dict[
Optional[GraphID], Dict[FunctionID, bool]
self.non_cudagraph_managed_mutation_hint: dict[
Optional[GraphID], dict[FunctionID, bool]
] = defaultdict(dict)
self.warmup_node_counter = itertools.count(start=-1, step=-1)
# mapping from graph_id to (function id to re-record count). We fall back to
# eager function if a function is re-recorded frequently on a node.
self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict(
self.num_rerecord: dict[Optional[GraphID], dict[FunctionID, int]] = defaultdict(
lambda: defaultdict(lambda: 0)
)
@ -1916,7 +1912,7 @@ class CUDAGraphTreeManager:
# number of instances we had to checkpoint the function
self.debug_checkpointing_counter = 0
self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
self.id_to_mode: dict[FunctionID, CompilationMode] = {}
# Note: [Backward Generation Handling]
# We generally perform a sequence of forward executions followed by backward executions.
@ -1943,7 +1939,7 @@ class CUDAGraphTreeManager:
)
)
def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
def run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
assert self.graph is not None, "Running CUDAGraph after shutdown"
self.mode = self.id_to_mode[function_id]
out = self._run(new_inputs, function_id)
@ -1971,7 +1967,7 @@ class CUDAGraphTreeManager:
return GraphID(next(self.warmup_node_counter))
def _update_non_cudagraph_managed_mutation(
self, function_id: FunctionID, inputs: List[InputType]
self, function_id: FunctionID, inputs: list[InputType]
) -> None:
node_id = self._get_node_id()
if maybe_mutation_str := check_for_mutation(
@ -2007,7 +2003,7 @@ class CUDAGraphTreeManager:
> torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
)
def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
# we will try to end the current execution lazily, since
# we dont want to do unnecessary checking of the existing outputs
# on the hot path, but both recording and warmup only happen once
@ -2152,7 +2148,7 @@ class CUDAGraphTreeManager:
self.current_node = None
def record_function(
self, new_inputs: List[InputType], function_id: FunctionID
self, new_inputs: list[InputType], function_id: FunctionID
) -> OutputType:
assert not isinstance(self.current_node, CUDAWarmupNode)
graph_id = self.new_graph_id()
@ -2183,7 +2179,7 @@ class CUDAGraphTreeManager:
return node.run_first_inputs(new_inputs)
def execute_node(
self, node: CUDAGraphNode, new_inputs: List[InputType]
self, node: CUDAGraphNode, new_inputs: list[InputType]
) -> OutputType:
self.current_node = node
self.path_state = ExecutionState.EXECUTION
@ -2191,7 +2187,7 @@ class CUDAGraphTreeManager:
return node.run(new_inputs)
def run_eager(
self, new_inputs: List[InputType], function_id: FunctionID
self, new_inputs: list[InputType], function_id: FunctionID
) -> OutputType:
# this is only stored on current node, because when we start a new path,
# we will deallocate it
@ -2229,7 +2225,7 @@ class CUDAGraphTreeManager:
def add_function(
self,
model: ModelType,
inputs: List[InputType],
inputs: list[InputType],
static_input_idxs: Sequence[int],
stack_traces: Optional[StackTraces],
mode: CompilationMode,
@ -2409,7 +2405,7 @@ class CUDAGraphTreeManager:
# TODO: we could also allow the these weak refs to continue to be allocated,
# but that adds some complications.
stor_stack_trace: Dict[int, Optional[str]] = {}
stor_stack_trace: dict[int, Optional[str]] = {}
for node in self.current_node._path_from_root:
assert node.stack_traces is not None
assert len(node.tensor_weakrefs) == len(node.stack_traces)
@ -2475,7 +2471,7 @@ class CUDAGraphTreeManager:
assert state is not None and device is not None
# currently we deallocate on instead of allowing stale recordings
stale_storages: List[int] = []
stale_storages: list[int] = []
# remove cached tensors, otherwise they would prevent memory from being
# reclaimed in subsequent recordings
@ -2506,7 +2502,7 @@ class CUDAGraphTreeManager:
def live_cudagraph_pool_storages_in_curr_execution(
self,
) -> List[StorageWeakRefPointer]:
) -> list[StorageWeakRefPointer]:
if self.current_node is None:
return []
# explicitly ignoring previous recorded outputs from past path

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch
from torch._dynamo.utils import counters
@ -11,14 +11,18 @@ from torch._inductor.utils import InputType
from torch.utils._ordered_set import OrderedSet
if TYPE_CHECKING:
from collections.abc import Sequence
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs"
)
OutputType = List[Optional[Union[int, torch.Tensor]]]
ModelType = Callable[[List[InputType]], OutputType]
OutputType = list[Optional[Union[int, torch.Tensor]]]
ModelType = Callable[[list[InputType]], OutputType]
@dataclasses.dataclass(frozen=True)
@ -38,7 +42,7 @@ class PlaceholderInfo:
name: str
stack_trace: Optional[str]
# This field is recursive, but never cyclic (since a node never uses itself)
users: List[PlaceholderInfo]
users: list[PlaceholderInfo]
mutating_use_stack_trace: Optional[str]
@ -92,7 +96,7 @@ def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]:
return [
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
]
@ -123,7 +127,7 @@ def get_mutation_stack_trace(
def check_for_mutation(
func: WrappedFunction,
inputs: List[InputType],
inputs: list[InputType],
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
) -> Optional[str]:
# doesnt work for non-trees because the warmup run would apply mutation twice
@ -160,7 +164,7 @@ def _get_use_stack_trace(node) -> Optional[str]:
def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: Dict[torch.device, torch.fx.Node]
device_node_mapping: dict[torch.device, torch.fx.Node]
) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})"
@ -180,7 +184,7 @@ def check_multiple_devices_or_any_cpu_nodes(
def check_lowering_disable_cudagraph(
device_node_mapping: Dict[torch.device, torch.fx.Node]
device_node_mapping: dict[torch.device, torch.fx.Node]
):
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
@ -262,7 +266,7 @@ class CheckInvariantStatus(Enum):
def log_data_ptr_mismatch(
placeholders: Sequence[PlaceholderInfo],
inputs: List[InputType],
inputs: list[InputType],
recorded_data_ptr: Sequence[Optional[int]],
target_idxs: Sequence[int],
mismatch: CheckInvariantStatus,
@ -292,7 +296,7 @@ def log_data_ptr_mismatch(
def maybe_warning_due_to_dynamic_shape(
fn_cache: Dict[tuple[int, ...], Callable[..., Any]],
fn_cache: dict[tuple[int, ...], Callable[..., Any]],
new_int_key: Any,
) -> bool:
num_cudagraphs = len(fn_cache.keys()) + 1
@ -327,5 +331,5 @@ class CudagraphCachedInfo:
"""
placeholders: Sequence[PlaceholderInfo]
stack_traces: List[Optional[str]]
cudagraph_fail_reasons: List[str]
stack_traces: list[Optional[str]]
cudagraph_fail_reasons: list[str]

View File

@ -12,7 +12,8 @@ import pickle
import pstats
import shutil
import subprocess
from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union
from collections.abc import Iterator
from typing import Any, Callable, IO, Optional, Union
from unittest.mock import patch
import torch
@ -39,7 +40,7 @@ from .virtualized import V
log = logging.getLogger(__name__)
SchedulerNodeList = List[Any]
SchedulerNodeList = list[Any]
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
@ -54,7 +55,7 @@ def has_dot() -> bool:
def draw_buffers(
nodes: List[BaseSchedulerNode],
nodes: list[BaseSchedulerNode],
print_graph: bool = False,
fname: Optional[str] = None,
) -> None:
@ -99,7 +100,7 @@ def draw_buffers(
)
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph:
"""
Creates a FX Graph from a list of SchedulerNode objects.
"""
@ -199,7 +200,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
def update_orig_fx_node_name_to_buf_name(
nodes: Optional[SchedulerNodeList],
node_name_to_buf_name: Dict[str, str],
node_name_to_buf_name: dict[str, str],
parent_buf_name: Optional[str] = None,
n_origins: int = 0,
) -> None:
@ -233,8 +234,8 @@ def update_orig_fx_node_name_to_buf_name(
def get_node_name_to_buf_meta(
node_name_to_buf_name: Dict[str, str]
) -> Dict[str, BufMeta]:
node_name_to_buf_name: dict[str, str]
) -> dict[str, BufMeta]:
buf_name_to_n_node = {}
for node_name, buf_name in node_name_to_buf_name.items():
if buf_name not in buf_name_to_n_node:
@ -256,7 +257,7 @@ def annotate_orig_fx_with_snodes(
"""
Creates a FX Graph from a list of SchedulerNode objects.
"""
node_name_to_buf_name: Dict[str, str] = {}
node_name_to_buf_name: dict[str, str] = {}
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
if node_name_to_buf_name is None:
return
@ -309,7 +310,7 @@ def enable_aot_logging() -> Iterator[None]:
class DebugContext:
_counter = itertools.count()
_inductor_triton_kernel_to_post_grad_node_info: Dict[str, List[str]] = {}
_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
@staticmethod
def create_debug_dir(folder_name: str) -> Optional[str]:
@ -425,7 +426,7 @@ class DebugContext:
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[Any],
) -> None:
@ -474,7 +475,7 @@ class DebugFormatter:
def fx_graph(
self,
gm: torch.fx.GraphModule,
inputs: List[torch.Tensor],
inputs: list[torch.Tensor],
) -> None:
with self.fopen("fx_graph_runnable.py") as fd:
save_dir = None
@ -504,7 +505,7 @@ class DebugFormatter:
def fx_graph_transformed(
self,
gm: torch.fx.GraphModule,
inputs: List[torch.Tensor],
inputs: list[torch.Tensor],
) -> None:
with self.fopen("fx_graph_transformed.py") as fd:
fd.write(gm.print_readable(print_output=False))
@ -557,14 +558,14 @@ class DebugFormatter:
def log_autotuning_results(
self,
name: str,
input_nodes: List[ir.IRNode],
timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
input_nodes: list[ir.IRNode],
timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
elapse: float,
precompile_elapse: float,
) -> None:
from .ir import FixedLayout
def build_node_info(node: ir.IRNode) -> Dict[str, str]:
def build_node_info(node: ir.IRNode) -> dict[str, str]:
if hasattr(node, "name"):
node_name = node.name
else:
@ -725,7 +726,7 @@ def aot_inductor_minifier_wrapper(
func: Callable[..., str],
exported_program: torch.export.ExportedProgram,
*,
inductor_configs: Dict[str, Any],
inductor_configs: dict[str, Any],
package_path: Optional[Union[str, io.BytesIO]] = None,
) -> str:
from torch._inductor import config

View File

@ -4,7 +4,7 @@ import logging
import math
import sys
import typing
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec
import torch
@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
def register_decomposition(
ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
if op in decompositions:
@ -170,7 +170,7 @@ def clamp(
@register_decomposition([aten.full])
def full(
size: List[Union[int, torch.SymInt]],
size: list[Union[int, torch.SymInt]],
fill_value: torch.types.Number,
**kwargs: Any,
) -> torch.Tensor:
@ -205,8 +205,8 @@ def index_add(
# cool with strides and everything goes to empty_strided)
@register_decomposition([aten.empty_permuted.default])
def empty_permuted(
size: List[Union[int, torch.SymInt]],
physical_layout: List[int],
size: list[Union[int, torch.SymInt]],
physical_layout: list[int],
**kwargs: Any,
) -> torch.Tensor:
perm = [0] * len(size)
@ -220,14 +220,14 @@ def convolution_backward(
grad_output: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
bias_sizes: List[int],
stride: Union[int, List[int]],
padding: Union[int, List[int]],
dilation: Union[int, List[int]],
bias_sizes: list[int],
stride: Union[int, list[int]],
padding: Union[int, list[int]],
dilation: Union[int, list[int]],
transposed: bool,
output_padding: List[int],
output_padding: list[int],
groups: int,
output_mask: List[bool],
output_mask: list[bool],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not output_mask[2] or not is_gpu(grad_output.device.type):
return NotImplemented
@ -345,7 +345,7 @@ def mm(
# don't remove ALL empty tensors, only the naughty ones)
@register_decomposition([aten.cat.default])
def cat(
tensors: List[torch.Tensor],
tensors: list[torch.Tensor],
dim: int = 0,
) -> torch.Tensor:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@ -515,7 +515,7 @@ def narrow_copy(
@register_decomposition([aten.view_copy.default])
def view_copy_default(
self: torch.Tensor,
size: List[Union[int, torch.SymInt]],
size: list[Union[int, torch.SymInt]],
) -> torch.Tensor:
return aten.view(self, size).clone()
@ -639,7 +639,7 @@ def randint_like_low(
@register_decomposition(aten.randint.default)
def randint(
high: int,
size: List[Union[int, torch.SymInt]],
size: list[Union[int, torch.SymInt]],
**kwargs: Any,
) -> torch.Tensor:
return aten.randint.low(0, high, size, **kwargs)
@ -731,11 +731,11 @@ def grid_sampler_2d(
@register_decomposition(aten._foreach_addcmul.Scalar)
def _foreach_addcmul_scalar(
self: List[torch.Tensor],
left_tensors: List[torch.Tensor],
right_tensors: List[torch.Tensor],
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
right_tensors: list[torch.Tensor],
scalar: float = 1,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
return aten._foreach_add.List(
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
)
@ -743,11 +743,11 @@ def _foreach_addcmul_scalar(
@register_decomposition(aten._foreach_addcdiv.Scalar)
def _foreach_addcdiv_scalar(
self: List[torch.Tensor],
left_tensors: List[torch.Tensor],
right_tensors: List[torch.Tensor],
self: list[torch.Tensor],
left_tensors: list[torch.Tensor],
right_tensors: list[torch.Tensor],
scalar: float = 1,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
return aten._foreach_add.List(
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
)
@ -755,10 +755,10 @@ def _foreach_addcdiv_scalar(
@register_decomposition(aten._foreach_lerp.Scalar)
def _foreach_lerp_scalar(
start_tensors: List[torch.Tensor],
end_tensors: List[torch.Tensor],
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
weight: torch.types.Number,
) -> List[torch.Tensor]:
) -> list[torch.Tensor]:
return aten._foreach_add.List(
start_tensors,
aten._foreach_mul.Scalar(
@ -769,10 +769,10 @@ def _foreach_lerp_scalar(
@register_decomposition(aten._foreach_lerp.ScalarList)
def _foreach_lerp_scalarlist(
start_tensors: List[torch.Tensor],
end_tensors: List[torch.Tensor],
scalars: List[torch.types.Number],
) -> List[torch.Tensor]:
start_tensors: list[torch.Tensor],
end_tensors: list[torch.Tensor],
scalars: list[torch.types.Number],
) -> list[torch.Tensor]:
return aten._foreach_add.List(
start_tensors,
aten._foreach_mul.ScalarList(
@ -814,13 +814,13 @@ def miopen_batch_norm(
@functools.lru_cache(None)
def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
return {**decompositions, **extra_random_decomps}
# TODO(aakhundov): replace this (and the above) Any by more
# specific type and fix all the cascading mypy errors
def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
def select_decomp_table() -> dict[Any, Callable[..., Any]]:
"""decomps can change based on config"""
if config.fallback_random:
return decompositions
@ -965,10 +965,10 @@ def index_reduce(
@register_decomposition(aten.max_pool2d_with_indices)
def max_pool2d_with_indices(
x: torch.Tensor,
kernel_size: List[int],
stride: Optional[Union[int, List[int]]] = None,
padding: Union[int, List[int]] = 0,
dilation: Union[int, List[int]] = 1,
kernel_size: list[int],
stride: Optional[Union[int, list[int]]] = None,
padding: Union[int, list[int]] = 0,
dilation: Union[int, list[int]] = 1,
ceil_mode: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]:
if dilation == 1:
@ -1015,7 +1015,7 @@ def max_pool2d_with_indices(
@register_decomposition(aten.adaptive_max_pool2d)
def adaptive_max_pool2d(
x: torch.Tensor, output_size: List[int]
x: torch.Tensor, output_size: list[int]
) -> tuple[torch.Tensor, torch.Tensor]:
*batch, h_in, w_in = x.shape
h_out, w_out = output_size

View File

@ -4,8 +4,8 @@ import dataclasses
import itertools
import logging
import re
import typing
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
from collections.abc import Sequence
from typing import Any, Callable, Optional, TypeVar, Union
from unittest.mock import patch
import sympy
@ -38,7 +38,7 @@ class Dep(abc.ABC):
index: sympy.Expr
@abc.abstractmethod
def rename(self, renames: Dict[str, str]) -> "Dep":
def rename(self, renames: dict[str, str]) -> "Dep":
pass
@abc.abstractmethod
@ -197,7 +197,7 @@ class MemoryDep(Dep):
return out
@property
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
def ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
"""{c0: 128, c1: 512, ...}"""
return dict(zip(self.var_names, self.size))
@ -221,7 +221,7 @@ class MemoryDep(Dep):
numel = numel * size
return numel # type: ignore[return-value]
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
def rename(self, renames: dict[str, str]) -> "MemoryDep":
if self.name in renames:
return MemoryDep(
renames[self.name],
@ -299,7 +299,7 @@ class StarDep(Dep):
def get_numel(self) -> sympy.Expr:
return V.graph.get_numel(self.name) # type: ignore[return-value]
def rename(self, renames: Dict[str, str]) -> "StarDep":
def rename(self, renames: dict[str, str]) -> "StarDep":
if self.name in renames:
return StarDep(renames[self.name], self.mode)
return self
@ -347,7 +347,7 @@ class WeakDep(Dep):
def get_numel(self) -> sympy.Expr:
return sympy.S.One
def rename(self, renames: Dict[str, str]) -> "WeakDep":
def rename(self, renames: dict[str, str]) -> "WeakDep":
if self.name in renames:
return WeakDep(renames[self.name], self.mutating_buf)
return self
@ -374,10 +374,10 @@ class ReadWrites:
reads: OrderedSet[Dep]
writes: OrderedSet[Dep]
index_exprs: OrderedSet[IndexExprDep]
range_vars: Optional[List[sympy.Expr]] = None
range_vars: Optional[list[sympy.Expr]] = None
var_ranges: Optional[VarRanges] = None
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
def rename(self, renames: dict[str, str]) -> "ReadWrites":
return ReadWrites(
OrderedSet(dep.rename(renames) for dep in self.reads),
OrderedSet(dep.rename(renames) for dep in self.writes),
@ -405,7 +405,7 @@ class ReadWrites:
return ReadWrites(reads - writes, writes, index_exprs)
@staticmethod
def merge_list(read_writes: List["ReadWrites"]):
def merge_list(read_writes: list["ReadWrites"]):
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
@ -564,7 +564,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
def index_vars_no_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str):
var_ranges, add_var = var_builder(prefix)
args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
return args, var_ranges
@ -572,8 +572,8 @@ def index_vars_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str = "d"):
from .ir import SqueezeView
var_ranges, add_var = var_builder(prefix)
args: List[List[sympy.Expr]] = []
new_sizes: List[List[sympy.Expr]] = []
args: list[list[sympy.Expr]] = []
new_sizes: list[list[sympy.Expr]] = []
for size in argsizes:
new_size, reindex = SqueezeView.squeezer(size)
new_sizes.append(new_size)
@ -653,7 +653,7 @@ def extract_loop_body_with_args(fn, args, var_ranges, normalize=False):
def extract_input_node_reduction_ranges(
input_node: "torch._inductor.ir.IRNode",
) -> tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]:
"""
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
@ -663,8 +663,8 @@ def extract_input_node_reduction_ranges(
from .ir import ComputedBuffer, ExternKernel, Loops
size: Optional[List[sympy.Expr]]
reduction_size: Optional[List[sympy.Expr]]
size: Optional[list[sympy.Expr]]
reduction_size: Optional[list[sympy.Expr]]
if isinstance(input_node.get_defining_op(), ComputedBuffer):
# Input node has already been realized. Return its size and reduction_size.
@ -683,11 +683,11 @@ def extract_input_node_reduction_ranges(
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
# Is there a way to check whether there are permutations inbetween?
reads = input_node.get_reads()
reduction_size: Optional[List[sympy.Expr]] = None
size: Optional[List[sympy.Expr]] = None
reduction_size: Optional[list[sympy.Expr]] = None
size: Optional[list[sympy.Expr]] = None
while reduction_size is None and len(reads) > 0:
seen: OrderedSet[str] = OrderedSet()
new_reads: List[Dep] = []
new_reads: list[Dep] = []
for read in reads:
if not isinstance(read, MemoryDep):
continue

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
import functools
from typing import Callable, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, Union
from collections.abc import Sequence
from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
import sympy

View File

@ -4,7 +4,7 @@ import os
import tempfile
import textwrap
from functools import lru_cache
from typing import Any, List, Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback
@ -29,7 +29,7 @@ else:
class OperatorIssue(RuntimeError):
@staticmethod
def operator_str(target: Any, args: List[Any], kwargs: dict[str, Any]) -> str:
def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str:
lines = [f"target: {target}"] + [
f"args[{i}]: {arg}" for i, arg in enumerate(args)
]
@ -39,13 +39,13 @@ class OperatorIssue(RuntimeError):
class MissingOperatorWithoutDecomp(OperatorIssue):
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None:
def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
_record_missing_op(target)
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
class MissingOperatorWithDecomp(OperatorIssue):
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None:
def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
_record_missing_op(target)
super().__init__(
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
@ -62,7 +62,7 @@ class MissingOperatorWithDecomp(OperatorIssue):
class LoweringException(OperatorIssue):
def __init__(
self, exc: Exception, target: Any, args: List[Any], kwargs: dict[str, Any]
self, exc: Exception, target: Any, args: list[Any], kwargs: dict[str, Any]
) -> None:
super().__init__(
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"

View File

@ -1,5 +1,4 @@
import json
from typing import List
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
@ -17,7 +16,7 @@ def serialize_extern_kernel_node(
def extern_node_json_serializer(
extern_kernel_nodes: List[inductor_ExternKernelNode],
extern_kernel_nodes: list[inductor_ExternKernelNode],
) -> str:
serialized_nodes = ExternKernelNodes(
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import itertools
import logging
import weakref
from typing import Any, List, Optional
from typing import Any, Optional
import torch
import torch.utils._pytree as pytree
@ -28,7 +28,7 @@ def replace_params_with_constants(
gm: torch.fx.GraphModule,
flat_params: list[Any],
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
) -> List[int]:
) -> list[int]:
"""
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
Returns a list of indices representing the input parameters that were not converted to constants.
@ -66,8 +66,8 @@ def replace_params_with_constants(
def freeze(
dynamo_gm: torch.fx.GraphModule,
aot_autograd_gm: torch.fx.GraphModule,
example_inputs: List[torch._subclasses.FakeTensor],
) -> tuple[torch.fx.GraphModule, List[int]]:
example_inputs: list[torch._subclasses.FakeTensor],
) -> tuple[torch.fx.GraphModule, list[int]]:
"""
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.

View File

@ -6,21 +6,17 @@ import signal
import string
import sys
import traceback
from collections.abc import KeysView
from enum import Enum
from functools import partial, wraps
from types import FrameType
from typing import (
Any,
Callable,
Dict,
get_args,
get_origin,
KeysView,
List,
Literal,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
@ -92,14 +88,14 @@ class TypeExemplars:
}
@staticmethod
def example(t: Type[T]) -> Optional[T]:
def example(t: type[T]) -> Optional[T]:
"""
Return an example of a class.
"""
return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None)
@staticmethod
def contains(t: Type[T]) -> bool:
def contains(t: type[T]) -> bool:
return t.__name__ in TypeExemplars.TYPE_EXEMPLARS
@ -136,7 +132,7 @@ class Status(Enum):
# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be
# manually specified here:
TYPE_OVERRIDES: Dict[str, List[Any]] = {
TYPE_OVERRIDES: dict[str, list[Any]] = {
"post_grad_fusion_options": [
{
"batch_linear_post_grad": {
@ -160,7 +156,7 @@ TYPE_OVERRIDES: Dict[str, List[Any]] = {
"autoheuristic_collect": ["pad_mm", "mixed_mm"],
"autoheuristic_use": ["pad_mm", "mixed_mm"],
}
SamplingType = Callable[[str, Type[Any], Any], Any]
SamplingType = Callable[[str, type[Any], Any], Any]
class SamplingMethod(Enum):
@ -178,7 +174,7 @@ class SamplingMethod(Enum):
@staticmethod
def _generate_value_for_type(
random_sample: bool, field_name: str, type_hint: Type[Any], default: Any
random_sample: bool, field_name: str, type_hint: type[Any], default: Any
) -> Any:
"""
Generates a value of a type based on the setting.
@ -304,9 +300,11 @@ class SamplingMethod(Enum):
if random_sample:
return random.choice(type_hint.__args__)
else:
return random.choice(
[t for t in type_hint.__args__ if t != default]
)
choices = [t for t in type_hint.__args__ if t != default]
if choices:
return random.choice(choices)
else:
return default
except AttributeError as err:
raise ValueError("Literal type with no args") from err
elif is_optional_type(type_hint):
@ -374,7 +372,7 @@ class Default:
DEFAULT = Default()
# The combination of config settings being set (based on their strings)
ComboType = Tuple[str, ...]
ComboType = tuple[str, ...]
class ResultType:
@ -382,7 +380,7 @@ class ResultType:
The mapping of the combo strings to the result status after running the config fuzzer.
"""
_vals: Dict[ComboType, Status]
_vals: dict[ComboType, Status]
def __repr__(self) -> str:
return f"ResultType[{self._vals}]"
@ -416,7 +414,7 @@ class ResultType:
# Type that maps config strings to their default value
ConfigType = Dict[str, Any]
ConfigType = dict[str, Any]
# Callable that returns a bool
FactoryOutputType = Callable[[], bool]
# input function factory
@ -504,10 +502,10 @@ class ConfigFuzzer:
return
self.seed = seed
self.test_timeout = test_timeout
self.detailed_results: Dict[ComboType, Dict[str, Any]] = {}
self.detailed_results: dict[ComboType, dict[str, Any]] = {}
self.config_module = config_module
self.test_model_fn_factory = test_model_fn_factory
self.fields: Dict[str, _ConfigEntry] = self.config_module._config
self.fields: dict[str, _ConfigEntry] = self.config_module._config
self.sample = SamplingMethod.dispatch(sm)
if default is None:
@ -587,7 +585,7 @@ class ConfigFuzzer:
}
return ret
def reproduce(self, configs: List[ConfigType]) -> ResultType:
def reproduce(self, configs: list[ConfigType]) -> ResultType:
"""entrypoint to reproduce any failure"""
results = ResultType()
for conf in configs:
@ -675,7 +673,7 @@ class ConfigFuzzer:
for field, value in config.items():
print(f"{field} = {value}")
def get_error_info(exc: Exception) -> Dict[str, Any]:
def get_error_info(exc: Exception) -> dict[str, Any]:
return {
"exception": str(exc),
"traceback": traceback.format_exc(),
@ -741,7 +739,7 @@ class ConfigFuzzer:
else:
return handle_return("Function succeeded", Status.PASSED, False, None)
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> List[ConfigType]:
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> list[ConfigType]:
"""
Test configs and bisect to minimal failing configuration.
"""
@ -749,7 +747,7 @@ class ConfigFuzzer:
random.seed(self.seed)
self._reset_configs()
results = ResultType()
ret: List[ConfigType] = []
ret: list[ConfigType] = []
for attempt in range(num_attempts):
print(f"Random attempt {attempt + 1}/{num_attempts}")
@ -783,7 +781,7 @@ class ConfigFuzzer:
return self._bisect_failing_config_helper(results, list(failing_config.items()))
def _bisect_failing_config_helper(
self, results: ResultType, failing_config: List[Tuple[str, Any]]
self, results: ResultType, failing_config: list[tuple[str, Any]]
) -> Optional[ConfigType]:
"""
Bisect a failing configuration to find minimal set of configs that cause failure.
@ -795,7 +793,7 @@ class ConfigFuzzer:
if not failing_config:
return None
def test(x: List[Tuple[str, Any]]) -> Status:
def test(x: list[tuple[str, Any]]) -> Status:
d = dict(x)
result = self.test_config(results, d)
return result

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import operator
from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, Optional, Type
from typing import Any, Callable, Optional
import sympy
@ -24,9 +24,9 @@ from .virtualized import V
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern(
pattern: tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]],
node: torch.fx.node.Node,
modules: Dict[str, torch.nn.modules.Module],
modules: dict[str, torch.nn.modules.Module],
) -> bool:
if len(node.args) == 0:
return False
@ -86,7 +86,7 @@ class FakeTensorUpdater:
return (node, node.target, id(node.args), id(node.kwargs))
def incremental_update(self):
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1
@ -208,7 +208,7 @@ def get_fake(x):
return x
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], Dict[str, Any]]:
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]:
"""
First value returns a boolean if any of the input nodes don't have a faketensor.
"""

View File

@ -8,22 +8,10 @@ import re
import sys
import time
from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager
from types import ModuleType
from typing import (
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
NoReturn,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
import sympy
from sympy import Expr
@ -198,8 +186,8 @@ def getattr_recursive(
return attr_itr
def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
ret: Dict[Node, tuple[int, ...]] = {}
def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
ret: dict[Node, tuple[int, ...]] = {}
output_node = g.find_nodes(op="output")[0]
if "user_visible_output_idxs" not in output_node.meta:
@ -212,7 +200,7 @@ def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
def mark_nodes_dislike_padding(
g: Graph, user_visible_output_strides: Dict[Node, tuple[int, ...]]
g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]]
) -> None:
"""
Nodes like convolution/convolution_backward want its input to be dense.
@ -282,7 +270,7 @@ def mark_nodes_dislike_padding(
class GraphLowering(torch.fx.Interpreter):
graph_outputs: List[ir.IRNode]
graph_outputs: list[ir.IRNode]
def symbolic_sizes_strides(
self, ex: torch.Tensor
@ -323,7 +311,7 @@ class GraphLowering(torch.fx.Interpreter):
def static_sizes_strides(
self, ex: torch.Tensor
) -> tuple[List[sympy.Expr], List[sympy.Expr]]:
) -> tuple[list[sympy.Expr], list[sympy.Expr]]:
"""
Primarily used to weights
"""
@ -341,12 +329,12 @@ class GraphLowering(torch.fx.Interpreter):
aot_mode: bool = False,
layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[
Callable[[List[ir.ExternKernelNode]], Any]
Callable[[list[ir.ExternKernelNode]], Any]
] = None,
is_inference: bool = False,
is_backward: bool = False,
is_const_graph: bool = False,
const_output_index: Optional[Dict[str, int]] = None,
const_output_index: Optional[dict[str, int]] = None,
const_code: Optional[str] = None,
const_module: Optional["GraphLowering"] = None,
name: Optional[str] = None,
@ -379,14 +367,14 @@ class GraphLowering(torch.fx.Interpreter):
# you don't start adding new ones in the lowering process
shape_env.freeze_runtime_asserts()
# We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: Dict[
Optional[sympy.Symbol], List[RuntimeAssert]
self.ras_by_symbol: dict[
Optional[sympy.Symbol], list[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy()
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: List[str] = []
self.graph_inputs: Dict[str, TensorBox] = {}
self.graph_inputs_original: Dict[str, InputBuffer] = {}
self.graph_input_names: list[str] = []
self.graph_inputs: dict[str, TensorBox] = {}
self.graph_inputs_original: dict[str, InputBuffer] = {}
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
self.device_types: OrderedSet[str] = (
const_module.device_types if const_module else OrderedSet()
@ -395,9 +383,9 @@ class GraphLowering(torch.fx.Interpreter):
const_module.device_idxs if const_module else OrderedSet()
)
self.device_type = "cpu"
self.buffers: List[ir.Buffer] = []
self.operations: List[ir.Operation] = []
self.const_output_index: Dict[str, int] = (
self.buffers: list[ir.Buffer] = []
self.operations: list[ir.Operation] = []
self.const_output_index: dict[str, int] = (
const_output_index if const_output_index else {}
)
self.folded_constants: OrderedSet[str] = (
@ -405,12 +393,12 @@ class GraphLowering(torch.fx.Interpreter):
if const_output_index
else OrderedSet()
)
self.constants: Dict[str, torch.Tensor] = (
self.constants: dict[str, torch.Tensor] = (
const_module.constants if const_module else {}
)
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
self.seen_subgraphs: Dict[str, ir.Subgraph] = {}
self.constant_reprs: Dict[str, str] = {}
self.torchbind_constants: dict[str, torch._C.ScriptObject] = {}
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
self.constant_reprs: dict[str, str] = {}
self.removed_operations = OrderedSet[str]()
self.removed_buffers = OrderedSet[str]()
self.removed_inplace_buffers = OrderedSet[str]()
@ -420,23 +408,23 @@ class GraphLowering(torch.fx.Interpreter):
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
from torch._inductor.extern_node_serializer import extern_node_json_serializer
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
extern_node_serializer
if config.is_fbcode() and extern_node_serializer
else extern_node_json_serializer
)
self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.lists: Dict[str, List[str]] = {}
self.lists: dict[str, list[str]] = {}
self.mutated_inputs = OrderedSet[str]()
self.mutated_input_idxs: List[int] = []
self.name_to_buffer: Dict[str, ir.Buffer] = {}
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
self.name_to_op: Dict[str, ir.Operation] = {}
self.mutated_input_idxs: list[int] = []
self.name_to_buffer: dict[str, ir.Buffer] = {}
self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
self.name_to_op: dict[str, ir.Operation] = {}
self.creation_time = time.time()
self.name = name # type: ignore[assignment]
self.cpp_wrapper = cpp_wrapper
@ -445,7 +433,7 @@ class GraphLowering(torch.fx.Interpreter):
# which sub-kernel is picked. Copy cpp_wrapper to another variable
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
self.record_multi_kernel_choice = cpp_wrapper
self.multi_kernel_to_choice: Dict[str, str] = {}
self.multi_kernel_to_choice: dict[str, str] = {}
self.aot_mode = aot_mode
self.graph_id = graph_id
@ -464,7 +452,7 @@ class GraphLowering(torch.fx.Interpreter):
mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
self.cache_key: str = "" # This is the cache key for the compiled artifact
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: List[
self.cache_linemap: list[
tuple[int, str]
] = (
[]
@ -473,18 +461,18 @@ class GraphLowering(torch.fx.Interpreter):
self.disable_cudagraphs_reason: Optional[str] = None
# only keeping one node per device for stack trace purposes
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
self.device_node_mapping: dict[torch.device, torch.fx.Node] = {}
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
"dynamo_flat_name_to_original_fqn", {}
)
self.allocated_constant_name: Dict[str, str] = (
self.allocated_constant_name: dict[str, str] = (
const_module.allocated_constant_name if const_module is not None else {}
)
init_backend_registration()
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
self.aligned_inputs = OrderedSet[str]()
self.no_fuse_buffer_names = OrderedSet[str]()
@ -599,7 +587,7 @@ class GraphLowering(torch.fx.Interpreter):
if is_inference:
from torch.utils.flop_counter import FlopCounterMode
flop_counts: Dict[str, float] = defaultdict(float)
flop_counts: dict[str, float] = defaultdict(float)
for node in conv_nodes:
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
node
@ -702,7 +690,7 @@ class GraphLowering(torch.fx.Interpreter):
def make_subgraph(
self,
gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor],
example_inputs: list[torch.Tensor],
subgraph_name: str,
) -> "SubgraphLowering":
"""
@ -886,7 +874,7 @@ class GraphLowering(torch.fx.Interpreter):
buffer.name = name
return name
def register_operation_list(self, operation_names: List[str]) -> str:
def register_operation_list(self, operation_names: list[str]) -> str:
name = self.qualify_name("list_" + "_".join(operation_names))
self.lists[name] = operation_names
return name
@ -995,7 +983,7 @@ class GraphLowering(torch.fx.Interpreter):
)
def placeholder(
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override]
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
@ -1072,7 +1060,7 @@ class GraphLowering(torch.fx.Interpreter):
self.aligned_inputs.add(target)
return tensor
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs)
@ -1155,7 +1143,7 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(
self, target: str, args: tuple[()], kwargs: Dict[str, object] # type: ignore[override]
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1203,7 +1191,7 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError
def output(
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override]
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
) -> None:
result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)):
@ -1306,9 +1294,9 @@ class GraphLowering(torch.fx.Interpreter):
self,
fx_node: torch.fx.Node,
old_args: tuple[Any],
old_kwargs: Dict[str, Any],
old_kwargs: dict[str, Any],
new_args: tuple[Any],
new_kwargs: Dict[str, Any],
new_kwargs: dict[str, Any],
) -> None:
"""Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
@ -1803,7 +1791,7 @@ class GraphLowering(torch.fx.Interpreter):
self.const_module.wrapper_code.src_to_kernel
)
def codegen_with_cpp_wrapper(self) -> tuple[str, List[tuple[int, Node]]]:
def codegen_with_cpp_wrapper(self) -> tuple[str, list[tuple[int, Node]]]:
"""
For GPU, Triton kernels are autotuned and stored as cubin files
"""
@ -1902,7 +1890,7 @@ class GraphLowering(torch.fx.Interpreter):
# cpu
return self.codegen()
def codegen(self) -> tuple[str, List[tuple[int, Node]]]:
def codegen(self) -> tuple[str, list[tuple[int, Node]]]:
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
from .scheduler import Scheduler
@ -1949,7 +1937,7 @@ class GraphLowering(torch.fx.Interpreter):
def count_bytes(
self,
) -> tuple[
int, List[tuple[BaseSchedulerNode, int]], List[tuple[BaseSchedulerNode, float]]
int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]]
]:
total_bytes = 0
node_counts = []
@ -2041,7 +2029,7 @@ class GraphLowering(torch.fx.Interpreter):
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod
def get_output_names(self) -> List[str]:
def get_output_names(self) -> list[str]:
names = []
shape_counter = itertools.count(0)
none_counter = itertools.count(0)

View File

@ -1,13 +1,13 @@
# mypy: allow-untyped-defs
import contextlib
from typing import Callable, List, TYPE_CHECKING
from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING:
import torch
# Executed in the order they're registered
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = []
@contextlib.contextmanager

View File

@ -22,7 +22,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
"""
import itertools
from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, Optional, overload, Union
from typing import Any, Callable, Literal, Optional, overload, Union
from typing_extensions import TypeAlias
import sympy
@ -196,8 +196,8 @@ class IndexPropagation:
def __init__(
self,
inner: Any,
iter_ranges: Dict[sympy.Symbol, sympy.Expr],
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
iter_ranges: dict[sympy.Symbol, sympy.Expr],
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr],
) -> None:
self._inner = inner
self.shape_env = V.graph.sizevars.shape_env
@ -248,18 +248,18 @@ class IndexPropagation:
self,
name: Literal["indirect_indexing"],
args: tuple[Any, ...],
kwargs: Dict[str, Any],
kwargs: dict[str, Any],
) -> IndexPropVar:
...
@overload
def fallback(
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult:
...
def fallback(
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult:
# Fallback to the wrapped handler
new_args = [self.unwrap(a) for a in args]
@ -267,7 +267,7 @@ class IndexPropagation:
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
def propagate_sympy(
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult:
# Build a new SymPy expression from this ops call
def unwrap(a: Union[Any, IndexPropVar]) -> Any:

View File

@ -2,12 +2,16 @@
from __future__ import annotations
import logging
from typing import Optional, Sequence
from typing import Optional, TYPE_CHECKING
import torch
from torch import _prims, Tensor
if TYPE_CHECKING:
from collections.abc import Sequence
log = logging.getLogger(__name__)

View File

@ -8,6 +8,7 @@ import logging
import textwrap
import traceback
import typing
from collections.abc import Generator, Iterable, Sequence
from contextlib import nullcontext
from enum import Enum
from functools import partial
@ -16,14 +17,9 @@ from typing import (
Callable,
ClassVar,
ContextManager,
Dict,
Generator,
Iterable,
List,
Literal,
Optional,
overload,
Sequence,
TYPE_CHECKING,
TypeVar,
Union,
@ -165,11 +161,11 @@ e.g. it may be a graph input or compile time constant.
_NodeOrNodes: TypeAlias = Union[
int,
"TensorBox",
Dict[str, "TensorBox"],
dict[str, "TensorBox"],
"Symbol",
"IRNode",
Sequence[
Optional[Union[int, Dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
],
]
@ -426,7 +422,7 @@ class IRNode:
# NB: These are kinda weird,
origins: OrderedSet[Any] = dataclasses.field(init=False)
traceback: Optional[List[str]] = dataclasses.field(init=False)
traceback: Optional[list[str]] = dataclasses.field(init=False)
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
@staticmethod
@ -455,7 +451,7 @@ class IRNode:
def get_read_names(self) -> OrderedSet[str]:
return OrderedSet(dep.name for dep in self.get_reads())
def get_traceback(self) -> Optional[List[str]]:
def get_traceback(self) -> Optional[list[str]]:
return self.traceback
def get_origin_node(self) -> Optional[torch.fx.Node]:
@ -604,18 +600,18 @@ class IRNode:
raise NotImplementedError(type(self).__name__)
def freeze_layout_with_stride_order(
self, order: List[int], allow_padding: bool = False
self, order: list[int], allow_padding: bool = False
) -> None:
raise NotImplementedError(type(self).__name__)
def freeze_layout_with_fill_order(self, order: List[int]) -> None:
def freeze_layout_with_fill_order(self, order: list[int]) -> None:
raise NotImplementedError(type(self).__name__)
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None:
def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
raise NotImplementedError(type(self).__name__)
def freeze_layout_with_exact_strides(
self, exact_strides: List[_IntLike], allow_padding: bool = False
self, exact_strides: list[_IntLike], allow_padding: bool = False
) -> None:
raise NotImplementedError(type(self).__name__)
@ -703,7 +699,7 @@ class Operation:
def get_reads(self) -> OrderedSet[Dep]:
return self.get_read_writes().reads
def get_outputs(self) -> List[Buffer]:
def get_outputs(self) -> list[Buffer]:
raise NotImplementedError
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
@ -936,7 +932,7 @@ class Scatter(Pointwise):
)
REDUCTION_COMBINE_FN: Dict[str, Callable[..., OpsValue]] = {
REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
"any": ops_wrapper("logical_or"),
"max": ops_wrapper("maximum"),
"min": ops_wrapper("minimum"),
@ -1575,8 +1571,8 @@ class Reduction(Loops):
wrapper_fn: Callable[..., Any],
original_ranges: Sequence[Expr],
original_reduction_ranges: Sequence[Expr],
new_ranges: List[Expr],
new_reduction_ranges: List[Integer],
new_ranges: list[Expr],
new_reduction_ranges: list[Integer],
reduction_type: str,
split: _IntLike,
reduction_hint: ReductionHint,
@ -1678,8 +1674,8 @@ class Reduction(Loops):
inner_fn: Callable[..., Any],
original_ranges: Sequence[Expr],
original_reduction_ranges: Sequence[Expr],
new_ranges: List[Integer],
new_reduction_ranges: List[Integer],
new_ranges: list[Integer],
new_reduction_ranges: list[Integer],
reduction_type: str,
reduction_hint: ReductionHint,
) -> TensorBox:
@ -1767,8 +1763,8 @@ class WelfordReduction(Reduction):
device: torch.device,
dtype: torch.dtype,
inner_fns: Sequence[Callable[..., Any]],
ranges: List[Integer],
reduction_ranges: List[Integer],
ranges: list[Integer],
reduction_ranges: list[Integer],
reduction_type: str,
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
) -> Sequence[TensorBox]:
@ -1893,8 +1889,8 @@ class WelfordReduction(Reduction):
device: torch.device,
dtype: torch.dtype,
inner_fns: Sequence[Callable[..., Any]],
ranges: List[Integer],
reduction_ranges: List[Integer],
ranges: list[Integer],
reduction_ranges: list[Integer],
reduction_type: str,
split: _IntLike,
reduction_hint: ReductionHint,
@ -1983,8 +1979,8 @@ class WelfordReduction(Reduction):
@ir_dataclass
class Scan(Loops):
scan_ranges: List[Integer]
size: List[Integer]
scan_ranges: list[Integer]
size: list[Integer]
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
reduction_hint: ReductionHint
@ -2055,7 +2051,7 @@ class Scan(Loops):
device: torch.device,
dtypes: tuple[torch.dtype, ...],
inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
size: List[Integer],
size: list[Integer],
axis: int,
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
@ -2114,7 +2110,7 @@ class Scan(Loops):
else:
scan_type = SplitScan
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> List[Expr]:
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
assert len(scan_index) == len(scan_ranges)
assert len(index) == len(pointwise_ranges)
return [*index[:axis], *scan_index, *index[axis:]]
@ -2152,8 +2148,8 @@ class Scan(Loops):
dtype: torch.dtype,
inner_fn: Callable[[Sequence[Expr]], OpsValue],
axis: int,
pointwise_ranges: List[Integer],
scan_ranges: List[Integer],
pointwise_ranges: list[Integer],
scan_ranges: list[Integer],
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
scan_numel: Expr,
) -> tuple[ReductionHint, _IntLike]:
@ -2182,8 +2178,8 @@ class SplitScan(Scan):
@ir_dataclass
class Sort(Loops):
# Sorts a tuple of key, value pairs
sort_ranges: List[Integer]
size: List[Integer]
sort_ranges: list[Integer]
size: list[Integer]
reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
reduction_hint: ReductionHint
output_index: int
@ -2251,8 +2247,8 @@ class Sort(Loops):
cls,
device: torch.device,
dtypes: tuple[torch.dtype, ...],
inner_fns: tuple[Callable[[List[Expr]], Any], ...],
size: List[Integer],
inner_fns: tuple[Callable[[list[Expr]], Any], ...],
size: list[Integer],
axis: int,
stable: bool,
descending: bool,
@ -2293,7 +2289,7 @@ class Sort(Loops):
for output_index in range(len(dtypes))
]
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> List[Expr]:
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
assert len(sort_index) == len(sort_ranges)
assert len(index) == len(pointwise_ranges)
return [*index[:axis], *sort_index, *index[axis:]]
@ -2509,7 +2505,7 @@ class BaseView(IRNode):
@ir_dataclass
class ExpandView(BaseView):
size: List[Expr]
size: list[Expr]
@staticmethod
def _normalize_size(x, new_size): # type: ignore[no-untyped-def]
@ -2588,7 +2584,7 @@ class ExpandView(BaseView):
@ir_dataclass
class PermuteView(BaseView):
dims: List[Expr]
dims: list[Expr]
@classmethod
def create(cls, x, dims): # type: ignore[no-untyped-def]
@ -2676,7 +2672,7 @@ class SqueezeView(BaseView):
not_one = [i for i, s in enumerate(size) if s != 1]
length = len(size)
def reindex(index: List[sympy.Expr]) -> tuple[sympy.Expr, ...]:
def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]:
assert len(index) == len(not_one), f"{index} {not_one}"
new_index = [sympy.S.Zero] * length
for idx, s in zip(not_one, index):
@ -2691,7 +2687,7 @@ class SqueezeView(BaseView):
@ir_dataclass
class GenericView(BaseView):
size: List[Expr]
size: list[Expr]
reindex: Callable[..., Any]
def make_reindexer(self): # type: ignore[no-untyped-def]
@ -3159,8 +3155,8 @@ class Layout(OutputSpec):
self,
device: torch.device,
dtype: torch.dtype,
size: List[Expr],
stride: Optional[List[Expr]] = None,
size: list[Expr],
stride: Optional[list[Expr]] = None,
offset: Expr = Integer(0),
) -> None:
if stride is None:
@ -3169,8 +3165,8 @@ class Layout(OutputSpec):
self.dtype = dtype
assert len(size) == len(stride), f"size={size}, stride={stride}"
assert all(isinstance(s, (Expr, int)) for s in size)
self.size: List[Expr] = size
self.stride: List[Expr] = stride
self.size: list[Expr] = size
self.stride: list[Expr] = stride
self.offset: Expr = offset
def __str__(self) -> str:
@ -3594,8 +3590,8 @@ class NoneLayout(OutputSpec):
# dependencies manually in scheduler
device: Optional[torch.device]
size: List[int] = dataclasses.field(default_factory=lambda: [0])
stride: List[int] = dataclasses.field(default_factory=lambda: [0])
size: list[int] = dataclasses.field(default_factory=lambda: [0])
stride: list[int] = dataclasses.field(default_factory=lambda: [0])
def storage_size(self) -> int:
return 0
@ -3620,7 +3616,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
V.graph.mark_buffer_mutated(name)
@property
def stride(self) -> List[Expr]:
def stride(self) -> list[Expr]:
return self.real_layout().stride
@stride.setter
@ -3725,7 +3721,7 @@ class Buffer(IRNode):
def get_size(self) -> Sequence[Expr]:
return [*self.get_layout().size]
def get_stride(self) -> List[Expr]:
def get_stride(self) -> list[Expr]:
return [*self.get_layout().stride]
def get_offset(self) -> Expr:
@ -3816,7 +3812,7 @@ class Buffer(IRNode):
@ir_dataclass(frozen=False)
class OperationBuffer(Buffer, Operation):
# An operation that produces a single output buffer
def get_outputs(self) -> List[Buffer]:
def get_outputs(self) -> list[Buffer]:
return [self]
def get_defining_op(self) -> Operation:
@ -3977,7 +3973,7 @@ class ComputedBuffer(OperationBuffer):
assert isinstance(self.data, Pointwise)
return partial(self.data.store_output, self.name, indexer)
def get_fill_order(self) -> Optional[List[int]]:
def get_fill_order(self) -> Optional[list[int]]:
"""
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
@ -4028,9 +4024,9 @@ class ComputedBuffer(OperationBuffer):
def get_default_sizes_body(
self,
) -> tuple[
tuple[List[sympy.Expr], List[sympy.Expr]],
tuple[list[sympy.Expr], list[sympy.Expr]],
LoopBody,
tuple[List[sympy.Expr], List[sympy.Expr]],
tuple[list[sympy.Expr], list[sympy.Expr]],
]:
args, var_ranges = dependencies.index_vars_squeeze(
self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
@ -4043,7 +4039,7 @@ class ComputedBuffer(OperationBuffer):
*args,
)
index_vars = []
reduce_vars: List[Any] = []
reduce_vars: list[Any] = []
index_size = []
reduce_size = []
for v, s in var_ranges.items():
@ -4059,9 +4055,9 @@ class ComputedBuffer(OperationBuffer):
def simplify_and_reorder(
self,
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
) -> tuple[tuple[List[sympy.Expr], List[sympy.Expr]], LoopBody]:
) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]:
"""
This is a main place where we do loop transformations in a
backend-agnostic way.
@ -4282,7 +4278,7 @@ class TemplateBuffer(OperationBuffer):
def simplify_and_reorder( # type: ignore[no-untyped-def]
self,
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
):
return (
@ -4314,7 +4310,7 @@ class TritonTemplateBuffer(TemplateBuffer):
"""
super().__init__(layout, inputs, make_kernel_render)
self.mutated_inputs = mutated_inputs
self.outputs: List[Buffer] = [self]
self.outputs: list[Buffer] = [self]
if mutated_inputs is not None:
# Ensure that the mutated inputs are only allowed for certain nodes
allowed_set = (
@ -4335,7 +4331,7 @@ class TritonTemplateBuffer(TemplateBuffer):
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
)
def get_outputs(self) -> List[Buffer]:
def get_outputs(self) -> list[Buffer]:
return self.outputs
def get_allowed_prologue_inps(self) -> OrderedSet[str]:
@ -4346,7 +4342,7 @@ class TritonTemplateBuffer(TemplateBuffer):
return out
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
class ChoiceCaller:
@ -4361,7 +4357,7 @@ class ChoiceCaller:
def __init__(
self,
name: str,
input_nodes: List[Buffer],
input_nodes: list[Buffer],
layout: Layout,
description: str,
) -> None:
@ -4389,7 +4385,7 @@ class ChoiceCaller:
def output_node(self) -> TensorBox:
raise NotImplementedError
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled."""
return {}
@ -4414,9 +4410,9 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
def __init__(
self,
layout: Layout,
inputs: List[IRNode],
choice_timings: Callable[[], Dict[ChoiceCaller, float]],
unfiltered_choices: List[ChoiceCaller],
inputs: list[IRNode],
choice_timings: Callable[[], dict[ChoiceCaller, float]],
unfiltered_choices: list[ChoiceCaller],
allowed_prologue_inps: OrderedSet[str],
) -> None:
super().__init__(
@ -4426,7 +4422,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
allowed_prologue_inps=allowed_prologue_inps,
)
self._choice_timings_fn = choice_timings
self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None
self._choice_timings: Optional[dict[ChoiceCaller, float]] = None
self.original_inputs = inputs
self._output_plannable = all(
isinstance(choice, TritonTemplateCallerBase)
@ -4445,7 +4441,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
return self._output_plannable
@property
def choice_timings(self) -> Dict[ChoiceCaller, float]:
def choice_timings(self) -> dict[ChoiceCaller, float]:
if self._choice_timings is None:
self._choice_timings = self._choice_timings_fn()
return self._choice_timings
@ -4496,7 +4492,7 @@ class CppTemplateBuffer(TemplateBuffer):
super().__init__(layout, inputs, make_kernel_render)
self.template = template
self.choice = choice
self.outputs: Optional[List[Buffer]] = None
self.outputs: Optional[list[Buffer]] = None
def get_layout(self) -> Layout:
if isinstance(self.layout, MultiOutputLayout):
@ -4512,7 +4508,7 @@ class CppTemplateBuffer(TemplateBuffer):
@ir_dataclass(frozen=False)
class InputsKernel(OperationBuffer):
inputs: List[Buffer]
inputs: list[Buffer]
def get_read_writes(self) -> dependencies.ReadWrites:
reads = OrderedSet[dependencies.Dep]()
@ -4752,7 +4748,7 @@ class ConcatKernel(NopKernel):
@ir_dataclass(frozen=False)
class ExternKernel(InputsKernel):
constant_args: tuple[Any, ...] = ()
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
output_view: Optional[ReinterpretView] = None
python_kernel_name: Optional[str] = None
cpp_kernel_name: Optional[str] = None
@ -4764,12 +4760,12 @@ class ExternKernel(InputsKernel):
op_overload: Optional[
Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
] = None
arg_properties: Optional[List[Dict[str, Any]]] = None
kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None
unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
arg_properties: Optional[list[dict[str, Any]]] = None
kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
default_factory=dict
)
mutation_outputs: List[MutationOutput] = dataclasses.field(default_factory=list)
mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
def __init__( # type: ignore[no-untyped-def]
self,
@ -4801,7 +4797,7 @@ class ExternKernel(InputsKernel):
self.mutation_outputs = []
self.fx_node = V.graph.current_node
def get_outputs(self) -> List[Buffer]:
def get_outputs(self) -> list[Buffer]:
return [self, *self.mutation_outputs]
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
@ -4919,10 +4915,10 @@ class ExternKernel(InputsKernel):
cls, kernel, *args, **kwargs
) -> tuple[
Any,
List[Any],
List[Any],
list[Any],
list[Any],
Callable[[Any, Any], Any],
Optional[Dict[sympy.Symbol, pytree.KeyPath]],
Optional[dict[sympy.Symbol, pytree.KeyPath]],
]:
binded_args = {"args": args, "kwargs": kwargs}
@ -4930,7 +4926,7 @@ class ExternKernel(InputsKernel):
is_arg_tensor = []
tensor_args = []
non_tensor_args: List[Any] = []
non_tensor_args: list[Any] = []
for arg in args_flat:
is_arg_tensor.append(isinstance(arg, IRNode))
if is_arg_tensor[-1]:
@ -4963,7 +4959,7 @@ class ExternKernel(InputsKernel):
# Rerun fake tensor propagation, because Inductor may have changed the
# strides of inputs and we need to determine accurately what the
# output stride will be.
example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = []
example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
# We need to retain the constant values of fake tensors that we originally
# propagated the graph with, because for some operators running without a
@ -4984,7 +4980,7 @@ class ExternKernel(InputsKernel):
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
example_output = kernel(*new_args, **new_kwargs)
unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
if shape_env := V.fake_mode.shape_env:
rebind_unbacked(shape_env, V.current_node, example_output)
unbacked_bindings = compute_unbacked_bindings(
@ -5309,7 +5305,7 @@ class ExternKernel(InputsKernel):
)
return args
def codegen_const_args(self, names: Optional[List[str]] = None): # type: ignore[no-untyped-def]
def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def]
if V.graph.cpp_wrapper:
result = []
# Aten ops follow the convention that tensor args are before non-tensor args,
@ -5635,14 +5631,14 @@ class TMADescriptor(ExternKernel):
# as TMA descriptors are immutable,
# we can dedup them by the input args
_CACHE: Dict[Any, TMADescriptor] = {}
_CACHE: dict[Any, TMADescriptor] = {}
@classmethod
def create( # type: ignore[no-untyped-def]
cls,
tensor: IRNode,
dims: List[Union[int, torch.SymInt]],
block_dims: List[Union[int, torch.SymInt]],
dims: list[Union[int, torch.SymInt]],
block_dims: list[Union[int, torch.SymInt]],
element_size: Optional[int] = None,
):
key = (id(tensor), dims, block_dims, element_size)
@ -5653,8 +5649,8 @@ class TMADescriptor(ExternKernel):
def __init__(
self,
tensor: IRNode,
dims: List[Union[int, torch.SymInt]],
block_dims: List[Union[int, torch.SymInt]],
dims: list[Union[int, torch.SymInt]],
block_dims: list[Union[int, torch.SymInt]],
element_size: Optional[int] = None,
) -> None:
assert len(dims) in (1, 2)
@ -5707,8 +5703,8 @@ class UserDefinedTritonKernel(ExternKernel):
kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = []
restore_value_args: List[str] = []
reset_to_zero_args: List[str] = []
restore_value_args: list[str] = []
reset_to_zero_args: list[str] = []
if isinstance(kernel, Autotuner):
# https://github.com/triton-lang/triton/pull/5083
# changes kernel.restore_idx to kernel.restore_value
@ -5871,7 +5867,7 @@ class UserDefinedTritonKernel(ExternKernel):
]
V.graph.register_operation(self)
def get_outputs(self) -> List[Buffer]:
def get_outputs(self) -> list[Buffer]:
return list(self.mutation_outputs)
def get_device(self) -> Optional[torch.device]:
@ -6333,9 +6329,9 @@ class FallbackKernel(ExternKernelAlloc):
V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type]
# args that are aliased
self.alias_names: List[str] = []
self.alias_names: list[str] = []
# args that are mutated AND returned from the op
self.mutation_names: List[str] = []
self.mutation_names: list[str] = []
if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
# We assume here that HOPs with FallbackKernel are functional.
@ -6834,7 +6830,7 @@ class MultiOutput(ExternKernel):
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
)
def __init__(self, layout: OutputSpec, input, indices: List[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
super().__init__(None, layout, [input], ())
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
@ -6903,18 +6899,18 @@ class MutableBox(IRNode):
return self.data.freeze_layout()
def freeze_layout_with_stride_order(
self, order: List[int], allow_padding: bool = False
self, order: list[int], allow_padding: bool = False
) -> None:
return self.data.freeze_layout_with_stride_order(order, allow_padding)
def freeze_layout_with_fill_order(self, order: List[int]) -> None:
def freeze_layout_with_fill_order(self, order: list[int]) -> None:
return self.data.freeze_layout_with_fill_order(order)
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None:
def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
return self.data.freeze_layout_with_same_order(stride)
def freeze_layout_with_exact_strides(
self, exact_strides: List[_IntLike], allow_padding: bool = False
self, exact_strides: list[_IntLike], allow_padding: bool = False
) -> None:
return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
@ -7119,11 +7115,11 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
@ir_dataclass(frozen=False)
class InvokeSubgraph(ExternKernel):
subgraph: Optional[Subgraph] = None
operands: Optional[List[TensorBox]] = None
outputs: Optional[List[MultiOutput]] = None
operands: Optional[list[TensorBox]] = None
outputs: Optional[list[MultiOutput]] = None
def __init__(
self, subgraph: Subgraph, operands: List[TensorBox], layout: MultiOutputLayout
self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout
) -> None:
super().__init__(
name=None,
@ -7212,15 +7208,15 @@ class InvokeSubgraph(ExternKernel):
@ir_dataclass(frozen=False)
class Conditional(ExternKernel):
predicate: Optional[IRNode] = None
operands: Optional[List[TensorBox]] = None
operands: Optional[list[TensorBox]] = None
true_subgraph: Optional[Subgraph] = None
false_subgraph: Optional[Subgraph] = None
outputs: Optional[List[MultiOutput]] = None
outputs: Optional[list[MultiOutput]] = None
def __init__(
self,
predicate: IRNode,
operands: List[TensorBox],
operands: list[TensorBox],
true_subgraph: Subgraph,
false_subgraph: Subgraph,
layout: MultiOutputLayout,
@ -7250,7 +7246,7 @@ class Conditional(ExternKernel):
predicate: TensorBox,
true_fn: Subgraph,
false_fn: Subgraph,
operands: List[TensorBox],
operands: list[TensorBox],
):
predicate = cls.realize_input(predicate)
operands = [cls.realize_input(x) for x in operands]
@ -7332,16 +7328,16 @@ class Conditional(ExternKernel):
@ir_dataclass(frozen=False)
class WhileLoop(ExternKernel):
carried_inputs: Optional[List[TensorBox]] = None
additional_inputs: Optional[List[TensorBox]] = None
carried_inputs: Optional[list[TensorBox]] = None
additional_inputs: Optional[list[TensorBox]] = None
cond_subgraph: Optional[Subgraph] = None
body_subgraph: Optional[Subgraph] = None
outputs: Optional[List[MultiOutput]] = None
outputs: Optional[list[MultiOutput]] = None
def __init__(
self,
carried_inputs: List[TensorBox],
additional_inputs: List[TensorBox],
carried_inputs: list[TensorBox],
additional_inputs: list[TensorBox],
cond_subgraph: Subgraph,
body_subgraph: Subgraph,
layout: MultiOutputLayout,
@ -7365,8 +7361,8 @@ class WhileLoop(ExternKernel):
cls,
cond_fn: Subgraph,
body_fn: Subgraph,
carried_inputs: List[TensorBox],
additional_inputs: List[TensorBox],
carried_inputs: list[TensorBox],
additional_inputs: list[TensorBox],
):
carried_inputs = [cls.realize_input(x) for x in carried_inputs]
additional_inputs = [cls.realize_input(x) for x in additional_inputs]
@ -7411,8 +7407,8 @@ class WhileLoop(ExternKernel):
for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)):
def _guard_list_equals(
lhs_exprs: List[Union[int, sympy.expr]],
rhs_exprs: List[Union[int, sympy.expr]],
lhs_exprs: list[Union[int, sympy.expr]],
rhs_exprs: list[Union[int, sympy.expr]],
) -> None:
for lhs, rhs in zip(lhs_exprs, rhs_exprs):
V.graph.sizevars.guard_equals(lhs, rhs)
@ -7549,7 +7545,7 @@ class _CollectiveKernel(FallbackKernel):
# mutation of the input buffers.
@classmethod
def create_inplace( # type: ignore[no-untyped-def]
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
) -> None:
with V.graph.fake_mode:
(
@ -7610,7 +7606,7 @@ class _CollectiveKernel(FallbackKernel):
# usage in the user program.
@classmethod
def create_out_of_place( # type: ignore[no-untyped-def]
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
):
with V.graph.fake_mode:
(