PEP585 update - torch/_inductor/[_-i]* (#145137)

See #145101 for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137
Approved by: https://github.com/bobrenjc93
This commit is contained in:
Aaron Orenstein
2025-01-18 09:28:55 -08:00
committed by PyTorch MergeBot
parent cede43e06b
commit 893ca1dfe1
36 changed files with 727 additions and 765 deletions

View File

@ -28,8 +28,8 @@ log = logging.getLogger(__name__)
def compile( def compile(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
example_inputs: List[InputType], example_inputs: list[InputType],
options: Optional[Dict[str, Any]] = None, options: Optional[dict[str, Any]] = None,
): ):
""" """
Compile a given FX graph with TorchInductor. This allows compiling Compile a given FX graph with TorchInductor. This allows compiling
@ -54,7 +54,7 @@ def aoti_compile_and_package(
_deprecated_unused_kwargs=None, _deprecated_unused_kwargs=None,
*, *,
package_path: Optional[Union[str, io.BytesIO]] = None, package_path: Optional[Union[str, io.BytesIO]] = None,
inductor_configs: Optional[Dict[str, Any]] = None, inductor_configs: Optional[dict[str, Any]] = None,
) -> str: ) -> str:
""" """
Compiles the exported program with AOTInductor, and packages it into a .pt2 Compiles the exported program with AOTInductor, and packages it into a .pt2
@ -131,11 +131,11 @@ def _aoti_compile_and_package_inner(
gm: torch.nn.Module, gm: torch.nn.Module,
# flat_example_inputs: List[Any], # flat_example_inputs: List[Any],
args: tuple[Any], args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
*, *,
load_and_run: bool = False, load_and_run: bool = False,
package_path: Optional[Union[str, io.BytesIO]] = None, package_path: Optional[Union[str, io.BytesIO]] = None,
inductor_configs: Optional[Dict[str, Any]] = None, inductor_configs: Optional[dict[str, Any]] = None,
): ):
""" """
See docstring for aoti_compile_and_package. See docstring for aoti_compile_and_package.
@ -199,10 +199,10 @@ def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type
def aot_compile( def aot_compile(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
args: tuple[Any], args: tuple[Any],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
*, *,
options: Optional[Dict[str, Any]] = None, options: Optional[dict[str, Any]] = None,
) -> Union[str, List[str]]: ) -> Union[str, list[str]]:
""" """
Ahead-of-time compile a given FX graph with TorchInductor into a shared library. Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
@ -232,7 +232,7 @@ def aot_compile(
def list_mode_options( def list_mode_options(
mode: Optional[str] = None, dynamic: Optional[bool] = None mode: Optional[str] = None, dynamic: Optional[bool] = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
r"""Returns a dictionary describing the optimizations that each of the available r"""Returns a dictionary describing the optimizations that each of the available
modes passed to `torch.compile()` performs. modes passed to `torch.compile()` performs.
@ -245,7 +245,7 @@ def list_mode_options(
>>> torch._inductor.list_mode_options() >>> torch._inductor.list_mode_options()
""" """
mode_options: Dict[str, Dict[str, bool]] = { mode_options: dict[str, dict[str, bool]] = {
"default": {}, "default": {},
# enable cudagraphs # enable cudagraphs
"reduce-overhead": { "reduce-overhead": {
@ -267,7 +267,7 @@ def list_mode_options(
return mode_options[mode] if mode else mode_options # type: ignore[return-value] return mode_options[mode] if mode else mode_options # type: ignore[return-value]
def list_options() -> List[str]: def list_options() -> list[str]:
r"""Returns a dictionary describing the optimizations and debug configurations r"""Returns a dictionary describing the optimizations and debug configurations
that are available to `torch.compile()`. that are available to `torch.compile()`.
@ -280,7 +280,7 @@ def list_options() -> List[str]:
from torch._inductor import config from torch._inductor import config
current_config: Dict[str, Any] = config.get_config_copy() current_config: dict[str, Any] = config.get_config_copy()
return list(current_config.keys()) return list(current_config.keys())

View File

@ -2,7 +2,7 @@ import json
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
from unittest import mock from unittest import mock
import torch import torch
@ -31,7 +31,7 @@ def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
def load_aoti_eager_cache( def load_aoti_eager_cache(
ns: str, op_func_name_with_overload: str, device_type: str ns: str, op_func_name_with_overload: str, device_type: str
) -> List[Optional[Dict[str, Any]]]: ) -> list[Optional[dict[str, Any]]]:
device_kernel_cache = aoti_eager_cache_dir(ns, device_type) device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json" op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
if not op_conf.exists(): if not op_conf.exists():
@ -81,7 +81,7 @@ def load_aoti_eager_cache(
return [] return []
def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]: def supported_builtin_dtype_torch_dtype() -> dict[type, torch.dtype]:
return {int: torch.int32, float: torch.float, bool: torch.bool} return {int: torch.int32, float: torch.float, bool: torch.bool}
@ -90,8 +90,8 @@ def supported_scalar_types() -> tuple[type, ...]:
return tuple(type_to_torch_dtype.keys()) return tuple(type_to_torch_dtype.keys())
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]: def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> dict[str, Any]:
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["is_dynamic"] = dynamic metadata["is_dynamic"] = dynamic
assert isinstance(input, torch.Tensor) assert isinstance(input, torch.Tensor)
@ -110,21 +110,21 @@ def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any
def extract_tensor_list_metadata( def extract_tensor_list_metadata(
dynamic: bool, dynamic: bool,
input: List[torch.Tensor], input: list[torch.Tensor],
) -> Dict[str, Any]: ) -> dict[str, Any]:
metadata_list = [] metadata_list = []
for item in input: for item in input:
assert isinstance(item, torch.Tensor) assert isinstance(item, torch.Tensor)
metadata_list.append(extract_tensor_metadata(dynamic, item)) metadata_list.append(extract_tensor_metadata(dynamic, item))
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["tensor_list"] = metadata_list metadata["tensor_list"] = metadata_list
return metadata return metadata
def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]: def extract_scalar_metadata(device_type: str, input: Any) -> dict[str, Any]:
assert isinstance(input, supported_scalar_types()) assert isinstance(input, supported_scalar_types())
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["is_dynamic"] = False metadata["is_dynamic"] = False
# Scalar tensor # Scalar tensor
metadata["device_type"] = device_type metadata["device_type"] = device_type
@ -135,31 +135,31 @@ def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
return metadata return metadata
def extract_string_metadata(input: str) -> Dict[str, Any]: def extract_string_metadata(input: str) -> dict[str, Any]:
assert isinstance(input, str) assert isinstance(input, str)
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["string_value"] = input metadata["string_value"] = input
return metadata return metadata
def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]: def extract_dtype_metadata(input: torch.dtype) -> dict[str, Any]:
assert isinstance(input, torch.dtype) assert isinstance(input, torch.dtype)
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["dtype_value"] = f"{input}" metadata["dtype_value"] = f"{input}"
return metadata return metadata
def extract_device_metadata(input: torch.device) -> Dict[str, Any]: def extract_device_metadata(input: torch.device) -> dict[str, Any]:
assert isinstance(input, torch.device) assert isinstance(input, torch.device)
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["device_type_value"] = f"{input.type}" metadata["device_type_value"] = f"{input.type}"
metadata["device_index_value"] = input.index metadata["device_index_value"] = input.index
return metadata return metadata
def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]: def extract_layout_metadata(input: torch.layout) -> dict[str, Any]:
assert isinstance(input, torch.layout) assert isinstance(input, torch.layout)
metadata: Dict[str, Any] = {} metadata: dict[str, Any] = {}
metadata["layout_value"] = f"{input}" metadata["layout_value"] = f"{input}"
return metadata return metadata
@ -171,10 +171,10 @@ def aoti_compile_with_persistent_cache(
dynamic: bool, dynamic: bool,
f: Callable[..., Any], f: Callable[..., Any],
args: tuple[Any], args: tuple[Any],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
*, *,
dynamic_shapes: Optional[Dict[str, Any]] = None, dynamic_shapes: Optional[dict[str, Any]] = None,
options: Optional[Dict[str, Any]] = None, options: Optional[dict[str, Any]] = None,
remove_runtime_assertions: bool = False, remove_runtime_assertions: bool = False,
disable_constraint_solver: bool = False, disable_constraint_solver: bool = False,
) -> str: ) -> str:
@ -261,7 +261,7 @@ def aoti_compile_with_persistent_cache(
metadata["arg_order"] = idx metadata["arg_order"] = idx
kernel_metadata_items.append(metadata) kernel_metadata_items.append(metadata)
kernel_meta_info: Dict[str, Any] = {} kernel_meta_info: dict[str, Any] = {}
kernel_meta_info["meta_info"] = kernel_metadata_items kernel_meta_info["meta_info"] = kernel_metadata_items
kernel_meta_info["kernel_path"] = ( kernel_meta_info["kernel_path"] = (
Path(kernel_lib_path).relative_to(persistent_cache).as_posix() Path(kernel_lib_path).relative_to(persistent_cache).as_posix()

View File

@ -11,7 +11,7 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool from concurrent.futures.process import BrokenProcessPool
from functools import partial from functools import partial
from time import time from time import time
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING from typing import Any, Callable, Optional, TYPE_CHECKING
import torch import torch
from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.device_interface import get_registered_device_interfaces
@ -250,7 +250,7 @@ class AsyncCompile:
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit) get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
return LambdaFuture(lambda: get_result().kernel) return LambdaFuture(lambda: get_result().kernel)
def cpp_pybinding(self, argtypes: List[str], source_code: str): def cpp_pybinding(self, argtypes: list[str], source_code: str):
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code) kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
if get_compile_threads() <= 1: if get_compile_threads() <= 1:
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code) return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
@ -299,7 +299,7 @@ class AsyncCompile:
) )
return LambdaFuture(get_result) return LambdaFuture(get_result)
def wait(self, scope: Dict[str, Any]) -> None: def wait(self, scope: dict[str, Any]) -> None:
with dynamo_timed( with dynamo_timed(
"async_compile.wait", "async_compile.wait",
log_pt2_compile_event=True, log_pt2_compile_event=True,

View File

@ -1,7 +1,7 @@
import json import json
import os import os
from functools import partial from functools import partial
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import torch import torch
from torch._inductor.autoheuristic.autoheuristic_utils import ( from torch._inductor.autoheuristic.autoheuristic_utils import (
@ -50,16 +50,16 @@ class AutoHeuristic:
a heuristic (see torchgen/autoheuristic/). a heuristic (see torchgen/autoheuristic/).
""" """
collected_feedback: Dict[Choice, Feedback] collected_feedback: dict[Choice, Feedback]
def __init__( def __init__(
self, self,
fallback: Callable[[], Choice], fallback: Callable[[], Choice],
choices: List[Choice], choices: list[Choice],
feedback: Optional[LocalFeedback], feedback: Optional[LocalFeedback],
context: AHContext, context: AHContext,
name: str, name: str,
augment_context: Optional[List[AHOperation]] = None, augment_context: Optional[list[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None: ) -> None:
""" """
@ -135,8 +135,8 @@ class AutoHeuristic:
return self.fallback() return self.fallback()
def get_top_k_choices( def get_top_k_choices(
self, top_k: int, always_included: Optional[List[str]] = None self, top_k: int, always_included: Optional[list[str]] = None
) -> Optional[List[Choice]]: ) -> Optional[list[Choice]]:
if not self.satisfies_precondition(): if not self.satisfies_precondition():
return None return None
if torch._inductor.config.use_autoheuristic(self.name): if torch._inductor.config.use_autoheuristic(self.name):
@ -223,11 +223,11 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
def __init__( def __init__(
self, self,
fallback: Callable[[], Optional[ChoiceCaller]], fallback: Callable[[], Optional[ChoiceCaller]],
choices: List[ChoiceCaller], choices: list[ChoiceCaller],
input_nodes: List[Any], input_nodes: list[Any],
context: AHContext, context: AHContext,
name: str, name: str,
augment_context: Optional[List[AHOperation]] = None, augment_context: Optional[list[AHOperation]] = None,
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None, precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
) -> None: ) -> None:
""" """
@ -237,7 +237,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
have to be used here. have to be used here.
""" """
self.input_nodes = input_nodes self.input_nodes = input_nodes
self.choicestr2choice: Dict[str, ChoiceCaller] = {} self.choicestr2choice: dict[str, ChoiceCaller] = {}
for choice in choices: for choice in choices:
self.choicestr2choice[choice.autoheuristic_id()] = choice self.choicestr2choice[choice.autoheuristic_id()] = choice
choices_str = list(self.choicestr2choice.keys()) choices_str = list(self.choicestr2choice.keys())
@ -266,7 +266,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
self.register_global_feedback(input_nodes, choices) self.register_global_feedback(input_nodes, choices)
def register_global_feedback( def register_global_feedback(
self, input_nodes: List[Any], choices: List[ChoiceCaller] self, input_nodes: list[Any], choices: list[ChoiceCaller]
) -> None: ) -> None:
""" """
Registers a callback in select_algorithm, which is called with the timing of each choice. Registers a callback in select_algorithm, which is called with the timing of each choice.
@ -281,10 +281,10 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
def store_global_feedback( def store_global_feedback(
ah_inputs_key: str, ah_inputs_key: str,
ah_precompile_key: str, ah_precompile_key: str,
timings: Dict[ChoiceCaller, float], timings: dict[ChoiceCaller, float],
name: str, name: str,
input_nodes: List[Any], input_nodes: list[Any],
choices: List[ChoiceCaller], choices: list[ChoiceCaller],
) -> None: ) -> None:
current_inputs_key = create_inputs_key(input_nodes) current_inputs_key = create_inputs_key(input_nodes)
if current_inputs_key != ah_inputs_key: if current_inputs_key != ah_inputs_key:
@ -307,8 +307,8 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
return self.choicestr2choice.get(choice, None) return self.choicestr2choice.get(choice, None)
def get_top_k_choices_caller( def get_top_k_choices_caller(
self, top_k: int, always_included: Optional[List[str]] = None self, top_k: int, always_included: Optional[list[str]] = None
) -> Optional[List[ChoiceCaller]]: ) -> Optional[list[ChoiceCaller]]:
choices = self.get_top_k_choices(top_k, always_included) choices = self.get_top_k_choices(top_k, always_included)
if choices is None: if choices is None:
return None return None

View File

@ -1,5 +1,5 @@
import functools import functools
from typing import Any, Callable, Dict, List from typing import Any, Callable
import torch import torch
@ -51,8 +51,8 @@ class AHContext:
information that will help to learn a heuristic. information that will help to learn a heuristic.
""" """
features: List[AHFeature] features: list[AHFeature]
context_dict: Dict[str, Value] context_dict: dict[str, Value]
def __init__(self) -> None: def __init__(self) -> None:
self.features = [] self.features = []
@ -64,7 +64,7 @@ class AHContext:
self.features.append(AHFeature(name, value, is_categorical=is_categorical)) self.features.append(AHFeature(name, value, is_categorical=is_categorical))
self.context_dict[name] = value self.context_dict[name] = value
def get_numerical_and_categorical_features(self) -> tuple[List[str], List[str]]: def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
numerical_features = [] numerical_features = []
categorical_features = [] categorical_features = []
for feature in self.features: for feature in self.features:
@ -84,7 +84,7 @@ class AHContext:
def get_value(self, name: str) -> Value: def get_value(self, name: str) -> Value:
return self.context_dict[name] return self.context_dict[name]
def apply_operations(self, operations: List[AHOperation]) -> None: def apply_operations(self, operations: list[AHOperation]) -> None:
for op in operations: for op in operations:
op.apply_operation(self.context_dict) op.apply_operation(self.context_dict)
@ -94,7 +94,7 @@ class AHMetadata:
self, self,
shared_memory: Any, shared_memory: Any,
device_capa: tuple[int, int], device_capa: tuple[int, int],
choices: List[Choice], choices: list[Choice],
name: str, name: str,
) -> None: ) -> None:
# use amount of shared_memory and device_capability to identify GPU # use amount of shared_memory and device_capability to identify GPU
@ -104,7 +104,7 @@ class AHMetadata:
self.choices = choices self.choices = choices
self.name = name self.name = name
def to_dict(self) -> Dict[str, Value]: def to_dict(self) -> dict[str, Value]:
return { return {
"shared_memory": self.shared_memory, "shared_memory": self.shared_memory,
"device_capa": self.device_capa, "device_capa": self.device_capa,
@ -147,7 +147,7 @@ def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
return mat1_iscontig and not mat2_iscontig return mat1_iscontig and not mat2_iscontig
def get_mult_dims_ops() -> List[AHOperation]: def get_mult_dims_ops() -> list[AHOperation]:
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"]) m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"]) m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"]) k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
@ -163,7 +163,7 @@ def get_arith_intensity(data: Any) -> float:
return m * k * n / (m * k + k * n + m * n) return m * k * n / (m * k + k * n + m * n)
def pad_mm_operations() -> List[AHOperation]: def pad_mm_operations() -> list[AHOperation]:
mult_dims_ops = get_mult_dims_ops() mult_dims_ops = get_mult_dims_ops()
k_div_m_times_n_op = AHOperation( k_div_m_times_n_op = AHOperation(
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"]) "k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
@ -200,7 +200,7 @@ def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
return data[dim] >= lower and data[dim] <= upper return data[dim] >= lower and data[dim] <= upper
def between_ops() -> List[AHOperation]: def between_ops() -> list[AHOperation]:
dims = ["m", "k", "n"] dims = ["m", "k", "n"]
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)] limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
ah_operations = [] ah_operations = []
@ -221,13 +221,13 @@ def pow2_op(data: Any, dim: str, exponent: int) -> bool:
return data[dim] == 2**exponent return data[dim] == 2**exponent
def mm_operations() -> List[AHOperation]: def mm_operations() -> list[AHOperation]:
mult_dims_ops = get_mult_dims_ops() mult_dims_ops = get_mult_dims_ops()
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity) arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
return mult_dims_ops + [arith_intensity_op] return mult_dims_ops + [arith_intensity_op]
def mixed_mm_operations() -> List[AHOperation]: def mixed_mm_operations() -> list[AHOperation]:
return mm_operations() + between_ops() return mm_operations() + between_ops()
@ -235,7 +235,7 @@ def is_multiple(data: Any, dim: str, mult: int) -> bool:
return data[dim] % mult == 0 return data[dim] % mult == 0
def get_dims_multiple_ops() -> List[AHOperation]: def get_dims_multiple_ops() -> list[AHOperation]:
multiples = [2, 4, 8, 16, 32] multiples = [2, 4, 8, 16, 32]
dims = ["m", "k", "n"] dims = ["m", "k", "n"]
dims_multiple_ops = [] dims_multiple_ops = []
@ -249,7 +249,7 @@ def get_dims_multiple_ops() -> List[AHOperation]:
return dims_multiple_ops return dims_multiple_ops
def get_dims_need_padding_ops() -> List[AHOperation]: def get_dims_need_padding_ops() -> list[AHOperation]:
def mat1_innermost_needs_padding_fn(data: Any) -> bool: def mat1_innermost_needs_padding_fn(data: Any) -> bool:
mat1_stride_0 = data["mat1_stride_0"] mat1_stride_0 = data["mat1_stride_0"]
mat1_stride_1 = data["mat1_stride_1"] mat1_stride_1 = data["mat1_stride_1"]
@ -303,7 +303,7 @@ def get_dims_need_padding_ops() -> List[AHOperation]:
return [mat1_innermost_op, mat2_innermost_op, num_dims_op] return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
def get_is_contig_ops() -> List[AHOperation]: def get_is_contig_ops() -> list[AHOperation]:
def mat1_is_contig_fn(data: Any) -> bool: def mat1_is_contig_fn(data: Any) -> bool:
stride_0 = data["mat1_stride_0"] stride_0 = data["mat1_stride_0"]
stride_1 = data["mat1_stride_1"] stride_1 = data["mat1_stride_1"]

View File

@ -2,7 +2,7 @@ import importlib
import inspect import inspect
import pkgutil import pkgutil
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, Optional from typing import Any, Optional
from torch._inductor.autoheuristic.autoheuristic_utils import ( from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext, AHContext,
@ -14,7 +14,7 @@ from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeur
def find_and_instantiate_subclasses( def find_and_instantiate_subclasses(
package_name: str, base_class: Any package_name: str, base_class: Any
) -> List[LearnedHeuristic]: ) -> list[LearnedHeuristic]:
instances = [] instances = []
package = importlib.import_module(package_name) package = importlib.import_module(package_name)
@ -49,7 +49,7 @@ class LearnedHeuristicController:
a way to get the decision of a learned heuristic. a way to get the decision of a learned heuristic.
""" """
existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list) existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
""" """
A dictionary that stores all the learned heuristics for each optimization. A dictionary that stores all the learned heuristics for each optimization.
The key is the optimization name, and the value is a list of LearnedHeuristic objects. The key is the optimization name, and the value is a list of LearnedHeuristic objects.
@ -69,7 +69,7 @@ class LearnedHeuristicController:
self.metadata = metadata self.metadata = metadata
self.context = context self.context = context
def get_heuristics(self, name: str) -> List[LearnedHeuristic]: def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
""" """
Returns a list of learned heuristics for the given optimization name. Returns a list of learned heuristics for the given optimization name.
""" """
@ -105,7 +105,7 @@ class LearnedHeuristicController:
return heuristic.get_decision(self.context, self.metadata.choices) return heuristic.get_decision(self.context, self.metadata.choices)
return None return None
def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]: def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
heuristics = self.get_heuristics(self.metadata.name) heuristics = self.get_heuristics(self.metadata.name)
for heuristic in heuristics: for heuristic in heuristics:
if heuristic.check_precondition(self.metadata, self.context): if heuristic.check_precondition(self.metadata, self.context):

View File

@ -1,4 +1,4 @@
from typing import List, Optional from typing import Optional
from torch._inductor.autoheuristic.autoheuristic_utils import ( from torch._inductor.autoheuristic.autoheuristic_utils import (
AHContext, AHContext,
@ -23,7 +23,7 @@ class LearnedHeuristic:
return True return True
def get_decision( def get_decision(
self, context: AHContext, choices: List[Choice] self, context: AHContext, choices: list[Choice]
) -> Optional[Choice]: ) -> Optional[Choice]:
return None return None
@ -33,7 +33,7 @@ class LearnedHeuristic:
def get_name(self) -> str: def get_name(self) -> str:
return "" return ""
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]: def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
return None return None
@ -45,7 +45,7 @@ class LearnedHeuristicRegression(LearnedHeuristic):
return 1.0 return 1.0
def get_decision( def get_decision(
self, context: AHContext, choices: List[Choice] self, context: AHContext, choices: list[Choice]
) -> Optional[Choice]: ) -> Optional[Choice]:
choice2feedback = {} choice2feedback = {}
for choice in choices: for choice in choices:
@ -68,7 +68,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
return None return None
def get_decision( def get_decision(
self, context: AHContext, choices: List[Choice] self, context: AHContext, choices: list[Choice]
) -> Optional[Choice]: ) -> Optional[Choice]:
best_choices = self.get_best_choices(context) best_choices = self.get_best_choices(context)
if not best_choices: if not best_choices:
@ -78,7 +78,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
return None return None
return self.get_choice(best_choice_idx) return self.get_choice(best_choice_idx)
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]: def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
feedback_idx_list = self.get_best_choices(context) feedback_idx_list = self.get_best_choices(context)
if feedback_idx_list is None: if feedback_idx_list is None:
return None return None
@ -88,5 +88,5 @@ class LearnedHeuristicDecision(LearnedHeuristic):
choices = [choice for choice in choices if choice is not None] choices = [choice for choice in choices if choice is not None]
return choices return choices
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]: def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
return [] return []

View File

@ -10,19 +10,10 @@ import os
import queue import queue
import time import time
import warnings import warnings
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from ctypes import byref, c_size_t, c_void_p, CDLL from ctypes import byref, c_size_t, c_void_p, CDLL
from typing import ( from typing import Any, Callable, Optional, TYPE_CHECKING, Union
Any,
Callable,
Dict,
Iterable,
List,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
import torch import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
@ -396,8 +387,8 @@ class TuningProcessPool:
def benchmark( def benchmark(
self, self,
choices: List[TritonTemplateCaller], choices: list[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]: ) -> dict[TritonTemplateCaller, float]:
""" """
Benchmark each choice in a separate process. Benchmark each choice in a separate process.
""" """
@ -432,9 +423,9 @@ class TensorMeta:
@classmethod @classmethod
def from_irnodes( def from_irnodes(
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]] cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
) -> Union[TensorMeta, List[TensorMeta]]: ) -> Union[TensorMeta, list[TensorMeta]]:
if isinstance(irnodes, Sequence): if isinstance(irnodes, Sequence):
result: List[Any] = [cls.from_irnodes(x) for x in irnodes] result: list[Any] = [cls.from_irnodes(x) for x in irnodes]
assert all(isinstance(x, TensorMeta) for x in result) assert all(isinstance(x, TensorMeta) for x in result)
return result return result
@ -488,8 +479,8 @@ class BenchmarkRequest:
def __init__( def __init__(
self, self,
kernel_name: str, kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]], input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any], extra_args: Iterable[Any],
) -> None: ) -> None:
# the kernel name defined in the module # the kernel name defined in the module
@ -640,12 +631,12 @@ class TritonBenchmarkRequest(BenchmarkRequest):
def __init__( def __init__(
self, self,
kernel_name: str, kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]], input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any], extra_args: Iterable[Any],
module_path: str, # the path of the module defining the triton kernel module_path: str, # the path of the module defining the triton kernel
module_cache_key: str, module_cache_key: str,
grid: List[int], grid: list[int],
num_stages: int, num_stages: int,
num_warps: int, num_warps: int,
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction. matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
@ -770,8 +761,8 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
def __init__( def __init__(
self, self,
kernel_name: str, kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]], input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any], extra_args: Iterable[Any],
source_code: str, source_code: str,
) -> None: ) -> None:
@ -889,8 +880,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
def __init__( def __init__(
self, self,
kernel_name: str, kernel_name: str,
input_tensor_meta: Union[TensorMeta, List[TensorMeta]], input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
output_tensor_meta: Union[TensorMeta, List[TensorMeta]], output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
extra_args: Iterable[Any], extra_args: Iterable[Any],
source_code: str, source_code: str,
) -> None: ) -> None:
@ -946,8 +937,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
def benchmark_in_sub_process( def benchmark_in_sub_process(
choices: List[TritonTemplateCaller], choices: list[TritonTemplateCaller],
) -> Dict[TritonTemplateCaller, float]: ) -> dict[TritonTemplateCaller, float]:
""" """
Do benchmarking in a subprocess and return the perf number (latency). Do benchmarking in a subprocess and return the perf number (latency).
""" """

View File

@ -1,7 +1,7 @@
import logging import logging
import operator import operator
from functools import partial from functools import partial
from typing import Any, Callable, Dict, Union from typing import Any, Callable, Union
from sympy import Expr from sympy import Expr
@ -43,7 +43,7 @@ class BoundVars:
or "masked_subblock" in node.target or "masked_subblock" in node.target
) )
# To access this variable call `get_bounds()` # To access this variable call `get_bounds()`
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {} self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
def __repr__(self) -> str: def __repr__(self) -> str:
return ( return (
@ -55,7 +55,7 @@ class BoundVars:
) )
@cache_on_self @cache_on_self
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]: def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
submodules = self.swap_submodules(self.loop_body.submodules) submodules = self.swap_submodules(self.loop_body.submodules)
# Initialize the environment with the unbounded variables # Initialize the environment with the unbounded variables
@ -74,9 +74,9 @@ class BoundVars:
return self._bounds return self._bounds
def swap_submodules( def swap_submodules(
self, submodules: Dict[str, Callable[..., Any]] self, submodules: dict[str, Callable[..., Any]]
) -> Dict[str, Callable[..., ValueRanges[Expr]]]: ) -> dict[str, Callable[..., ValueRanges[Expr]]]:
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {} result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
for key in submodules.keys(): for key in submodules.keys():
if key == "get_index": if key == "get_index":
result[key] = self.get_index result[key] = self.get_index
@ -111,10 +111,10 @@ class BoundVars:
def masked_subblock( def masked_subblock(
self, self,
subblock: LoopBodyBlock, subblock: LoopBodyBlock,
env: Dict[torch.fx.Node, ValueRanges[Expr]], env: dict[torch.fx.Node, ValueRanges[Expr]],
mask: Any, mask: Any,
value: Any, value: Any,
submodules: Dict[str, Callable[..., Any]], submodules: dict[str, Callable[..., Any]],
) -> ValueRanges[Expr]: ) -> ValueRanges[Expr]:
interp = InterpreterShim(subblock.graph, submodules) interp = InterpreterShim(subblock.graph, submodules)
interp.run(V.get_ops_handler(), initial_env=env) interp.run(V.get_ops_handler(), initial_env=env)

View File

@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
import typing import typing
from typing import Any, Dict, List, Type, TYPE_CHECKING from typing import Any, TYPE_CHECKING
import sympy import sympy
@ -42,11 +42,11 @@ class InductorChoices:
def triton_kernel_kwargs( def triton_kernel_kwargs(
self, self,
kernel_cls: Type[TritonKernel], kernel_cls: type[TritonKernel],
features: SIMDKernelFeatures, features: SIMDKernelFeatures,
groups: List[sympy.Expr], groups: list[sympy.Expr],
kernel_kwargs: Dict[str, Any], kernel_kwargs: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations""" """Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
return kernel_kwargs return kernel_kwargs

View File

@ -36,13 +36,8 @@ from typing import (
Any, Any,
Callable, Callable,
cast, cast,
Dict,
Generator,
List,
NoReturn, NoReturn,
Optional, Optional,
Sequence,
Tuple,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -127,7 +122,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import KeysView from collections.abc import Generator, KeysView, Sequence
from concurrent.futures import Future from concurrent.futures import Future
from .compile_fx import _CompileFxKwargs, CompiledFxGraph from .compile_fx import _CompileFxKwargs, CompiledFxGraph
@ -168,7 +163,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
class CacheBase: class CacheBase:
@staticmethod @staticmethod
@functools.lru_cache(None) @functools.lru_cache(None)
def get_system() -> Dict[str, Any]: def get_system() -> dict[str, Any]:
try: try:
from triton.compiler.compiler import triton_key from triton.compiler.compiler import triton_key
@ -179,7 +174,7 @@ class CacheBase:
triton_version = None triton_version = None
try: try:
system: Dict[str, Any] = { system: dict[str, Any] = {
"device": {"name": None}, "device": {"name": None},
"version": { "version": {
"triton": triton_version, "triton": triton_version,
@ -217,7 +212,7 @@ class CacheBase:
def __init__(self) -> None: def __init__(self) -> None:
self.system = CacheBase.get_system() self.system = CacheBase.get_system()
def get_local_cache(self) -> Dict[str, Any]: def get_local_cache(self) -> dict[str, Any]:
local_cache_path = self.get_local_cache_path() local_cache_path = self.get_local_cache_path()
if not local_cache_path.is_file(): if not local_cache_path.is_file():
return {} return {}
@ -225,7 +220,7 @@ class CacheBase:
local_cache = json.load(local_cache_fp) local_cache = json.load(local_cache_fp)
return local_cache["cache"] return local_cache["cache"]
def update_local_cache(self, local_cache: Dict[str, Any]) -> None: def update_local_cache(self, local_cache: dict[str, Any]) -> None:
local_cache_path = self.get_local_cache_path() local_cache_path = self.get_local_cache_path()
write_atomic( write_atomic(
str(local_cache_path), str(local_cache_path),
@ -235,7 +230,7 @@ class CacheBase:
class LocalCache(CacheBase): class LocalCache(CacheBase):
def lookup(self, *keys: str) -> Optional[Dict[str, Any]]: def lookup(self, *keys: str) -> Optional[dict[str, Any]]:
cache = self.get_local_cache() cache = self.get_local_cache()
sub_cache = cache sub_cache = cache
@ -261,7 +256,7 @@ class LocalCache(CacheBase):
class PersistentCache(CacheBase): class PersistentCache(CacheBase):
@functools.lru_cache(None) # noqa: B019 @functools.lru_cache(None) # noqa: B019
def get_global_cache(self) -> Dict[str, Any]: def get_global_cache(self) -> dict[str, Any]:
global_cache_path = self.get_global_cache_path() global_cache_path = self.get_global_cache_path()
if global_cache_path is None or not global_cache_path.is_file(): if global_cache_path is None or not global_cache_path.is_file():
return {} return {}
@ -271,11 +266,11 @@ class PersistentCache(CacheBase):
def lookup( def lookup(
self, self,
choices: List[ChoiceCaller], choices: list[ChoiceCaller],
op: str, op: str,
inputs: str, inputs: str,
benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]], benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
) -> Dict[ChoiceCaller, float]: ) -> dict[ChoiceCaller, float]:
""" """
Check to see if we have benchmarked the given choice callers. For each Check to see if we have benchmarked the given choice callers. For each
choice caller: choice caller:
@ -296,7 +291,7 @@ class PersistentCache(CacheBase):
) )
timings = {} timings = {}
def check_cache(cache: Dict[str, Any], callback: Any = None) -> bool: def check_cache(cache: dict[str, Any], callback: Any = None) -> bool:
"""Check if `cache` contains data for all the choices""" """Check if `cache` contains data for all the choices"""
hit = True hit = True
for choice in choices: for choice in choices:
@ -456,7 +451,7 @@ class TensorMetadataAndValues:
""" """
tensor_metadata: TensorMetadata tensor_metadata: TensorMetadata
values: List[Any] values: list[Any]
def _ident(x: T) -> T: def _ident(x: T) -> T:
@ -584,7 +579,7 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_graph_module( def _reduce_graph_module(
self, gm: torch.fx.GraphModule self, gm: torch.fx.GraphModule
) -> tuple[Any, tuple[Dict[str, Any], str]]: ) -> tuple[Any, tuple[dict[str, Any], str]]:
""" """
Custom reducer for graph module to handle irrelevant data for user Custom reducer for graph module to handle irrelevant data for user
defined triton kernels defined triton kernels
@ -624,7 +619,7 @@ class FxGraphCachePickler(pickle.Pickler):
serialized_data = self.dumps(obj) serialized_data = self.dumps(obj)
return sha256_hash(serialized_data) return sha256_hash(serialized_data)
def debug_lines(self, inp: FxGraphHashDetails) -> List[str]: def debug_lines(self, inp: FxGraphHashDetails) -> list[str]:
""" """
Get a printable string describing in more detail all the attributes Get a printable string describing in more detail all the attributes
comprising an object. Useful for debugging when one graph hashes comprising an object. Useful for debugging when one graph hashes
@ -659,7 +654,7 @@ class FxGraphCachePickler(pickle.Pickler):
def build_code_hash( def build_code_hash(
roots: List[str] | None, prefix: str, hasher: hashlib._Hash roots: list[str] | None, prefix: str, hasher: hashlib._Hash
) -> None: ) -> None:
for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name): for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
spec = lib.module_finder.find_spec(lib.name, None) spec = lib.module_finder.find_spec(lib.name, None)
@ -721,7 +716,7 @@ class OrderedSetHolder:
of set kwargs. of set kwargs.
""" """
items: List[Any] items: list[Any]
class BypassFxGraphCache(Exception): class BypassFxGraphCache(Exception):
@ -753,7 +748,7 @@ class FxGraphHashDetails:
# Order kwargs so hashing is stable to changes in kwarg order. Although # Order kwargs so hashing is stable to changes in kwarg order. Although
# it's technically a _CompileFxKwargs we don't actually need it typed as # it's technically a _CompileFxKwargs we don't actually need it typed as
# such since we're just using it to generate a hash. # such since we're just using it to generate a hash.
self.fx_kwargs: Dict[str, object] = {} self.fx_kwargs: dict[str, object] = {}
for k, v in sorted(fx_kwargs.items()): for k, v in sorted(fx_kwargs.items()):
if k not in self.EXCLUDED_KWARGS: if k not in self.EXCLUDED_KWARGS:
if type(v) in (set, OrderedSet): # noqa: set_linter if type(v) in (set, OrderedSet): # noqa: set_linter
@ -774,7 +769,7 @@ class FxGraphHashDetails:
# Node meta will not be part of gm's reduce function, so lets remember # Node meta will not be part of gm's reduce function, so lets remember
# the kernel source code separately # the kernel source code separately
self.user_defined_triton_source: List[Any] = [] self.user_defined_triton_source: list[Any] = []
if gm is not None: if gm is not None:
for module in gm.modules(): for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule): if not isinstance(module, torch.fx.GraphModule):
@ -856,7 +851,7 @@ def compiled_fx_graph_hash(
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
fx_kwargs: _CompileFxKwargs, fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int], inputs_to_check: Sequence[int],
) -> tuple[str, List[str]]: ) -> tuple[str, list[str]]:
""" """
Generate a unique hash of the FX graph for caching. Generate a unique hash of the FX graph for caching.
""" """
@ -952,7 +947,7 @@ class FxGraphCache:
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key) return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
@staticmethod @staticmethod
def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]: def _filter_backed_symints(inputs: Sequence[InputType]) -> list[torch.SymInt]:
""" """
Get the backed SymInt objects from the input list. Note that we can never Get the backed SymInt objects from the input list. Note that we can never
have guards that depend on unbacked symint. have guards that depend on unbacked symint.
@ -976,7 +971,7 @@ class FxGraphCache:
local: bool, local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]], remote_cache: Optional[RemoteCache[JsonDataTy]],
constants: CompiledFxGraphConstants, constants: CompiledFxGraphConstants,
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]: ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
""" """
Lookup a compiled graph in the cache by key. On a hit, return the Lookup a compiled graph in the cache by key. On a hit, return the
deserialized CompiledFxGraph object. On a miss, return None. deserialized CompiledFxGraph object. On a miss, return None.
@ -988,7 +983,7 @@ class FxGraphCache:
hints = [hint_int(s) for s in symints] hints = [hint_int(s) for s in symints]
def iterate_over_candidates() -> ( def iterate_over_candidates() -> (
Generator[Tuple[CompiledFxGraph, bytes], None, None] Generator[tuple[CompiledFxGraph, bytes], None, None]
): ):
if local: if local:
subdir = FxGraphCache._get_tmp_dir_for_key(key) subdir = FxGraphCache._get_tmp_dir_for_key(key)
@ -1021,7 +1016,7 @@ class FxGraphCache:
# their guards to determine whether there's a hit. # their guards to determine whether there's a hit.
graph = None graph = None
pickled_content = None pickled_content = None
cache_info: Dict[str, Any] = dict() cache_info: dict[str, Any] = dict()
for candidate, pickled_content in iterate_over_candidates(): for candidate, pickled_content in iterate_over_candidates():
if not candidate.guards_expr: if not candidate.guards_expr:
@ -1234,7 +1229,7 @@ class FxGraphCache:
fx_kwargs: _CompileFxKwargs, fx_kwargs: _CompileFxKwargs,
inputs_to_check: Sequence[int], inputs_to_check: Sequence[int],
remote: bool, remote: bool,
) -> tuple[Optional[tuple[str, List[str]]], Dict[str, Any]]: ) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
""" """
Checks that the inductor input is cacheable, then computes Checks that the inductor input is cacheable, then computes
and returns the cache key for the input. and returns the cache key for the input.
@ -1280,13 +1275,13 @@ class FxGraphCache:
@staticmethod @staticmethod
def load_with_key( def load_with_key(
key: str, key: str,
debug_lines: List[str], debug_lines: list[str],
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
local: bool, local: bool,
remote_cache: Optional[RemoteCache[JsonDataTy]], remote_cache: Optional[RemoteCache[JsonDataTy]],
is_backward: bool, is_backward: bool,
constants: CompiledFxGraphConstants, constants: CompiledFxGraphConstants,
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]: ) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
""" """
Lookup the graph with the given key, and return results and metadata. Lookup the graph with the given key, and return results and metadata.
Doesn't do any logging on its own, because AOTAutograd handles a cache miss Doesn't do any logging on its own, because AOTAutograd handles a cache miss
@ -1373,11 +1368,11 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
class CudaKernelParamCache: class CudaKernelParamCache:
cache: Dict[str, Dict[str, str]] = {} cache: dict[str, dict[str, str]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
@classmethod @classmethod
def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> None: def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
_, path = write( _, path = write(
cubin, cubin,
bin_type, bin_type,
@ -1391,7 +1386,7 @@ class CudaKernelParamCache:
cls.cache[key] = params cls.cache[key] = params
@classmethod @classmethod
def get(cls, key: str) -> Optional[Dict[str, str]]: def get(cls, key: str) -> Optional[dict[str, str]]:
return cls.cache.get(key, None) return cls.cache.get(key, None)
@classmethod @classmethod
@ -1407,8 +1402,8 @@ class AotCodeCompiler:
source_code: str, source_code: str,
serialized_extern_kernel_nodes: Optional[str], serialized_extern_kernel_nodes: Optional[str],
device_type: str, device_type: str,
additional_files: List[str], additional_files: list[str],
) -> Union[List[str], str]: ) -> Union[list[str], str]:
""" """
Returns the .so path, or returns a list of files that were generated if Returns the .so path, or returns a list of files that were generated if
config.aot_inductor.package=True. config.aot_inductor.package=True.
@ -1842,14 +1837,14 @@ def cpp_prefix() -> str:
# Given a path to an input cpp file and an output path, # Given a path to an input cpp file and an output path,
# Attempts to compile the file, storing the output in "output_path" # Attempts to compile the file, storing the output in "output_path"
def compile_file( def compile_file(
input_path: Union[str, List[str]], output_path: str, cmd: List[str] input_path: Union[str, list[str]], output_path: str, cmd: list[str]
) -> None: ) -> None:
with dynamo_timed("compile_file"): with dynamo_timed("compile_file"):
return _compile_file(input_path, output_path, cmd) return _compile_file(input_path, output_path, cmd)
def _compile_file( def _compile_file(
input_path: Union[str, List[str]], output_path: str, cmd: List[str] input_path: Union[str, list[str]], output_path: str, cmd: list[str]
) -> None: ) -> None:
input_paths = [input_path] if isinstance(input_path, str) else input_path input_paths = [input_path] if isinstance(input_path, str) else input_path
input_files = [ input_files = [
@ -1948,9 +1943,9 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p]:
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
class CppCodeCache: class CppCodeCache:
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags: Dict[str, Any] = {} cpp_compile_command_flags: dict[str, Any] = {}
@staticmethod @staticmethod
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]: def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
@ -2093,7 +2088,7 @@ def _worker_compile_cpp(
# Customized Python binding for cpp kernels # Customized Python binding for cpp kernels
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
class CppPythonBindingsCodeCache(CppCodeCache): class CppPythonBindingsCodeCache(CppCodeCache):
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = { cpp_compile_command_flags = {
# kernels have no dependency on libtorch # kernels have no dependency on libtorch
@ -2212,7 +2207,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
@classmethod @classmethod
def load_pybinding_async( def load_pybinding_async(
cls, cls,
argtypes: List[str], argtypes: list[str],
source_code: str, source_code: str,
device_type: str = "cpu", device_type: str = "cpu",
num_outputs: int = -1, num_outputs: int = -1,
@ -2269,7 +2264,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
class CppWrapperCodeCache(CppPythonBindingsCodeCache): class CppWrapperCodeCache(CppPythonBindingsCodeCache):
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {} cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = { cpp_compile_command_flags = {
"include_pytorch": True, "include_pytorch": True,
@ -2335,7 +2330,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
@clear_on_fresh_inductor_cache @clear_on_fresh_inductor_cache
class HalideCodeCache(CppPythonBindingsCodeCache): class HalideCodeCache(CppPythonBindingsCodeCache):
cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {} cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
_standalone_runtime_path: Optional[str] = None _standalone_runtime_path: Optional[str] = None
prefix = textwrap.dedent( prefix = textwrap.dedent(
@ -2412,7 +2407,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
) )
@classmethod @classmethod
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> List[str]: def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]:
assert arg.shape is not None assert arg.shape is not None
assert arg.stride is not None and len(arg.shape) == len(arg.stride) assert arg.stride is not None and len(arg.shape) == len(arg.stride)
assert arg.offset is not None assert arg.offset is not None
@ -2573,7 +2568,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
donefile = str(dirpath / "done") donefile = str(dirpath / "done")
lockfile = str(dirpath / "lock") lockfile = str(dirpath / "lock")
need_compile = not os.path.exists(donefile) need_compile = not os.path.exists(donefile)
jobs: List[Any] = [] jobs: list[Any] = []
if need_compile: if need_compile:
write_atomic(genfile, source_code) write_atomic(genfile, source_code)
cmd = [ cmd = [
@ -2685,7 +2680,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
return sofile return sofile
def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None: def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
from torch.utils._filelock import FileLock from torch.utils._filelock import FileLock
try: try:
@ -2733,8 +2728,8 @@ class PyCodeCache:
# clearing the cache. Note also that we may load the same path more # clearing the cache. Note also that we may load the same path more
# than once, but attach different attributes, i.e., due to different # than once, but attach different attributes, i.e., due to different
# constant values. # constant values.
modules: List[ModuleType] = [] modules: list[ModuleType] = []
linemaps: Dict[str, List[tuple[Any, ...]]] = {} linemaps: dict[str, list[tuple[Any, ...]]] = {}
@classmethod @classmethod
def write(cls, source_code: str, extra: str = "") -> tuple[str, str]: def write(cls, source_code: str, extra: str = "") -> tuple[str, str]:
@ -2745,8 +2740,8 @@ class PyCodeCache:
cls, cls,
source_code: str, source_code: str,
extra: str = "", extra: str = "",
linemap: Optional[List[tuple[int, str]]] = None, linemap: Optional[list[tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None, attrs: Optional[dict[str, Any]] = None,
) -> ModuleType: ) -> ModuleType:
key, path = write(source_code, "py", extra=extra) key, path = write(source_code, "py", extra=extra)
return cls.load_by_key_path(key, path, linemap, attrs) return cls.load_by_key_path(key, path, linemap, attrs)
@ -2756,8 +2751,8 @@ class PyCodeCache:
cls, cls,
key: str, key: str,
path: str, path: str,
linemap: Optional[List[tuple[int, str]]] = None, linemap: Optional[list[tuple[int, str]]] = None,
attrs: Optional[Dict[str, Any]] = None, attrs: Optional[dict[str, Any]] = None,
) -> ModuleType: ) -> ModuleType:
if linemap is None: if linemap is None:
linemap = [] linemap = []
@ -2798,7 +2793,7 @@ class PyCodeCache:
@functools.lru_cache(None) @functools.lru_cache(None)
def stack_frames_for_code( def stack_frames_for_code(
cls, path: str, lineno: int cls, path: str, lineno: int
) -> Optional[List[Dict[str, Any]]]: ) -> Optional[list[dict[str, Any]]]:
if path not in cls.linemaps: if path not in cls.linemaps:
return None return None
# [(starting_line, <fx node>), ...] # [(starting_line, <fx node>), ...]
@ -2810,7 +2805,7 @@ class PyCodeCache:
if not entry: if not entry:
return None return None
def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]: def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
# ideally fx stores stack traces as data rather than a string # ideally fx stores stack traces as data rather than a string
# but this is not along a performance critical path # but this is not along a performance critical path
regex = r'File "(.+)", line (\d+), in (.+)\n' regex = r'File "(.+)", line (\d+), in (.+)\n'
@ -2841,7 +2836,7 @@ def _cuda_compiler() -> Optional[str]:
return "nvcc" return "nvcc"
def _cutlass_include_paths() -> List[str]: def _cutlass_include_paths() -> list[str]:
if config.is_fbcode(): if config.is_fbcode():
from libfb.py import parutil from libfb.py import parutil
@ -2857,14 +2852,14 @@ def _cutlass_include_paths() -> List[str]:
] ]
def _cuda_lib_options() -> List[str]: def _cuda_lib_options() -> list[str]:
_set_gpu_runtime_env() # cpp_extension consults the env _set_gpu_runtime_env() # cpp_extension consults the env
from torch.utils import cpp_extension from torch.utils import cpp_extension
lpaths = cpp_extension.library_paths(device_type="cuda") + [ lpaths = cpp_extension.library_paths(device_type="cuda") + [
sysconfig.get_config_var("LIBDIR") sysconfig.get_config_var("LIBDIR")
] ]
extra_ldflags: List[str] = [] extra_ldflags: list[str] = []
if is_linux(): if is_linux():
_transform_cuda_paths(lpaths) _transform_cuda_paths(lpaths)
for path in lpaths: for path in lpaths:
@ -2880,7 +2875,7 @@ def _cuda_lib_options() -> List[str]:
return extra_ldflags return extra_ldflags
def _nvcc_host_compiler_options() -> List[str]: def _nvcc_host_compiler_options() -> list[str]:
return [ return [
"-fPIC", "-fPIC",
"-fno-strict-aliasing", "-fno-strict-aliasing",
@ -2889,7 +2884,7 @@ def _nvcc_host_compiler_options() -> List[str]:
] ]
def _nvcc_compiler_options() -> List[str]: def _nvcc_compiler_options() -> list[str]:
arch = cuda_env.get_cuda_arch() arch = cuda_env.get_cuda_arch()
if arch == "90": if arch == "90":
# Required by cutlass compilation. # Required by cutlass compilation.
@ -2934,10 +2929,10 @@ def _nvcc_compiler_options() -> List[str]:
def cuda_compile_command( def cuda_compile_command(
src_files: List[str], src_files: list[str],
dst_file: str, dst_file: str,
dst_file_ext: str, dst_file_ext: str,
extra_args: Optional[List[str]] = None, extra_args: Optional[list[str]] = None,
) -> str: ) -> str:
if extra_args is None: if extra_args is None:
extra_args = [] extra_args = []
@ -3052,7 +3047,7 @@ class CUDACodeCache:
input_path: str input_path: str
output_path: str output_path: str
cache: Dict[str, CacheEntry] = {} cache: dict[str, CacheEntry] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
_SOURCE_CODE_SUFFIX = "cu" _SOURCE_CODE_SUFFIX = "cu"
@ -3073,7 +3068,7 @@ class CUDACodeCache:
@classmethod @classmethod
def compile( def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
) -> tuple[str, str, str]: ) -> tuple[str, str, str]:
""" """
Compiles CUDA source_code into a file with dst_file_ext extension. Compiles CUDA source_code into a file with dst_file_ext extension.
@ -3137,7 +3132,7 @@ class ROCmCodeCache:
input_path: str input_path: str
output_path: str output_path: str
cache: Dict[str, CacheEntry] = {} cache: dict[str, CacheEntry] = {}
cache_clear = staticmethod(cache.clear) cache_clear = staticmethod(cache.clear)
_SOURCE_CODE_SUFFIX = "cpp" _SOURCE_CODE_SUFFIX = "cpp"
_logged_compiler_version = False _logged_compiler_version = False
@ -3159,7 +3154,7 @@ class ROCmCodeCache:
@classmethod @classmethod
def compile( def compile(
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
) -> tuple[str, str, str]: ) -> tuple[str, str, str]:
""" """
Compiles source_code into a file with dst_file_ext extension, Compiles source_code into a file with dst_file_ext extension,

View File

@ -7,7 +7,7 @@ import logging
import operator import operator
import sys import sys
from collections import defaultdict from collections import defaultdict
from typing import Any, Dict, List, TYPE_CHECKING from typing import Any, TYPE_CHECKING
import torch import torch
from torch.multiprocessing.reductions import StorageWeakRef from torch.multiprocessing.reductions import StorageWeakRef
@ -33,7 +33,7 @@ if TYPE_CHECKING:
from .scheduler import BaseSchedulerNode from .scheduler import BaseSchedulerNode
def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
""" """
Greedily schedules waits as late as possible. Greedily schedules waits as late as possible.
""" """
@ -42,7 +42,7 @@ def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
) )
def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]: def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
""" """
Greedily schedules comms as early as possible. Greedily schedules comms as early as possible.
""" """
@ -52,8 +52,8 @@ def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
def reorder_compute_for_overlap( def reorder_compute_for_overlap(
snodes: List[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> List[BaseSchedulerNode]: ) -> list[BaseSchedulerNode]:
""" """
This achieves the following overall scheduling procedure: This achieves the following overall scheduling procedure:
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
@ -71,11 +71,11 @@ def reorder_compute_for_overlap(
def _schedule_for_comm( def _schedule_for_comm(
snodes: List[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
raise_comms: bool, raise_comms: bool,
sink_waits: bool, sink_waits: bool,
reorder_for_overlap: bool, reorder_for_overlap: bool,
) -> List[BaseSchedulerNode]: ) -> list[BaseSchedulerNode]:
""" """
Schedule `snodes` for various comm optimization objectives. Schedule `snodes` for various comm optimization objectives.
@ -149,13 +149,13 @@ def _schedule_for_comm(
def __lt__(self, other): def __lt__(self, other):
return self.score < other.score return self.score < other.score
unmet_deps: Dict[BaseSchedulerNode, OrderedSet[str]] = { unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
snode: OrderedSet(dep.name for dep in snode.unmet_dependencies) snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
for snode in snodes for snode in snodes
} }
ready: List[Runnable] = [] ready: list[Runnable] = []
buffer_users: Dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet) buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes} snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
for snode, deps in unmet_deps.items(): for snode, deps in unmet_deps.items():
@ -226,8 +226,8 @@ def _schedule_for_comm(
def decide_global_ordering_of_comms( def decide_global_ordering_of_comms(
nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
) -> List[BaseSchedulerNode]: ) -> list[BaseSchedulerNode]:
""" """
Decide global ordering of comms, by just enforcing the ordering that's in the input graph Decide global ordering of comms, by just enforcing the ordering that's in the input graph
(might not be the same ordering as the eager mode program). (might not be the same ordering as the eager mode program).
@ -303,8 +303,8 @@ def visualize_overlap(order):
def reorder_compute_and_comm_for_overlap( def reorder_compute_and_comm_for_overlap(
snodes: List[BaseSchedulerNode], snodes: list[BaseSchedulerNode],
) -> List[BaseSchedulerNode]: ) -> list[BaseSchedulerNode]:
order = snodes order = snodes
for p in config.reorder_for_compute_comm_overlap_passes: for p in config.reorder_for_compute_comm_overlap_passes:
@ -653,10 +653,10 @@ def get_op_idx(snode):
def enforce_comm_ordering_for_fsdp( def enforce_comm_ordering_for_fsdp(
snodes: List[torch._inductor.scheduler.BaseSchedulerNode], snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer], name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
name_to_fused_node: Dict[str, BaseSchedulerNode], name_to_fused_node: dict[str, BaseSchedulerNode],
) -> List[torch._inductor.scheduler.BaseSchedulerNode]: ) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
from . import scheduler from . import scheduler
new_order: list[BaseSchedulerNode] = [] new_order: list[BaseSchedulerNode] = []

View File

@ -16,11 +16,7 @@ from typing import (
Any, Any,
Callable, Callable,
ContextManager, ContextManager,
Dict,
Generator,
List,
Optional, Optional,
Sequence,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -121,6 +117,8 @@ from .virtualized import V
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator, Sequence
from torch._inductor.output_code import _StrideExprStr from torch._inductor.output_code import _StrideExprStr
from torch._ops import OpOverload from torch._ops import OpOverload
@ -157,7 +155,7 @@ static_inputs_log = torch._logging.getArtifactLogger(
) )
def get_static_input_idxs(num_fixed: int) -> List[int]: def get_static_input_idxs(num_fixed: int) -> list[int]:
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes # If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
# of cudagraphs. Rather than copying these into cudagraph-owned memory # of cudagraphs. Rather than copying these into cudagraph-owned memory
# like we do for normal inputs on each run, we will re-record a cudagraph if these # like we do for normal inputs on each run, we will re-record a cudagraph if these
@ -208,7 +206,7 @@ def _unlift_graph(
) -> GraphModule: ) -> GraphModule:
from torch.export.unflatten import _assign_attr, _AttrKind from torch.export.unflatten import _assign_attr, _AttrKind
state_dict: Dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {} state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
for name, param in mod.named_parameters(remove_duplicate=False): for name, param in mod.named_parameters(remove_duplicate=False):
state_dict[name] = param state_dict[name] = param
_assign_attr( _assign_attr(
@ -227,7 +225,7 @@ def _unlift_graph(
) )
placeholder_nodes = gm.graph.find_nodes(op="placeholder") placeholder_nodes = gm.graph.find_nodes(op="placeholder")
lifted_inputs: List[Optional[FQN]] = [] lifted_inputs: list[Optional[FQN]] = []
# In AOTI, module parameters and buffers are not lifted as graph inputs. # In AOTI, module parameters and buffers are not lifted as graph inputs.
# As a result, mutation to buffers has side effect which makes their initial # As a result, mutation to buffers has side effect which makes their initial
@ -343,9 +341,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
def split_const_gm( def split_const_gm(
gm: GraphModule, gm: GraphModule,
skip_constructor: bool = True, skip_constructor: bool = True,
lifted_constant_names: Optional[List[str]] = None, lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> tuple[GraphModule, Dict[str, int]]: ) -> tuple[GraphModule, dict[str, int]]:
""" """
This function takes an GraphModule input "gm". This function takes an GraphModule input "gm".
The gm will be split into 2 components, The gm will be split into 2 components,
@ -488,8 +486,8 @@ def fake_tensor_prop(
# pass config dict back to user # pass config dict back to user
def get_patched_config_dict( def get_patched_config_dict(
config_patches: Optional[Union[str, Dict[str, Any]]] = None config_patches: Optional[Union[str, dict[str, Any]]] = None
) -> Dict[str, Any]: ) -> dict[str, Any]:
with config.patch(config_patches): with config.patch(config_patches):
return config.get_config_copy() return config.get_config_copy()
@ -515,7 +513,7 @@ class _CompileFxKwargs(TypedDict, total=False):
aot_mode: bool aot_mode: bool
is_inference: bool is_inference: bool
layout_opt: Optional[bool] layout_opt: Optional[bool]
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]] extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
boxed_forward_device_index: Optional[BoxedDeviceIndex] boxed_forward_device_index: Optional[BoxedDeviceIndex]
@ -822,7 +820,7 @@ class _InProcessFxCompile(FxCompile):
aot_mode: bool = V.aot_compilation aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False) is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[ extern_node_serializer: Optional[
Callable[[List[ExternKernelNode]], Any] Callable[[list[ExternKernelNode]], Any]
] = graph_kwargs.get("extern_node_serializer", None) ] = graph_kwargs.get("extern_node_serializer", None)
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get( boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
"boxed_forward_device_index", None "boxed_forward_device_index", None
@ -997,7 +995,7 @@ class _InProcessFxCompile(FxCompile):
metrics_helper = metrics.CachedMetricsHelper() metrics_helper = metrics.CachedMetricsHelper()
with V.set_graph_handler(graph): with V.set_graph_handler(graph):
graph.run(*example_inputs) graph.run(*example_inputs)
output_strides: List[Optional[tuple[_StrideExprStr, ...]]] = [] output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
if graph.graph_outputs is not None: if graph.graph_outputs is not None:
# We'll put the output strides in the compiled graph so we # We'll put the output strides in the compiled graph so we
# can later return them to the caller via TracingContext # can later return them to the caller via TracingContext
@ -1189,7 +1187,7 @@ def cudagraphify(
static_input_idxs: Sequence[int] = (), static_input_idxs: Sequence[int] = (),
*, *,
device_index: int, device_index: int,
stack_traces: List[Optional[str]], stack_traces: list[Optional[str]],
is_backward: bool, is_backward: bool,
is_inference: bool, is_inference: bool,
constants: tuple[torch.Tensor, ...] = (), constants: tuple[torch.Tensor, ...] = (),
@ -1240,7 +1238,7 @@ def static_input(x: torch.Tensor) -> torch.Tensor:
def index_expanded_dims_and_copy_( def index_expanded_dims_and_copy_(
dst: torch.Tensor, dst: torch.Tensor,
src: torch.Tensor, src: torch.Tensor,
expanded_dims: List[int], expanded_dims: list[int],
) -> None: ) -> None:
"Index into expanded dimensions of both dst and src then copy_" "Index into expanded dimensions of both dst and src then copy_"
dst = index_expanded_dims(dst, expanded_dims) dst = index_expanded_dims(dst, expanded_dims)
@ -1250,9 +1248,9 @@ def index_expanded_dims_and_copy_(
def cudagraphify_impl( def cudagraphify_impl(
model: Callable[..., Any], model: Callable[..., Any],
inputs: List[torch.Tensor], inputs: list[torch.Tensor],
static_input_idxs: Sequence[int] = (), static_input_idxs: Sequence[int] = (),
) -> Callable[[List[InputType]], Any]: ) -> Callable[[list[InputType]], Any]:
""" """
Assumes inputs[static_input_idxs[i]] are always the same memory address Assumes inputs[static_input_idxs[i]] are always the same memory address
""" """
@ -1304,7 +1302,7 @@ def cudagraphify_impl(
if config.size_asserts: if config.size_asserts:
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]: def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
assert len(static_inputs) == len(new_inputs) assert len(static_inputs) == len(new_inputs)
for idx, (dst, src, expanded_dims) in enumerate( for idx, (dst, src, expanded_dims) in enumerate(
zip(static_inputs, new_inputs, inps_expanded_dims) zip(static_inputs, new_inputs, inps_expanded_dims)
@ -1328,7 +1326,7 @@ def cudagraphify_impl(
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
] ]
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]: def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
for idx in copy_indices: for idx in copy_indices:
expanded_dims = inps_expanded_dims[idx] expanded_dims = inps_expanded_dims[idx]
src = new_inputs[idx] src = new_inputs[idx]
@ -1343,16 +1341,16 @@ def cudagraphify_impl(
def compile_fx_aot( def compile_fx_aot(
model_: GraphModule, model_: GraphModule,
example_inputs_: List[InputType], example_inputs_: list[InputType],
inner_compile: _CompileFxCallable = compile_fx_inner, inner_compile: _CompileFxCallable = compile_fx_inner,
config_patches: Optional[Dict[str, str]] = None, config_patches: Optional[dict[str, str]] = None,
) -> Union[List[str], str]: ) -> Union[list[str], str]:
assert isinstance(model_, GraphModule), model_ assert isinstance(model_, GraphModule), model_
# [See NOTE] Unwrapping subclasses AOT # [See NOTE] Unwrapping subclasses AOT
unwrap_tensor_subclass_parameters(model_) unwrap_tensor_subclass_parameters(model_)
config_patches: Dict[str, Any] = ( config_patches: dict[str, Any] = (
{"cpp_wrapper": True} {"cpp_wrapper": True}
if config_patches is None if config_patches is None
else {**config_patches, "cpp_wrapper": True} else {**config_patches, "cpp_wrapper": True}
@ -1409,7 +1407,7 @@ def fw_compiler_freezing(
cudagraphs: BoxedBool, cudagraphs: BoxedBool,
graph_id: int, graph_id: int,
forward_device: BoxedDeviceIndex, forward_device: BoxedDeviceIndex,
) -> Callable[[List[object]], Sequence[torch.Tensor]]: ) -> Callable[[list[object]], Sequence[torch.Tensor]]:
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
# partition_fn won't be called # partition_fn won't be called
@ -1492,7 +1490,7 @@ def fw_compiler_freezing(
if V.aot_compilation: if V.aot_compilation:
return optimized_function return optimized_function
def wrapper(args: List[object]) -> Sequence[torch.Tensor]: def wrapper(args: list[object]) -> Sequence[torch.Tensor]:
args_new = [ args_new = [
args[i - unwrapped_args_offsets[min(i, max_offset_idx)]] args[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
for i in preserved_arg_indices for i in preserved_arg_indices
@ -1505,7 +1503,7 @@ def fw_compiler_freezing(
return wrapper return wrapper
def get_cpp_wrapper_config() -> Dict[str, object]: def get_cpp_wrapper_config() -> dict[str, object]:
return { return {
# Set autotune_at_compile_time to True as default if the option is not explicitly set # Set autotune_at_compile_time to True as default if the option is not explicitly set
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time "triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
@ -1551,9 +1549,9 @@ def compile_fx(
model_: GraphModule, model_: GraphModule,
example_inputs_: Sequence[InputType], example_inputs_: Sequence[InputType],
inner_compile: Callable[..., OutputCode] = compile_fx_inner, inner_compile: Callable[..., OutputCode] = compile_fx_inner,
config_patches: Optional[Dict[str, Any]] = None, config_patches: Optional[dict[str, Any]] = None,
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None, decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]: ) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
""" """
Main entry point for compiling given FX graph. Despite the fact that this Main entry point for compiling given FX graph. Despite the fact that this
lives in :mod:`torch._inductor`, this function is responsible for calling lives in :mod:`torch._inductor`, this function is responsible for calling
@ -2017,11 +2015,11 @@ def _check_triton_bf16_support(graph: GraphLowering) -> None:
def _aoti_flatten_inputs( def _aoti_flatten_inputs(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
args: Union[List[Any], tuple[Any, ...]], args: Union[list[Any], tuple[Any, ...]],
kwargs: Optional[Dict[str, Any]] = None, kwargs: Optional[dict[str, Any]] = None,
*, *,
options: Optional[Dict[str, Any]] = None, options: Optional[dict[str, Any]] = None,
) -> tuple[List[Any], Dict[str, Any]]: ) -> tuple[list[Any], dict[str, Any]]:
""" """
Flatten the inputs to the graph module and return the flat inputs and options. Flatten the inputs to the graph module and return the flat inputs and options.
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options. Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.

View File

@ -5,7 +5,7 @@ import importlib
import logging import logging
import os import os
import sys import sys
from typing import Type, TypeVar from typing import TypeVar
from torch._inductor.async_compile import pre_fork_setup from torch._inductor.async_compile import pre_fork_setup
from torch._inductor.compile_worker.subproc_pool import ( from torch._inductor.compile_worker.subproc_pool import (
@ -32,7 +32,7 @@ except ImportError:
pass pass
def _lookup_and_create_type(base: Type[_T], qname: str) -> _T: def _lookup_and_create_type(base: type[_T], qname: str) -> _T:
""" """
Given a base type and qualified name: import & lookup that name, check Given a base type and qualified name: import & lookup that name, check
that it's of the given type and then instantiate it. that it's of the given type and then instantiate it.

View File

@ -13,7 +13,7 @@ import typing
from concurrent.futures import Future, ProcessPoolExecutor from concurrent.futures import Future, ProcessPoolExecutor
from concurrent.futures.process import BrokenProcessPool from concurrent.futures.process import BrokenProcessPool
from enum import Enum from enum import Enum
from typing import Any, BinaryIO, Callable, Dict, Optional, TypeVar from typing import Any, BinaryIO, Callable, Optional, TypeVar
from typing_extensions import Never, ParamSpec from typing_extensions import Never, ParamSpec
# _thread_safe_fork is needed because the subprocesses in the pool can read # _thread_safe_fork is needed because the subprocesses in the pool can read
@ -158,7 +158,7 @@ class SubprocPool:
self.read_thread = threading.Thread(target=self._read_thread, daemon=True) self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
self.futures_lock = threading.Lock() self.futures_lock = threading.Lock()
self.pending_futures: Dict[int, Future[Any]] = {} self.pending_futures: dict[int, Future[Any]] = {}
self.job_id_count = itertools.count() self.job_id_count = itertools.count()
self.running = True self.running = True

View File

@ -7,7 +7,7 @@ import shutil
import sys import sys
import tempfile import tempfile
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Callable, Dict, List, Optional from typing import Callable, Optional
from torch._inductor.runtime.cache_dir_utils import cache_dir from torch._inductor.runtime.cache_dir_utils import cache_dir
@ -43,7 +43,7 @@ class ConfigChange(BinarySubsystem):
# Dictionary of backend -> subsystems # Dictionary of backend -> subsystems
BACKENDS: Dict[str, List[Subsystem]] = { BACKENDS: dict[str, list[Subsystem]] = {
# run dynamo without aot_autograd # run dynamo without aot_autograd
"eager": [], "eager": [],
# run dynamo with aot_autograd, but no partitioner or decomps # run dynamo with aot_autograd, but no partitioner or decomps
@ -68,8 +68,8 @@ BACKENDS: Dict[str, List[Subsystem]] = {
], # TODO - add more - fusions ? ], # TODO - add more - fusions ?
} }
subsystem_call_counter: Dict[str, int] = collections.Counter() subsystem_call_counter: dict[str, int] = collections.Counter()
call_counter_debug_info: Dict[int, str] = {} call_counter_debug_info: dict[int, str] = {}
def reset_counters() -> None: def reset_counters() -> None:
@ -123,13 +123,13 @@ class CompilerBisector:
return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}" return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}"
@classmethod @classmethod
def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None: def write_lines_to_file(cls, file_path: str, lines: list[str]) -> None:
os.makedirs(os.path.dirname(file_path), exist_ok=True) os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w") as file: with open(file_path, "w") as file:
file.writelines(lines) file.writelines(lines)
@classmethod @classmethod
def read_lines_from_file(cls, file_path: str) -> List[str]: def read_lines_from_file(cls, file_path: str) -> list[str]:
if os.path.exists(file_path): if os.path.exists(file_path):
with open(file_path) as file: with open(file_path) as file:
return file.readlines() return file.readlines()
@ -154,7 +154,7 @@ class CompilerBisector:
@classmethod @classmethod
def set_config_values( def set_config_values(
cls, backend: str, subsystem: str, config_data: Dict[str, object] cls, backend: str, subsystem: str, config_data: dict[str, object]
) -> None: ) -> None:
file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt") file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt")
lines = [f"{k}={v}\n" for k, v in config_data.items()] lines = [f"{k}={v}\n" for k, v in config_data.items()]
@ -267,7 +267,7 @@ class CompilerBisector:
cls.write_lines_to_file(file_path, lines) cls.write_lines_to_file(file_path, lines)
@classmethod @classmethod
def get_config_change(cls, config_name: str) -> Optional[Dict[str, object]]: def get_config_change(cls, config_name: str) -> Optional[dict[str, object]]:
backend = cls.get_backend() backend = cls.get_backend()
subsystem = cls.get_subsystem() subsystem = cls.get_subsystem()

View File

@ -1,16 +1,6 @@
import os # noqa: C101 import os # noqa: C101
import sys import sys
from typing import ( from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
Any,
Callable,
Dict,
List,
Literal,
Optional,
Tuple,
TYPE_CHECKING,
Union,
)
import torch import torch
import torch._inductor.custom_graph_pass import torch._inductor.custom_graph_pass
@ -193,8 +183,8 @@ pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
# hence custom IR passes built on top of it might break in the future. # hence custom IR passes built on top of it might break in the future.
_pre_fusion_custom_pass: Optional[ _pre_fusion_custom_pass: Optional[
Callable[ Callable[
[List["torch._inductor.scheduler.BaseSchedulerNode"]], [list["torch._inductor.scheduler.BaseSchedulerNode"]],
List["torch._inductor.scheduler.BaseSchedulerNode"], list["torch._inductor.scheduler.BaseSchedulerNode"],
] ]
] = None ] = None
@ -231,11 +221,11 @@ batch_fusion = True
# merge_splits_pass # merge_splits_pass
# mutate_cat_pass # mutate_cat_pass
# split_cat_pass # split_cat_pass
pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {} pre_grad_fusion_options: dict[str, dict[str, Any]] = {}
# Post grad fusion and options, set to empty dict to disable fusion. # Post grad fusion and options, set to empty dict to disable fusion.
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions. # Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {} post_grad_fusion_options: dict[str, dict[str, Any]] = {}
# enable reordering pass for improving memory locality # enable reordering pass for improving memory locality
reorder_for_locality = True reorder_for_locality = True
@ -257,7 +247,7 @@ use_mixed_mm = True
# floating point numbers,about 16 decimal digits for double precision floating point numbers) # floating point numbers,about 16 decimal digits for double precision floating point numbers)
# according to PyTorch documentation. # according to PyTorch documentation.
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations # https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
fx_passes_numeric_check: Dict[str, Any] = { fx_passes_numeric_check: dict[str, Any] = {
"pre_grad": False, "pre_grad": False,
"precision": 1e-4, "precision": 1e-4,
"num_iterations": 1, "num_iterations": 1,
@ -287,12 +277,12 @@ reorder_for_compute_comm_overlap = False
# for built-in passes, use string name; for user-defined passes, pass in the function handle # for built-in passes, use string name; for user-defined passes, pass in the function handle
# WARNING: Inductor scheduler IR is at prototype stage and subject to change, # WARNING: Inductor scheduler IR is at prototype stage and subject to change,
# hence custom IR passes built on top of it might break in the future. # hence custom IR passes built on top of it might break in the future.
reorder_for_compute_comm_overlap_passes: List[ reorder_for_compute_comm_overlap_passes: list[
Union[ Union[
str, str,
Callable[ Callable[
[List["torch._inductor.scheduler.BaseSchedulerNode"]], [list["torch._inductor.scheduler.BaseSchedulerNode"]],
List["torch._inductor.scheduler.BaseSchedulerNode"], list["torch._inductor.scheduler.BaseSchedulerNode"],
], ],
] ]
] = [ ] = [
@ -618,7 +608,7 @@ _fuse_ddp_bucket_size = 25
# overlapping. At this moment, this pass performs better than # overlapping. At this moment, this pass performs better than
# reorder_for_compute_comm_overlap_passes but we will add the logic of # reorder_for_compute_comm_overlap_passes but we will add the logic of
# "schedule_comm_wait" in the future and remove the one here. # "schedule_comm_wait" in the future and remove the one here.
_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [ _fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
"fuse_ddp_with_concat_op", "fuse_ddp_with_concat_op",
"schedule_comm_wait", "schedule_comm_wait",
] ]
@ -852,7 +842,7 @@ class cpp:
simdlen: Optional[int] = None simdlen: Optional[int] = None
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096")) min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
cxx: Tuple[None, str] = ( cxx: tuple[Literal[None], str] = (
None, # download gcc12 from conda-forge if conda is installed None, # download gcc12 from conda-forge if conda is installed
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"), os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
) # type: ignore[assignment] ) # type: ignore[assignment]
@ -1157,7 +1147,7 @@ class aot_inductor:
# Dictionary of metadata users might want to save to pass to the runtime. # Dictionary of metadata users might want to save to pass to the runtime.
# TODO: Move this somewhere else, since it's no longer really a config # TODO: Move this somewhere else, since it's no longer really a config
metadata: Dict[str, str] = {} metadata: dict[str, str] = {}
# fbcode only. Whether to raise error if C++ codegen is too big to optimize # fbcode only. Whether to raise error if C++ codegen is too big to optimize
raise_error_on_ignored_optimization: bool = ( raise_error_on_ignored_optimization: bool = (
@ -1168,7 +1158,7 @@ class aot_inductor:
dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1" dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
# Dictionary of presets that can be passed in # Dictionary of presets that can be passed in
presets: Dict[str, Any] = {} presets: dict[str, Any] = {}
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests # Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
# should be run with this flag both on and off to make sure we have coverage. # should be run with this flag both on and off to make sure we have coverage.
@ -1265,11 +1255,11 @@ class cuda:
class rocm: class rocm:
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"]. # Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
# If empty, the `native` arch is used # If empty, the `native` arch is used
arch: List[str] = [] arch: list[str] = []
# Enable the CK backend for CDNA2 and CDNA3 only (for now) # Enable the CK backend for CDNA2 and CDNA3 only (for now)
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"] ck_supported_arch: list[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
# Optimization level, use to balance compilation speed and runtime performance. # Optimization level, use to balance compilation speed and runtime performance.
# The type will not necessarily be comprehensive and won't be enforced at runtime. # The type will not necessarily be comprehensive and won't be enforced at runtime.
@ -1415,7 +1405,7 @@ class trace:
log_inductor_triton_kernel_to_post_grad_node_info: bool = True log_inductor_triton_kernel_to_post_grad_node_info: bool = True
_save_config_ignore: List[str] = [ _save_config_ignore: list[str] = [
# workaround: "Can't pickle <function ...>" # workaround: "Can't pickle <function ...>"
"trace.upload_tar", "trace.upload_tar",
"joint_custom_pre_pass", "joint_custom_pre_pass",
@ -1423,7 +1413,7 @@ _save_config_ignore: List[str] = [
"pre_grad_custom_pass", "pre_grad_custom_pass",
] ]
_cache_config_ignore_prefix: List[str] = [ _cache_config_ignore_prefix: list[str] = [
# trace functions are not relevant to config caching # trace functions are not relevant to config caching
"trace", "trace",
# uses absolute path # uses absolute path
@ -1439,7 +1429,7 @@ _cache_config_ignore_prefix: List[str] = [
] ]
# External callable for matmul tuning candidates # External callable for matmul tuning candidates
external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = [] external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
class test_configs: class test_configs:

View File

@ -1,5 +1,5 @@
import collections import collections
from typing import Any, Callable, Dict, List, Optional from typing import Any, Callable, Optional
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -57,7 +57,7 @@ def replace_node_with_constant(
def is_const_source( def is_const_source(
node: torch.fx.Node, lifted_constant_names: Optional[List[str]] node: torch.fx.Node, lifted_constant_names: Optional[list[str]]
) -> bool: ) -> bool:
return node.op == "get_attr" or node.name in (lifted_constant_names or ()) return node.op == "get_attr" or node.name in (lifted_constant_names or ())
@ -67,12 +67,12 @@ class ConstantFolder(torch.fx.Interpreter):
self, self,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
skip_constructors: bool = False, skip_constructors: bool = False,
lifted_constant_names: Optional[List[str]] = None, lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None: ) -> None:
super().__init__(gm) super().__init__(gm)
self.node_replacements: Dict[torch.fx.Node, Any] = {} self.node_replacements: dict[torch.fx.Node, Any] = {}
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
self.unknown_value = object() self.unknown_value = object()
self.skip_constructors: bool = skip_constructors self.skip_constructors: bool = skip_constructors
@ -141,7 +141,7 @@ class ConstantFolder(torch.fx.Interpreter):
return True return True
return False return False
def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]: def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:
last_non_output_use = collections.defaultdict(list) last_non_output_use = collections.defaultdict(list)
seen_uses = OrderedSet[torch.fx.Node]() seen_uses = OrderedSet[torch.fx.Node]()
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr] output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
@ -264,11 +264,11 @@ class ConstantFolder(torch.fx.Interpreter):
self.node_replacements[node] = tensor self.node_replacements[node] = tensor
def run(self) -> Any: # type: ignore[override] def run(self) -> Any: # type: ignore[override]
env: Dict[torch.fx.Node, Any] = {} env: dict[torch.fx.Node, Any] = {}
self.insert_placerholder_values(env) self.insert_placerholder_values(env)
return super().run(initial_env=env) return super().run(initial_env=env)
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None: def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None:
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr] for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
env[n] = self.unknown_value # type: ignore[assignment] env[n] = self.unknown_value # type: ignore[assignment]
if self.lifted_constant_names is None: if self.lifted_constant_names is None:
@ -309,7 +309,7 @@ def constant_fold(
def constant_graph_tag( def constant_graph_tag(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
skip_constructors: bool = True, skip_constructors: bool = True,
lifted_constant_names: Optional[List[str]] = None, lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> None: ) -> None:
with torch.utils._python_dispatch._disable_current_modes(): with torch.utils._python_dispatch._disable_current_modes():
@ -337,7 +337,7 @@ def constant_graph_tag(
def run_and_get_constant_graph( def run_and_get_constant_graph(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
skip_constructors: bool = True, skip_constructors: bool = True,
lifted_constant_names: Optional[List[str]] = None, lifted_constant_names: Optional[list[str]] = None,
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
) -> torch.fx.GraphModule: ) -> torch.fx.GraphModule:
""" """
@ -367,7 +367,7 @@ def run_and_get_constant_graph(
new_graph = torch.fx.Graph() new_graph = torch.fx.Graph()
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
output_nodes = [] output_nodes = []
for node in gm.graph.nodes: for node in gm.graph.nodes:
if node.meta[META_TAG] == MODULE_TAG: if node.meta[META_TAG] == MODULE_TAG:

View File

@ -16,10 +16,11 @@ import sys
import sysconfig import sysconfig
import textwrap import textwrap
import warnings import warnings
from collections.abc import Sequence
from ctypes import cdll from ctypes import cdll
from ctypes.util import find_library from ctypes.util import find_library
from pathlib import Path from pathlib import Path
from typing import Any, List, Optional, Sequence, Union from typing import Any, Optional, Union
import torch import torch
from torch._dynamo.utils import dynamo_timed from torch._dynamo.utils import dynamo_timed
@ -285,12 +286,12 @@ def get_compiler_version_info(compiler: str) -> str:
# =============================== cpp builder =============================== # =============================== cpp builder ===============================
def _append_list(dest_list: List[str], src_list: List[str]) -> None: def _append_list(dest_list: list[str], src_list: list[str]) -> None:
dest_list.extend(copy.deepcopy(item) for item in src_list) dest_list.extend(copy.deepcopy(item) for item in src_list)
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]: def _remove_duplication_in_list(orig_list: list[str]) -> list[str]:
new_list: List[str] = [] new_list: list[str] = []
for item in orig_list: for item in orig_list:
if item not in new_list: if item not in new_list:
new_list.append(item) new_list.append(item)
@ -362,26 +363,26 @@ class BuildOptionsBase:
def __init__( def __init__(
self, self,
compiler: str = "", compiler: str = "",
definitions: Optional[List[str]] = None, definitions: Optional[list[str]] = None,
include_dirs: Optional[List[str]] = None, include_dirs: Optional[list[str]] = None,
cflags: Optional[List[str]] = None, cflags: Optional[list[str]] = None,
ldflags: Optional[List[str]] = None, ldflags: Optional[list[str]] = None,
libraries_dirs: Optional[List[str]] = None, libraries_dirs: Optional[list[str]] = None,
libraries: Optional[List[str]] = None, libraries: Optional[list[str]] = None,
passthrough_args: Optional[List[str]] = None, passthrough_args: Optional[list[str]] = None,
aot_mode: bool = False, aot_mode: bool = False,
use_absolute_path: bool = False, use_absolute_path: bool = False,
compile_only: bool = False, compile_only: bool = False,
) -> None: ) -> None:
self._compiler = compiler self._compiler = compiler
self._definations: List[str] = definitions or [] self._definations: list[str] = definitions or []
self._include_dirs: List[str] = include_dirs or [] self._include_dirs: list[str] = include_dirs or []
self._cflags: List[str] = cflags or [] self._cflags: list[str] = cflags or []
self._ldflags: List[str] = ldflags or [] self._ldflags: list[str] = ldflags or []
self._libraries_dirs: List[str] = libraries_dirs or [] self._libraries_dirs: list[str] = libraries_dirs or []
self._libraries: List[str] = libraries or [] self._libraries: list[str] = libraries or []
# Some args is hard to abstract to OS compatable, passthrough it directly. # Some args is hard to abstract to OS compatable, passthrough it directly.
self._passthrough_args: List[str] = passthrough_args or [] self._passthrough_args: list[str] = passthrough_args or []
self._aot_mode: bool = aot_mode self._aot_mode: bool = aot_mode
self._use_absolute_path: bool = use_absolute_path self._use_absolute_path: bool = use_absolute_path
@ -408,25 +409,25 @@ class BuildOptionsBase:
def get_compiler(self) -> str: def get_compiler(self) -> str:
return self._compiler return self._compiler
def get_definations(self) -> List[str]: def get_definations(self) -> list[str]:
return self._definations return self._definations
def get_include_dirs(self) -> List[str]: def get_include_dirs(self) -> list[str]:
return self._include_dirs return self._include_dirs
def get_cflags(self) -> List[str]: def get_cflags(self) -> list[str]:
return self._cflags return self._cflags
def get_ldflags(self) -> List[str]: def get_ldflags(self) -> list[str]:
return self._ldflags return self._ldflags
def get_libraries_dirs(self) -> List[str]: def get_libraries_dirs(self) -> list[str]:
return self._libraries_dirs return self._libraries_dirs
def get_libraries(self) -> List[str]: def get_libraries(self) -> list[str]:
return self._libraries return self._libraries
def get_passthrough_args(self) -> List[str]: def get_passthrough_args(self) -> list[str]:
return self._passthrough_args return self._passthrough_args
def get_aot_mode(self) -> bool: def get_aot_mode(self) -> bool:
@ -457,14 +458,14 @@ class BuildOptionsBase:
json.dump(attrs, f) json.dump(attrs, f)
def _get_warning_all_cflag(warning_all: bool = True) -> List[str]: def _get_warning_all_cflag(warning_all: bool = True) -> list[str]:
if not _IS_WINDOWS: if not _IS_WINDOWS:
return ["Wall"] if warning_all else [] return ["Wall"] if warning_all else []
else: else:
return [] return []
def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]: def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
if _IS_WINDOWS: if _IS_WINDOWS:
""" """
On Windows, only c++20 can support `std::enable_if_t`. On Windows, only c++20 can support `std::enable_if_t`.
@ -479,7 +480,7 @@ def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
return [f"std={std_num}"] return [f"std={std_num}"]
def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]: def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
if _IS_WINDOWS: if _IS_WINDOWS:
cflags = [ cflags = [
"wd4819", "wd4819",
@ -506,7 +507,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
return cflags return cflags
def _get_ffast_math_flags() -> List[str]: def _get_ffast_math_flags() -> list[str]:
# ffast-math is equivalent to these flags as in # ffast-math is equivalent to these flags as in
# https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468 # https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468
# however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have # however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have
@ -527,7 +528,7 @@ def _get_ffast_math_flags() -> List[str]:
return flags return flags
def _get_optimization_cflags(cpp_compiler: str) -> List[str]: def _get_optimization_cflags(cpp_compiler: str) -> list[str]:
if _IS_WINDOWS: if _IS_WINDOWS:
return ["O2"] return ["O2"]
else: else:
@ -554,7 +555,7 @@ def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
return cflags return cflags
def _get_shared_cflag(compile_only: bool) -> List[str]: def _get_shared_cflag(compile_only: bool) -> list[str]:
if _IS_WINDOWS: if _IS_WINDOWS:
""" """
MSVC `/MD` using python `ucrtbase.dll` lib as runtime. MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
@ -578,14 +579,14 @@ def get_cpp_options(
compile_only: bool, compile_only: bool,
warning_all: bool = True, warning_all: bool = True,
extra_flags: Sequence[str] = (), extra_flags: Sequence[str] = (),
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
definations: List[str] = [] definations: list[str] = []
include_dirs: List[str] = [] include_dirs: list[str] = []
cflags: List[str] = [] cflags: list[str] = []
ldflags: List[str] = [] ldflags: list[str] = []
libraries_dirs: List[str] = [] libraries_dirs: list[str] = []
libraries: List[str] = [] libraries: list[str] = []
passthrough_args: List[str] = [] passthrough_args: list[str] = []
cflags = ( cflags = (
_get_shared_cflag(compile_only) _get_shared_cflag(compile_only)
@ -657,22 +658,22 @@ class CppOptions(BuildOptionsBase):
self._finalize_options() self._finalize_options()
def _get_glibcxx_abi_build_flags() -> List[str]: def _get_glibcxx_abi_build_flags() -> list[str]:
if not _IS_WINDOWS: if not _IS_WINDOWS:
return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))] return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
else: else:
return [] return []
def _get_torch_cpp_wrapper_defination() -> List[str]: def _get_torch_cpp_wrapper_defination() -> list[str]:
return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"] return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"]
def _use_custom_generated_macros() -> List[str]: def _use_custom_generated_macros() -> list[str]:
return [" C10_USING_CUSTOM_GENERATED_MACROS"] return [" C10_USING_CUSTOM_GENERATED_MACROS"]
def _use_fb_internal_macros() -> List[str]: def _use_fb_internal_macros() -> list[str]:
if not _IS_WINDOWS: if not _IS_WINDOWS:
if config.is_fbcode(): if config.is_fbcode():
fb_internal_macros = [ fb_internal_macros = [
@ -697,12 +698,12 @@ def _setup_standard_sys_libs(
cpp_compiler: str, cpp_compiler: str,
aot_mode: bool, aot_mode: bool,
use_absolute_path: bool, use_absolute_path: bool,
) -> tuple[List[str], List[str], List[str]]: ) -> tuple[list[str], list[str], list[str]]:
from torch._inductor.codecache import _LINKER_SCRIPT from torch._inductor.codecache import _LINKER_SCRIPT
cflags: List[str] = [] cflags: list[str] = []
include_dirs: List[str] = [] include_dirs: list[str] = []
passthrough_args: List[str] = [] passthrough_args: list[str] = []
if _IS_WINDOWS: if _IS_WINDOWS:
return cflags, include_dirs, passthrough_args return cflags, include_dirs, passthrough_args
@ -737,9 +738,9 @@ def _setup_standard_sys_libs(
return cflags, include_dirs, passthrough_args return cflags, include_dirs, passthrough_args
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]]: def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]:
macros: List[str] = [] macros: list[str] = []
build_flags: List[str] = [] build_flags: list[str] = []
if vec_isa != invalid_vec_isa: if vec_isa != invalid_vec_isa:
# Add Windows support later. # Add Windows support later.
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro()) macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
@ -759,7 +760,7 @@ def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]
def _get_torch_related_args( def _get_torch_related_args(
include_pytorch: bool, aot_mode: bool include_pytorch: bool, aot_mode: bool
) -> tuple[List[str], List[str], List[str]]: ) -> tuple[list[str], list[str], list[str]]:
from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH
include_dirs = [ include_dirs = [
@ -783,7 +784,7 @@ def _get_torch_related_args(
return include_dirs, libraries_dirs, libraries return include_dirs, libraries_dirs, libraries
def _get_python_include_dirs() -> List[str]: def _get_python_include_dirs() -> list[str]:
include_dir = Path(sysconfig.get_path("include")) include_dir = Path(sysconfig.get_path("include"))
# On Darwin Python executable from a framework can return # On Darwin Python executable from a framework can return
# non-existing /Library/Python/... include path, in which case # non-existing /Library/Python/... include path, in which case
@ -796,7 +797,7 @@ def _get_python_include_dirs() -> List[str]:
return [str(include_dir)] return [str(include_dir)]
def _get_python_related_args() -> tuple[List[str], List[str]]: def _get_python_related_args() -> tuple[list[str], list[str]]:
python_include_dirs = _get_python_include_dirs() python_include_dirs = _get_python_include_dirs()
python_include_path = sysconfig.get_path( python_include_path = sysconfig.get_path(
"include", scheme="nt" if _IS_WINDOWS else "posix_prefix" "include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
@ -893,13 +894,13 @@ def perload_icx_libomp_win(cpp_compiler: str) -> None:
def _get_openmp_args( def _get_openmp_args(
cpp_compiler: str, cpp_compiler: str,
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str]]: ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str]]:
cflags: List[str] = [] cflags: list[str] = []
ldflags: List[str] = [] ldflags: list[str] = []
include_dir_paths: List[str] = [] include_dir_paths: list[str] = []
lib_dir_paths: List[str] = [] lib_dir_paths: list[str] = []
libs: List[str] = [] libs: list[str] = []
passthrough_args: List[str] = [] passthrough_args: list[str] = []
if _IS_MACOS: if _IS_MACOS:
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang` # Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
cflags.append("Xclang") cflags.append("Xclang")
@ -998,7 +999,7 @@ def _get_openmp_args(
return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args
def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]: def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]:
macros = [] macros = []
if use_mmap_weights: if use_mmap_weights:
macros.append(" USE_MMAP_SELF") macros.append(" USE_MMAP_SELF")
@ -1013,14 +1014,14 @@ def get_cpp_torch_options(
compile_only: bool, compile_only: bool,
use_absolute_path: bool, use_absolute_path: bool,
use_mmap_weights: bool, use_mmap_weights: bool,
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
definations: List[str] = [] definations: list[str] = []
include_dirs: List[str] = [] include_dirs: list[str] = []
cflags: List[str] = [] cflags: list[str] = []
ldflags: List[str] = [] ldflags: list[str] = []
libraries_dirs: List[str] = [] libraries_dirs: list[str] = []
libraries: List[str] = [] libraries: list[str] = []
passthrough_args: List[str] = [] passthrough_args: list[str] = []
torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination() torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination()
use_custom_generated_macros_definations = _use_custom_generated_macros() use_custom_generated_macros_definations = _use_custom_generated_macros()
@ -1163,7 +1164,7 @@ def _set_gpu_runtime_env() -> None:
os.environ["CUDA_HOME"] = build_paths.sdk_home os.environ["CUDA_HOME"] = build_paths.sdk_home
def _transform_cuda_paths(lpaths: List[str]) -> None: def _transform_cuda_paths(lpaths: list[str]) -> None:
# This handles two cases: # This handles two cases:
# 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs # 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs
# 2. Linux machines may have CUDA installed under either lib64/ or lib/ # 2. Linux machines may have CUDA installed under either lib64/ or lib/
@ -1186,14 +1187,14 @@ def get_cpp_torch_device_options(
device_type: str, device_type: str,
aot_mode: bool = False, aot_mode: bool = False,
compile_only: bool = False, compile_only: bool = False,
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]: ) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
definations: List[str] = [] definations: list[str] = []
include_dirs: List[str] = [] include_dirs: list[str] = []
cflags: List[str] = [] cflags: list[str] = []
ldflags: List[str] = [] ldflags: list[str] = []
libraries_dirs: List[str] = [] libraries_dirs: list[str] = []
libraries: List[str] = [] libraries: list[str] = []
passthrough_args: List[str] = [] passthrough_args: list[str] = []
if ( if (
config.is_fbcode() config.is_fbcode()
and "CUDA_HOME" not in os.environ and "CUDA_HOME" not in os.environ
@ -1287,13 +1288,13 @@ class CppTorchDeviceOptions(CppTorchOptions):
extra_flags=extra_flags, extra_flags=extra_flags,
) )
device_definations: List[str] = [] device_definations: list[str] = []
device_include_dirs: List[str] = [] device_include_dirs: list[str] = []
device_cflags: List[str] = [] device_cflags: list[str] = []
device_ldflags: List[str] = [] device_ldflags: list[str] = []
device_libraries_dirs: List[str] = [] device_libraries_dirs: list[str] = []
device_libraries: List[str] = [] device_libraries: list[str] = []
device_passthrough_args: List[str] = [] device_passthrough_args: list[str] = []
( (
device_definations, device_definations,
@ -1379,7 +1380,7 @@ class CppBuilder:
def __init__( def __init__(
self, self,
name: str, name: str,
sources: Union[str, List[str]], sources: Union[str, list[str]],
BuildOption: BuildOptionsBase, BuildOption: BuildOptionsBase,
output_dir: str = "", output_dir: str = "",
) -> None: ) -> None:

View File

@ -7,7 +7,7 @@ import re
import subprocess import subprocess
import sys import sys
import warnings import warnings
from typing import Any, Callable, Dict, List, Union from typing import Any, Callable, Union
import torch import torch
from torch._inductor import config from torch._inductor import config
@ -33,9 +33,9 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
class VecISA: class VecISA:
_bit_width: int _bit_width: int
_macro: List[str] _macro: list[str]
_arch_flags: str _arch_flags: str
_dtype_nelements: Dict[torch.dtype, int] _dtype_nelements: dict[torch.dtype, int]
# Note [Checking for Vectorized Support in Inductor] # Note [Checking for Vectorized Support in Inductor]
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions # TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
@ -79,7 +79,7 @@ cdll.LoadLibrary("__lib_path__")
def nelements(self, dtype: torch.dtype = torch.float) -> int: def nelements(self, dtype: torch.dtype = torch.float) -> int:
return self._dtype_nelements[dtype] return self._dtype_nelements[dtype]
def build_macro(self) -> List[str]: def build_macro(self) -> list[str]:
return self._macro return self._macro
def build_arch_flags(self) -> str: def build_arch_flags(self) -> str:
@ -300,11 +300,11 @@ class InvalidVecISA(VecISA):
__hash__: Callable[[VecISA], Any] = VecISA.__hash__ __hash__: Callable[[VecISA], Any] = VecISA.__hash__
def x86_isa_checker() -> List[str]: def x86_isa_checker() -> list[str]:
supported_isa: List[str] = [] supported_isa: list[str] = []
def _check_and_append_supported_isa( def _check_and_append_supported_isa(
dest: List[str], isa_supported: bool, isa_name: str dest: list[str], isa_supported: bool, isa_name: str
) -> None: ) -> None:
if isa_supported: if isa_supported:
dest.append(isa_name) dest.append(isa_name)
@ -333,7 +333,7 @@ supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
def get_isa_from_cpu_capability( def get_isa_from_cpu_capability(
capability: Union[str, None], capability: Union[str, None],
vec_isa_list: List[VecISA], vec_isa_list: list[VecISA],
invalid_vec_isa: InvalidVecISA, invalid_vec_isa: InvalidVecISA,
): ):
# AMX setting is not supported in eager # AMX setting is not supported in eager
@ -364,8 +364,8 @@ def get_isa_from_cpu_capability(
# might have too much redundant content that is useless for ISA check. Hence, # might have too much redundant content that is useless for ISA check. Hence,
# we only cache some key isa information. # we only cache some key isa information.
@functools.lru_cache(None) @functools.lru_cache(None)
def valid_vec_isa_list() -> List[VecISA]: def valid_vec_isa_list() -> list[VecISA]:
isa_list: List[VecISA] = [] isa_list: list[VecISA] = []
if sys.platform == "darwin" and platform.processor() == "arm": if sys.platform == "darwin" and platform.processor() == "arm":
isa_list.append(VecNEON()) isa_list.append(VecNEON())
@ -411,7 +411,7 @@ def pick_vec_isa() -> VecISA:
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]): if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
return VecAVX2() return VecAVX2()
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list() _valid_vec_isa_list: list[VecISA] = valid_vec_isa_list()
if not _valid_vec_isa_list: if not _valid_vec_isa_list:
return invalid_vec_isa return invalid_vec_isa

View File

@ -54,13 +54,7 @@ from typing import (
Callable, Callable,
cast, cast,
ContextManager, ContextManager,
Dict,
Generator,
Iterator,
List,
Optional, Optional,
Sequence,
Type,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -99,6 +93,8 @@ from torch.utils.weak import TensorWeakRef
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator, Iterator, Sequence
from torch._inductor.utils import InputType from torch._inductor.utils import InputType
from torch.types import _bool from torch.types import _bool
@ -357,12 +353,12 @@ def get_manager(
def cudagraphify_impl( def cudagraphify_impl(
model: ModelType, model: ModelType,
inputs: List[InputType], inputs: list[InputType],
static_input_idxs: Sequence[int], static_input_idxs: Sequence[int],
*args: Any, *args: Any,
**kwargs: Any, **kwargs: Any,
) -> ModelType: ) -> ModelType:
fn_cache: Dict[tuple[int, ...], Callable[..., Any]] = {} fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {}
# Detect int inputs: we need to index on these # Detect int inputs: we need to index on these
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)] int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
@ -372,7 +368,7 @@ def cudagraphify_impl(
del inputs del inputs
def deferred_cudagraphify(inputs: List[InputType]) -> OutputType: def deferred_cudagraphify(inputs: list[InputType]) -> OutputType:
nonlocal has_warn nonlocal has_warn
int_key = get_ints(inputs) int_key = get_ints(inputs)
@ -405,7 +401,7 @@ def cudagraphify_impl(
def cudagraphify( def cudagraphify(
model: ModelType, model: ModelType,
inputs: List[InputType], inputs: list[InputType],
static_input_idxs: Sequence[int] = (), static_input_idxs: Sequence[int] = (),
*, *,
device_index: int, device_index: int,
@ -466,7 +462,7 @@ class StorageWeakRefWrapper:
@classmethod @classmethod
def from_weakref_and_data_ptr( def from_weakref_and_data_ptr(
cls: Type[S], cls: type[S],
cdata: Any, cdata: Any,
data_ptr: int, data_ptr: int,
extra_ref_check: Optional[Callable[[], bool]] = None, extra_ref_check: Optional[Callable[[], bool]] = None,
@ -561,9 +557,9 @@ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
PathOutputIndex = tuple[int, int] PathOutputIndex = tuple[int, int]
# For each node in the path, for each output, is the output alive # For each node in the path, for each output, is the output alive
PathLiveness = List[List[bool]] PathLiveness = list[list[bool]]
StackTraces = List[Optional[str]] StackTraces = list[Optional[str]]
class CUDAWarmupNode: class CUDAWarmupNode:
@ -600,8 +596,8 @@ class CUDAWarmupNode:
self.wrapped_function = wrapped_function self.wrapped_function = wrapped_function
self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
self.cuda_graphs_pool = cuda_graphs_pool self.cuda_graphs_pool = cuda_graphs_pool
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = [] self.outputs_weakrefs: list[Optional[StorageWeakRefWrapper]] = []
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = [] self.tensor_weakrefs: list[Optional[TensorWeakRef]] = []
self.existing_cuda_graph = existing_cuda_graph self.existing_cuda_graph = existing_cuda_graph
self.has_run = False self.has_run = False
self.device_index = device_index self.device_index = device_index
@ -619,7 +615,7 @@ class CUDAWarmupNode:
[t.data_ptr() for t in self.path_live_weakrefs() if t()] [t.data_ptr() for t in self.path_live_weakrefs() if t()]
) )
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]: def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]:
non_cudagraph_inps = [ non_cudagraph_inps = [
weakref.ref(t.untyped_storage()) weakref.ref(t.untyped_storage())
for t in itertools.chain(new_inputs, self.wrapped_function.constants) for t in itertools.chain(new_inputs, self.wrapped_function.constants)
@ -707,9 +703,9 @@ class CUDAWarmupNode:
# Aliases for List that say what the indices denote # Aliases for List that say what the indices denote
InputList = List # input indexes InputList = list # input indexes
OutputList = List # output indexes OutputList = list # output indexes
LevelList = List # levels (distance from root of tree) LevelList = list # levels (distance from root of tree)
class OutputAliasInfo: class OutputAliasInfo:
@ -772,7 +768,7 @@ class CUDAGraphNode:
wrapped_function: WrappedFunction, wrapped_function: WrappedFunction,
id: GraphID, id: GraphID,
parent: Optional[CUDAGraphNode], parent: Optional[CUDAGraphNode],
inputs: List[InputType], inputs: list[InputType],
cuda_graphs_pool: tuple[int, int], cuda_graphs_pool: tuple[int, int],
device_index: int, device_index: int,
stack_traces: Optional[StackTraces], stack_traces: Optional[StackTraces],
@ -800,7 +796,7 @@ class CUDAGraphNode:
# A single wrapped function may be recorded multiple times if memory patterns or # A single wrapped function may be recorded multiple times if memory patterns or
# invariants change from one execution to the next # invariants change from one execution to the next
self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) self.children: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
# StorageWeakRef maintains whether the Storage C++ object remains allocated, # StorageWeakRef maintains whether the Storage C++ object remains allocated,
# not whether the corresponding memory has been deallocated. In order # not whether the corresponding memory has been deallocated. In order
@ -825,7 +821,7 @@ class CUDAGraphNode:
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = [] self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
# tensors which are outputs of previous graphs in the tree # tensors which are outputs of previous graphs in the tree
self.cudagraph_managed_idxs: List[int] = [ self.cudagraph_managed_idxs: list[int] = [
idx idx
for idx, t in enumerate(inputs) for idx, t in enumerate(inputs)
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t) if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
@ -845,7 +841,7 @@ class CUDAGraphNode:
# and also aliases an output of the current CUDAGraphNode # and also aliases an output of the current CUDAGraphNode
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs) self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
self.static_input_idxs: List[int] = list( self.static_input_idxs: list[int] = list(
OrderedSet(wrapped_function.static_input_idxs) OrderedSet(wrapped_function.static_input_idxs)
| OrderedSet(self.cudagraph_managed_idxs) | OrderedSet(self.cudagraph_managed_idxs)
) )
@ -866,8 +862,8 @@ class CUDAGraphNode:
def maybe_get_static_data_ptr( def maybe_get_static_data_ptr(
idx: int, idx: int,
inputs: List[InputType], inputs: list[InputType],
static_input_idxs: List[int], static_input_idxs: list[int],
) -> Optional[int]: ) -> Optional[int]:
inp = inputs[idx] inp = inputs[idx]
if isinstance(inp, torch.Tensor) and idx in static_input_idxs: if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
@ -888,7 +884,7 @@ class CUDAGraphNode:
# fresh allocations. # fresh allocations.
# precompute expanded dims to avoid computing in the hot path # precompute expanded dims to avoid computing in the hot path
self.expanded_dims: List[List[int]] = [ self.expanded_dims: list[list[int]] = [
get_expanded_dims(x) get_expanded_dims(x)
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
else [] else []
@ -903,11 +899,11 @@ class CUDAGraphNode:
# List of Tuples of (depth, output_index) that index into node at depth # List of Tuples of (depth, output_index) that index into node at depth
# number of nodes from root and output_index of outputs. Will index into # number of nodes from root and output_index of outputs. Will index into
# path_weakrefs. # path_weakrefs.
self.expected_dead_indices_before_graph: List[PathOutputIndex] = [] self.expected_dead_indices_before_graph: list[PathOutputIndex] = []
self.expected_dead_indices_after_graph: List[PathOutputIndex] = [] self.expected_dead_indices_after_graph: list[PathOutputIndex] = []
# all live indices after graph recording # all live indices after graph recording
self.live_indices_after_graph: List[PathOutputIndex] = [] self.live_indices_after_graph: list[PathOutputIndex] = []
if self.parent is not None: if self.parent is not None:
previous_liveness = self.parent.recorded_liveness_after_graph previous_liveness = self.parent.recorded_liveness_after_graph
@ -934,7 +930,7 @@ class CUDAGraphNode:
# we reconstruct tensors at the correct data pointers of our inputs which are # we reconstruct tensors at the correct data pointers of our inputs which are
# non owning and do not prevent deallocation. On subsequent executions, input values # non owning and do not prevent deallocation. On subsequent executions, input values
# will be copied over to these tensors. # will be copied over to these tensors.
self.reconstructed_inputs: List[InputType] = [ self.reconstructed_inputs: list[InputType] = [
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x)) self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
if isinstance(x, torch.Tensor) if isinstance(x, torch.Tensor)
else x else x
@ -983,7 +979,7 @@ class CUDAGraphNode:
self.recording_outputs: Optional[OutputType] = self._record( self.recording_outputs: Optional[OutputType] = self._record(
wrapped_function.model, recording_inputs wrapped_function.model, recording_inputs
) )
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = [] self.outputs_metadata: OutputList[Union[dict[str, Any], int, None]] = []
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent # As with inputs, we do not want to keep the outputs permanently alive because that would prevent
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata # their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
@ -1001,7 +997,7 @@ class CUDAGraphNode:
self.graph.replay() self.graph.replay()
def _copy_inputs_and_remove_from_src( def _copy_inputs_and_remove_from_src(
self, dsts: List[InputType], srcs: List[InputType] self, dsts: list[InputType], srcs: list[InputType]
) -> None: ) -> None:
dst_tensors = [] dst_tensors = []
src_tensors = [] src_tensors = []
@ -1016,7 +1012,7 @@ class CUDAGraphNode:
if dst_tensors: if dst_tensors:
torch._foreach_copy_(dst_tensors, src_tensors) torch._foreach_copy_(dst_tensors, src_tensors)
def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None: def check_static_inputs_are_stable(self, new_inputs: list[InputType]) -> None:
# avoid checking managed tensor static points since we already checked those in check_invariants # avoid checking managed tensor static points since we already checked those in check_invariants
if ( if (
not self.rerecord_if_static_inputs_change not self.rerecord_if_static_inputs_change
@ -1036,7 +1032,7 @@ class CUDAGraphNode:
) )
torch._check(False, lambda: error_msg) torch._check(False, lambda: error_msg)
def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType: def run_first_inputs(self, new_inputs: list[InputType]) -> OutputType:
if config.triton.fast_path_cudagraph_asserts: if config.triton.fast_path_cudagraph_asserts:
self.debug_check_invariants_before_invocation() self.debug_check_invariants_before_invocation()
@ -1048,7 +1044,7 @@ class CUDAGraphNode:
assert outputs is not None assert outputs is not None
return outputs return outputs
def run(self, new_inputs: List[InputType]) -> OutputType: def run(self, new_inputs: list[InputType]) -> OutputType:
self.check_static_inputs_are_stable(new_inputs) self.check_static_inputs_are_stable(new_inputs)
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs) self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
@ -1130,7 +1126,7 @@ class CUDAGraphNode:
def prepare_alias_info_for_tensor_construction( def prepare_alias_info_for_tensor_construction(
self, self,
out_alias_info: Optional[OutputAliasInfo], out_alias_info: Optional[OutputAliasInfo],
metadata: Union[Dict[str, Any], int, None], metadata: Union[dict[str, Any], int, None],
) -> Union[UntypedStorage, None, int]: ) -> Union[UntypedStorage, None, int]:
if ( if (
isinstance(metadata, (int, type(None))) isinstance(metadata, (int, type(None)))
@ -1149,7 +1145,7 @@ class CUDAGraphNode:
def prepare_storages_for_construction( def prepare_storages_for_construction(
self, self,
) -> List[Union[UntypedStorage, None, int]]: ) -> list[Union[UntypedStorage, None, int]]:
output_storages = [] output_storages = []
for output_storage_alias, metadata in zip( for output_storage_alias, metadata in zip(
self.output_storage_alias, self.outputs_metadata self.output_storage_alias, self.outputs_metadata
@ -1173,7 +1169,7 @@ class CUDAGraphNode:
return False return False
return True return True
def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType: def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
"Record the model" "Record the model"
def static_input_iter() -> Generator[torch.Tensor, None, None]: def static_input_iter() -> Generator[torch.Tensor, None, None]:
@ -1185,7 +1181,7 @@ class CUDAGraphNode:
yield _inp yield _inp
# see: output_is_alias_of_persistent_static_inputs above # see: output_is_alias_of_persistent_static_inputs above
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = { static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper] = {
inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp) inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
for inp in itertools.chain( for inp in itertools.chain(
static_input_iter(), self.wrapped_function.constants static_input_iter(), self.wrapped_function.constants
@ -1229,7 +1225,7 @@ class CUDAGraphNode:
def _add_first_outputs( def _add_first_outputs(
self, self,
outputs: OutputType, outputs: OutputType,
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper], static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper],
) -> None: ) -> None:
"Add the outputs from the first invocation of the node and set up metadata" "Add the outputs from the first invocation of the node and set up metadata"
@ -1243,7 +1239,7 @@ class CUDAGraphNode:
assert len(self.outputs_weakrefs) == 0 assert len(self.outputs_weakrefs) == 0
# index from data pointer to index in outputs # index from data pointer to index in outputs
output_new_storages_index: Dict[StorageDataPtr, int] = {} output_new_storages_index: dict[StorageDataPtr, int] = {}
self.unaliased_in_all_paths = [False for _ in range(len(outputs))] self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
self.static_output_tensors = [None for _ in range(len(outputs))] self.static_output_tensors = [None for _ in range(len(outputs))]
@ -1431,8 +1427,8 @@ class CUDAGraphNode:
@staticmethod @staticmethod
def _check_liveness( def _check_liveness(
indices: List[PathOutputIndex], indices: list[PathOutputIndex],
output_refs: List[List[Optional[StorageWeakRefWrapper]]], output_refs: list[list[Optional[StorageWeakRefWrapper]]],
) -> bool: ) -> bool:
"Check that all of the indices specified are dead references" "Check that all of the indices specified are dead references"
for depth, output_index in indices: for depth, output_index in indices:
@ -1448,8 +1444,8 @@ class CUDAGraphNode:
@staticmethod @staticmethod
def _get_different_indices( def _get_different_indices(
prev: List[List[bool]], curr: List[List[bool]] prev: list[list[bool]], curr: list[list[bool]]
) -> List[PathOutputIndex]: ) -> list[PathOutputIndex]:
"Find indices where the two lists differ." "Find indices where the two lists differ."
dead_indices = [] dead_indices = []
assert len(prev) <= len(curr) assert len(prev) <= len(curr)
@ -1463,8 +1459,8 @@ class CUDAGraphNode:
@staticmethod @staticmethod
def _get_liveness( def _get_liveness(
weakrefs: List[List[Optional[StorageWeakRefWrapper]]], weakrefs: list[list[Optional[StorageWeakRefWrapper]]],
) -> List[List[bool]]: ) -> list[list[bool]]:
"Maps weakrefs to true if the reference is alive and false otherwise" "Maps weakrefs to true if the reference is alive and false otherwise"
if len(weakrefs) == 0: if len(weakrefs) == 0:
return [] return []
@ -1472,7 +1468,7 @@ class CUDAGraphNode:
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs] return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
def debug_assert_invariants( def debug_assert_invariants(
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex] self, expected_liveness: list[list[bool]], newly_dead: list[PathOutputIndex]
) -> None: ) -> None:
if not config.triton.fast_path_cudagraph_asserts: if not config.triton.fast_path_cudagraph_asserts:
return return
@ -1520,7 +1516,7 @@ class CUDAGraphNode:
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
) )
def data_ptrs_dead_since_invocation(self) -> List[int]: def data_ptrs_dead_since_invocation(self) -> list[int]:
""" """
Since this node was invoked, return data ptrs of all tensor outputs that have died Since this node was invoked, return data ptrs of all tensor outputs that have died
in the current executing tree path. in the current executing tree path.
@ -1568,7 +1564,7 @@ class CUDAGraphNode:
@staticmethod @staticmethod
def _tensor_metadata( def _tensor_metadata(
x: torch.Tensor, ignore_storage_offset: bool = True x: torch.Tensor, ignore_storage_offset: bool = True
) -> Dict[str, Any]: ) -> dict[str, Any]:
assert isinstance(x, torch.Tensor) assert isinstance(x, torch.Tensor)
# We ignore the storage offset for inputs, but not for outputs # We ignore the storage offset for inputs, but not for outputs
# TODO: - should we make the storage resizable ? # TODO: - should we make the storage resizable ?
@ -1583,19 +1579,19 @@ class CUDAGraphNode:
} }
def _reconstruct_from_tensor_metadata( def _reconstruct_from_tensor_metadata(
self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None self, metadata: dict[str, Any], storage: Optional[UntypedStorage] = None
) -> Tensor: ) -> Tensor:
s = self.create_storage(metadata) if storage is None else storage s = self.create_storage(metadata) if storage is None else storage
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type] return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type]
def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage: def create_storage(self, metadata: dict[str, Any]) -> torch.types.Storage:
return torch._C._construct_storage_from_data_pointer( return torch._C._construct_storage_from_data_pointer(
metadata["data_ptr"], metadata["device"], metadata["nbytes"] metadata["data_ptr"], metadata["device"], metadata["nbytes"]
) )
def _allocate_and_copy_recording_inputs( def _allocate_and_copy_recording_inputs(
self, inputs: List[InputType] self, inputs: list[InputType]
) -> List[InputType]: ) -> list[InputType]:
""" """
Allocate inputs for non static, non cudagraph managed tensors in the memory pool Allocate inputs for non static, non cudagraph managed tensors in the memory pool
and copy over the tensor values. and copy over the tensor values.
@ -1603,7 +1599,7 @@ class CUDAGraphNode:
torch.cuda.synchronize() torch.cuda.synchronize()
self.stream.wait_stream(torch.cuda.current_stream()) self.stream.wait_stream(torch.cuda.current_stream())
recording_inputs: List[InputType] = [] recording_inputs: list[InputType] = []
with warnings.catch_warnings(record=True), torch.cuda.device( with warnings.catch_warnings(record=True), torch.cuda.device(
self.device self.device
@ -1627,7 +1623,7 @@ class CUDAGraphNode:
return recording_inputs return recording_inputs
def check_invariants( def check_invariants(
self, inputs: List[InputType] self, inputs: list[InputType]
) -> tuple[CheckInvariantStatus, Callable[..., str]]: ) -> tuple[CheckInvariantStatus, Callable[..., str]]:
""" """
Checks if this node can be run. The same pattern of tensor liveness, static inputs, Checks if this node can be run. The same pattern of tensor liveness, static inputs,
@ -1714,7 +1710,7 @@ def get_cudagraph_segments(pool_id: tuple[int, int]) -> Any:
return [segment for segment in segments if segment["segment_pool_id"] == pool_id] return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[int]: def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> list[int]:
blocks = [] blocks = []
for segment in get_cudagraph_segments(pool_id): for segment in get_cudagraph_segments(pool_id):
@ -1728,7 +1724,7 @@ def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[in
return blocks return blocks
def format_tb(frames: List[Any]) -> str: def format_tb(frames: list[Any]) -> str:
formatted_traceback = [ formatted_traceback = [
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"]) traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
for entry in frames for entry in frames
@ -1740,7 +1736,7 @@ def format_tb(frames: List[Any]) -> str:
def check_memory_pool( def check_memory_pool(
device: int, device: int,
pool_id: tuple[int, int], pool_id: tuple[int, int],
live_storages_ptrs: List[StorageWeakRefWrapper], live_storages_ptrs: list[StorageWeakRefWrapper],
) -> None: ) -> None:
assert all( assert all(
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
@ -1839,12 +1835,12 @@ class CUDAGraphTreeManager:
# when they are first invoked, none of their inputs are outputs are outputs # when they are first invoked, none of their inputs are outputs are outputs
# of another node, nor are there any live outputs of another node whose # of another node, nor are there any live outputs of another node whose
# liveness would create a dependency. # liveness would create a dependency.
self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list) self.roots: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
# mapping from function id to wrapped function # mapping from function id to wrapped function
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {} self.ids_to_funcs: dict[FunctionID, WrappedFunction] = {}
self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {} self.ids_to_stack_traces: dict[FunctionID, Optional[StackTraces]] = {}
self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet() self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet()
# if we fail to increment generation, and are stuck warming up, # if we fail to increment generation, and are stuck warming up,
@ -1883,14 +1879,14 @@ class CUDAGraphTreeManager:
# mapping from graph_id to (function id to mutation type hint) since we are # mapping from graph_id to (function id to mutation type hint) since we are
# specializing on a particular combination of Parent Node -> Function ID. # specializing on a particular combination of Parent Node -> Function ID.
self.non_cudagraph_managed_mutation_hint: Dict[ self.non_cudagraph_managed_mutation_hint: dict[
Optional[GraphID], Dict[FunctionID, bool] Optional[GraphID], dict[FunctionID, bool]
] = defaultdict(dict) ] = defaultdict(dict)
self.warmup_node_counter = itertools.count(start=-1, step=-1) self.warmup_node_counter = itertools.count(start=-1, step=-1)
# mapping from graph_id to (function id to re-record count). We fall back to # mapping from graph_id to (function id to re-record count). We fall back to
# eager function if a function is re-recorded frequently on a node. # eager function if a function is re-recorded frequently on a node.
self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict( self.num_rerecord: dict[Optional[GraphID], dict[FunctionID, int]] = defaultdict(
lambda: defaultdict(lambda: 0) lambda: defaultdict(lambda: 0)
) )
@ -1916,7 +1912,7 @@ class CUDAGraphTreeManager:
# number of instances we had to checkpoint the function # number of instances we had to checkpoint the function
self.debug_checkpointing_counter = 0 self.debug_checkpointing_counter = 0
self.id_to_mode: Dict[FunctionID, CompilationMode] = {} self.id_to_mode: dict[FunctionID, CompilationMode] = {}
# Note: [Backward Generation Handling] # Note: [Backward Generation Handling]
# We generally perform a sequence of forward executions followed by backward executions. # We generally perform a sequence of forward executions followed by backward executions.
@ -1943,7 +1939,7 @@ class CUDAGraphTreeManager:
) )
) )
def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: def run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
assert self.graph is not None, "Running CUDAGraph after shutdown" assert self.graph is not None, "Running CUDAGraph after shutdown"
self.mode = self.id_to_mode[function_id] self.mode = self.id_to_mode[function_id]
out = self._run(new_inputs, function_id) out = self._run(new_inputs, function_id)
@ -1971,7 +1967,7 @@ class CUDAGraphTreeManager:
return GraphID(next(self.warmup_node_counter)) return GraphID(next(self.warmup_node_counter))
def _update_non_cudagraph_managed_mutation( def _update_non_cudagraph_managed_mutation(
self, function_id: FunctionID, inputs: List[InputType] self, function_id: FunctionID, inputs: list[InputType]
) -> None: ) -> None:
node_id = self._get_node_id() node_id = self._get_node_id()
if maybe_mutation_str := check_for_mutation( if maybe_mutation_str := check_for_mutation(
@ -2007,7 +2003,7 @@ class CUDAGraphTreeManager:
> torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit > torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
) )
def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType: def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
# we will try to end the current execution lazily, since # we will try to end the current execution lazily, since
# we dont want to do unnecessary checking of the existing outputs # we dont want to do unnecessary checking of the existing outputs
# on the hot path, but both recording and warmup only happen once # on the hot path, but both recording and warmup only happen once
@ -2152,7 +2148,7 @@ class CUDAGraphTreeManager:
self.current_node = None self.current_node = None
def record_function( def record_function(
self, new_inputs: List[InputType], function_id: FunctionID self, new_inputs: list[InputType], function_id: FunctionID
) -> OutputType: ) -> OutputType:
assert not isinstance(self.current_node, CUDAWarmupNode) assert not isinstance(self.current_node, CUDAWarmupNode)
graph_id = self.new_graph_id() graph_id = self.new_graph_id()
@ -2183,7 +2179,7 @@ class CUDAGraphTreeManager:
return node.run_first_inputs(new_inputs) return node.run_first_inputs(new_inputs)
def execute_node( def execute_node(
self, node: CUDAGraphNode, new_inputs: List[InputType] self, node: CUDAGraphNode, new_inputs: list[InputType]
) -> OutputType: ) -> OutputType:
self.current_node = node self.current_node = node
self.path_state = ExecutionState.EXECUTION self.path_state = ExecutionState.EXECUTION
@ -2191,7 +2187,7 @@ class CUDAGraphTreeManager:
return node.run(new_inputs) return node.run(new_inputs)
def run_eager( def run_eager(
self, new_inputs: List[InputType], function_id: FunctionID self, new_inputs: list[InputType], function_id: FunctionID
) -> OutputType: ) -> OutputType:
# this is only stored on current node, because when we start a new path, # this is only stored on current node, because when we start a new path,
# we will deallocate it # we will deallocate it
@ -2229,7 +2225,7 @@ class CUDAGraphTreeManager:
def add_function( def add_function(
self, self,
model: ModelType, model: ModelType,
inputs: List[InputType], inputs: list[InputType],
static_input_idxs: Sequence[int], static_input_idxs: Sequence[int],
stack_traces: Optional[StackTraces], stack_traces: Optional[StackTraces],
mode: CompilationMode, mode: CompilationMode,
@ -2409,7 +2405,7 @@ class CUDAGraphTreeManager:
# TODO: we could also allow the these weak refs to continue to be allocated, # TODO: we could also allow the these weak refs to continue to be allocated,
# but that adds some complications. # but that adds some complications.
stor_stack_trace: Dict[int, Optional[str]] = {} stor_stack_trace: dict[int, Optional[str]] = {}
for node in self.current_node._path_from_root: for node in self.current_node._path_from_root:
assert node.stack_traces is not None assert node.stack_traces is not None
assert len(node.tensor_weakrefs) == len(node.stack_traces) assert len(node.tensor_weakrefs) == len(node.stack_traces)
@ -2475,7 +2471,7 @@ class CUDAGraphTreeManager:
assert state is not None and device is not None assert state is not None and device is not None
# currently we deallocate on instead of allowing stale recordings # currently we deallocate on instead of allowing stale recordings
stale_storages: List[int] = [] stale_storages: list[int] = []
# remove cached tensors, otherwise they would prevent memory from being # remove cached tensors, otherwise they would prevent memory from being
# reclaimed in subsequent recordings # reclaimed in subsequent recordings
@ -2506,7 +2502,7 @@ class CUDAGraphTreeManager:
def live_cudagraph_pool_storages_in_curr_execution( def live_cudagraph_pool_storages_in_curr_execution(
self, self,
) -> List[StorageWeakRefPointer]: ) -> list[StorageWeakRefPointer]:
if self.current_node is None: if self.current_node is None:
return [] return []
# explicitly ignoring previous recorded outputs from past path # explicitly ignoring previous recorded outputs from past path

View File

@ -3,7 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from enum import Enum from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Sequence, Union from typing import Any, Callable, Optional, TYPE_CHECKING, Union
import torch import torch
from torch._dynamo.utils import counters from torch._dynamo.utils import counters
@ -11,14 +11,18 @@ from torch._inductor.utils import InputType
from torch.utils._ordered_set import OrderedSet from torch.utils._ordered_set import OrderedSet
if TYPE_CHECKING:
from collections.abc import Sequence
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
static_inputs_log = torch._logging.getArtifactLogger( static_inputs_log = torch._logging.getArtifactLogger(
__name__, "cudagraph_static_inputs" __name__, "cudagraph_static_inputs"
) )
OutputType = List[Optional[Union[int, torch.Tensor]]] OutputType = list[Optional[Union[int, torch.Tensor]]]
ModelType = Callable[[List[InputType]], OutputType] ModelType = Callable[[list[InputType]], OutputType]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -38,7 +42,7 @@ class PlaceholderInfo:
name: str name: str
stack_trace: Optional[str] stack_trace: Optional[str]
# This field is recursive, but never cyclic (since a node never uses itself) # This field is recursive, but never cyclic (since a node never uses itself)
users: List[PlaceholderInfo] users: list[PlaceholderInfo]
mutating_use_stack_trace: Optional[str] mutating_use_stack_trace: Optional[str]
@ -92,7 +96,7 @@ def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace) return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]: def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]:
return [ return [
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder" to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
] ]
@ -123,7 +127,7 @@ def get_mutation_stack_trace(
def check_for_mutation( def check_for_mutation(
func: WrappedFunction, func: WrappedFunction,
inputs: List[InputType], inputs: list[InputType],
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool], is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
) -> Optional[str]: ) -> Optional[str]:
# doesnt work for non-trees because the warmup run would apply mutation twice # doesnt work for non-trees because the warmup run would apply mutation twice
@ -160,7 +164,7 @@ def _get_use_stack_trace(node) -> Optional[str]:
def check_multiple_devices_or_any_cpu_nodes( def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: Dict[torch.device, torch.fx.Node] device_node_mapping: dict[torch.device, torch.fx.Node]
) -> Optional[str]: ) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")): if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})" msg = f"cpu device ({cpu_node.name})"
@ -180,7 +184,7 @@ def check_multiple_devices_or_any_cpu_nodes(
def check_lowering_disable_cudagraph( def check_lowering_disable_cudagraph(
device_node_mapping: Dict[torch.device, torch.fx.Node] device_node_mapping: dict[torch.device, torch.fx.Node]
): ):
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
@ -262,7 +266,7 @@ class CheckInvariantStatus(Enum):
def log_data_ptr_mismatch( def log_data_ptr_mismatch(
placeholders: Sequence[PlaceholderInfo], placeholders: Sequence[PlaceholderInfo],
inputs: List[InputType], inputs: list[InputType],
recorded_data_ptr: Sequence[Optional[int]], recorded_data_ptr: Sequence[Optional[int]],
target_idxs: Sequence[int], target_idxs: Sequence[int],
mismatch: CheckInvariantStatus, mismatch: CheckInvariantStatus,
@ -292,7 +296,7 @@ def log_data_ptr_mismatch(
def maybe_warning_due_to_dynamic_shape( def maybe_warning_due_to_dynamic_shape(
fn_cache: Dict[tuple[int, ...], Callable[..., Any]], fn_cache: dict[tuple[int, ...], Callable[..., Any]],
new_int_key: Any, new_int_key: Any,
) -> bool: ) -> bool:
num_cudagraphs = len(fn_cache.keys()) + 1 num_cudagraphs = len(fn_cache.keys()) + 1
@ -327,5 +331,5 @@ class CudagraphCachedInfo:
""" """
placeholders: Sequence[PlaceholderInfo] placeholders: Sequence[PlaceholderInfo]
stack_traces: List[Optional[str]] stack_traces: list[Optional[str]]
cudagraph_fail_reasons: List[str] cudagraph_fail_reasons: list[str]

View File

@ -12,7 +12,8 @@ import pickle
import pstats import pstats
import shutil import shutil
import subprocess import subprocess
from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union from collections.abc import Iterator
from typing import Any, Callable, IO, Optional, Union
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -39,7 +40,7 @@ from .virtualized import V
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
SchedulerNodeList = List[Any] SchedulerNodeList = list[Any]
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"]) BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"] GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
@ -54,7 +55,7 @@ def has_dot() -> bool:
def draw_buffers( def draw_buffers(
nodes: List[BaseSchedulerNode], nodes: list[BaseSchedulerNode],
print_graph: bool = False, print_graph: bool = False,
fname: Optional[str] = None, fname: Optional[str] = None,
) -> None: ) -> None:
@ -99,7 +100,7 @@ def draw_buffers(
) )
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph: def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph:
""" """
Creates a FX Graph from a list of SchedulerNode objects. Creates a FX Graph from a list of SchedulerNode objects.
""" """
@ -199,7 +200,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
def update_orig_fx_node_name_to_buf_name( def update_orig_fx_node_name_to_buf_name(
nodes: Optional[SchedulerNodeList], nodes: Optional[SchedulerNodeList],
node_name_to_buf_name: Dict[str, str], node_name_to_buf_name: dict[str, str],
parent_buf_name: Optional[str] = None, parent_buf_name: Optional[str] = None,
n_origins: int = 0, n_origins: int = 0,
) -> None: ) -> None:
@ -233,8 +234,8 @@ def update_orig_fx_node_name_to_buf_name(
def get_node_name_to_buf_meta( def get_node_name_to_buf_meta(
node_name_to_buf_name: Dict[str, str] node_name_to_buf_name: dict[str, str]
) -> Dict[str, BufMeta]: ) -> dict[str, BufMeta]:
buf_name_to_n_node = {} buf_name_to_n_node = {}
for node_name, buf_name in node_name_to_buf_name.items(): for node_name, buf_name in node_name_to_buf_name.items():
if buf_name not in buf_name_to_n_node: if buf_name not in buf_name_to_n_node:
@ -256,7 +257,7 @@ def annotate_orig_fx_with_snodes(
""" """
Creates a FX Graph from a list of SchedulerNode objects. Creates a FX Graph from a list of SchedulerNode objects.
""" """
node_name_to_buf_name: Dict[str, str] = {} node_name_to_buf_name: dict[str, str] = {}
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name) update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
if node_name_to_buf_name is None: if node_name_to_buf_name is None:
return return
@ -309,7 +310,7 @@ def enable_aot_logging() -> Iterator[None]:
class DebugContext: class DebugContext:
_counter = itertools.count() _counter = itertools.count()
_inductor_triton_kernel_to_post_grad_node_info: Dict[str, List[str]] = {} _inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
@staticmethod @staticmethod
def create_debug_dir(folder_name: str) -> Optional[str]: def create_debug_dir(folder_name: str) -> Optional[str]:
@ -425,7 +426,7 @@ class DebugContext:
def __exit__( def __exit__(
self, self,
exc_type: Optional[Type[BaseException]], exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException], exc_val: Optional[BaseException],
exc_tb: Optional[Any], exc_tb: Optional[Any],
) -> None: ) -> None:
@ -474,7 +475,7 @@ class DebugFormatter:
def fx_graph( def fx_graph(
self, self,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
inputs: List[torch.Tensor], inputs: list[torch.Tensor],
) -> None: ) -> None:
with self.fopen("fx_graph_runnable.py") as fd: with self.fopen("fx_graph_runnable.py") as fd:
save_dir = None save_dir = None
@ -504,7 +505,7 @@ class DebugFormatter:
def fx_graph_transformed( def fx_graph_transformed(
self, self,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
inputs: List[torch.Tensor], inputs: list[torch.Tensor],
) -> None: ) -> None:
with self.fopen("fx_graph_transformed.py") as fd: with self.fopen("fx_graph_transformed.py") as fd:
fd.write(gm.print_readable(print_output=False)) fd.write(gm.print_readable(print_output=False))
@ -557,14 +558,14 @@ class DebugFormatter:
def log_autotuning_results( def log_autotuning_results(
self, self,
name: str, name: str,
input_nodes: List[ir.IRNode], input_nodes: list[ir.IRNode],
timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821 timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
elapse: float, elapse: float,
precompile_elapse: float, precompile_elapse: float,
) -> None: ) -> None:
from .ir import FixedLayout from .ir import FixedLayout
def build_node_info(node: ir.IRNode) -> Dict[str, str]: def build_node_info(node: ir.IRNode) -> dict[str, str]:
if hasattr(node, "name"): if hasattr(node, "name"):
node_name = node.name node_name = node.name
else: else:
@ -725,7 +726,7 @@ def aot_inductor_minifier_wrapper(
func: Callable[..., str], func: Callable[..., str],
exported_program: torch.export.ExportedProgram, exported_program: torch.export.ExportedProgram,
*, *,
inductor_configs: Dict[str, Any], inductor_configs: dict[str, Any],
package_path: Optional[Union[str, io.BytesIO]] = None, package_path: Optional[Union[str, io.BytesIO]] = None,
) -> str: ) -> str:
from torch._inductor import config from torch._inductor import config

View File

@ -4,7 +4,7 @@ import logging
import math import math
import sys import sys
import typing import typing
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
from typing_extensions import ParamSpec from typing_extensions import ParamSpec
import torch import torch
@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
def register_decomposition( def register_decomposition(
ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]] ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
if op in decompositions: if op in decompositions:
@ -170,7 +170,7 @@ def clamp(
@register_decomposition([aten.full]) @register_decomposition([aten.full])
def full( def full(
size: List[Union[int, torch.SymInt]], size: list[Union[int, torch.SymInt]],
fill_value: torch.types.Number, fill_value: torch.types.Number,
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
@ -205,8 +205,8 @@ def index_add(
# cool with strides and everything goes to empty_strided) # cool with strides and everything goes to empty_strided)
@register_decomposition([aten.empty_permuted.default]) @register_decomposition([aten.empty_permuted.default])
def empty_permuted( def empty_permuted(
size: List[Union[int, torch.SymInt]], size: list[Union[int, torch.SymInt]],
physical_layout: List[int], physical_layout: list[int],
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
perm = [0] * len(size) perm = [0] * len(size)
@ -220,14 +220,14 @@ def convolution_backward(
grad_output: torch.Tensor, grad_output: torch.Tensor,
input: torch.Tensor, input: torch.Tensor,
weight: torch.Tensor, weight: torch.Tensor,
bias_sizes: List[int], bias_sizes: list[int],
stride: Union[int, List[int]], stride: Union[int, list[int]],
padding: Union[int, List[int]], padding: Union[int, list[int]],
dilation: Union[int, List[int]], dilation: Union[int, list[int]],
transposed: bool, transposed: bool,
output_padding: List[int], output_padding: list[int],
groups: int, groups: int,
output_mask: List[bool], output_mask: list[bool],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if not output_mask[2] or not is_gpu(grad_output.device.type): if not output_mask[2] or not is_gpu(grad_output.device.type):
return NotImplemented return NotImplemented
@ -345,7 +345,7 @@ def mm(
# don't remove ALL empty tensors, only the naughty ones) # don't remove ALL empty tensors, only the naughty ones)
@register_decomposition([aten.cat.default]) @register_decomposition([aten.cat.default])
def cat( def cat(
tensors: List[torch.Tensor], tensors: list[torch.Tensor],
dim: int = 0, dim: int = 0,
) -> torch.Tensor: ) -> torch.Tensor:
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@ -515,7 +515,7 @@ def narrow_copy(
@register_decomposition([aten.view_copy.default]) @register_decomposition([aten.view_copy.default])
def view_copy_default( def view_copy_default(
self: torch.Tensor, self: torch.Tensor,
size: List[Union[int, torch.SymInt]], size: list[Union[int, torch.SymInt]],
) -> torch.Tensor: ) -> torch.Tensor:
return aten.view(self, size).clone() return aten.view(self, size).clone()
@ -639,7 +639,7 @@ def randint_like_low(
@register_decomposition(aten.randint.default) @register_decomposition(aten.randint.default)
def randint( def randint(
high: int, high: int,
size: List[Union[int, torch.SymInt]], size: list[Union[int, torch.SymInt]],
**kwargs: Any, **kwargs: Any,
) -> torch.Tensor: ) -> torch.Tensor:
return aten.randint.low(0, high, size, **kwargs) return aten.randint.low(0, high, size, **kwargs)
@ -731,11 +731,11 @@ def grid_sampler_2d(
@register_decomposition(aten._foreach_addcmul.Scalar) @register_decomposition(aten._foreach_addcmul.Scalar)
def _foreach_addcmul_scalar( def _foreach_addcmul_scalar(
self: List[torch.Tensor], self: list[torch.Tensor],
left_tensors: List[torch.Tensor], left_tensors: list[torch.Tensor],
right_tensors: List[torch.Tensor], right_tensors: list[torch.Tensor],
scalar: float = 1, scalar: float = 1,
) -> List[torch.Tensor]: ) -> list[torch.Tensor]:
return aten._foreach_add.List( return aten._foreach_add.List(
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
) )
@ -743,11 +743,11 @@ def _foreach_addcmul_scalar(
@register_decomposition(aten._foreach_addcdiv.Scalar) @register_decomposition(aten._foreach_addcdiv.Scalar)
def _foreach_addcdiv_scalar( def _foreach_addcdiv_scalar(
self: List[torch.Tensor], self: list[torch.Tensor],
left_tensors: List[torch.Tensor], left_tensors: list[torch.Tensor],
right_tensors: List[torch.Tensor], right_tensors: list[torch.Tensor],
scalar: float = 1, scalar: float = 1,
) -> List[torch.Tensor]: ) -> list[torch.Tensor]:
return aten._foreach_add.List( return aten._foreach_add.List(
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
) )
@ -755,10 +755,10 @@ def _foreach_addcdiv_scalar(
@register_decomposition(aten._foreach_lerp.Scalar) @register_decomposition(aten._foreach_lerp.Scalar)
def _foreach_lerp_scalar( def _foreach_lerp_scalar(
start_tensors: List[torch.Tensor], start_tensors: list[torch.Tensor],
end_tensors: List[torch.Tensor], end_tensors: list[torch.Tensor],
weight: torch.types.Number, weight: torch.types.Number,
) -> List[torch.Tensor]: ) -> list[torch.Tensor]:
return aten._foreach_add.List( return aten._foreach_add.List(
start_tensors, start_tensors,
aten._foreach_mul.Scalar( aten._foreach_mul.Scalar(
@ -769,10 +769,10 @@ def _foreach_lerp_scalar(
@register_decomposition(aten._foreach_lerp.ScalarList) @register_decomposition(aten._foreach_lerp.ScalarList)
def _foreach_lerp_scalarlist( def _foreach_lerp_scalarlist(
start_tensors: List[torch.Tensor], start_tensors: list[torch.Tensor],
end_tensors: List[torch.Tensor], end_tensors: list[torch.Tensor],
scalars: List[torch.types.Number], scalars: list[torch.types.Number],
) -> List[torch.Tensor]: ) -> list[torch.Tensor]:
return aten._foreach_add.List( return aten._foreach_add.List(
start_tensors, start_tensors,
aten._foreach_mul.ScalarList( aten._foreach_mul.ScalarList(
@ -814,13 +814,13 @@ def miopen_batch_norm(
@functools.lru_cache(None) @functools.lru_cache(None)
def fast_random_decomps() -> Dict[Any, Callable[..., Any]]: def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
return {**decompositions, **extra_random_decomps} return {**decompositions, **extra_random_decomps}
# TODO(aakhundov): replace this (and the above) Any by more # TODO(aakhundov): replace this (and the above) Any by more
# specific type and fix all the cascading mypy errors # specific type and fix all the cascading mypy errors
def select_decomp_table() -> Dict[Any, Callable[..., Any]]: def select_decomp_table() -> dict[Any, Callable[..., Any]]:
"""decomps can change based on config""" """decomps can change based on config"""
if config.fallback_random: if config.fallback_random:
return decompositions return decompositions
@ -965,10 +965,10 @@ def index_reduce(
@register_decomposition(aten.max_pool2d_with_indices) @register_decomposition(aten.max_pool2d_with_indices)
def max_pool2d_with_indices( def max_pool2d_with_indices(
x: torch.Tensor, x: torch.Tensor,
kernel_size: List[int], kernel_size: list[int],
stride: Optional[Union[int, List[int]]] = None, stride: Optional[Union[int, list[int]]] = None,
padding: Union[int, List[int]] = 0, padding: Union[int, list[int]] = 0,
dilation: Union[int, List[int]] = 1, dilation: Union[int, list[int]] = 1,
ceil_mode: bool = False, ceil_mode: bool = False,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
if dilation == 1: if dilation == 1:
@ -1015,7 +1015,7 @@ def max_pool2d_with_indices(
@register_decomposition(aten.adaptive_max_pool2d) @register_decomposition(aten.adaptive_max_pool2d)
def adaptive_max_pool2d( def adaptive_max_pool2d(
x: torch.Tensor, output_size: List[int] x: torch.Tensor, output_size: list[int]
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
*batch, h_in, w_in = x.shape *batch, h_in, w_in = x.shape
h_out, w_out = output_size h_out, w_out = output_size

View File

@ -4,8 +4,8 @@ import dataclasses
import itertools import itertools
import logging import logging
import re import re
import typing from collections.abc import Sequence
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
from unittest.mock import patch from unittest.mock import patch
import sympy import sympy
@ -38,7 +38,7 @@ class Dep(abc.ABC):
index: sympy.Expr index: sympy.Expr
@abc.abstractmethod @abc.abstractmethod
def rename(self, renames: Dict[str, str]) -> "Dep": def rename(self, renames: dict[str, str]) -> "Dep":
pass pass
@abc.abstractmethod @abc.abstractmethod
@ -197,7 +197,7 @@ class MemoryDep(Dep):
return out return out
@property @property
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]: def ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
"""{c0: 128, c1: 512, ...}""" """{c0: 128, c1: 512, ...}"""
return dict(zip(self.var_names, self.size)) return dict(zip(self.var_names, self.size))
@ -221,7 +221,7 @@ class MemoryDep(Dep):
numel = numel * size numel = numel * size
return numel # type: ignore[return-value] return numel # type: ignore[return-value]
def rename(self, renames: Dict[str, str]) -> "MemoryDep": def rename(self, renames: dict[str, str]) -> "MemoryDep":
if self.name in renames: if self.name in renames:
return MemoryDep( return MemoryDep(
renames[self.name], renames[self.name],
@ -299,7 +299,7 @@ class StarDep(Dep):
def get_numel(self) -> sympy.Expr: def get_numel(self) -> sympy.Expr:
return V.graph.get_numel(self.name) # type: ignore[return-value] return V.graph.get_numel(self.name) # type: ignore[return-value]
def rename(self, renames: Dict[str, str]) -> "StarDep": def rename(self, renames: dict[str, str]) -> "StarDep":
if self.name in renames: if self.name in renames:
return StarDep(renames[self.name], self.mode) return StarDep(renames[self.name], self.mode)
return self return self
@ -347,7 +347,7 @@ class WeakDep(Dep):
def get_numel(self) -> sympy.Expr: def get_numel(self) -> sympy.Expr:
return sympy.S.One return sympy.S.One
def rename(self, renames: Dict[str, str]) -> "WeakDep": def rename(self, renames: dict[str, str]) -> "WeakDep":
if self.name in renames: if self.name in renames:
return WeakDep(renames[self.name], self.mutating_buf) return WeakDep(renames[self.name], self.mutating_buf)
return self return self
@ -374,10 +374,10 @@ class ReadWrites:
reads: OrderedSet[Dep] reads: OrderedSet[Dep]
writes: OrderedSet[Dep] writes: OrderedSet[Dep]
index_exprs: OrderedSet[IndexExprDep] index_exprs: OrderedSet[IndexExprDep]
range_vars: Optional[List[sympy.Expr]] = None range_vars: Optional[list[sympy.Expr]] = None
var_ranges: Optional[VarRanges] = None var_ranges: Optional[VarRanges] = None
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites": def rename(self, renames: dict[str, str]) -> "ReadWrites":
return ReadWrites( return ReadWrites(
OrderedSet(dep.rename(renames) for dep in self.reads), OrderedSet(dep.rename(renames) for dep in self.reads),
OrderedSet(dep.rename(renames) for dep in self.writes), OrderedSet(dep.rename(renames) for dep in self.writes),
@ -405,7 +405,7 @@ class ReadWrites:
return ReadWrites(reads - writes, writes, index_exprs) return ReadWrites(reads - writes, writes, index_exprs)
@staticmethod @staticmethod
def merge_list(read_writes: List["ReadWrites"]): def merge_list(read_writes: list["ReadWrites"]):
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes]) all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes]) all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
@ -564,7 +564,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
def index_vars_no_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str): def index_vars_no_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str):
var_ranges, add_var = var_builder(prefix) var_ranges, add_var = var_builder(prefix)
args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes] args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
return args, var_ranges return args, var_ranges
@ -572,8 +572,8 @@ def index_vars_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str = "d"):
from .ir import SqueezeView from .ir import SqueezeView
var_ranges, add_var = var_builder(prefix) var_ranges, add_var = var_builder(prefix)
args: List[List[sympy.Expr]] = [] args: list[list[sympy.Expr]] = []
new_sizes: List[List[sympy.Expr]] = [] new_sizes: list[list[sympy.Expr]] = []
for size in argsizes: for size in argsizes:
new_size, reindex = SqueezeView.squeezer(size) new_size, reindex = SqueezeView.squeezer(size)
new_sizes.append(new_size) new_sizes.append(new_size)
@ -653,7 +653,7 @@ def extract_loop_body_with_args(fn, args, var_ranges, normalize=False):
def extract_input_node_reduction_ranges( def extract_input_node_reduction_ranges(
input_node: "torch._inductor.ir.IRNode", input_node: "torch._inductor.ir.IRNode",
) -> tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]: ) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]:
""" """
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same. Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes. It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
@ -663,8 +663,8 @@ def extract_input_node_reduction_ranges(
from .ir import ComputedBuffer, ExternKernel, Loops from .ir import ComputedBuffer, ExternKernel, Loops
size: Optional[List[sympy.Expr]] size: Optional[list[sympy.Expr]]
reduction_size: Optional[List[sympy.Expr]] reduction_size: Optional[list[sympy.Expr]]
if isinstance(input_node.get_defining_op(), ComputedBuffer): if isinstance(input_node.get_defining_op(), ComputedBuffer):
# Input node has already been realized. Return its size and reduction_size. # Input node has already been realized. Return its size and reduction_size.
@ -683,11 +683,11 @@ def extract_input_node_reduction_ranges(
# The current method still uses reduction ranges from the dependent realized node, which is not ideal. # The current method still uses reduction ranges from the dependent realized node, which is not ideal.
# Is there a way to check whether there are permutations inbetween? # Is there a way to check whether there are permutations inbetween?
reads = input_node.get_reads() reads = input_node.get_reads()
reduction_size: Optional[List[sympy.Expr]] = None reduction_size: Optional[list[sympy.Expr]] = None
size: Optional[List[sympy.Expr]] = None size: Optional[list[sympy.Expr]] = None
while reduction_size is None and len(reads) > 0: while reduction_size is None and len(reads) > 0:
seen: OrderedSet[str] = OrderedSet() seen: OrderedSet[str] = OrderedSet()
new_reads: List[Dep] = [] new_reads: list[Dep] = []
for read in reads: for read in reads:
if not isinstance(read, MemoryDep): if not isinstance(read, MemoryDep):
continue continue

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import functools import functools
from typing import Callable, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, Union from collections.abc import Sequence
from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
import sympy import sympy

View File

@ -4,7 +4,7 @@ import os
import tempfile import tempfile
import textwrap import textwrap
from functools import lru_cache from functools import lru_cache
from typing import Any, List, Optional, TYPE_CHECKING from typing import Any, Optional, TYPE_CHECKING
from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback
@ -29,7 +29,7 @@ else:
class OperatorIssue(RuntimeError): class OperatorIssue(RuntimeError):
@staticmethod @staticmethod
def operator_str(target: Any, args: List[Any], kwargs: dict[str, Any]) -> str: def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str:
lines = [f"target: {target}"] + [ lines = [f"target: {target}"] + [
f"args[{i}]: {arg}" for i, arg in enumerate(args) f"args[{i}]: {arg}" for i, arg in enumerate(args)
] ]
@ -39,13 +39,13 @@ class OperatorIssue(RuntimeError):
class MissingOperatorWithoutDecomp(OperatorIssue): class MissingOperatorWithoutDecomp(OperatorIssue):
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None: def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
_record_missing_op(target) _record_missing_op(target)
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}") super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
class MissingOperatorWithDecomp(OperatorIssue): class MissingOperatorWithDecomp(OperatorIssue):
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None: def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
_record_missing_op(target) _record_missing_op(target)
super().__init__( super().__init__(
f"missing decomposition\n{self.operator_str(target, args, kwargs)}" f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
@ -62,7 +62,7 @@ class MissingOperatorWithDecomp(OperatorIssue):
class LoweringException(OperatorIssue): class LoweringException(OperatorIssue):
def __init__( def __init__(
self, exc: Exception, target: Any, args: List[Any], kwargs: dict[str, Any] self, exc: Exception, target: Any, args: list[Any], kwargs: dict[str, Any]
) -> None: ) -> None:
super().__init__( super().__init__(
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}" f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"

View File

@ -1,5 +1,4 @@
import json import json
from typing import List
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
@ -17,7 +16,7 @@ def serialize_extern_kernel_node(
def extern_node_json_serializer( def extern_node_json_serializer(
extern_kernel_nodes: List[inductor_ExternKernelNode], extern_kernel_nodes: list[inductor_ExternKernelNode],
) -> str: ) -> str:
serialized_nodes = ExternKernelNodes( serialized_nodes = ExternKernelNodes(
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes] nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import itertools import itertools
import logging import logging
import weakref import weakref
from typing import Any, List, Optional from typing import Any, Optional
import torch import torch
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
@ -28,7 +28,7 @@ def replace_params_with_constants(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
flat_params: list[Any], flat_params: list[Any],
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta, fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
) -> List[int]: ) -> list[int]:
""" """
Replaces the parameters of a PyTorch GraphModule with constants wherever possible. Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
Returns a list of indices representing the input parameters that were not converted to constants. Returns a list of indices representing the input parameters that were not converted to constants.
@ -66,8 +66,8 @@ def replace_params_with_constants(
def freeze( def freeze(
dynamo_gm: torch.fx.GraphModule, dynamo_gm: torch.fx.GraphModule,
aot_autograd_gm: torch.fx.GraphModule, aot_autograd_gm: torch.fx.GraphModule,
example_inputs: List[torch._subclasses.FakeTensor], example_inputs: list[torch._subclasses.FakeTensor],
) -> tuple[torch.fx.GraphModule, List[int]]: ) -> tuple[torch.fx.GraphModule, list[int]]:
""" """
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency. and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.

View File

@ -6,21 +6,17 @@ import signal
import string import string
import sys import sys
import traceback import traceback
from collections.abc import KeysView
from enum import Enum from enum import Enum
from functools import partial, wraps from functools import partial, wraps
from types import FrameType from types import FrameType
from typing import ( from typing import (
Any, Any,
Callable, Callable,
Dict,
get_args, get_args,
get_origin, get_origin,
KeysView,
List,
Literal, Literal,
Optional, Optional,
Tuple,
Type,
TypeVar, TypeVar,
Union, Union,
) )
@ -92,14 +88,14 @@ class TypeExemplars:
} }
@staticmethod @staticmethod
def example(t: Type[T]) -> Optional[T]: def example(t: type[T]) -> Optional[T]:
""" """
Return an example of a class. Return an example of a class.
""" """
return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None) return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None)
@staticmethod @staticmethod
def contains(t: Type[T]) -> bool: def contains(t: type[T]) -> bool:
return t.__name__ in TypeExemplars.TYPE_EXEMPLARS return t.__name__ in TypeExemplars.TYPE_EXEMPLARS
@ -136,7 +132,7 @@ class Status(Enum):
# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be # Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be
# manually specified here: # manually specified here:
TYPE_OVERRIDES: Dict[str, List[Any]] = { TYPE_OVERRIDES: dict[str, list[Any]] = {
"post_grad_fusion_options": [ "post_grad_fusion_options": [
{ {
"batch_linear_post_grad": { "batch_linear_post_grad": {
@ -160,7 +156,7 @@ TYPE_OVERRIDES: Dict[str, List[Any]] = {
"autoheuristic_collect": ["pad_mm", "mixed_mm"], "autoheuristic_collect": ["pad_mm", "mixed_mm"],
"autoheuristic_use": ["pad_mm", "mixed_mm"], "autoheuristic_use": ["pad_mm", "mixed_mm"],
} }
SamplingType = Callable[[str, Type[Any], Any], Any] SamplingType = Callable[[str, type[Any], Any], Any]
class SamplingMethod(Enum): class SamplingMethod(Enum):
@ -178,7 +174,7 @@ class SamplingMethod(Enum):
@staticmethod @staticmethod
def _generate_value_for_type( def _generate_value_for_type(
random_sample: bool, field_name: str, type_hint: Type[Any], default: Any random_sample: bool, field_name: str, type_hint: type[Any], default: Any
) -> Any: ) -> Any:
""" """
Generates a value of a type based on the setting. Generates a value of a type based on the setting.
@ -304,9 +300,11 @@ class SamplingMethod(Enum):
if random_sample: if random_sample:
return random.choice(type_hint.__args__) return random.choice(type_hint.__args__)
else: else:
return random.choice( choices = [t for t in type_hint.__args__ if t != default]
[t for t in type_hint.__args__ if t != default] if choices:
) return random.choice(choices)
else:
return default
except AttributeError as err: except AttributeError as err:
raise ValueError("Literal type with no args") from err raise ValueError("Literal type with no args") from err
elif is_optional_type(type_hint): elif is_optional_type(type_hint):
@ -374,7 +372,7 @@ class Default:
DEFAULT = Default() DEFAULT = Default()
# The combination of config settings being set (based on their strings) # The combination of config settings being set (based on their strings)
ComboType = Tuple[str, ...] ComboType = tuple[str, ...]
class ResultType: class ResultType:
@ -382,7 +380,7 @@ class ResultType:
The mapping of the combo strings to the result status after running the config fuzzer. The mapping of the combo strings to the result status after running the config fuzzer.
""" """
_vals: Dict[ComboType, Status] _vals: dict[ComboType, Status]
def __repr__(self) -> str: def __repr__(self) -> str:
return f"ResultType[{self._vals}]" return f"ResultType[{self._vals}]"
@ -416,7 +414,7 @@ class ResultType:
# Type that maps config strings to their default value # Type that maps config strings to their default value
ConfigType = Dict[str, Any] ConfigType = dict[str, Any]
# Callable that returns a bool # Callable that returns a bool
FactoryOutputType = Callable[[], bool] FactoryOutputType = Callable[[], bool]
# input function factory # input function factory
@ -504,10 +502,10 @@ class ConfigFuzzer:
return return
self.seed = seed self.seed = seed
self.test_timeout = test_timeout self.test_timeout = test_timeout
self.detailed_results: Dict[ComboType, Dict[str, Any]] = {} self.detailed_results: dict[ComboType, dict[str, Any]] = {}
self.config_module = config_module self.config_module = config_module
self.test_model_fn_factory = test_model_fn_factory self.test_model_fn_factory = test_model_fn_factory
self.fields: Dict[str, _ConfigEntry] = self.config_module._config self.fields: dict[str, _ConfigEntry] = self.config_module._config
self.sample = SamplingMethod.dispatch(sm) self.sample = SamplingMethod.dispatch(sm)
if default is None: if default is None:
@ -587,7 +585,7 @@ class ConfigFuzzer:
} }
return ret return ret
def reproduce(self, configs: List[ConfigType]) -> ResultType: def reproduce(self, configs: list[ConfigType]) -> ResultType:
"""entrypoint to reproduce any failure""" """entrypoint to reproduce any failure"""
results = ResultType() results = ResultType()
for conf in configs: for conf in configs:
@ -675,7 +673,7 @@ class ConfigFuzzer:
for field, value in config.items(): for field, value in config.items():
print(f"{field} = {value}") print(f"{field} = {value}")
def get_error_info(exc: Exception) -> Dict[str, Any]: def get_error_info(exc: Exception) -> dict[str, Any]:
return { return {
"exception": str(exc), "exception": str(exc),
"traceback": traceback.format_exc(), "traceback": traceback.format_exc(),
@ -741,7 +739,7 @@ class ConfigFuzzer:
else: else:
return handle_return("Function succeeded", Status.PASSED, False, None) return handle_return("Function succeeded", Status.PASSED, False, None)
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> List[ConfigType]: def bisect(self, num_attempts: int = 100, p: float = 0.5) -> list[ConfigType]:
""" """
Test configs and bisect to minimal failing configuration. Test configs and bisect to minimal failing configuration.
""" """
@ -749,7 +747,7 @@ class ConfigFuzzer:
random.seed(self.seed) random.seed(self.seed)
self._reset_configs() self._reset_configs()
results = ResultType() results = ResultType()
ret: List[ConfigType] = [] ret: list[ConfigType] = []
for attempt in range(num_attempts): for attempt in range(num_attempts):
print(f"Random attempt {attempt + 1}/{num_attempts}") print(f"Random attempt {attempt + 1}/{num_attempts}")
@ -783,7 +781,7 @@ class ConfigFuzzer:
return self._bisect_failing_config_helper(results, list(failing_config.items())) return self._bisect_failing_config_helper(results, list(failing_config.items()))
def _bisect_failing_config_helper( def _bisect_failing_config_helper(
self, results: ResultType, failing_config: List[Tuple[str, Any]] self, results: ResultType, failing_config: list[tuple[str, Any]]
) -> Optional[ConfigType]: ) -> Optional[ConfigType]:
""" """
Bisect a failing configuration to find minimal set of configs that cause failure. Bisect a failing configuration to find minimal set of configs that cause failure.
@ -795,7 +793,7 @@ class ConfigFuzzer:
if not failing_config: if not failing_config:
return None return None
def test(x: List[Tuple[str, Any]]) -> Status: def test(x: list[tuple[str, Any]]) -> Status:
d = dict(x) d = dict(x)
result = self.test_config(results, d) result = self.test_config(results, d)
return result return result

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import operator import operator
from collections import defaultdict from collections import defaultdict
from typing import Any, Callable, DefaultDict, Dict, Optional, Type from typing import Any, Callable, Optional
import sympy import sympy
@ -24,9 +24,9 @@ from .virtualized import V
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched. # Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
# Works for length 2 patterns with 1 module and 1 function/method. # Works for length 2 patterns with 1 module and 1 function/method.
def matches_module_function_pattern( def matches_module_function_pattern(
pattern: tuple[Type[torch.nn.modules.Module], Callable[..., Any]], pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]],
node: torch.fx.node.Node, node: torch.fx.node.Node,
modules: Dict[str, torch.nn.modules.Module], modules: dict[str, torch.nn.modules.Module],
) -> bool: ) -> bool:
if len(node.args) == 0: if len(node.args) == 0:
return False return False
@ -86,7 +86,7 @@ class FakeTensorUpdater:
return (node, node.target, id(node.args), id(node.kwargs)) return (node, node.target, id(node.args), id(node.kwargs))
def incremental_update(self): def incremental_update(self):
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int) existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
for node in self.graph.nodes: for node in self.graph.nodes:
existing_storages[get_node_storage(node)] += 1 existing_storages[get_node_storage(node)] += 1
@ -208,7 +208,7 @@ def get_fake(x):
return x return x
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], Dict[str, Any]]: def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]:
""" """
First value returns a boolean if any of the input nodes don't have a faketensor. First value returns a boolean if any of the input nodes don't have a faketensor.
""" """

View File

@ -8,22 +8,10 @@ import re
import sys import sys
import time import time
from collections import defaultdict from collections import defaultdict
from collections.abc import Iterable, Iterator, Sequence
from contextlib import contextmanager from contextlib import contextmanager
from types import ModuleType from types import ModuleType
from typing import ( from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
Any,
Callable,
DefaultDict,
Dict,
Iterable,
Iterator,
List,
NoReturn,
Optional,
Sequence,
TYPE_CHECKING,
Union,
)
import sympy import sympy
from sympy import Expr from sympy import Expr
@ -198,8 +186,8 @@ def getattr_recursive(
return attr_itr return attr_itr
def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]: def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
ret: Dict[Node, tuple[int, ...]] = {} ret: dict[Node, tuple[int, ...]] = {}
output_node = g.find_nodes(op="output")[0] output_node = g.find_nodes(op="output")[0]
if "user_visible_output_idxs" not in output_node.meta: if "user_visible_output_idxs" not in output_node.meta:
@ -212,7 +200,7 @@ def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
def mark_nodes_dislike_padding( def mark_nodes_dislike_padding(
g: Graph, user_visible_output_strides: Dict[Node, tuple[int, ...]] g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]]
) -> None: ) -> None:
""" """
Nodes like convolution/convolution_backward want its input to be dense. Nodes like convolution/convolution_backward want its input to be dense.
@ -282,7 +270,7 @@ def mark_nodes_dislike_padding(
class GraphLowering(torch.fx.Interpreter): class GraphLowering(torch.fx.Interpreter):
graph_outputs: List[ir.IRNode] graph_outputs: list[ir.IRNode]
def symbolic_sizes_strides( def symbolic_sizes_strides(
self, ex: torch.Tensor self, ex: torch.Tensor
@ -323,7 +311,7 @@ class GraphLowering(torch.fx.Interpreter):
def static_sizes_strides( def static_sizes_strides(
self, ex: torch.Tensor self, ex: torch.Tensor
) -> tuple[List[sympy.Expr], List[sympy.Expr]]: ) -> tuple[list[sympy.Expr], list[sympy.Expr]]:
""" """
Primarily used to weights Primarily used to weights
""" """
@ -341,12 +329,12 @@ class GraphLowering(torch.fx.Interpreter):
aot_mode: bool = False, aot_mode: bool = False,
layout_opt: Optional[bool] = None, layout_opt: Optional[bool] = None,
extern_node_serializer: Optional[ extern_node_serializer: Optional[
Callable[[List[ir.ExternKernelNode]], Any] Callable[[list[ir.ExternKernelNode]], Any]
] = None, ] = None,
is_inference: bool = False, is_inference: bool = False,
is_backward: bool = False, is_backward: bool = False,
is_const_graph: bool = False, is_const_graph: bool = False,
const_output_index: Optional[Dict[str, int]] = None, const_output_index: Optional[dict[str, int]] = None,
const_code: Optional[str] = None, const_code: Optional[str] = None,
const_module: Optional["GraphLowering"] = None, const_module: Optional["GraphLowering"] = None,
name: Optional[str] = None, name: Optional[str] = None,
@ -379,14 +367,14 @@ class GraphLowering(torch.fx.Interpreter):
# you don't start adding new ones in the lowering process # you don't start adding new ones in the lowering process
shape_env.freeze_runtime_asserts() shape_env.freeze_runtime_asserts()
# We're going to mutate ras_by_symbol as we finish generating them # We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: Dict[ self.ras_by_symbol: dict[
Optional[sympy.Symbol], List[RuntimeAssert] Optional[sympy.Symbol], list[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy() ] = shape_env.deferred_runtime_asserts.copy()
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
self.sizevars = SizeVarAllocator(shape_env) self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: List[str] = [] self.graph_input_names: list[str] = []
self.graph_inputs: Dict[str, TensorBox] = {} self.graph_inputs: dict[str, TensorBox] = {}
self.graph_inputs_original: Dict[str, InputBuffer] = {} self.graph_inputs_original: dict[str, InputBuffer] = {}
self.zero_dim_cpu_tensor_list = OrderedSet[str]() self.zero_dim_cpu_tensor_list = OrderedSet[str]()
self.device_types: OrderedSet[str] = ( self.device_types: OrderedSet[str] = (
const_module.device_types if const_module else OrderedSet() const_module.device_types if const_module else OrderedSet()
@ -395,9 +383,9 @@ class GraphLowering(torch.fx.Interpreter):
const_module.device_idxs if const_module else OrderedSet() const_module.device_idxs if const_module else OrderedSet()
) )
self.device_type = "cpu" self.device_type = "cpu"
self.buffers: List[ir.Buffer] = [] self.buffers: list[ir.Buffer] = []
self.operations: List[ir.Operation] = [] self.operations: list[ir.Operation] = []
self.const_output_index: Dict[str, int] = ( self.const_output_index: dict[str, int] = (
const_output_index if const_output_index else {} const_output_index if const_output_index else {}
) )
self.folded_constants: OrderedSet[str] = ( self.folded_constants: OrderedSet[str] = (
@ -405,12 +393,12 @@ class GraphLowering(torch.fx.Interpreter):
if const_output_index if const_output_index
else OrderedSet() else OrderedSet()
) )
self.constants: Dict[str, torch.Tensor] = ( self.constants: dict[str, torch.Tensor] = (
const_module.constants if const_module else {} const_module.constants if const_module else {}
) )
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {} self.torchbind_constants: dict[str, torch._C.ScriptObject] = {}
self.seen_subgraphs: Dict[str, ir.Subgraph] = {} self.seen_subgraphs: dict[str, ir.Subgraph] = {}
self.constant_reprs: Dict[str, str] = {} self.constant_reprs: dict[str, str] = {}
self.removed_operations = OrderedSet[str]() self.removed_operations = OrderedSet[str]()
self.removed_buffers = OrderedSet[str]() self.removed_buffers = OrderedSet[str]()
self.removed_inplace_buffers = OrderedSet[str]() self.removed_inplace_buffers = OrderedSet[str]()
@ -420,23 +408,23 @@ class GraphLowering(torch.fx.Interpreter):
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment] self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment] self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
# See `ProxyExecutor Design Note` in ir.py for more details # See `ProxyExecutor Design Note` in ir.py for more details
self.extern_kernel_nodes: List[ir.ExternKernelNode] = [] self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
from torch._inductor.extern_node_serializer import extern_node_json_serializer from torch._inductor.extern_node_serializer import extern_node_json_serializer
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = ( self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
extern_node_serializer extern_node_serializer
if config.is_fbcode() and extern_node_serializer if config.is_fbcode() and extern_node_serializer
else extern_node_json_serializer else extern_node_json_serializer
) )
self.current_node: torch.fx.Node = None # type: ignore[assignment] self.current_node: torch.fx.Node = None # type: ignore[assignment]
self.lists: Dict[str, List[str]] = {} self.lists: dict[str, list[str]] = {}
self.mutated_inputs = OrderedSet[str]() self.mutated_inputs = OrderedSet[str]()
self.mutated_input_idxs: List[int] = [] self.mutated_input_idxs: list[int] = []
self.name_to_buffer: Dict[str, ir.Buffer] = {} self.name_to_buffer: dict[str, ir.Buffer] = {}
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list) self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
self.name_to_op: Dict[str, ir.Operation] = {} self.name_to_op: dict[str, ir.Operation] = {}
self.creation_time = time.time() self.creation_time = time.time()
self.name = name # type: ignore[assignment] self.name = name # type: ignore[assignment]
self.cpp_wrapper = cpp_wrapper self.cpp_wrapper = cpp_wrapper
@ -445,7 +433,7 @@ class GraphLowering(torch.fx.Interpreter):
# which sub-kernel is picked. Copy cpp_wrapper to another variable # which sub-kernel is picked. Copy cpp_wrapper to another variable
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen. # since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
self.record_multi_kernel_choice = cpp_wrapper self.record_multi_kernel_choice = cpp_wrapper
self.multi_kernel_to_choice: Dict[str, str] = {} self.multi_kernel_to_choice: dict[str, str] = {}
self.aot_mode = aot_mode self.aot_mode = aot_mode
self.graph_id = graph_id self.graph_id = graph_id
@ -464,7 +452,7 @@ class GraphLowering(torch.fx.Interpreter):
mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides) mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
self.cache_key: str = "" # This is the cache key for the compiled artifact self.cache_key: str = "" # This is the cache key for the compiled artifact
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: List[ self.cache_linemap: list[
tuple[int, str] tuple[int, str]
] = ( ] = (
[] []
@ -473,18 +461,18 @@ class GraphLowering(torch.fx.Interpreter):
self.disable_cudagraphs_reason: Optional[str] = None self.disable_cudagraphs_reason: Optional[str] = None
# only keeping one node per device for stack trace purposes # only keeping one node per device for stack trace purposes
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {} self.device_node_mapping: dict[torch.device, torch.fx.Node] = {}
self.orig_gm: torch.fx.GraphModule = gm.__copy__() self.orig_gm: torch.fx.GraphModule = gm.__copy__()
self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr] self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
"dynamo_flat_name_to_original_fqn", {} "dynamo_flat_name_to_original_fqn", {}
) )
self.allocated_constant_name: Dict[str, str] = ( self.allocated_constant_name: dict[str, str] = (
const_module.allocated_constant_name if const_module is not None else {} const_module.allocated_constant_name if const_module is not None else {}
) )
init_backend_registration() init_backend_registration()
self.get_backend_features = functools.lru_cache(None)(get_backend_features) self.get_backend_features = functools.lru_cache(None)(get_backend_features)
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {} self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
self.aligned_inputs = OrderedSet[str]() self.aligned_inputs = OrderedSet[str]()
self.no_fuse_buffer_names = OrderedSet[str]() self.no_fuse_buffer_names = OrderedSet[str]()
@ -599,7 +587,7 @@ class GraphLowering(torch.fx.Interpreter):
if is_inference: if is_inference:
from torch.utils.flop_counter import FlopCounterMode from torch.utils.flop_counter import FlopCounterMode
flop_counts: Dict[str, float] = defaultdict(float) flop_counts: dict[str, float] = defaultdict(float)
for node in conv_nodes: for node in conv_nodes:
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs( success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
node node
@ -702,7 +690,7 @@ class GraphLowering(torch.fx.Interpreter):
def make_subgraph( def make_subgraph(
self, self,
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
example_inputs: List[torch.Tensor], example_inputs: list[torch.Tensor],
subgraph_name: str, subgraph_name: str,
) -> "SubgraphLowering": ) -> "SubgraphLowering":
""" """
@ -886,7 +874,7 @@ class GraphLowering(torch.fx.Interpreter):
buffer.name = name buffer.name = name
return name return name
def register_operation_list(self, operation_names: List[str]) -> str: def register_operation_list(self, operation_names: list[str]) -> str:
name = self.qualify_name("list_" + "_".join(operation_names)) name = self.qualify_name("list_" + "_".join(operation_names))
self.lists[name] = operation_names self.lists[name] = operation_names
return name return name
@ -995,7 +983,7 @@ class GraphLowering(torch.fx.Interpreter):
) )
def placeholder( def placeholder(
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
) -> Union[Expr, TensorBox, None]: ) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1 self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
@ -1072,7 +1060,7 @@ class GraphLowering(torch.fx.Interpreter):
self.aligned_inputs.add(target) self.aligned_inputs.add(target)
return tensor return tensor
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override] def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)): if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
return super().call_function(target, args, kwargs) return super().call_function(target, args, kwargs)
@ -1155,7 +1143,7 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8 return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr( def get_attr(
self, target: str, args: tuple[()], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant # this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type] value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1203,7 +1191,7 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError raise AssertionError
def output( def output(
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override] self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
) -> None: ) -> None:
result = super().output(target, args, kwargs) # type: ignore[arg-type] result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)): if not isinstance(result, (tuple, list)):
@ -1306,9 +1294,9 @@ class GraphLowering(torch.fx.Interpreter):
self, self,
fx_node: torch.fx.Node, fx_node: torch.fx.Node,
old_args: tuple[Any], old_args: tuple[Any],
old_kwargs: Dict[str, Any], old_kwargs: dict[str, Any],
new_args: tuple[Any], new_args: tuple[Any],
new_kwargs: Dict[str, Any], new_kwargs: dict[str, Any],
) -> None: ) -> None:
"""Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs. """Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
@ -1803,7 +1791,7 @@ class GraphLowering(torch.fx.Interpreter):
self.const_module.wrapper_code.src_to_kernel self.const_module.wrapper_code.src_to_kernel
) )
def codegen_with_cpp_wrapper(self) -> tuple[str, List[tuple[int, Node]]]: def codegen_with_cpp_wrapper(self) -> tuple[str, list[tuple[int, Node]]]:
""" """
For GPU, Triton kernels are autotuned and stored as cubin files For GPU, Triton kernels are autotuned and stored as cubin files
""" """
@ -1902,7 +1890,7 @@ class GraphLowering(torch.fx.Interpreter):
# cpu # cpu
return self.codegen() return self.codegen()
def codegen(self) -> tuple[str, List[tuple[int, Node]]]: def codegen(self) -> tuple[str, list[tuple[int, Node]]]:
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True): with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
from .scheduler import Scheduler from .scheduler import Scheduler
@ -1949,7 +1937,7 @@ class GraphLowering(torch.fx.Interpreter):
def count_bytes( def count_bytes(
self, self,
) -> tuple[ ) -> tuple[
int, List[tuple[BaseSchedulerNode, int]], List[tuple[BaseSchedulerNode, float]] int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]]
]: ]:
total_bytes = 0 total_bytes = 0
node_counts = [] node_counts = []
@ -2041,7 +2029,7 @@ class GraphLowering(torch.fx.Interpreter):
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug") V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
return mod return mod
def get_output_names(self) -> List[str]: def get_output_names(self) -> list[str]:
names = [] names = []
shape_counter = itertools.count(0) shape_counter = itertools.count(0)
none_counter = itertools.count(0) none_counter = itertools.count(0)

View File

@ -1,13 +1,13 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import contextlib import contextlib
from typing import Callable, List, TYPE_CHECKING from typing import Callable, TYPE_CHECKING
if TYPE_CHECKING: if TYPE_CHECKING:
import torch import torch
# Executed in the order they're registered # Executed in the order they're registered
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = [] INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = []
@contextlib.contextmanager @contextlib.contextmanager

View File

@ -22,7 +22,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
""" """
import itertools import itertools
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Literal, Optional, overload, Union from typing import Any, Callable, Literal, Optional, overload, Union
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
import sympy import sympy
@ -196,8 +196,8 @@ class IndexPropagation:
def __init__( def __init__(
self, self,
inner: Any, inner: Any,
iter_ranges: Dict[sympy.Symbol, sympy.Expr], iter_ranges: dict[sympy.Symbol, sympy.Expr],
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr], indirect_var_ranges: dict[sympy.Symbol, sympy.Expr],
) -> None: ) -> None:
self._inner = inner self._inner = inner
self.shape_env = V.graph.sizevars.shape_env self.shape_env = V.graph.sizevars.shape_env
@ -248,18 +248,18 @@ class IndexPropagation:
self, self,
name: Literal["indirect_indexing"], name: Literal["indirect_indexing"],
args: tuple[Any, ...], args: tuple[Any, ...],
kwargs: Dict[str, Any], kwargs: dict[str, Any],
) -> IndexPropVar: ) -> IndexPropVar:
... ...
@overload @overload
def fallback( def fallback(
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any] self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult: ) -> IndexPropResult:
... ...
def fallback( def fallback(
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any] self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult: ) -> IndexPropResult:
# Fallback to the wrapped handler # Fallback to the wrapped handler
new_args = [self.unwrap(a) for a in args] new_args = [self.unwrap(a) for a in args]
@ -267,7 +267,7 @@ class IndexPropagation:
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs)) return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
def propagate_sympy( def propagate_sympy(
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any] self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
) -> IndexPropResult: ) -> IndexPropResult:
# Build a new SymPy expression from this ops call # Build a new SymPy expression from this ops call
def unwrap(a: Union[Any, IndexPropVar]) -> Any: def unwrap(a: Union[Any, IndexPropVar]) -> Any:

View File

@ -2,12 +2,16 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Optional, Sequence from typing import Optional, TYPE_CHECKING
import torch import torch
from torch import _prims, Tensor from torch import _prims, Tensor
if TYPE_CHECKING:
from collections.abc import Sequence
log = logging.getLogger(__name__) log = logging.getLogger(__name__)

View File

@ -8,6 +8,7 @@ import logging
import textwrap import textwrap
import traceback import traceback
import typing import typing
from collections.abc import Generator, Iterable, Sequence
from contextlib import nullcontext from contextlib import nullcontext
from enum import Enum from enum import Enum
from functools import partial from functools import partial
@ -16,14 +17,9 @@ from typing import (
Callable, Callable,
ClassVar, ClassVar,
ContextManager, ContextManager,
Dict,
Generator,
Iterable,
List,
Literal, Literal,
Optional, Optional,
overload, overload,
Sequence,
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Union, Union,
@ -165,11 +161,11 @@ e.g. it may be a graph input or compile time constant.
_NodeOrNodes: TypeAlias = Union[ _NodeOrNodes: TypeAlias = Union[
int, int,
"TensorBox", "TensorBox",
Dict[str, "TensorBox"], dict[str, "TensorBox"],
"Symbol", "Symbol",
"IRNode", "IRNode",
Sequence[ Sequence[
Optional[Union[int, Dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]] Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
], ],
] ]
@ -426,7 +422,7 @@ class IRNode:
# NB: These are kinda weird, # NB: These are kinda weird,
origins: OrderedSet[Any] = dataclasses.field(init=False) origins: OrderedSet[Any] = dataclasses.field(init=False)
traceback: Optional[List[str]] = dataclasses.field(init=False) traceback: Optional[list[str]] = dataclasses.field(init=False)
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False) origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
@staticmethod @staticmethod
@ -455,7 +451,7 @@ class IRNode:
def get_read_names(self) -> OrderedSet[str]: def get_read_names(self) -> OrderedSet[str]:
return OrderedSet(dep.name for dep in self.get_reads()) return OrderedSet(dep.name for dep in self.get_reads())
def get_traceback(self) -> Optional[List[str]]: def get_traceback(self) -> Optional[list[str]]:
return self.traceback return self.traceback
def get_origin_node(self) -> Optional[torch.fx.Node]: def get_origin_node(self) -> Optional[torch.fx.Node]:
@ -604,18 +600,18 @@ class IRNode:
raise NotImplementedError(type(self).__name__) raise NotImplementedError(type(self).__name__)
def freeze_layout_with_stride_order( def freeze_layout_with_stride_order(
self, order: List[int], allow_padding: bool = False self, order: list[int], allow_padding: bool = False
) -> None: ) -> None:
raise NotImplementedError(type(self).__name__) raise NotImplementedError(type(self).__name__)
def freeze_layout_with_fill_order(self, order: List[int]) -> None: def freeze_layout_with_fill_order(self, order: list[int]) -> None:
raise NotImplementedError(type(self).__name__) raise NotImplementedError(type(self).__name__)
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None: def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
raise NotImplementedError(type(self).__name__) raise NotImplementedError(type(self).__name__)
def freeze_layout_with_exact_strides( def freeze_layout_with_exact_strides(
self, exact_strides: List[_IntLike], allow_padding: bool = False self, exact_strides: list[_IntLike], allow_padding: bool = False
) -> None: ) -> None:
raise NotImplementedError(type(self).__name__) raise NotImplementedError(type(self).__name__)
@ -703,7 +699,7 @@ class Operation:
def get_reads(self) -> OrderedSet[Dep]: def get_reads(self) -> OrderedSet[Dep]:
return self.get_read_writes().reads return self.get_read_writes().reads
def get_outputs(self) -> List[Buffer]: def get_outputs(self) -> list[Buffer]:
raise NotImplementedError raise NotImplementedError
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
@ -936,7 +932,7 @@ class Scatter(Pointwise):
) )
REDUCTION_COMBINE_FN: Dict[str, Callable[..., OpsValue]] = { REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
"any": ops_wrapper("logical_or"), "any": ops_wrapper("logical_or"),
"max": ops_wrapper("maximum"), "max": ops_wrapper("maximum"),
"min": ops_wrapper("minimum"), "min": ops_wrapper("minimum"),
@ -1575,8 +1571,8 @@ class Reduction(Loops):
wrapper_fn: Callable[..., Any], wrapper_fn: Callable[..., Any],
original_ranges: Sequence[Expr], original_ranges: Sequence[Expr],
original_reduction_ranges: Sequence[Expr], original_reduction_ranges: Sequence[Expr],
new_ranges: List[Expr], new_ranges: list[Expr],
new_reduction_ranges: List[Integer], new_reduction_ranges: list[Integer],
reduction_type: str, reduction_type: str,
split: _IntLike, split: _IntLike,
reduction_hint: ReductionHint, reduction_hint: ReductionHint,
@ -1678,8 +1674,8 @@ class Reduction(Loops):
inner_fn: Callable[..., Any], inner_fn: Callable[..., Any],
original_ranges: Sequence[Expr], original_ranges: Sequence[Expr],
original_reduction_ranges: Sequence[Expr], original_reduction_ranges: Sequence[Expr],
new_ranges: List[Integer], new_ranges: list[Integer],
new_reduction_ranges: List[Integer], new_reduction_ranges: list[Integer],
reduction_type: str, reduction_type: str,
reduction_hint: ReductionHint, reduction_hint: ReductionHint,
) -> TensorBox: ) -> TensorBox:
@ -1767,8 +1763,8 @@ class WelfordReduction(Reduction):
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
inner_fns: Sequence[Callable[..., Any]], inner_fns: Sequence[Callable[..., Any]],
ranges: List[Integer], ranges: list[Integer],
reduction_ranges: List[Integer], reduction_ranges: list[Integer],
reduction_type: str, reduction_type: str,
reduction_hint: ReductionHint = ReductionHint.DEFAULT, reduction_hint: ReductionHint = ReductionHint.DEFAULT,
) -> Sequence[TensorBox]: ) -> Sequence[TensorBox]:
@ -1893,8 +1889,8 @@ class WelfordReduction(Reduction):
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
inner_fns: Sequence[Callable[..., Any]], inner_fns: Sequence[Callable[..., Any]],
ranges: List[Integer], ranges: list[Integer],
reduction_ranges: List[Integer], reduction_ranges: list[Integer],
reduction_type: str, reduction_type: str,
split: _IntLike, split: _IntLike,
reduction_hint: ReductionHint, reduction_hint: ReductionHint,
@ -1983,8 +1979,8 @@ class WelfordReduction(Reduction):
@ir_dataclass @ir_dataclass
class Scan(Loops): class Scan(Loops):
scan_ranges: List[Integer] scan_ranges: list[Integer]
size: List[Integer] size: list[Integer]
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]] combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]] reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
reduction_hint: ReductionHint reduction_hint: ReductionHint
@ -2055,7 +2051,7 @@ class Scan(Loops):
device: torch.device, device: torch.device,
dtypes: tuple[torch.dtype, ...], dtypes: tuple[torch.dtype, ...],
inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...], inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
size: List[Integer], size: list[Integer],
axis: int, axis: int,
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
reduction_hint: ReductionHint = ReductionHint.DEFAULT, reduction_hint: ReductionHint = ReductionHint.DEFAULT,
@ -2114,7 +2110,7 @@ class Scan(Loops):
else: else:
scan_type = SplitScan scan_type = SplitScan
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> List[Expr]: def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
assert len(scan_index) == len(scan_ranges) assert len(scan_index) == len(scan_ranges)
assert len(index) == len(pointwise_ranges) assert len(index) == len(pointwise_ranges)
return [*index[:axis], *scan_index, *index[axis:]] return [*index[:axis], *scan_index, *index[axis:]]
@ -2152,8 +2148,8 @@ class Scan(Loops):
dtype: torch.dtype, dtype: torch.dtype,
inner_fn: Callable[[Sequence[Expr]], OpsValue], inner_fn: Callable[[Sequence[Expr]], OpsValue],
axis: int, axis: int,
pointwise_ranges: List[Integer], pointwise_ranges: list[Integer],
scan_ranges: List[Integer], scan_ranges: list[Integer],
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]], combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
scan_numel: Expr, scan_numel: Expr,
) -> tuple[ReductionHint, _IntLike]: ) -> tuple[ReductionHint, _IntLike]:
@ -2182,8 +2178,8 @@ class SplitScan(Scan):
@ir_dataclass @ir_dataclass
class Sort(Loops): class Sort(Loops):
# Sorts a tuple of key, value pairs # Sorts a tuple of key, value pairs
sort_ranges: List[Integer] sort_ranges: list[Integer]
size: List[Integer] size: list[Integer]
reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]] reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
reduction_hint: ReductionHint reduction_hint: ReductionHint
output_index: int output_index: int
@ -2251,8 +2247,8 @@ class Sort(Loops):
cls, cls,
device: torch.device, device: torch.device,
dtypes: tuple[torch.dtype, ...], dtypes: tuple[torch.dtype, ...],
inner_fns: tuple[Callable[[List[Expr]], Any], ...], inner_fns: tuple[Callable[[list[Expr]], Any], ...],
size: List[Integer], size: list[Integer],
axis: int, axis: int,
stable: bool, stable: bool,
descending: bool, descending: bool,
@ -2293,7 +2289,7 @@ class Sort(Loops):
for output_index in range(len(dtypes)) for output_index in range(len(dtypes))
] ]
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> List[Expr]: def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
assert len(sort_index) == len(sort_ranges) assert len(sort_index) == len(sort_ranges)
assert len(index) == len(pointwise_ranges) assert len(index) == len(pointwise_ranges)
return [*index[:axis], *sort_index, *index[axis:]] return [*index[:axis], *sort_index, *index[axis:]]
@ -2509,7 +2505,7 @@ class BaseView(IRNode):
@ir_dataclass @ir_dataclass
class ExpandView(BaseView): class ExpandView(BaseView):
size: List[Expr] size: list[Expr]
@staticmethod @staticmethod
def _normalize_size(x, new_size): # type: ignore[no-untyped-def] def _normalize_size(x, new_size): # type: ignore[no-untyped-def]
@ -2588,7 +2584,7 @@ class ExpandView(BaseView):
@ir_dataclass @ir_dataclass
class PermuteView(BaseView): class PermuteView(BaseView):
dims: List[Expr] dims: list[Expr]
@classmethod @classmethod
def create(cls, x, dims): # type: ignore[no-untyped-def] def create(cls, x, dims): # type: ignore[no-untyped-def]
@ -2676,7 +2672,7 @@ class SqueezeView(BaseView):
not_one = [i for i, s in enumerate(size) if s != 1] not_one = [i for i, s in enumerate(size) if s != 1]
length = len(size) length = len(size)
def reindex(index: List[sympy.Expr]) -> tuple[sympy.Expr, ...]: def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]:
assert len(index) == len(not_one), f"{index} {not_one}" assert len(index) == len(not_one), f"{index} {not_one}"
new_index = [sympy.S.Zero] * length new_index = [sympy.S.Zero] * length
for idx, s in zip(not_one, index): for idx, s in zip(not_one, index):
@ -2691,7 +2687,7 @@ class SqueezeView(BaseView):
@ir_dataclass @ir_dataclass
class GenericView(BaseView): class GenericView(BaseView):
size: List[Expr] size: list[Expr]
reindex: Callable[..., Any] reindex: Callable[..., Any]
def make_reindexer(self): # type: ignore[no-untyped-def] def make_reindexer(self): # type: ignore[no-untyped-def]
@ -3159,8 +3155,8 @@ class Layout(OutputSpec):
self, self,
device: torch.device, device: torch.device,
dtype: torch.dtype, dtype: torch.dtype,
size: List[Expr], size: list[Expr],
stride: Optional[List[Expr]] = None, stride: Optional[list[Expr]] = None,
offset: Expr = Integer(0), offset: Expr = Integer(0),
) -> None: ) -> None:
if stride is None: if stride is None:
@ -3169,8 +3165,8 @@ class Layout(OutputSpec):
self.dtype = dtype self.dtype = dtype
assert len(size) == len(stride), f"size={size}, stride={stride}" assert len(size) == len(stride), f"size={size}, stride={stride}"
assert all(isinstance(s, (Expr, int)) for s in size) assert all(isinstance(s, (Expr, int)) for s in size)
self.size: List[Expr] = size self.size: list[Expr] = size
self.stride: List[Expr] = stride self.stride: list[Expr] = stride
self.offset: Expr = offset self.offset: Expr = offset
def __str__(self) -> str: def __str__(self) -> str:
@ -3594,8 +3590,8 @@ class NoneLayout(OutputSpec):
# dependencies manually in scheduler # dependencies manually in scheduler
device: Optional[torch.device] device: Optional[torch.device]
size: List[int] = dataclasses.field(default_factory=lambda: [0]) size: list[int] = dataclasses.field(default_factory=lambda: [0])
stride: List[int] = dataclasses.field(default_factory=lambda: [0]) stride: list[int] = dataclasses.field(default_factory=lambda: [0])
def storage_size(self) -> int: def storage_size(self) -> int:
return 0 return 0
@ -3620,7 +3616,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
V.graph.mark_buffer_mutated(name) V.graph.mark_buffer_mutated(name)
@property @property
def stride(self) -> List[Expr]: def stride(self) -> list[Expr]:
return self.real_layout().stride return self.real_layout().stride
@stride.setter @stride.setter
@ -3725,7 +3721,7 @@ class Buffer(IRNode):
def get_size(self) -> Sequence[Expr]: def get_size(self) -> Sequence[Expr]:
return [*self.get_layout().size] return [*self.get_layout().size]
def get_stride(self) -> List[Expr]: def get_stride(self) -> list[Expr]:
return [*self.get_layout().stride] return [*self.get_layout().stride]
def get_offset(self) -> Expr: def get_offset(self) -> Expr:
@ -3816,7 +3812,7 @@ class Buffer(IRNode):
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
class OperationBuffer(Buffer, Operation): class OperationBuffer(Buffer, Operation):
# An operation that produces a single output buffer # An operation that produces a single output buffer
def get_outputs(self) -> List[Buffer]: def get_outputs(self) -> list[Buffer]:
return [self] return [self]
def get_defining_op(self) -> Operation: def get_defining_op(self) -> Operation:
@ -3977,7 +3973,7 @@ class ComputedBuffer(OperationBuffer):
assert isinstance(self.data, Pointwise) assert isinstance(self.data, Pointwise)
return partial(self.data.store_output, self.name, indexer) return partial(self.data.store_output, self.name, indexer)
def get_fill_order(self) -> Optional[List[int]]: def get_fill_order(self) -> Optional[list[int]]:
""" """
If our layout is still flexible, try to determine the stride order based on stride orders of reads. If our layout is still flexible, try to determine the stride order based on stride orders of reads.
@ -4028,9 +4024,9 @@ class ComputedBuffer(OperationBuffer):
def get_default_sizes_body( def get_default_sizes_body(
self, self,
) -> tuple[ ) -> tuple[
tuple[List[sympy.Expr], List[sympy.Expr]], tuple[list[sympy.Expr], list[sympy.Expr]],
LoopBody, LoopBody,
tuple[List[sympy.Expr], List[sympy.Expr]], tuple[list[sympy.Expr], list[sympy.Expr]],
]: ]:
args, var_ranges = dependencies.index_vars_squeeze( args, var_ranges = dependencies.index_vars_squeeze(
self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q" self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
@ -4043,7 +4039,7 @@ class ComputedBuffer(OperationBuffer):
*args, *args,
) )
index_vars = [] index_vars = []
reduce_vars: List[Any] = [] reduce_vars: list[Any] = []
index_size = [] index_size = []
reduce_size = [] reduce_size = []
for v, s in var_ranges.items(): for v, s in var_ranges.items():
@ -4059,9 +4055,9 @@ class ComputedBuffer(OperationBuffer):
def simplify_and_reorder( def simplify_and_reorder(
self, self,
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
recompute_sizes_body_func: Optional[Callable[..., Any]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
) -> tuple[tuple[List[sympy.Expr], List[sympy.Expr]], LoopBody]: ) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]:
""" """
This is a main place where we do loop transformations in a This is a main place where we do loop transformations in a
backend-agnostic way. backend-agnostic way.
@ -4282,7 +4278,7 @@ class TemplateBuffer(OperationBuffer):
def simplify_and_reorder( # type: ignore[no-untyped-def] def simplify_and_reorder( # type: ignore[no-untyped-def]
self, self,
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None, extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
recompute_sizes_body_func: Optional[Callable[..., Any]] = None, recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
): ):
return ( return (
@ -4314,7 +4310,7 @@ class TritonTemplateBuffer(TemplateBuffer):
""" """
super().__init__(layout, inputs, make_kernel_render) super().__init__(layout, inputs, make_kernel_render)
self.mutated_inputs = mutated_inputs self.mutated_inputs = mutated_inputs
self.outputs: List[Buffer] = [self] self.outputs: list[Buffer] = [self]
if mutated_inputs is not None: if mutated_inputs is not None:
# Ensure that the mutated inputs are only allowed for certain nodes # Ensure that the mutated inputs are only allowed for certain nodes
allowed_set = ( allowed_set = (
@ -4335,7 +4331,7 @@ class TritonTemplateBuffer(TemplateBuffer):
allowed_prologue_inps if allowed_prologue_inps else OrderedSet() allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
) )
def get_outputs(self) -> List[Buffer]: def get_outputs(self) -> list[Buffer]:
return self.outputs return self.outputs
def get_allowed_prologue_inps(self) -> OrderedSet[str]: def get_allowed_prologue_inps(self) -> OrderedSet[str]:
@ -4346,7 +4342,7 @@ class TritonTemplateBuffer(TemplateBuffer):
return out return out
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]] PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
class ChoiceCaller: class ChoiceCaller:
@ -4361,7 +4357,7 @@ class ChoiceCaller:
def __init__( def __init__(
self, self,
name: str, name: str,
input_nodes: List[Buffer], input_nodes: list[Buffer],
layout: Layout, layout: Layout,
description: str, description: str,
) -> None: ) -> None:
@ -4389,7 +4385,7 @@ class ChoiceCaller:
def output_node(self) -> TensorBox: def output_node(self) -> TensorBox:
raise NotImplementedError raise NotImplementedError
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]: def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
"""Information returned here is logged to the autotune log file when that is enabled.""" """Information returned here is logged to the autotune log file when that is enabled."""
return {} return {}
@ -4414,9 +4410,9 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
def __init__( def __init__(
self, self,
layout: Layout, layout: Layout,
inputs: List[IRNode], inputs: list[IRNode],
choice_timings: Callable[[], Dict[ChoiceCaller, float]], choice_timings: Callable[[], dict[ChoiceCaller, float]],
unfiltered_choices: List[ChoiceCaller], unfiltered_choices: list[ChoiceCaller],
allowed_prologue_inps: OrderedSet[str], allowed_prologue_inps: OrderedSet[str],
) -> None: ) -> None:
super().__init__( super().__init__(
@ -4426,7 +4422,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
allowed_prologue_inps=allowed_prologue_inps, allowed_prologue_inps=allowed_prologue_inps,
) )
self._choice_timings_fn = choice_timings self._choice_timings_fn = choice_timings
self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None self._choice_timings: Optional[dict[ChoiceCaller, float]] = None
self.original_inputs = inputs self.original_inputs = inputs
self._output_plannable = all( self._output_plannable = all(
isinstance(choice, TritonTemplateCallerBase) isinstance(choice, TritonTemplateCallerBase)
@ -4445,7 +4441,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
return self._output_plannable return self._output_plannable
@property @property
def choice_timings(self) -> Dict[ChoiceCaller, float]: def choice_timings(self) -> dict[ChoiceCaller, float]:
if self._choice_timings is None: if self._choice_timings is None:
self._choice_timings = self._choice_timings_fn() self._choice_timings = self._choice_timings_fn()
return self._choice_timings return self._choice_timings
@ -4496,7 +4492,7 @@ class CppTemplateBuffer(TemplateBuffer):
super().__init__(layout, inputs, make_kernel_render) super().__init__(layout, inputs, make_kernel_render)
self.template = template self.template = template
self.choice = choice self.choice = choice
self.outputs: Optional[List[Buffer]] = None self.outputs: Optional[list[Buffer]] = None
def get_layout(self) -> Layout: def get_layout(self) -> Layout:
if isinstance(self.layout, MultiOutputLayout): if isinstance(self.layout, MultiOutputLayout):
@ -4512,7 +4508,7 @@ class CppTemplateBuffer(TemplateBuffer):
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
class InputsKernel(OperationBuffer): class InputsKernel(OperationBuffer):
inputs: List[Buffer] inputs: list[Buffer]
def get_read_writes(self) -> dependencies.ReadWrites: def get_read_writes(self) -> dependencies.ReadWrites:
reads = OrderedSet[dependencies.Dep]() reads = OrderedSet[dependencies.Dep]()
@ -4752,7 +4748,7 @@ class ConcatKernel(NopKernel):
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
class ExternKernel(InputsKernel): class ExternKernel(InputsKernel):
constant_args: tuple[Any, ...] = () constant_args: tuple[Any, ...] = ()
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict) kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
output_view: Optional[ReinterpretView] = None output_view: Optional[ReinterpretView] = None
python_kernel_name: Optional[str] = None python_kernel_name: Optional[str] = None
cpp_kernel_name: Optional[str] = None cpp_kernel_name: Optional[str] = None
@ -4764,12 +4760,12 @@ class ExternKernel(InputsKernel):
op_overload: Optional[ op_overload: Optional[
Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator] Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
] = None ] = None
arg_properties: Optional[List[Dict[str, Any]]] = None arg_properties: Optional[list[dict[str, Any]]] = None
kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field( unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
default_factory=dict default_factory=dict
) )
mutation_outputs: List[MutationOutput] = dataclasses.field(default_factory=list) mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
def __init__( # type: ignore[no-untyped-def] def __init__( # type: ignore[no-untyped-def]
self, self,
@ -4801,7 +4797,7 @@ class ExternKernel(InputsKernel):
self.mutation_outputs = [] self.mutation_outputs = []
self.fx_node = V.graph.current_node self.fx_node = V.graph.current_node
def get_outputs(self) -> List[Buffer]: def get_outputs(self) -> list[Buffer]:
return [self, *self.mutation_outputs] return [self, *self.mutation_outputs]
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
@ -4919,10 +4915,10 @@ class ExternKernel(InputsKernel):
cls, kernel, *args, **kwargs cls, kernel, *args, **kwargs
) -> tuple[ ) -> tuple[
Any, Any,
List[Any], list[Any],
List[Any], list[Any],
Callable[[Any, Any], Any], Callable[[Any, Any], Any],
Optional[Dict[sympy.Symbol, pytree.KeyPath]], Optional[dict[sympy.Symbol, pytree.KeyPath]],
]: ]:
binded_args = {"args": args, "kwargs": kwargs} binded_args = {"args": args, "kwargs": kwargs}
@ -4930,7 +4926,7 @@ class ExternKernel(InputsKernel):
is_arg_tensor = [] is_arg_tensor = []
tensor_args = [] tensor_args = []
non_tensor_args: List[Any] = [] non_tensor_args: list[Any] = []
for arg in args_flat: for arg in args_flat:
is_arg_tensor.append(isinstance(arg, IRNode)) is_arg_tensor.append(isinstance(arg, IRNode))
if is_arg_tensor[-1]: if is_arg_tensor[-1]:
@ -4963,7 +4959,7 @@ class ExternKernel(InputsKernel):
# Rerun fake tensor propagation, because Inductor may have changed the # Rerun fake tensor propagation, because Inductor may have changed the
# strides of inputs and we need to determine accurately what the # strides of inputs and we need to determine accurately what the
# output stride will be. # output stride will be.
example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = [] example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
# We need to retain the constant values of fake tensors that we originally # We need to retain the constant values of fake tensors that we originally
# propagated the graph with, because for some operators running without a # propagated the graph with, because for some operators running without a
@ -4984,7 +4980,7 @@ class ExternKernel(InputsKernel):
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args) new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
example_output = kernel(*new_args, **new_kwargs) example_output = kernel(*new_args, **new_kwargs)
unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
if shape_env := V.fake_mode.shape_env: if shape_env := V.fake_mode.shape_env:
rebind_unbacked(shape_env, V.current_node, example_output) rebind_unbacked(shape_env, V.current_node, example_output)
unbacked_bindings = compute_unbacked_bindings( unbacked_bindings = compute_unbacked_bindings(
@ -5309,7 +5305,7 @@ class ExternKernel(InputsKernel):
) )
return args return args
def codegen_const_args(self, names: Optional[List[str]] = None): # type: ignore[no-untyped-def] def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def]
if V.graph.cpp_wrapper: if V.graph.cpp_wrapper:
result = [] result = []
# Aten ops follow the convention that tensor args are before non-tensor args, # Aten ops follow the convention that tensor args are before non-tensor args,
@ -5635,14 +5631,14 @@ class TMADescriptor(ExternKernel):
# as TMA descriptors are immutable, # as TMA descriptors are immutable,
# we can dedup them by the input args # we can dedup them by the input args
_CACHE: Dict[Any, TMADescriptor] = {} _CACHE: dict[Any, TMADescriptor] = {}
@classmethod @classmethod
def create( # type: ignore[no-untyped-def] def create( # type: ignore[no-untyped-def]
cls, cls,
tensor: IRNode, tensor: IRNode,
dims: List[Union[int, torch.SymInt]], dims: list[Union[int, torch.SymInt]],
block_dims: List[Union[int, torch.SymInt]], block_dims: list[Union[int, torch.SymInt]],
element_size: Optional[int] = None, element_size: Optional[int] = None,
): ):
key = (id(tensor), dims, block_dims, element_size) key = (id(tensor), dims, block_dims, element_size)
@ -5653,8 +5649,8 @@ class TMADescriptor(ExternKernel):
def __init__( def __init__(
self, self,
tensor: IRNode, tensor: IRNode,
dims: List[Union[int, torch.SymInt]], dims: list[Union[int, torch.SymInt]],
block_dims: List[Union[int, torch.SymInt]], block_dims: list[Union[int, torch.SymInt]],
element_size: Optional[int] = None, element_size: Optional[int] = None,
) -> None: ) -> None:
assert len(dims) in (1, 2) assert len(dims) in (1, 2)
@ -5707,8 +5703,8 @@ class UserDefinedTritonKernel(ExternKernel):
kernel = kernel_side_table.get_kernel(self.kernel_idx) kernel = kernel_side_table.get_kernel(self.kernel_idx)
configs = [] configs = []
restore_value_args: List[str] = [] restore_value_args: list[str] = []
reset_to_zero_args: List[str] = [] reset_to_zero_args: list[str] = []
if isinstance(kernel, Autotuner): if isinstance(kernel, Autotuner):
# https://github.com/triton-lang/triton/pull/5083 # https://github.com/triton-lang/triton/pull/5083
# changes kernel.restore_idx to kernel.restore_value # changes kernel.restore_idx to kernel.restore_value
@ -5871,7 +5867,7 @@ class UserDefinedTritonKernel(ExternKernel):
] ]
V.graph.register_operation(self) V.graph.register_operation(self)
def get_outputs(self) -> List[Buffer]: def get_outputs(self) -> list[Buffer]:
return list(self.mutation_outputs) return list(self.mutation_outputs)
def get_device(self) -> Optional[torch.device]: def get_device(self) -> Optional[torch.device]:
@ -6333,9 +6329,9 @@ class FallbackKernel(ExternKernelAlloc):
V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type] V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type]
# args that are aliased # args that are aliased
self.alias_names: List[str] = [] self.alias_names: list[str] = []
# args that are mutated AND returned from the op # args that are mutated AND returned from the op
self.mutation_names: List[str] = [] self.mutation_names: list[str] = []
if isinstance(self.op_overload, torch._ops.HigherOrderOperator): if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
# We assume here that HOPs with FallbackKernel are functional. # We assume here that HOPs with FallbackKernel are functional.
@ -6834,7 +6830,7 @@ class MultiOutput(ExternKernel):
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices), self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
) )
def __init__(self, layout: OutputSpec, input, indices: List[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def] def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
super().__init__(None, layout, [input], ()) super().__init__(None, layout, [input], ())
self.name = V.graph.register_buffer(self) self.name = V.graph.register_buffer(self)
V.graph.register_operation(self) V.graph.register_operation(self)
@ -6903,18 +6899,18 @@ class MutableBox(IRNode):
return self.data.freeze_layout() return self.data.freeze_layout()
def freeze_layout_with_stride_order( def freeze_layout_with_stride_order(
self, order: List[int], allow_padding: bool = False self, order: list[int], allow_padding: bool = False
) -> None: ) -> None:
return self.data.freeze_layout_with_stride_order(order, allow_padding) return self.data.freeze_layout_with_stride_order(order, allow_padding)
def freeze_layout_with_fill_order(self, order: List[int]) -> None: def freeze_layout_with_fill_order(self, order: list[int]) -> None:
return self.data.freeze_layout_with_fill_order(order) return self.data.freeze_layout_with_fill_order(order)
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None: def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
return self.data.freeze_layout_with_same_order(stride) return self.data.freeze_layout_with_same_order(stride)
def freeze_layout_with_exact_strides( def freeze_layout_with_exact_strides(
self, exact_strides: List[_IntLike], allow_padding: bool = False self, exact_strides: list[_IntLike], allow_padding: bool = False
) -> None: ) -> None:
return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding) return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
@ -7119,11 +7115,11 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
class InvokeSubgraph(ExternKernel): class InvokeSubgraph(ExternKernel):
subgraph: Optional[Subgraph] = None subgraph: Optional[Subgraph] = None
operands: Optional[List[TensorBox]] = None operands: Optional[list[TensorBox]] = None
outputs: Optional[List[MultiOutput]] = None outputs: Optional[list[MultiOutput]] = None
def __init__( def __init__(
self, subgraph: Subgraph, operands: List[TensorBox], layout: MultiOutputLayout self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout
) -> None: ) -> None:
super().__init__( super().__init__(
name=None, name=None,
@ -7212,15 +7208,15 @@ class InvokeSubgraph(ExternKernel):
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
class Conditional(ExternKernel): class Conditional(ExternKernel):
predicate: Optional[IRNode] = None predicate: Optional[IRNode] = None
operands: Optional[List[TensorBox]] = None operands: Optional[list[TensorBox]] = None
true_subgraph: Optional[Subgraph] = None true_subgraph: Optional[Subgraph] = None
false_subgraph: Optional[Subgraph] = None false_subgraph: Optional[Subgraph] = None
outputs: Optional[List[MultiOutput]] = None outputs: Optional[list[MultiOutput]] = None
def __init__( def __init__(
self, self,
predicate: IRNode, predicate: IRNode,
operands: List[TensorBox], operands: list[TensorBox],
true_subgraph: Subgraph, true_subgraph: Subgraph,
false_subgraph: Subgraph, false_subgraph: Subgraph,
layout: MultiOutputLayout, layout: MultiOutputLayout,
@ -7250,7 +7246,7 @@ class Conditional(ExternKernel):
predicate: TensorBox, predicate: TensorBox,
true_fn: Subgraph, true_fn: Subgraph,
false_fn: Subgraph, false_fn: Subgraph,
operands: List[TensorBox], operands: list[TensorBox],
): ):
predicate = cls.realize_input(predicate) predicate = cls.realize_input(predicate)
operands = [cls.realize_input(x) for x in operands] operands = [cls.realize_input(x) for x in operands]
@ -7332,16 +7328,16 @@ class Conditional(ExternKernel):
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
class WhileLoop(ExternKernel): class WhileLoop(ExternKernel):
carried_inputs: Optional[List[TensorBox]] = None carried_inputs: Optional[list[TensorBox]] = None
additional_inputs: Optional[List[TensorBox]] = None additional_inputs: Optional[list[TensorBox]] = None
cond_subgraph: Optional[Subgraph] = None cond_subgraph: Optional[Subgraph] = None
body_subgraph: Optional[Subgraph] = None body_subgraph: Optional[Subgraph] = None
outputs: Optional[List[MultiOutput]] = None outputs: Optional[list[MultiOutput]] = None
def __init__( def __init__(
self, self,
carried_inputs: List[TensorBox], carried_inputs: list[TensorBox],
additional_inputs: List[TensorBox], additional_inputs: list[TensorBox],
cond_subgraph: Subgraph, cond_subgraph: Subgraph,
body_subgraph: Subgraph, body_subgraph: Subgraph,
layout: MultiOutputLayout, layout: MultiOutputLayout,
@ -7365,8 +7361,8 @@ class WhileLoop(ExternKernel):
cls, cls,
cond_fn: Subgraph, cond_fn: Subgraph,
body_fn: Subgraph, body_fn: Subgraph,
carried_inputs: List[TensorBox], carried_inputs: list[TensorBox],
additional_inputs: List[TensorBox], additional_inputs: list[TensorBox],
): ):
carried_inputs = [cls.realize_input(x) for x in carried_inputs] carried_inputs = [cls.realize_input(x) for x in carried_inputs]
additional_inputs = [cls.realize_input(x) for x in additional_inputs] additional_inputs = [cls.realize_input(x) for x in additional_inputs]
@ -7411,8 +7407,8 @@ class WhileLoop(ExternKernel):
for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)): for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)):
def _guard_list_equals( def _guard_list_equals(
lhs_exprs: List[Union[int, sympy.expr]], lhs_exprs: list[Union[int, sympy.expr]],
rhs_exprs: List[Union[int, sympy.expr]], rhs_exprs: list[Union[int, sympy.expr]],
) -> None: ) -> None:
for lhs, rhs in zip(lhs_exprs, rhs_exprs): for lhs, rhs in zip(lhs_exprs, rhs_exprs):
V.graph.sizevars.guard_equals(lhs, rhs) V.graph.sizevars.guard_equals(lhs, rhs)
@ -7549,7 +7545,7 @@ class _CollectiveKernel(FallbackKernel):
# mutation of the input buffers. # mutation of the input buffers.
@classmethod @classmethod
def create_inplace( # type: ignore[no-untyped-def] def create_inplace( # type: ignore[no-untyped-def]
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
) -> None: ) -> None:
with V.graph.fake_mode: with V.graph.fake_mode:
( (
@ -7610,7 +7606,7 @@ class _CollectiveKernel(FallbackKernel):
# usage in the user program. # usage in the user program.
@classmethod @classmethod
def create_out_of_place( # type: ignore[no-untyped-def] def create_out_of_place( # type: ignore[no-untyped-def]
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
): ):
with V.graph.fake_mode: with V.graph.fake_mode:
( (