mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_inductor/[_-i]* (#145137)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
cede43e06b
commit
893ca1dfe1
@ -28,8 +28,8 @@ log = logging.getLogger(__name__)
|
||||
|
||||
def compile(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: List[InputType],
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
example_inputs: list[InputType],
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
Compile a given FX graph with TorchInductor. This allows compiling
|
||||
@ -54,7 +54,7 @@ def aoti_compile_and_package(
|
||||
_deprecated_unused_kwargs=None,
|
||||
*,
|
||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||
inductor_configs: Optional[Dict[str, Any]] = None,
|
||||
inductor_configs: Optional[dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Compiles the exported program with AOTInductor, and packages it into a .pt2
|
||||
@ -131,11 +131,11 @@ def _aoti_compile_and_package_inner(
|
||||
gm: torch.nn.Module,
|
||||
# flat_example_inputs: List[Any],
|
||||
args: tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
load_and_run: bool = False,
|
||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||
inductor_configs: Optional[Dict[str, Any]] = None,
|
||||
inductor_configs: Optional[dict[str, Any]] = None,
|
||||
):
|
||||
"""
|
||||
See docstring for aoti_compile_and_package.
|
||||
@ -199,10 +199,10 @@ def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type
|
||||
def aot_compile(
|
||||
gm: torch.fx.GraphModule,
|
||||
args: tuple[Any],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, List[str]]:
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> Union[str, list[str]]:
|
||||
"""
|
||||
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
||||
|
||||
@ -232,7 +232,7 @@ def aot_compile(
|
||||
|
||||
def list_mode_options(
|
||||
mode: Optional[str] = None, dynamic: Optional[bool] = None
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
r"""Returns a dictionary describing the optimizations that each of the available
|
||||
modes passed to `torch.compile()` performs.
|
||||
|
||||
@ -245,7 +245,7 @@ def list_mode_options(
|
||||
>>> torch._inductor.list_mode_options()
|
||||
"""
|
||||
|
||||
mode_options: Dict[str, Dict[str, bool]] = {
|
||||
mode_options: dict[str, dict[str, bool]] = {
|
||||
"default": {},
|
||||
# enable cudagraphs
|
||||
"reduce-overhead": {
|
||||
@ -267,7 +267,7 @@ def list_mode_options(
|
||||
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
|
||||
|
||||
|
||||
def list_options() -> List[str]:
|
||||
def list_options() -> list[str]:
|
||||
r"""Returns a dictionary describing the optimizations and debug configurations
|
||||
that are available to `torch.compile()`.
|
||||
|
||||
@ -280,7 +280,7 @@ def list_options() -> List[str]:
|
||||
|
||||
from torch._inductor import config
|
||||
|
||||
current_config: Dict[str, Any] = config.get_config_copy()
|
||||
current_config: dict[str, Any] = config.get_config_copy()
|
||||
|
||||
return list(current_config.keys())
|
||||
|
||||
|
@ -2,7 +2,7 @@ import json
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest import mock
|
||||
|
||||
import torch
|
||||
@ -31,7 +31,7 @@ def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
|
||||
|
||||
def load_aoti_eager_cache(
|
||||
ns: str, op_func_name_with_overload: str, device_type: str
|
||||
) -> List[Optional[Dict[str, Any]]]:
|
||||
) -> list[Optional[dict[str, Any]]]:
|
||||
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
|
||||
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
|
||||
if not op_conf.exists():
|
||||
@ -81,7 +81,7 @@ def load_aoti_eager_cache(
|
||||
return []
|
||||
|
||||
|
||||
def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
|
||||
def supported_builtin_dtype_torch_dtype() -> dict[type, torch.dtype]:
|
||||
return {int: torch.int32, float: torch.float, bool: torch.bool}
|
||||
|
||||
|
||||
@ -90,8 +90,8 @@ def supported_scalar_types() -> tuple[type, ...]:
|
||||
return tuple(type_to_torch_dtype.keys())
|
||||
|
||||
|
||||
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
|
||||
metadata: Dict[str, Any] = {}
|
||||
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> dict[str, Any]:
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["is_dynamic"] = dynamic
|
||||
|
||||
assert isinstance(input, torch.Tensor)
|
||||
@ -110,21 +110,21 @@ def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any
|
||||
|
||||
def extract_tensor_list_metadata(
|
||||
dynamic: bool,
|
||||
input: List[torch.Tensor],
|
||||
) -> Dict[str, Any]:
|
||||
input: list[torch.Tensor],
|
||||
) -> dict[str, Any]:
|
||||
metadata_list = []
|
||||
for item in input:
|
||||
assert isinstance(item, torch.Tensor)
|
||||
metadata_list.append(extract_tensor_metadata(dynamic, item))
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["tensor_list"] = metadata_list
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
|
||||
def extract_scalar_metadata(device_type: str, input: Any) -> dict[str, Any]:
|
||||
assert isinstance(input, supported_scalar_types())
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["is_dynamic"] = False
|
||||
# Scalar tensor
|
||||
metadata["device_type"] = device_type
|
||||
@ -135,31 +135,31 @@ def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_string_metadata(input: str) -> Dict[str, Any]:
|
||||
def extract_string_metadata(input: str) -> dict[str, Any]:
|
||||
assert isinstance(input, str)
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["string_value"] = input
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]:
|
||||
def extract_dtype_metadata(input: torch.dtype) -> dict[str, Any]:
|
||||
assert isinstance(input, torch.dtype)
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["dtype_value"] = f"{input}"
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_device_metadata(input: torch.device) -> Dict[str, Any]:
|
||||
def extract_device_metadata(input: torch.device) -> dict[str, Any]:
|
||||
assert isinstance(input, torch.device)
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["device_type_value"] = f"{input.type}"
|
||||
metadata["device_index_value"] = input.index
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]:
|
||||
def extract_layout_metadata(input: torch.layout) -> dict[str, Any]:
|
||||
assert isinstance(input, torch.layout)
|
||||
metadata: Dict[str, Any] = {}
|
||||
metadata: dict[str, Any] = {}
|
||||
metadata["layout_value"] = f"{input}"
|
||||
return metadata
|
||||
|
||||
@ -171,10 +171,10 @@ def aoti_compile_with_persistent_cache(
|
||||
dynamic: bool,
|
||||
f: Callable[..., Any],
|
||||
args: tuple[Any],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
*,
|
||||
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
dynamic_shapes: Optional[dict[str, Any]] = None,
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
remove_runtime_assertions: bool = False,
|
||||
disable_constraint_solver: bool = False,
|
||||
) -> str:
|
||||
@ -261,7 +261,7 @@ def aoti_compile_with_persistent_cache(
|
||||
metadata["arg_order"] = idx
|
||||
kernel_metadata_items.append(metadata)
|
||||
|
||||
kernel_meta_info: Dict[str, Any] = {}
|
||||
kernel_meta_info: dict[str, Any] = {}
|
||||
kernel_meta_info["meta_info"] = kernel_metadata_items
|
||||
kernel_meta_info["kernel_path"] = (
|
||||
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
|
||||
|
@ -11,7 +11,7 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from functools import partial
|
||||
from time import time
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_registered_device_interfaces
|
||||
@ -250,7 +250,7 @@ class AsyncCompile:
|
||||
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
|
||||
return LambdaFuture(lambda: get_result().kernel)
|
||||
|
||||
def cpp_pybinding(self, argtypes: List[str], source_code: str):
|
||||
def cpp_pybinding(self, argtypes: list[str], source_code: str):
|
||||
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
|
||||
if get_compile_threads() <= 1:
|
||||
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
|
||||
@ -299,7 +299,7 @@ class AsyncCompile:
|
||||
)
|
||||
return LambdaFuture(get_result)
|
||||
|
||||
def wait(self, scope: Dict[str, Any]) -> None:
|
||||
def wait(self, scope: dict[str, Any]) -> None:
|
||||
with dynamo_timed(
|
||||
"async_compile.wait",
|
||||
log_pt2_compile_event=True,
|
||||
|
@ -1,7 +1,7 @@
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
@ -50,16 +50,16 @@ class AutoHeuristic:
|
||||
a heuristic (see torchgen/autoheuristic/).
|
||||
"""
|
||||
|
||||
collected_feedback: Dict[Choice, Feedback]
|
||||
collected_feedback: dict[Choice, Feedback]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fallback: Callable[[], Choice],
|
||||
choices: List[Choice],
|
||||
choices: list[Choice],
|
||||
feedback: Optional[LocalFeedback],
|
||||
context: AHContext,
|
||||
name: str,
|
||||
augment_context: Optional[List[AHOperation]] = None,
|
||||
augment_context: Optional[list[AHOperation]] = None,
|
||||
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -135,8 +135,8 @@ class AutoHeuristic:
|
||||
return self.fallback()
|
||||
|
||||
def get_top_k_choices(
|
||||
self, top_k: int, always_included: Optional[List[str]] = None
|
||||
) -> Optional[List[Choice]]:
|
||||
self, top_k: int, always_included: Optional[list[str]] = None
|
||||
) -> Optional[list[Choice]]:
|
||||
if not self.satisfies_precondition():
|
||||
return None
|
||||
if torch._inductor.config.use_autoheuristic(self.name):
|
||||
@ -223,11 +223,11 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
||||
def __init__(
|
||||
self,
|
||||
fallback: Callable[[], Optional[ChoiceCaller]],
|
||||
choices: List[ChoiceCaller],
|
||||
input_nodes: List[Any],
|
||||
choices: list[ChoiceCaller],
|
||||
input_nodes: list[Any],
|
||||
context: AHContext,
|
||||
name: str,
|
||||
augment_context: Optional[List[AHOperation]] = None,
|
||||
augment_context: Optional[list[AHOperation]] = None,
|
||||
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
@ -237,7 +237,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
||||
have to be used here.
|
||||
"""
|
||||
self.input_nodes = input_nodes
|
||||
self.choicestr2choice: Dict[str, ChoiceCaller] = {}
|
||||
self.choicestr2choice: dict[str, ChoiceCaller] = {}
|
||||
for choice in choices:
|
||||
self.choicestr2choice[choice.autoheuristic_id()] = choice
|
||||
choices_str = list(self.choicestr2choice.keys())
|
||||
@ -266,7 +266,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
||||
self.register_global_feedback(input_nodes, choices)
|
||||
|
||||
def register_global_feedback(
|
||||
self, input_nodes: List[Any], choices: List[ChoiceCaller]
|
||||
self, input_nodes: list[Any], choices: list[ChoiceCaller]
|
||||
) -> None:
|
||||
"""
|
||||
Registers a callback in select_algorithm, which is called with the timing of each choice.
|
||||
@ -281,10 +281,10 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
||||
def store_global_feedback(
|
||||
ah_inputs_key: str,
|
||||
ah_precompile_key: str,
|
||||
timings: Dict[ChoiceCaller, float],
|
||||
timings: dict[ChoiceCaller, float],
|
||||
name: str,
|
||||
input_nodes: List[Any],
|
||||
choices: List[ChoiceCaller],
|
||||
input_nodes: list[Any],
|
||||
choices: list[ChoiceCaller],
|
||||
) -> None:
|
||||
current_inputs_key = create_inputs_key(input_nodes)
|
||||
if current_inputs_key != ah_inputs_key:
|
||||
@ -307,8 +307,8 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
||||
return self.choicestr2choice.get(choice, None)
|
||||
|
||||
def get_top_k_choices_caller(
|
||||
self, top_k: int, always_included: Optional[List[str]] = None
|
||||
) -> Optional[List[ChoiceCaller]]:
|
||||
self, top_k: int, always_included: Optional[list[str]] = None
|
||||
) -> Optional[list[ChoiceCaller]]:
|
||||
choices = self.get_top_k_choices(top_k, always_included)
|
||||
if choices is None:
|
||||
return None
|
||||
|
@ -1,5 +1,5 @@
|
||||
import functools
|
||||
from typing import Any, Callable, Dict, List
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
@ -51,8 +51,8 @@ class AHContext:
|
||||
information that will help to learn a heuristic.
|
||||
"""
|
||||
|
||||
features: List[AHFeature]
|
||||
context_dict: Dict[str, Value]
|
||||
features: list[AHFeature]
|
||||
context_dict: dict[str, Value]
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.features = []
|
||||
@ -64,7 +64,7 @@ class AHContext:
|
||||
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
|
||||
self.context_dict[name] = value
|
||||
|
||||
def get_numerical_and_categorical_features(self) -> tuple[List[str], List[str]]:
|
||||
def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
|
||||
numerical_features = []
|
||||
categorical_features = []
|
||||
for feature in self.features:
|
||||
@ -84,7 +84,7 @@ class AHContext:
|
||||
def get_value(self, name: str) -> Value:
|
||||
return self.context_dict[name]
|
||||
|
||||
def apply_operations(self, operations: List[AHOperation]) -> None:
|
||||
def apply_operations(self, operations: list[AHOperation]) -> None:
|
||||
for op in operations:
|
||||
op.apply_operation(self.context_dict)
|
||||
|
||||
@ -94,7 +94,7 @@ class AHMetadata:
|
||||
self,
|
||||
shared_memory: Any,
|
||||
device_capa: tuple[int, int],
|
||||
choices: List[Choice],
|
||||
choices: list[Choice],
|
||||
name: str,
|
||||
) -> None:
|
||||
# use amount of shared_memory and device_capability to identify GPU
|
||||
@ -104,7 +104,7 @@ class AHMetadata:
|
||||
self.choices = choices
|
||||
self.name = name
|
||||
|
||||
def to_dict(self) -> Dict[str, Value]:
|
||||
def to_dict(self) -> dict[str, Value]:
|
||||
return {
|
||||
"shared_memory": self.shared_memory,
|
||||
"device_capa": self.device_capa,
|
||||
@ -147,7 +147,7 @@ def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
|
||||
return mat1_iscontig and not mat2_iscontig
|
||||
|
||||
|
||||
def get_mult_dims_ops() -> List[AHOperation]:
|
||||
def get_mult_dims_ops() -> list[AHOperation]:
|
||||
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
|
||||
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
|
||||
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
|
||||
@ -163,7 +163,7 @@ def get_arith_intensity(data: Any) -> float:
|
||||
return m * k * n / (m * k + k * n + m * n)
|
||||
|
||||
|
||||
def pad_mm_operations() -> List[AHOperation]:
|
||||
def pad_mm_operations() -> list[AHOperation]:
|
||||
mult_dims_ops = get_mult_dims_ops()
|
||||
k_div_m_times_n_op = AHOperation(
|
||||
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
|
||||
@ -200,7 +200,7 @@ def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
|
||||
return data[dim] >= lower and data[dim] <= upper
|
||||
|
||||
|
||||
def between_ops() -> List[AHOperation]:
|
||||
def between_ops() -> list[AHOperation]:
|
||||
dims = ["m", "k", "n"]
|
||||
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
|
||||
ah_operations = []
|
||||
@ -221,13 +221,13 @@ def pow2_op(data: Any, dim: str, exponent: int) -> bool:
|
||||
return data[dim] == 2**exponent
|
||||
|
||||
|
||||
def mm_operations() -> List[AHOperation]:
|
||||
def mm_operations() -> list[AHOperation]:
|
||||
mult_dims_ops = get_mult_dims_ops()
|
||||
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
||||
return mult_dims_ops + [arith_intensity_op]
|
||||
|
||||
|
||||
def mixed_mm_operations() -> List[AHOperation]:
|
||||
def mixed_mm_operations() -> list[AHOperation]:
|
||||
return mm_operations() + between_ops()
|
||||
|
||||
|
||||
@ -235,7 +235,7 @@ def is_multiple(data: Any, dim: str, mult: int) -> bool:
|
||||
return data[dim] % mult == 0
|
||||
|
||||
|
||||
def get_dims_multiple_ops() -> List[AHOperation]:
|
||||
def get_dims_multiple_ops() -> list[AHOperation]:
|
||||
multiples = [2, 4, 8, 16, 32]
|
||||
dims = ["m", "k", "n"]
|
||||
dims_multiple_ops = []
|
||||
@ -249,7 +249,7 @@ def get_dims_multiple_ops() -> List[AHOperation]:
|
||||
return dims_multiple_ops
|
||||
|
||||
|
||||
def get_dims_need_padding_ops() -> List[AHOperation]:
|
||||
def get_dims_need_padding_ops() -> list[AHOperation]:
|
||||
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
|
||||
mat1_stride_0 = data["mat1_stride_0"]
|
||||
mat1_stride_1 = data["mat1_stride_1"]
|
||||
@ -303,7 +303,7 @@ def get_dims_need_padding_ops() -> List[AHOperation]:
|
||||
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
|
||||
|
||||
|
||||
def get_is_contig_ops() -> List[AHOperation]:
|
||||
def get_is_contig_ops() -> list[AHOperation]:
|
||||
def mat1_is_contig_fn(data: Any) -> bool:
|
||||
stride_0 = data["mat1_stride_0"]
|
||||
stride_1 = data["mat1_stride_1"]
|
||||
|
@ -2,7 +2,7 @@ import importlib
|
||||
import inspect
|
||||
import pkgutil
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
AHContext,
|
||||
@ -14,7 +14,7 @@ from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeur
|
||||
|
||||
def find_and_instantiate_subclasses(
|
||||
package_name: str, base_class: Any
|
||||
) -> List[LearnedHeuristic]:
|
||||
) -> list[LearnedHeuristic]:
|
||||
instances = []
|
||||
|
||||
package = importlib.import_module(package_name)
|
||||
@ -49,7 +49,7 @@ class LearnedHeuristicController:
|
||||
a way to get the decision of a learned heuristic.
|
||||
"""
|
||||
|
||||
existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list)
|
||||
existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
|
||||
"""
|
||||
A dictionary that stores all the learned heuristics for each optimization.
|
||||
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
|
||||
@ -69,7 +69,7 @@ class LearnedHeuristicController:
|
||||
self.metadata = metadata
|
||||
self.context = context
|
||||
|
||||
def get_heuristics(self, name: str) -> List[LearnedHeuristic]:
|
||||
def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
|
||||
"""
|
||||
Returns a list of learned heuristics for the given optimization name.
|
||||
"""
|
||||
@ -105,7 +105,7 @@ class LearnedHeuristicController:
|
||||
return heuristic.get_decision(self.context, self.metadata.choices)
|
||||
return None
|
||||
|
||||
def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
|
||||
def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
|
||||
heuristics = self.get_heuristics(self.metadata.name)
|
||||
for heuristic in heuristics:
|
||||
if heuristic.check_precondition(self.metadata, self.context):
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||
AHContext,
|
||||
@ -23,7 +23,7 @@ class LearnedHeuristic:
|
||||
return True
|
||||
|
||||
def get_decision(
|
||||
self, context: AHContext, choices: List[Choice]
|
||||
self, context: AHContext, choices: list[Choice]
|
||||
) -> Optional[Choice]:
|
||||
return None
|
||||
|
||||
@ -33,7 +33,7 @@ class LearnedHeuristic:
|
||||
def get_name(self) -> str:
|
||||
return ""
|
||||
|
||||
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
|
||||
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
|
||||
return None
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ class LearnedHeuristicRegression(LearnedHeuristic):
|
||||
return 1.0
|
||||
|
||||
def get_decision(
|
||||
self, context: AHContext, choices: List[Choice]
|
||||
self, context: AHContext, choices: list[Choice]
|
||||
) -> Optional[Choice]:
|
||||
choice2feedback = {}
|
||||
for choice in choices:
|
||||
@ -68,7 +68,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
|
||||
return None
|
||||
|
||||
def get_decision(
|
||||
self, context: AHContext, choices: List[Choice]
|
||||
self, context: AHContext, choices: list[Choice]
|
||||
) -> Optional[Choice]:
|
||||
best_choices = self.get_best_choices(context)
|
||||
if not best_choices:
|
||||
@ -78,7 +78,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
|
||||
return None
|
||||
return self.get_choice(best_choice_idx)
|
||||
|
||||
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
|
||||
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
|
||||
feedback_idx_list = self.get_best_choices(context)
|
||||
if feedback_idx_list is None:
|
||||
return None
|
||||
@ -88,5 +88,5 @@ class LearnedHeuristicDecision(LearnedHeuristic):
|
||||
choices = [choice for choice in choices if choice is not None]
|
||||
return choices
|
||||
|
||||
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
|
||||
def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
|
||||
return []
|
||||
|
@ -10,19 +10,10 @@ import os
|
||||
import queue
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import Iterable, Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from ctypes import byref, c_size_t, c_void_p, CDLL
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
@ -396,8 +387,8 @@ class TuningProcessPool:
|
||||
|
||||
def benchmark(
|
||||
self,
|
||||
choices: List[TritonTemplateCaller],
|
||||
) -> Dict[TritonTemplateCaller, float]:
|
||||
choices: list[TritonTemplateCaller],
|
||||
) -> dict[TritonTemplateCaller, float]:
|
||||
"""
|
||||
Benchmark each choice in a separate process.
|
||||
"""
|
||||
@ -432,9 +423,9 @@ class TensorMeta:
|
||||
@classmethod
|
||||
def from_irnodes(
|
||||
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
||||
) -> Union[TensorMeta, List[TensorMeta]]:
|
||||
) -> Union[TensorMeta, list[TensorMeta]]:
|
||||
if isinstance(irnodes, Sequence):
|
||||
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
|
||||
result: list[Any] = [cls.from_irnodes(x) for x in irnodes]
|
||||
assert all(isinstance(x, TensorMeta) for x in result)
|
||||
return result
|
||||
|
||||
@ -488,8 +479,8 @@ class BenchmarkRequest:
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
extra_args: Iterable[Any],
|
||||
) -> None:
|
||||
# the kernel name defined in the module
|
||||
@ -640,12 +631,12 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
extra_args: Iterable[Any],
|
||||
module_path: str, # the path of the module defining the triton kernel
|
||||
module_cache_key: str,
|
||||
grid: List[int],
|
||||
grid: list[int],
|
||||
num_stages: int,
|
||||
num_warps: int,
|
||||
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
||||
@ -770,8 +761,8 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
extra_args: Iterable[Any],
|
||||
source_code: str,
|
||||
) -> None:
|
||||
@ -889,8 +880,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
||||
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||
extra_args: Iterable[Any],
|
||||
source_code: str,
|
||||
) -> None:
|
||||
@ -946,8 +937,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
|
||||
|
||||
|
||||
def benchmark_in_sub_process(
|
||||
choices: List[TritonTemplateCaller],
|
||||
) -> Dict[TritonTemplateCaller, float]:
|
||||
choices: list[TritonTemplateCaller],
|
||||
) -> dict[TritonTemplateCaller, float]:
|
||||
"""
|
||||
Do benchmarking in a subprocess and return the perf number (latency).
|
||||
"""
|
||||
|
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from sympy import Expr
|
||||
|
||||
@ -43,7 +43,7 @@ class BoundVars:
|
||||
or "masked_subblock" in node.target
|
||||
)
|
||||
# To access this variable call `get_bounds()`
|
||||
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
|
||||
self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
@ -55,7 +55,7 @@ class BoundVars:
|
||||
)
|
||||
|
||||
@cache_on_self
|
||||
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
|
||||
def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
|
||||
submodules = self.swap_submodules(self.loop_body.submodules)
|
||||
|
||||
# Initialize the environment with the unbounded variables
|
||||
@ -74,9 +74,9 @@ class BoundVars:
|
||||
return self._bounds
|
||||
|
||||
def swap_submodules(
|
||||
self, submodules: Dict[str, Callable[..., Any]]
|
||||
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
|
||||
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
|
||||
self, submodules: dict[str, Callable[..., Any]]
|
||||
) -> dict[str, Callable[..., ValueRanges[Expr]]]:
|
||||
result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
|
||||
for key in submodules.keys():
|
||||
if key == "get_index":
|
||||
result[key] = self.get_index
|
||||
@ -111,10 +111,10 @@ class BoundVars:
|
||||
def masked_subblock(
|
||||
self,
|
||||
subblock: LoopBodyBlock,
|
||||
env: Dict[torch.fx.Node, ValueRanges[Expr]],
|
||||
env: dict[torch.fx.Node, ValueRanges[Expr]],
|
||||
mask: Any,
|
||||
value: Any,
|
||||
submodules: Dict[str, Callable[..., Any]],
|
||||
submodules: dict[str, Callable[..., Any]],
|
||||
) -> ValueRanges[Expr]:
|
||||
interp = InterpreterShim(subblock.graph, submodules)
|
||||
interp.run(V.get_ops_handler(), initial_env=env)
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, Dict, List, Type, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
@ -42,11 +42,11 @@ class InductorChoices:
|
||||
|
||||
def triton_kernel_kwargs(
|
||||
self,
|
||||
kernel_cls: Type[TritonKernel],
|
||||
kernel_cls: type[TritonKernel],
|
||||
features: SIMDKernelFeatures,
|
||||
groups: List[sympy.Expr],
|
||||
kernel_kwargs: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
groups: list[sympy.Expr],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
|
||||
return kernel_kwargs
|
||||
|
||||
|
@ -36,13 +36,8 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -127,7 +122,7 @@ else:
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import KeysView
|
||||
from collections.abc import Generator, KeysView, Sequence
|
||||
from concurrent.futures import Future
|
||||
|
||||
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
|
||||
@ -168,7 +163,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
|
||||
class CacheBase:
|
||||
@staticmethod
|
||||
@functools.lru_cache(None)
|
||||
def get_system() -> Dict[str, Any]:
|
||||
def get_system() -> dict[str, Any]:
|
||||
try:
|
||||
from triton.compiler.compiler import triton_key
|
||||
|
||||
@ -179,7 +174,7 @@ class CacheBase:
|
||||
triton_version = None
|
||||
|
||||
try:
|
||||
system: Dict[str, Any] = {
|
||||
system: dict[str, Any] = {
|
||||
"device": {"name": None},
|
||||
"version": {
|
||||
"triton": triton_version,
|
||||
@ -217,7 +212,7 @@ class CacheBase:
|
||||
def __init__(self) -> None:
|
||||
self.system = CacheBase.get_system()
|
||||
|
||||
def get_local_cache(self) -> Dict[str, Any]:
|
||||
def get_local_cache(self) -> dict[str, Any]:
|
||||
local_cache_path = self.get_local_cache_path()
|
||||
if not local_cache_path.is_file():
|
||||
return {}
|
||||
@ -225,7 +220,7 @@ class CacheBase:
|
||||
local_cache = json.load(local_cache_fp)
|
||||
return local_cache["cache"]
|
||||
|
||||
def update_local_cache(self, local_cache: Dict[str, Any]) -> None:
|
||||
def update_local_cache(self, local_cache: dict[str, Any]) -> None:
|
||||
local_cache_path = self.get_local_cache_path()
|
||||
write_atomic(
|
||||
str(local_cache_path),
|
||||
@ -235,7 +230,7 @@ class CacheBase:
|
||||
|
||||
|
||||
class LocalCache(CacheBase):
|
||||
def lookup(self, *keys: str) -> Optional[Dict[str, Any]]:
|
||||
def lookup(self, *keys: str) -> Optional[dict[str, Any]]:
|
||||
cache = self.get_local_cache()
|
||||
|
||||
sub_cache = cache
|
||||
@ -261,7 +256,7 @@ class LocalCache(CacheBase):
|
||||
|
||||
class PersistentCache(CacheBase):
|
||||
@functools.lru_cache(None) # noqa: B019
|
||||
def get_global_cache(self) -> Dict[str, Any]:
|
||||
def get_global_cache(self) -> dict[str, Any]:
|
||||
global_cache_path = self.get_global_cache_path()
|
||||
if global_cache_path is None or not global_cache_path.is_file():
|
||||
return {}
|
||||
@ -271,11 +266,11 @@ class PersistentCache(CacheBase):
|
||||
|
||||
def lookup(
|
||||
self,
|
||||
choices: List[ChoiceCaller],
|
||||
choices: list[ChoiceCaller],
|
||||
op: str,
|
||||
inputs: str,
|
||||
benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]],
|
||||
) -> Dict[ChoiceCaller, float]:
|
||||
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
|
||||
) -> dict[ChoiceCaller, float]:
|
||||
"""
|
||||
Check to see if we have benchmarked the given choice callers. For each
|
||||
choice caller:
|
||||
@ -296,7 +291,7 @@ class PersistentCache(CacheBase):
|
||||
)
|
||||
timings = {}
|
||||
|
||||
def check_cache(cache: Dict[str, Any], callback: Any = None) -> bool:
|
||||
def check_cache(cache: dict[str, Any], callback: Any = None) -> bool:
|
||||
"""Check if `cache` contains data for all the choices"""
|
||||
hit = True
|
||||
for choice in choices:
|
||||
@ -456,7 +451,7 @@ class TensorMetadataAndValues:
|
||||
"""
|
||||
|
||||
tensor_metadata: TensorMetadata
|
||||
values: List[Any]
|
||||
values: list[Any]
|
||||
|
||||
|
||||
def _ident(x: T) -> T:
|
||||
@ -584,7 +579,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||
|
||||
def _reduce_graph_module(
|
||||
self, gm: torch.fx.GraphModule
|
||||
) -> tuple[Any, tuple[Dict[str, Any], str]]:
|
||||
) -> tuple[Any, tuple[dict[str, Any], str]]:
|
||||
"""
|
||||
Custom reducer for graph module to handle irrelevant data for user
|
||||
defined triton kernels
|
||||
@ -624,7 +619,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||
serialized_data = self.dumps(obj)
|
||||
return sha256_hash(serialized_data)
|
||||
|
||||
def debug_lines(self, inp: FxGraphHashDetails) -> List[str]:
|
||||
def debug_lines(self, inp: FxGraphHashDetails) -> list[str]:
|
||||
"""
|
||||
Get a printable string describing in more detail all the attributes
|
||||
comprising an object. Useful for debugging when one graph hashes
|
||||
@ -659,7 +654,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||
|
||||
|
||||
def build_code_hash(
|
||||
roots: List[str] | None, prefix: str, hasher: hashlib._Hash
|
||||
roots: list[str] | None, prefix: str, hasher: hashlib._Hash
|
||||
) -> None:
|
||||
for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
|
||||
spec = lib.module_finder.find_spec(lib.name, None)
|
||||
@ -721,7 +716,7 @@ class OrderedSetHolder:
|
||||
of set kwargs.
|
||||
"""
|
||||
|
||||
items: List[Any]
|
||||
items: list[Any]
|
||||
|
||||
|
||||
class BypassFxGraphCache(Exception):
|
||||
@ -753,7 +748,7 @@ class FxGraphHashDetails:
|
||||
# Order kwargs so hashing is stable to changes in kwarg order. Although
|
||||
# it's technically a _CompileFxKwargs we don't actually need it typed as
|
||||
# such since we're just using it to generate a hash.
|
||||
self.fx_kwargs: Dict[str, object] = {}
|
||||
self.fx_kwargs: dict[str, object] = {}
|
||||
for k, v in sorted(fx_kwargs.items()):
|
||||
if k not in self.EXCLUDED_KWARGS:
|
||||
if type(v) in (set, OrderedSet): # noqa: set_linter
|
||||
@ -774,7 +769,7 @@ class FxGraphHashDetails:
|
||||
|
||||
# Node meta will not be part of gm's reduce function, so lets remember
|
||||
# the kernel source code separately
|
||||
self.user_defined_triton_source: List[Any] = []
|
||||
self.user_defined_triton_source: list[Any] = []
|
||||
if gm is not None:
|
||||
for module in gm.modules():
|
||||
if not isinstance(module, torch.fx.GraphModule):
|
||||
@ -856,7 +851,7 @@ def compiled_fx_graph_hash(
|
||||
example_inputs: Sequence[InputType],
|
||||
fx_kwargs: _CompileFxKwargs,
|
||||
inputs_to_check: Sequence[int],
|
||||
) -> tuple[str, List[str]]:
|
||||
) -> tuple[str, list[str]]:
|
||||
"""
|
||||
Generate a unique hash of the FX graph for caching.
|
||||
"""
|
||||
@ -952,7 +947,7 @@ class FxGraphCache:
|
||||
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
|
||||
|
||||
@staticmethod
|
||||
def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]:
|
||||
def _filter_backed_symints(inputs: Sequence[InputType]) -> list[torch.SymInt]:
|
||||
"""
|
||||
Get the backed SymInt objects from the input list. Note that we can never
|
||||
have guards that depend on unbacked symint.
|
||||
@ -976,7 +971,7 @@ class FxGraphCache:
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
constants: CompiledFxGraphConstants,
|
||||
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
||||
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||
"""
|
||||
Lookup a compiled graph in the cache by key. On a hit, return the
|
||||
deserialized CompiledFxGraph object. On a miss, return None.
|
||||
@ -988,7 +983,7 @@ class FxGraphCache:
|
||||
hints = [hint_int(s) for s in symints]
|
||||
|
||||
def iterate_over_candidates() -> (
|
||||
Generator[Tuple[CompiledFxGraph, bytes], None, None]
|
||||
Generator[tuple[CompiledFxGraph, bytes], None, None]
|
||||
):
|
||||
if local:
|
||||
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
||||
@ -1021,7 +1016,7 @@ class FxGraphCache:
|
||||
# their guards to determine whether there's a hit.
|
||||
graph = None
|
||||
pickled_content = None
|
||||
cache_info: Dict[str, Any] = dict()
|
||||
cache_info: dict[str, Any] = dict()
|
||||
|
||||
for candidate, pickled_content in iterate_over_candidates():
|
||||
if not candidate.guards_expr:
|
||||
@ -1234,7 +1229,7 @@ class FxGraphCache:
|
||||
fx_kwargs: _CompileFxKwargs,
|
||||
inputs_to_check: Sequence[int],
|
||||
remote: bool,
|
||||
) -> tuple[Optional[tuple[str, List[str]]], Dict[str, Any]]:
|
||||
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
|
||||
"""
|
||||
Checks that the inductor input is cacheable, then computes
|
||||
and returns the cache key for the input.
|
||||
@ -1280,13 +1275,13 @@ class FxGraphCache:
|
||||
@staticmethod
|
||||
def load_with_key(
|
||||
key: str,
|
||||
debug_lines: List[str],
|
||||
debug_lines: list[str],
|
||||
example_inputs: Sequence[InputType],
|
||||
local: bool,
|
||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||
is_backward: bool,
|
||||
constants: CompiledFxGraphConstants,
|
||||
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
||||
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||
"""
|
||||
Lookup the graph with the given key, and return results and metadata.
|
||||
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
|
||||
@ -1373,11 +1368,11 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
class CudaKernelParamCache:
|
||||
cache: Dict[str, Dict[str, str]] = {}
|
||||
cache: dict[str, dict[str, str]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
|
||||
@classmethod
|
||||
def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> None:
|
||||
def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
|
||||
_, path = write(
|
||||
cubin,
|
||||
bin_type,
|
||||
@ -1391,7 +1386,7 @@ class CudaKernelParamCache:
|
||||
cls.cache[key] = params
|
||||
|
||||
@classmethod
|
||||
def get(cls, key: str) -> Optional[Dict[str, str]]:
|
||||
def get(cls, key: str) -> Optional[dict[str, str]]:
|
||||
return cls.cache.get(key, None)
|
||||
|
||||
@classmethod
|
||||
@ -1407,8 +1402,8 @@ class AotCodeCompiler:
|
||||
source_code: str,
|
||||
serialized_extern_kernel_nodes: Optional[str],
|
||||
device_type: str,
|
||||
additional_files: List[str],
|
||||
) -> Union[List[str], str]:
|
||||
additional_files: list[str],
|
||||
) -> Union[list[str], str]:
|
||||
"""
|
||||
Returns the .so path, or returns a list of files that were generated if
|
||||
config.aot_inductor.package=True.
|
||||
@ -1842,14 +1837,14 @@ def cpp_prefix() -> str:
|
||||
# Given a path to an input cpp file and an output path,
|
||||
# Attempts to compile the file, storing the output in "output_path"
|
||||
def compile_file(
|
||||
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
|
||||
input_path: Union[str, list[str]], output_path: str, cmd: list[str]
|
||||
) -> None:
|
||||
with dynamo_timed("compile_file"):
|
||||
return _compile_file(input_path, output_path, cmd)
|
||||
|
||||
|
||||
def _compile_file(
|
||||
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
|
||||
input_path: Union[str, list[str]], output_path: str, cmd: list[str]
|
||||
) -> None:
|
||||
input_paths = [input_path] if isinstance(input_path, str) else input_path
|
||||
input_files = [
|
||||
@ -1948,9 +1943,9 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p]:
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
class CppCodeCache:
|
||||
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags: Dict[str, Any] = {}
|
||||
cpp_compile_command_flags: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
|
||||
@ -2093,7 +2088,7 @@ def _worker_compile_cpp(
|
||||
# Customized Python binding for cpp kernels
|
||||
@clear_on_fresh_inductor_cache
|
||||
class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags = {
|
||||
# kernels have no dependency on libtorch
|
||||
@ -2212,7 +2207,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
@classmethod
|
||||
def load_pybinding_async(
|
||||
cls,
|
||||
argtypes: List[str],
|
||||
argtypes: list[str],
|
||||
source_code: str,
|
||||
device_type: str = "cpu",
|
||||
num_outputs: int = -1,
|
||||
@ -2269,7 +2264,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags = {
|
||||
"include_pytorch": True,
|
||||
@ -2335,7 +2330,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
_standalone_runtime_path: Optional[str] = None
|
||||
prefix = textwrap.dedent(
|
||||
@ -2412,7 +2407,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> List[str]:
|
||||
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]:
|
||||
assert arg.shape is not None
|
||||
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
|
||||
assert arg.offset is not None
|
||||
@ -2573,7 +2568,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
donefile = str(dirpath / "done")
|
||||
lockfile = str(dirpath / "lock")
|
||||
need_compile = not os.path.exists(donefile)
|
||||
jobs: List[Any] = []
|
||||
jobs: list[Any] = []
|
||||
if need_compile:
|
||||
write_atomic(genfile, source_code)
|
||||
cmd = [
|
||||
@ -2685,7 +2680,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
return sofile
|
||||
|
||||
|
||||
def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None:
|
||||
def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
|
||||
from torch.utils._filelock import FileLock
|
||||
|
||||
try:
|
||||
@ -2733,8 +2728,8 @@ class PyCodeCache:
|
||||
# clearing the cache. Note also that we may load the same path more
|
||||
# than once, but attach different attributes, i.e., due to different
|
||||
# constant values.
|
||||
modules: List[ModuleType] = []
|
||||
linemaps: Dict[str, List[tuple[Any, ...]]] = {}
|
||||
modules: list[ModuleType] = []
|
||||
linemaps: dict[str, list[tuple[Any, ...]]] = {}
|
||||
|
||||
@classmethod
|
||||
def write(cls, source_code: str, extra: str = "") -> tuple[str, str]:
|
||||
@ -2745,8 +2740,8 @@ class PyCodeCache:
|
||||
cls,
|
||||
source_code: str,
|
||||
extra: str = "",
|
||||
linemap: Optional[List[tuple[int, str]]] = None,
|
||||
attrs: Optional[Dict[str, Any]] = None,
|
||||
linemap: Optional[list[tuple[int, str]]] = None,
|
||||
attrs: Optional[dict[str, Any]] = None,
|
||||
) -> ModuleType:
|
||||
key, path = write(source_code, "py", extra=extra)
|
||||
return cls.load_by_key_path(key, path, linemap, attrs)
|
||||
@ -2756,8 +2751,8 @@ class PyCodeCache:
|
||||
cls,
|
||||
key: str,
|
||||
path: str,
|
||||
linemap: Optional[List[tuple[int, str]]] = None,
|
||||
attrs: Optional[Dict[str, Any]] = None,
|
||||
linemap: Optional[list[tuple[int, str]]] = None,
|
||||
attrs: Optional[dict[str, Any]] = None,
|
||||
) -> ModuleType:
|
||||
if linemap is None:
|
||||
linemap = []
|
||||
@ -2798,7 +2793,7 @@ class PyCodeCache:
|
||||
@functools.lru_cache(None)
|
||||
def stack_frames_for_code(
|
||||
cls, path: str, lineno: int
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
) -> Optional[list[dict[str, Any]]]:
|
||||
if path not in cls.linemaps:
|
||||
return None
|
||||
# [(starting_line, <fx node>), ...]
|
||||
@ -2810,7 +2805,7 @@ class PyCodeCache:
|
||||
if not entry:
|
||||
return None
|
||||
|
||||
def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
|
||||
def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
|
||||
# ideally fx stores stack traces as data rather than a string
|
||||
# but this is not along a performance critical path
|
||||
regex = r'File "(.+)", line (\d+), in (.+)\n'
|
||||
@ -2841,7 +2836,7 @@ def _cuda_compiler() -> Optional[str]:
|
||||
return "nvcc"
|
||||
|
||||
|
||||
def _cutlass_include_paths() -> List[str]:
|
||||
def _cutlass_include_paths() -> list[str]:
|
||||
if config.is_fbcode():
|
||||
from libfb.py import parutil
|
||||
|
||||
@ -2857,14 +2852,14 @@ def _cutlass_include_paths() -> List[str]:
|
||||
]
|
||||
|
||||
|
||||
def _cuda_lib_options() -> List[str]:
|
||||
def _cuda_lib_options() -> list[str]:
|
||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||
from torch.utils import cpp_extension
|
||||
|
||||
lpaths = cpp_extension.library_paths(device_type="cuda") + [
|
||||
sysconfig.get_config_var("LIBDIR")
|
||||
]
|
||||
extra_ldflags: List[str] = []
|
||||
extra_ldflags: list[str] = []
|
||||
if is_linux():
|
||||
_transform_cuda_paths(lpaths)
|
||||
for path in lpaths:
|
||||
@ -2880,7 +2875,7 @@ def _cuda_lib_options() -> List[str]:
|
||||
return extra_ldflags
|
||||
|
||||
|
||||
def _nvcc_host_compiler_options() -> List[str]:
|
||||
def _nvcc_host_compiler_options() -> list[str]:
|
||||
return [
|
||||
"-fPIC",
|
||||
"-fno-strict-aliasing",
|
||||
@ -2889,7 +2884,7 @@ def _nvcc_host_compiler_options() -> List[str]:
|
||||
]
|
||||
|
||||
|
||||
def _nvcc_compiler_options() -> List[str]:
|
||||
def _nvcc_compiler_options() -> list[str]:
|
||||
arch = cuda_env.get_cuda_arch()
|
||||
if arch == "90":
|
||||
# Required by cutlass compilation.
|
||||
@ -2934,10 +2929,10 @@ def _nvcc_compiler_options() -> List[str]:
|
||||
|
||||
|
||||
def cuda_compile_command(
|
||||
src_files: List[str],
|
||||
src_files: list[str],
|
||||
dst_file: str,
|
||||
dst_file_ext: str,
|
||||
extra_args: Optional[List[str]] = None,
|
||||
extra_args: Optional[list[str]] = None,
|
||||
) -> str:
|
||||
if extra_args is None:
|
||||
extra_args = []
|
||||
@ -3052,7 +3047,7 @@ class CUDACodeCache:
|
||||
input_path: str
|
||||
output_path: str
|
||||
|
||||
cache: Dict[str, CacheEntry] = {}
|
||||
cache: dict[str, CacheEntry] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
_SOURCE_CODE_SUFFIX = "cu"
|
||||
|
||||
@ -3073,7 +3068,7 @@ class CUDACodeCache:
|
||||
|
||||
@classmethod
|
||||
def compile(
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Compiles CUDA source_code into a file with dst_file_ext extension.
|
||||
@ -3137,7 +3132,7 @@ class ROCmCodeCache:
|
||||
input_path: str
|
||||
output_path: str
|
||||
|
||||
cache: Dict[str, CacheEntry] = {}
|
||||
cache: dict[str, CacheEntry] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
_SOURCE_CODE_SUFFIX = "cpp"
|
||||
_logged_compiler_version = False
|
||||
@ -3159,7 +3154,7 @@ class ROCmCodeCache:
|
||||
|
||||
@classmethod
|
||||
def compile(
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
|
||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Compiles source_code into a file with dst_file_ext extension,
|
||||
|
@ -7,7 +7,7 @@ import logging
|
||||
import operator
|
||||
import sys
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, TYPE_CHECKING
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
from .scheduler import BaseSchedulerNode
|
||||
|
||||
|
||||
def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Greedily schedules waits as late as possible.
|
||||
"""
|
||||
@ -42,7 +42,7 @@ def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
)
|
||||
|
||||
|
||||
def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Greedily schedules comms as early as possible.
|
||||
"""
|
||||
@ -52,8 +52,8 @@ def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
||||
|
||||
|
||||
def reorder_compute_for_overlap(
|
||||
snodes: List[BaseSchedulerNode],
|
||||
) -> List[BaseSchedulerNode]:
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
This achieves the following overall scheduling procedure:
|
||||
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
||||
@ -71,11 +71,11 @@ def reorder_compute_for_overlap(
|
||||
|
||||
|
||||
def _schedule_for_comm(
|
||||
snodes: List[BaseSchedulerNode],
|
||||
snodes: list[BaseSchedulerNode],
|
||||
raise_comms: bool,
|
||||
sink_waits: bool,
|
||||
reorder_for_overlap: bool,
|
||||
) -> List[BaseSchedulerNode]:
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Schedule `snodes` for various comm optimization objectives.
|
||||
|
||||
@ -149,13 +149,13 @@ def _schedule_for_comm(
|
||||
def __lt__(self, other):
|
||||
return self.score < other.score
|
||||
|
||||
unmet_deps: Dict[BaseSchedulerNode, OrderedSet[str]] = {
|
||||
unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
|
||||
snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
|
||||
for snode in snodes
|
||||
}
|
||||
|
||||
ready: List[Runnable] = []
|
||||
buffer_users: Dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
|
||||
ready: list[Runnable] = []
|
||||
buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
|
||||
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
|
||||
|
||||
for snode, deps in unmet_deps.items():
|
||||
@ -226,8 +226,8 @@ def _schedule_for_comm(
|
||||
|
||||
|
||||
def decide_global_ordering_of_comms(
|
||||
nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
||||
) -> List[BaseSchedulerNode]:
|
||||
nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
||||
) -> list[BaseSchedulerNode]:
|
||||
"""
|
||||
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
||||
(might not be the same ordering as the eager mode program).
|
||||
@ -303,8 +303,8 @@ def visualize_overlap(order):
|
||||
|
||||
|
||||
def reorder_compute_and_comm_for_overlap(
|
||||
snodes: List[BaseSchedulerNode],
|
||||
) -> List[BaseSchedulerNode]:
|
||||
snodes: list[BaseSchedulerNode],
|
||||
) -> list[BaseSchedulerNode]:
|
||||
order = snodes
|
||||
|
||||
for p in config.reorder_for_compute_comm_overlap_passes:
|
||||
@ -653,10 +653,10 @@ def get_op_idx(snode):
|
||||
|
||||
|
||||
def enforce_comm_ordering_for_fsdp(
|
||||
snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
|
||||
name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
||||
) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
|
||||
snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
|
||||
name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
||||
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||
) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
|
||||
from . import scheduler
|
||||
|
||||
new_order: list[BaseSchedulerNode] = []
|
||||
|
@ -16,11 +16,7 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -121,6 +117,8 @@ from .virtualized import V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Sequence
|
||||
|
||||
from torch._inductor.output_code import _StrideExprStr
|
||||
from torch._ops import OpOverload
|
||||
|
||||
@ -157,7 +155,7 @@ static_inputs_log = torch._logging.getArtifactLogger(
|
||||
)
|
||||
|
||||
|
||||
def get_static_input_idxs(num_fixed: int) -> List[int]:
|
||||
def get_static_input_idxs(num_fixed: int) -> list[int]:
|
||||
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
|
||||
# of cudagraphs. Rather than copying these into cudagraph-owned memory
|
||||
# like we do for normal inputs on each run, we will re-record a cudagraph if these
|
||||
@ -208,7 +206,7 @@ def _unlift_graph(
|
||||
) -> GraphModule:
|
||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||
|
||||
state_dict: Dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
|
||||
state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
|
||||
for name, param in mod.named_parameters(remove_duplicate=False):
|
||||
state_dict[name] = param
|
||||
_assign_attr(
|
||||
@ -227,7 +225,7 @@ def _unlift_graph(
|
||||
)
|
||||
|
||||
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
|
||||
lifted_inputs: List[Optional[FQN]] = []
|
||||
lifted_inputs: list[Optional[FQN]] = []
|
||||
|
||||
# In AOTI, module parameters and buffers are not lifted as graph inputs.
|
||||
# As a result, mutation to buffers has side effect which makes their initial
|
||||
@ -343,9 +341,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
|
||||
def split_const_gm(
|
||||
gm: GraphModule,
|
||||
skip_constructor: bool = True,
|
||||
lifted_constant_names: Optional[List[str]] = None,
|
||||
lifted_constant_names: Optional[list[str]] = None,
|
||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> tuple[GraphModule, Dict[str, int]]:
|
||||
) -> tuple[GraphModule, dict[str, int]]:
|
||||
"""
|
||||
This function takes an GraphModule input "gm".
|
||||
The gm will be split into 2 components,
|
||||
@ -488,8 +486,8 @@ def fake_tensor_prop(
|
||||
|
||||
# pass config dict back to user
|
||||
def get_patched_config_dict(
|
||||
config_patches: Optional[Union[str, Dict[str, Any]]] = None
|
||||
) -> Dict[str, Any]:
|
||||
config_patches: Optional[Union[str, dict[str, Any]]] = None
|
||||
) -> dict[str, Any]:
|
||||
with config.patch(config_patches):
|
||||
return config.get_config_copy()
|
||||
|
||||
@ -515,7 +513,7 @@ class _CompileFxKwargs(TypedDict, total=False):
|
||||
aot_mode: bool
|
||||
is_inference: bool
|
||||
layout_opt: Optional[bool]
|
||||
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]]
|
||||
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||
|
||||
|
||||
@ -822,7 +820,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
aot_mode: bool = V.aot_compilation
|
||||
is_inference: bool = graph_kwargs.get("is_inference", False)
|
||||
extern_node_serializer: Optional[
|
||||
Callable[[List[ExternKernelNode]], Any]
|
||||
Callable[[list[ExternKernelNode]], Any]
|
||||
] = graph_kwargs.get("extern_node_serializer", None)
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
|
||||
"boxed_forward_device_index", None
|
||||
@ -997,7 +995,7 @@ class _InProcessFxCompile(FxCompile):
|
||||
metrics_helper = metrics.CachedMetricsHelper()
|
||||
with V.set_graph_handler(graph):
|
||||
graph.run(*example_inputs)
|
||||
output_strides: List[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||
if graph.graph_outputs is not None:
|
||||
# We'll put the output strides in the compiled graph so we
|
||||
# can later return them to the caller via TracingContext
|
||||
@ -1189,7 +1187,7 @@ def cudagraphify(
|
||||
static_input_idxs: Sequence[int] = (),
|
||||
*,
|
||||
device_index: int,
|
||||
stack_traces: List[Optional[str]],
|
||||
stack_traces: list[Optional[str]],
|
||||
is_backward: bool,
|
||||
is_inference: bool,
|
||||
constants: tuple[torch.Tensor, ...] = (),
|
||||
@ -1240,7 +1238,7 @@ def static_input(x: torch.Tensor) -> torch.Tensor:
|
||||
def index_expanded_dims_and_copy_(
|
||||
dst: torch.Tensor,
|
||||
src: torch.Tensor,
|
||||
expanded_dims: List[int],
|
||||
expanded_dims: list[int],
|
||||
) -> None:
|
||||
"Index into expanded dimensions of both dst and src then copy_"
|
||||
dst = index_expanded_dims(dst, expanded_dims)
|
||||
@ -1250,9 +1248,9 @@ def index_expanded_dims_and_copy_(
|
||||
|
||||
def cudagraphify_impl(
|
||||
model: Callable[..., Any],
|
||||
inputs: List[torch.Tensor],
|
||||
inputs: list[torch.Tensor],
|
||||
static_input_idxs: Sequence[int] = (),
|
||||
) -> Callable[[List[InputType]], Any]:
|
||||
) -> Callable[[list[InputType]], Any]:
|
||||
"""
|
||||
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
||||
"""
|
||||
@ -1304,7 +1302,7 @@ def cudagraphify_impl(
|
||||
|
||||
if config.size_asserts:
|
||||
|
||||
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
|
||||
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
|
||||
assert len(static_inputs) == len(new_inputs)
|
||||
for idx, (dst, src, expanded_dims) in enumerate(
|
||||
zip(static_inputs, new_inputs, inps_expanded_dims)
|
||||
@ -1328,7 +1326,7 @@ def cudagraphify_impl(
|
||||
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
||||
]
|
||||
|
||||
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
|
||||
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
|
||||
for idx in copy_indices:
|
||||
expanded_dims = inps_expanded_dims[idx]
|
||||
src = new_inputs[idx]
|
||||
@ -1343,16 +1341,16 @@ def cudagraphify_impl(
|
||||
|
||||
def compile_fx_aot(
|
||||
model_: GraphModule,
|
||||
example_inputs_: List[InputType],
|
||||
example_inputs_: list[InputType],
|
||||
inner_compile: _CompileFxCallable = compile_fx_inner,
|
||||
config_patches: Optional[Dict[str, str]] = None,
|
||||
) -> Union[List[str], str]:
|
||||
config_patches: Optional[dict[str, str]] = None,
|
||||
) -> Union[list[str], str]:
|
||||
assert isinstance(model_, GraphModule), model_
|
||||
|
||||
# [See NOTE] Unwrapping subclasses AOT
|
||||
unwrap_tensor_subclass_parameters(model_)
|
||||
|
||||
config_patches: Dict[str, Any] = (
|
||||
config_patches: dict[str, Any] = (
|
||||
{"cpp_wrapper": True}
|
||||
if config_patches is None
|
||||
else {**config_patches, "cpp_wrapper": True}
|
||||
@ -1409,7 +1407,7 @@ def fw_compiler_freezing(
|
||||
cudagraphs: BoxedBool,
|
||||
graph_id: int,
|
||||
forward_device: BoxedDeviceIndex,
|
||||
) -> Callable[[List[object]], Sequence[torch.Tensor]]:
|
||||
) -> Callable[[list[object]], Sequence[torch.Tensor]]:
|
||||
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
|
||||
|
||||
# partition_fn won't be called
|
||||
@ -1492,7 +1490,7 @@ def fw_compiler_freezing(
|
||||
if V.aot_compilation:
|
||||
return optimized_function
|
||||
|
||||
def wrapper(args: List[object]) -> Sequence[torch.Tensor]:
|
||||
def wrapper(args: list[object]) -> Sequence[torch.Tensor]:
|
||||
args_new = [
|
||||
args[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
|
||||
for i in preserved_arg_indices
|
||||
@ -1505,7 +1503,7 @@ def fw_compiler_freezing(
|
||||
return wrapper
|
||||
|
||||
|
||||
def get_cpp_wrapper_config() -> Dict[str, object]:
|
||||
def get_cpp_wrapper_config() -> dict[str, object]:
|
||||
return {
|
||||
# Set autotune_at_compile_time to True as default if the option is not explicitly set
|
||||
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
|
||||
@ -1551,9 +1549,9 @@ def compile_fx(
|
||||
model_: GraphModule,
|
||||
example_inputs_: Sequence[InputType],
|
||||
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||
config_patches: Optional[Dict[str, Any]] = None,
|
||||
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
|
||||
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]:
|
||||
config_patches: Optional[dict[str, Any]] = None,
|
||||
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
|
||||
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
|
||||
"""
|
||||
Main entry point for compiling given FX graph. Despite the fact that this
|
||||
lives in :mod:`torch._inductor`, this function is responsible for calling
|
||||
@ -2017,11 +2015,11 @@ def _check_triton_bf16_support(graph: GraphLowering) -> None:
|
||||
|
||||
def _aoti_flatten_inputs(
|
||||
gm: torch.fx.GraphModule,
|
||||
args: Union[List[Any], tuple[Any, ...]],
|
||||
kwargs: Optional[Dict[str, Any]] = None,
|
||||
args: Union[list[Any], tuple[Any, ...]],
|
||||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
options: Optional[Dict[str, Any]] = None,
|
||||
) -> tuple[List[Any], Dict[str, Any]]:
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> tuple[list[Any], dict[str, Any]]:
|
||||
"""
|
||||
Flatten the inputs to the graph module and return the flat inputs and options.
|
||||
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
|
||||
|
@ -5,7 +5,7 @@ import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from typing import Type, TypeVar
|
||||
from typing import TypeVar
|
||||
|
||||
from torch._inductor.async_compile import pre_fork_setup
|
||||
from torch._inductor.compile_worker.subproc_pool import (
|
||||
@ -32,7 +32,7 @@ except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
def _lookup_and_create_type(base: Type[_T], qname: str) -> _T:
|
||||
def _lookup_and_create_type(base: type[_T], qname: str) -> _T:
|
||||
"""
|
||||
Given a base type and qualified name: import & lookup that name, check
|
||||
that it's of the given type and then instantiate it.
|
||||
|
@ -13,7 +13,7 @@ import typing
|
||||
from concurrent.futures import Future, ProcessPoolExecutor
|
||||
from concurrent.futures.process import BrokenProcessPool
|
||||
from enum import Enum
|
||||
from typing import Any, BinaryIO, Callable, Dict, Optional, TypeVar
|
||||
from typing import Any, BinaryIO, Callable, Optional, TypeVar
|
||||
from typing_extensions import Never, ParamSpec
|
||||
|
||||
# _thread_safe_fork is needed because the subprocesses in the pool can read
|
||||
@ -158,7 +158,7 @@ class SubprocPool:
|
||||
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
|
||||
|
||||
self.futures_lock = threading.Lock()
|
||||
self.pending_futures: Dict[int, Future[Any]] = {}
|
||||
self.pending_futures: dict[int, Future[Any]] = {}
|
||||
self.job_id_count = itertools.count()
|
||||
|
||||
self.running = True
|
||||
|
@ -7,7 +7,7 @@ import shutil
|
||||
import sys
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, List, Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
|
||||
@ -43,7 +43,7 @@ class ConfigChange(BinarySubsystem):
|
||||
|
||||
|
||||
# Dictionary of backend -> subsystems
|
||||
BACKENDS: Dict[str, List[Subsystem]] = {
|
||||
BACKENDS: dict[str, list[Subsystem]] = {
|
||||
# run dynamo without aot_autograd
|
||||
"eager": [],
|
||||
# run dynamo with aot_autograd, but no partitioner or decomps
|
||||
@ -68,8 +68,8 @@ BACKENDS: Dict[str, List[Subsystem]] = {
|
||||
], # TODO - add more - fusions ?
|
||||
}
|
||||
|
||||
subsystem_call_counter: Dict[str, int] = collections.Counter()
|
||||
call_counter_debug_info: Dict[int, str] = {}
|
||||
subsystem_call_counter: dict[str, int] = collections.Counter()
|
||||
call_counter_debug_info: dict[int, str] = {}
|
||||
|
||||
|
||||
def reset_counters() -> None:
|
||||
@ -123,13 +123,13 @@ class CompilerBisector:
|
||||
return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}"
|
||||
|
||||
@classmethod
|
||||
def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None:
|
||||
def write_lines_to_file(cls, file_path: str, lines: list[str]) -> None:
|
||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||
with open(file_path, "w") as file:
|
||||
file.writelines(lines)
|
||||
|
||||
@classmethod
|
||||
def read_lines_from_file(cls, file_path: str) -> List[str]:
|
||||
def read_lines_from_file(cls, file_path: str) -> list[str]:
|
||||
if os.path.exists(file_path):
|
||||
with open(file_path) as file:
|
||||
return file.readlines()
|
||||
@ -154,7 +154,7 @@ class CompilerBisector:
|
||||
|
||||
@classmethod
|
||||
def set_config_values(
|
||||
cls, backend: str, subsystem: str, config_data: Dict[str, object]
|
||||
cls, backend: str, subsystem: str, config_data: dict[str, object]
|
||||
) -> None:
|
||||
file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt")
|
||||
lines = [f"{k}={v}\n" for k, v in config_data.items()]
|
||||
@ -267,7 +267,7 @@ class CompilerBisector:
|
||||
cls.write_lines_to_file(file_path, lines)
|
||||
|
||||
@classmethod
|
||||
def get_config_change(cls, config_name: str) -> Optional[Dict[str, object]]:
|
||||
def get_config_change(cls, config_name: str) -> Optional[dict[str, object]]:
|
||||
backend = cls.get_backend()
|
||||
subsystem = cls.get_subsystem()
|
||||
|
||||
|
@ -1,16 +1,6 @@
|
||||
import os # noqa: C101
|
||||
import sys
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
import torch._inductor.custom_graph_pass
|
||||
@ -193,8 +183,8 @@ pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
||||
# hence custom IR passes built on top of it might break in the future.
|
||||
_pre_fusion_custom_pass: Optional[
|
||||
Callable[
|
||||
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||
List["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||
list["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||
]
|
||||
] = None
|
||||
|
||||
@ -231,11 +221,11 @@ batch_fusion = True
|
||||
# merge_splits_pass
|
||||
# mutate_cat_pass
|
||||
# split_cat_pass
|
||||
pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
|
||||
pre_grad_fusion_options: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# Post grad fusion and options, set to empty dict to disable fusion.
|
||||
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
|
||||
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
|
||||
post_grad_fusion_options: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# enable reordering pass for improving memory locality
|
||||
reorder_for_locality = True
|
||||
@ -257,7 +247,7 @@ use_mixed_mm = True
|
||||
# floating point numbers,about 16 decimal digits for double precision floating point numbers)
|
||||
# according to PyTorch documentation.
|
||||
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
|
||||
fx_passes_numeric_check: Dict[str, Any] = {
|
||||
fx_passes_numeric_check: dict[str, Any] = {
|
||||
"pre_grad": False,
|
||||
"precision": 1e-4,
|
||||
"num_iterations": 1,
|
||||
@ -287,12 +277,12 @@ reorder_for_compute_comm_overlap = False
|
||||
# for built-in passes, use string name; for user-defined passes, pass in the function handle
|
||||
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
|
||||
# hence custom IR passes built on top of it might break in the future.
|
||||
reorder_for_compute_comm_overlap_passes: List[
|
||||
reorder_for_compute_comm_overlap_passes: list[
|
||||
Union[
|
||||
str,
|
||||
Callable[
|
||||
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||
List["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||
list["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||
],
|
||||
]
|
||||
] = [
|
||||
@ -618,7 +608,7 @@ _fuse_ddp_bucket_size = 25
|
||||
# overlapping. At this moment, this pass performs better than
|
||||
# reorder_for_compute_comm_overlap_passes but we will add the logic of
|
||||
# "schedule_comm_wait" in the future and remove the one here.
|
||||
_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
|
||||
_fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
|
||||
"fuse_ddp_with_concat_op",
|
||||
"schedule_comm_wait",
|
||||
]
|
||||
@ -852,7 +842,7 @@ class cpp:
|
||||
simdlen: Optional[int] = None
|
||||
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
|
||||
|
||||
cxx: Tuple[None, str] = (
|
||||
cxx: tuple[Literal[None], str] = (
|
||||
None, # download gcc12 from conda-forge if conda is installed
|
||||
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
|
||||
) # type: ignore[assignment]
|
||||
@ -1157,7 +1147,7 @@ class aot_inductor:
|
||||
|
||||
# Dictionary of metadata users might want to save to pass to the runtime.
|
||||
# TODO: Move this somewhere else, since it's no longer really a config
|
||||
metadata: Dict[str, str] = {}
|
||||
metadata: dict[str, str] = {}
|
||||
|
||||
# fbcode only. Whether to raise error if C++ codegen is too big to optimize
|
||||
raise_error_on_ignored_optimization: bool = (
|
||||
@ -1168,7 +1158,7 @@ class aot_inductor:
|
||||
dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
|
||||
|
||||
# Dictionary of presets that can be passed in
|
||||
presets: Dict[str, Any] = {}
|
||||
presets: dict[str, Any] = {}
|
||||
|
||||
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
|
||||
# should be run with this flag both on and off to make sure we have coverage.
|
||||
@ -1265,11 +1255,11 @@ class cuda:
|
||||
class rocm:
|
||||
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
|
||||
# If empty, the `native` arch is used
|
||||
arch: List[str] = []
|
||||
arch: list[str] = []
|
||||
|
||||
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
|
||||
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
||||
ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
||||
ck_supported_arch: list[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
||||
|
||||
# Optimization level, use to balance compilation speed and runtime performance.
|
||||
# The type will not necessarily be comprehensive and won't be enforced at runtime.
|
||||
@ -1415,7 +1405,7 @@ class trace:
|
||||
log_inductor_triton_kernel_to_post_grad_node_info: bool = True
|
||||
|
||||
|
||||
_save_config_ignore: List[str] = [
|
||||
_save_config_ignore: list[str] = [
|
||||
# workaround: "Can't pickle <function ...>"
|
||||
"trace.upload_tar",
|
||||
"joint_custom_pre_pass",
|
||||
@ -1423,7 +1413,7 @@ _save_config_ignore: List[str] = [
|
||||
"pre_grad_custom_pass",
|
||||
]
|
||||
|
||||
_cache_config_ignore_prefix: List[str] = [
|
||||
_cache_config_ignore_prefix: list[str] = [
|
||||
# trace functions are not relevant to config caching
|
||||
"trace",
|
||||
# uses absolute path
|
||||
@ -1439,7 +1429,7 @@ _cache_config_ignore_prefix: List[str] = [
|
||||
]
|
||||
|
||||
# External callable for matmul tuning candidates
|
||||
external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
|
||||
external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
|
||||
|
||||
|
||||
class test_configs:
|
||||
|
@ -1,5 +1,5 @@
|
||||
import collections
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -57,7 +57,7 @@ def replace_node_with_constant(
|
||||
|
||||
|
||||
def is_const_source(
|
||||
node: torch.fx.Node, lifted_constant_names: Optional[List[str]]
|
||||
node: torch.fx.Node, lifted_constant_names: Optional[list[str]]
|
||||
) -> bool:
|
||||
return node.op == "get_attr" or node.name in (lifted_constant_names or ())
|
||||
|
||||
@ -67,12 +67,12 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
skip_constructors: bool = False,
|
||||
lifted_constant_names: Optional[List[str]] = None,
|
||||
lifted_constant_names: Optional[list[str]] = None,
|
||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> None:
|
||||
super().__init__(gm)
|
||||
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
||||
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
||||
self.node_replacements: dict[torch.fx.Node, Any] = {}
|
||||
self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
|
||||
self.unknown_value = object()
|
||||
self.skip_constructors: bool = skip_constructors
|
||||
|
||||
@ -141,7 +141,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
return True
|
||||
return False
|
||||
|
||||
def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
|
||||
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:
|
||||
last_non_output_use = collections.defaultdict(list)
|
||||
seen_uses = OrderedSet[torch.fx.Node]()
|
||||
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
|
||||
@ -264,11 +264,11 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||
self.node_replacements[node] = tensor
|
||||
|
||||
def run(self) -> Any: # type: ignore[override]
|
||||
env: Dict[torch.fx.Node, Any] = {}
|
||||
env: dict[torch.fx.Node, Any] = {}
|
||||
self.insert_placerholder_values(env)
|
||||
return super().run(initial_env=env)
|
||||
|
||||
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
|
||||
def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None:
|
||||
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||
env[n] = self.unknown_value # type: ignore[assignment]
|
||||
if self.lifted_constant_names is None:
|
||||
@ -309,7 +309,7 @@ def constant_fold(
|
||||
def constant_graph_tag(
|
||||
gm: torch.fx.GraphModule,
|
||||
skip_constructors: bool = True,
|
||||
lifted_constant_names: Optional[List[str]] = None,
|
||||
lifted_constant_names: Optional[list[str]] = None,
|
||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> None:
|
||||
with torch.utils._python_dispatch._disable_current_modes():
|
||||
@ -337,7 +337,7 @@ def constant_graph_tag(
|
||||
def run_and_get_constant_graph(
|
||||
gm: torch.fx.GraphModule,
|
||||
skip_constructors: bool = True,
|
||||
lifted_constant_names: Optional[List[str]] = None,
|
||||
lifted_constant_names: Optional[list[str]] = None,
|
||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||
) -> torch.fx.GraphModule:
|
||||
"""
|
||||
@ -367,7 +367,7 @@ def run_and_get_constant_graph(
|
||||
|
||||
new_graph = torch.fx.Graph()
|
||||
|
||||
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||
output_nodes = []
|
||||
for node in gm.graph.nodes:
|
||||
if node.meta[META_TAG] == MODULE_TAG:
|
||||
|
@ -16,10 +16,11 @@ import sys
|
||||
import sysconfig
|
||||
import textwrap
|
||||
import warnings
|
||||
from collections.abc import Sequence
|
||||
from ctypes import cdll
|
||||
from ctypes.util import find_library
|
||||
from pathlib import Path
|
||||
from typing import Any, List, Optional, Sequence, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
@ -285,12 +286,12 @@ def get_compiler_version_info(compiler: str) -> str:
|
||||
|
||||
|
||||
# =============================== cpp builder ===============================
|
||||
def _append_list(dest_list: List[str], src_list: List[str]) -> None:
|
||||
def _append_list(dest_list: list[str], src_list: list[str]) -> None:
|
||||
dest_list.extend(copy.deepcopy(item) for item in src_list)
|
||||
|
||||
|
||||
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
|
||||
new_list: List[str] = []
|
||||
def _remove_duplication_in_list(orig_list: list[str]) -> list[str]:
|
||||
new_list: list[str] = []
|
||||
for item in orig_list:
|
||||
if item not in new_list:
|
||||
new_list.append(item)
|
||||
@ -362,26 +363,26 @@ class BuildOptionsBase:
|
||||
def __init__(
|
||||
self,
|
||||
compiler: str = "",
|
||||
definitions: Optional[List[str]] = None,
|
||||
include_dirs: Optional[List[str]] = None,
|
||||
cflags: Optional[List[str]] = None,
|
||||
ldflags: Optional[List[str]] = None,
|
||||
libraries_dirs: Optional[List[str]] = None,
|
||||
libraries: Optional[List[str]] = None,
|
||||
passthrough_args: Optional[List[str]] = None,
|
||||
definitions: Optional[list[str]] = None,
|
||||
include_dirs: Optional[list[str]] = None,
|
||||
cflags: Optional[list[str]] = None,
|
||||
ldflags: Optional[list[str]] = None,
|
||||
libraries_dirs: Optional[list[str]] = None,
|
||||
libraries: Optional[list[str]] = None,
|
||||
passthrough_args: Optional[list[str]] = None,
|
||||
aot_mode: bool = False,
|
||||
use_absolute_path: bool = False,
|
||||
compile_only: bool = False,
|
||||
) -> None:
|
||||
self._compiler = compiler
|
||||
self._definations: List[str] = definitions or []
|
||||
self._include_dirs: List[str] = include_dirs or []
|
||||
self._cflags: List[str] = cflags or []
|
||||
self._ldflags: List[str] = ldflags or []
|
||||
self._libraries_dirs: List[str] = libraries_dirs or []
|
||||
self._libraries: List[str] = libraries or []
|
||||
self._definations: list[str] = definitions or []
|
||||
self._include_dirs: list[str] = include_dirs or []
|
||||
self._cflags: list[str] = cflags or []
|
||||
self._ldflags: list[str] = ldflags or []
|
||||
self._libraries_dirs: list[str] = libraries_dirs or []
|
||||
self._libraries: list[str] = libraries or []
|
||||
# Some args is hard to abstract to OS compatable, passthrough it directly.
|
||||
self._passthrough_args: List[str] = passthrough_args or []
|
||||
self._passthrough_args: list[str] = passthrough_args or []
|
||||
|
||||
self._aot_mode: bool = aot_mode
|
||||
self._use_absolute_path: bool = use_absolute_path
|
||||
@ -408,25 +409,25 @@ class BuildOptionsBase:
|
||||
def get_compiler(self) -> str:
|
||||
return self._compiler
|
||||
|
||||
def get_definations(self) -> List[str]:
|
||||
def get_definations(self) -> list[str]:
|
||||
return self._definations
|
||||
|
||||
def get_include_dirs(self) -> List[str]:
|
||||
def get_include_dirs(self) -> list[str]:
|
||||
return self._include_dirs
|
||||
|
||||
def get_cflags(self) -> List[str]:
|
||||
def get_cflags(self) -> list[str]:
|
||||
return self._cflags
|
||||
|
||||
def get_ldflags(self) -> List[str]:
|
||||
def get_ldflags(self) -> list[str]:
|
||||
return self._ldflags
|
||||
|
||||
def get_libraries_dirs(self) -> List[str]:
|
||||
def get_libraries_dirs(self) -> list[str]:
|
||||
return self._libraries_dirs
|
||||
|
||||
def get_libraries(self) -> List[str]:
|
||||
def get_libraries(self) -> list[str]:
|
||||
return self._libraries
|
||||
|
||||
def get_passthrough_args(self) -> List[str]:
|
||||
def get_passthrough_args(self) -> list[str]:
|
||||
return self._passthrough_args
|
||||
|
||||
def get_aot_mode(self) -> bool:
|
||||
@ -457,14 +458,14 @@ class BuildOptionsBase:
|
||||
json.dump(attrs, f)
|
||||
|
||||
|
||||
def _get_warning_all_cflag(warning_all: bool = True) -> List[str]:
|
||||
def _get_warning_all_cflag(warning_all: bool = True) -> list[str]:
|
||||
if not _IS_WINDOWS:
|
||||
return ["Wall"] if warning_all else []
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
|
||||
def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
"""
|
||||
On Windows, only c++20 can support `std::enable_if_t`.
|
||||
@ -479,7 +480,7 @@ def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
|
||||
return [f"std={std_num}"]
|
||||
|
||||
|
||||
def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
|
||||
def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
cflags = [
|
||||
"wd4819",
|
||||
@ -506,7 +507,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
|
||||
return cflags
|
||||
|
||||
|
||||
def _get_ffast_math_flags() -> List[str]:
|
||||
def _get_ffast_math_flags() -> list[str]:
|
||||
# ffast-math is equivalent to these flags as in
|
||||
# https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468
|
||||
# however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have
|
||||
@ -527,7 +528,7 @@ def _get_ffast_math_flags() -> List[str]:
|
||||
return flags
|
||||
|
||||
|
||||
def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
|
||||
def _get_optimization_cflags(cpp_compiler: str) -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
return ["O2"]
|
||||
else:
|
||||
@ -554,7 +555,7 @@ def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
|
||||
return cflags
|
||||
|
||||
|
||||
def _get_shared_cflag(compile_only: bool) -> List[str]:
|
||||
def _get_shared_cflag(compile_only: bool) -> list[str]:
|
||||
if _IS_WINDOWS:
|
||||
"""
|
||||
MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
|
||||
@ -578,14 +579,14 @@ def get_cpp_options(
|
||||
compile_only: bool,
|
||||
warning_all: bool = True,
|
||||
extra_flags: Sequence[str] = (),
|
||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
|
||||
definations: List[str] = []
|
||||
include_dirs: List[str] = []
|
||||
cflags: List[str] = []
|
||||
ldflags: List[str] = []
|
||||
libraries_dirs: List[str] = []
|
||||
libraries: List[str] = []
|
||||
passthrough_args: List[str] = []
|
||||
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||
definations: list[str] = []
|
||||
include_dirs: list[str] = []
|
||||
cflags: list[str] = []
|
||||
ldflags: list[str] = []
|
||||
libraries_dirs: list[str] = []
|
||||
libraries: list[str] = []
|
||||
passthrough_args: list[str] = []
|
||||
|
||||
cflags = (
|
||||
_get_shared_cflag(compile_only)
|
||||
@ -657,22 +658,22 @@ class CppOptions(BuildOptionsBase):
|
||||
self._finalize_options()
|
||||
|
||||
|
||||
def _get_glibcxx_abi_build_flags() -> List[str]:
|
||||
def _get_glibcxx_abi_build_flags() -> list[str]:
|
||||
if not _IS_WINDOWS:
|
||||
return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
def _get_torch_cpp_wrapper_defination() -> List[str]:
|
||||
def _get_torch_cpp_wrapper_defination() -> list[str]:
|
||||
return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"]
|
||||
|
||||
|
||||
def _use_custom_generated_macros() -> List[str]:
|
||||
def _use_custom_generated_macros() -> list[str]:
|
||||
return [" C10_USING_CUSTOM_GENERATED_MACROS"]
|
||||
|
||||
|
||||
def _use_fb_internal_macros() -> List[str]:
|
||||
def _use_fb_internal_macros() -> list[str]:
|
||||
if not _IS_WINDOWS:
|
||||
if config.is_fbcode():
|
||||
fb_internal_macros = [
|
||||
@ -697,12 +698,12 @@ def _setup_standard_sys_libs(
|
||||
cpp_compiler: str,
|
||||
aot_mode: bool,
|
||||
use_absolute_path: bool,
|
||||
) -> tuple[List[str], List[str], List[str]]:
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
from torch._inductor.codecache import _LINKER_SCRIPT
|
||||
|
||||
cflags: List[str] = []
|
||||
include_dirs: List[str] = []
|
||||
passthrough_args: List[str] = []
|
||||
cflags: list[str] = []
|
||||
include_dirs: list[str] = []
|
||||
passthrough_args: list[str] = []
|
||||
if _IS_WINDOWS:
|
||||
return cflags, include_dirs, passthrough_args
|
||||
|
||||
@ -737,9 +738,9 @@ def _setup_standard_sys_libs(
|
||||
return cflags, include_dirs, passthrough_args
|
||||
|
||||
|
||||
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]]:
|
||||
macros: List[str] = []
|
||||
build_flags: List[str] = []
|
||||
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]:
|
||||
macros: list[str] = []
|
||||
build_flags: list[str] = []
|
||||
if vec_isa != invalid_vec_isa:
|
||||
# Add Windows support later.
|
||||
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
|
||||
@ -759,7 +760,7 @@ def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]
|
||||
|
||||
def _get_torch_related_args(
|
||||
include_pytorch: bool, aot_mode: bool
|
||||
) -> tuple[List[str], List[str], List[str]]:
|
||||
) -> tuple[list[str], list[str], list[str]]:
|
||||
from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH
|
||||
|
||||
include_dirs = [
|
||||
@ -783,7 +784,7 @@ def _get_torch_related_args(
|
||||
return include_dirs, libraries_dirs, libraries
|
||||
|
||||
|
||||
def _get_python_include_dirs() -> List[str]:
|
||||
def _get_python_include_dirs() -> list[str]:
|
||||
include_dir = Path(sysconfig.get_path("include"))
|
||||
# On Darwin Python executable from a framework can return
|
||||
# non-existing /Library/Python/... include path, in which case
|
||||
@ -796,7 +797,7 @@ def _get_python_include_dirs() -> List[str]:
|
||||
return [str(include_dir)]
|
||||
|
||||
|
||||
def _get_python_related_args() -> tuple[List[str], List[str]]:
|
||||
def _get_python_related_args() -> tuple[list[str], list[str]]:
|
||||
python_include_dirs = _get_python_include_dirs()
|
||||
python_include_path = sysconfig.get_path(
|
||||
"include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
|
||||
@ -893,13 +894,13 @@ def perload_icx_libomp_win(cpp_compiler: str) -> None:
|
||||
|
||||
def _get_openmp_args(
|
||||
cpp_compiler: str,
|
||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str]]:
|
||||
cflags: List[str] = []
|
||||
ldflags: List[str] = []
|
||||
include_dir_paths: List[str] = []
|
||||
lib_dir_paths: List[str] = []
|
||||
libs: List[str] = []
|
||||
passthrough_args: List[str] = []
|
||||
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||
cflags: list[str] = []
|
||||
ldflags: list[str] = []
|
||||
include_dir_paths: list[str] = []
|
||||
lib_dir_paths: list[str] = []
|
||||
libs: list[str] = []
|
||||
passthrough_args: list[str] = []
|
||||
if _IS_MACOS:
|
||||
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
|
||||
cflags.append("Xclang")
|
||||
@ -998,7 +999,7 @@ def _get_openmp_args(
|
||||
return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args
|
||||
|
||||
|
||||
def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]:
|
||||
def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]:
|
||||
macros = []
|
||||
if use_mmap_weights:
|
||||
macros.append(" USE_MMAP_SELF")
|
||||
@ -1013,14 +1014,14 @@ def get_cpp_torch_options(
|
||||
compile_only: bool,
|
||||
use_absolute_path: bool,
|
||||
use_mmap_weights: bool,
|
||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
|
||||
definations: List[str] = []
|
||||
include_dirs: List[str] = []
|
||||
cflags: List[str] = []
|
||||
ldflags: List[str] = []
|
||||
libraries_dirs: List[str] = []
|
||||
libraries: List[str] = []
|
||||
passthrough_args: List[str] = []
|
||||
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||
definations: list[str] = []
|
||||
include_dirs: list[str] = []
|
||||
cflags: list[str] = []
|
||||
ldflags: list[str] = []
|
||||
libraries_dirs: list[str] = []
|
||||
libraries: list[str] = []
|
||||
passthrough_args: list[str] = []
|
||||
|
||||
torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination()
|
||||
use_custom_generated_macros_definations = _use_custom_generated_macros()
|
||||
@ -1163,7 +1164,7 @@ def _set_gpu_runtime_env() -> None:
|
||||
os.environ["CUDA_HOME"] = build_paths.sdk_home
|
||||
|
||||
|
||||
def _transform_cuda_paths(lpaths: List[str]) -> None:
|
||||
def _transform_cuda_paths(lpaths: list[str]) -> None:
|
||||
# This handles two cases:
|
||||
# 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs
|
||||
# 2. Linux machines may have CUDA installed under either lib64/ or lib/
|
||||
@ -1186,14 +1187,14 @@ def get_cpp_torch_device_options(
|
||||
device_type: str,
|
||||
aot_mode: bool = False,
|
||||
compile_only: bool = False,
|
||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
|
||||
definations: List[str] = []
|
||||
include_dirs: List[str] = []
|
||||
cflags: List[str] = []
|
||||
ldflags: List[str] = []
|
||||
libraries_dirs: List[str] = []
|
||||
libraries: List[str] = []
|
||||
passthrough_args: List[str] = []
|
||||
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||
definations: list[str] = []
|
||||
include_dirs: list[str] = []
|
||||
cflags: list[str] = []
|
||||
ldflags: list[str] = []
|
||||
libraries_dirs: list[str] = []
|
||||
libraries: list[str] = []
|
||||
passthrough_args: list[str] = []
|
||||
if (
|
||||
config.is_fbcode()
|
||||
and "CUDA_HOME" not in os.environ
|
||||
@ -1287,13 +1288,13 @@ class CppTorchDeviceOptions(CppTorchOptions):
|
||||
extra_flags=extra_flags,
|
||||
)
|
||||
|
||||
device_definations: List[str] = []
|
||||
device_include_dirs: List[str] = []
|
||||
device_cflags: List[str] = []
|
||||
device_ldflags: List[str] = []
|
||||
device_libraries_dirs: List[str] = []
|
||||
device_libraries: List[str] = []
|
||||
device_passthrough_args: List[str] = []
|
||||
device_definations: list[str] = []
|
||||
device_include_dirs: list[str] = []
|
||||
device_cflags: list[str] = []
|
||||
device_ldflags: list[str] = []
|
||||
device_libraries_dirs: list[str] = []
|
||||
device_libraries: list[str] = []
|
||||
device_passthrough_args: list[str] = []
|
||||
|
||||
(
|
||||
device_definations,
|
||||
@ -1379,7 +1380,7 @@ class CppBuilder:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
sources: Union[str, List[str]],
|
||||
sources: Union[str, list[str]],
|
||||
BuildOption: BuildOptionsBase,
|
||||
output_dir: str = "",
|
||||
) -> None:
|
||||
|
@ -7,7 +7,7 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Union
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
import torch
|
||||
from torch._inductor import config
|
||||
@ -33,9 +33,9 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
||||
|
||||
class VecISA:
|
||||
_bit_width: int
|
||||
_macro: List[str]
|
||||
_macro: list[str]
|
||||
_arch_flags: str
|
||||
_dtype_nelements: Dict[torch.dtype, int]
|
||||
_dtype_nelements: dict[torch.dtype, int]
|
||||
|
||||
# Note [Checking for Vectorized Support in Inductor]
|
||||
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
||||
@ -79,7 +79,7 @@ cdll.LoadLibrary("__lib_path__")
|
||||
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
||||
return self._dtype_nelements[dtype]
|
||||
|
||||
def build_macro(self) -> List[str]:
|
||||
def build_macro(self) -> list[str]:
|
||||
return self._macro
|
||||
|
||||
def build_arch_flags(self) -> str:
|
||||
@ -300,11 +300,11 @@ class InvalidVecISA(VecISA):
|
||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||
|
||||
|
||||
def x86_isa_checker() -> List[str]:
|
||||
supported_isa: List[str] = []
|
||||
def x86_isa_checker() -> list[str]:
|
||||
supported_isa: list[str] = []
|
||||
|
||||
def _check_and_append_supported_isa(
|
||||
dest: List[str], isa_supported: bool, isa_name: str
|
||||
dest: list[str], isa_supported: bool, isa_name: str
|
||||
) -> None:
|
||||
if isa_supported:
|
||||
dest.append(isa_name)
|
||||
@ -333,7 +333,7 @@ supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
|
||||
|
||||
def get_isa_from_cpu_capability(
|
||||
capability: Union[str, None],
|
||||
vec_isa_list: List[VecISA],
|
||||
vec_isa_list: list[VecISA],
|
||||
invalid_vec_isa: InvalidVecISA,
|
||||
):
|
||||
# AMX setting is not supported in eager
|
||||
@ -364,8 +364,8 @@ def get_isa_from_cpu_capability(
|
||||
# might have too much redundant content that is useless for ISA check. Hence,
|
||||
# we only cache some key isa information.
|
||||
@functools.lru_cache(None)
|
||||
def valid_vec_isa_list() -> List[VecISA]:
|
||||
isa_list: List[VecISA] = []
|
||||
def valid_vec_isa_list() -> list[VecISA]:
|
||||
isa_list: list[VecISA] = []
|
||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||
isa_list.append(VecNEON())
|
||||
|
||||
@ -411,7 +411,7 @@ def pick_vec_isa() -> VecISA:
|
||||
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
|
||||
return VecAVX2()
|
||||
|
||||
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
||||
_valid_vec_isa_list: list[VecISA] = valid_vec_isa_list()
|
||||
if not _valid_vec_isa_list:
|
||||
return invalid_vec_isa
|
||||
|
||||
|
@ -54,13 +54,7 @@ from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -99,6 +93,8 @@ from torch.utils.weak import TensorWeakRef
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator, Iterator, Sequence
|
||||
|
||||
from torch._inductor.utils import InputType
|
||||
from torch.types import _bool
|
||||
|
||||
@ -357,12 +353,12 @@ def get_manager(
|
||||
|
||||
def cudagraphify_impl(
|
||||
model: ModelType,
|
||||
inputs: List[InputType],
|
||||
inputs: list[InputType],
|
||||
static_input_idxs: Sequence[int],
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ModelType:
|
||||
fn_cache: Dict[tuple[int, ...], Callable[..., Any]] = {}
|
||||
fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {}
|
||||
|
||||
# Detect int inputs: we need to index on these
|
||||
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
|
||||
@ -372,7 +368,7 @@ def cudagraphify_impl(
|
||||
|
||||
del inputs
|
||||
|
||||
def deferred_cudagraphify(inputs: List[InputType]) -> OutputType:
|
||||
def deferred_cudagraphify(inputs: list[InputType]) -> OutputType:
|
||||
nonlocal has_warn
|
||||
|
||||
int_key = get_ints(inputs)
|
||||
@ -405,7 +401,7 @@ def cudagraphify_impl(
|
||||
|
||||
def cudagraphify(
|
||||
model: ModelType,
|
||||
inputs: List[InputType],
|
||||
inputs: list[InputType],
|
||||
static_input_idxs: Sequence[int] = (),
|
||||
*,
|
||||
device_index: int,
|
||||
@ -466,7 +462,7 @@ class StorageWeakRefWrapper:
|
||||
|
||||
@classmethod
|
||||
def from_weakref_and_data_ptr(
|
||||
cls: Type[S],
|
||||
cls: type[S],
|
||||
cdata: Any,
|
||||
data_ptr: int,
|
||||
extra_ref_check: Optional[Callable[[], bool]] = None,
|
||||
@ -561,9 +557,9 @@ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
|
||||
PathOutputIndex = tuple[int, int]
|
||||
|
||||
# For each node in the path, for each output, is the output alive
|
||||
PathLiveness = List[List[bool]]
|
||||
PathLiveness = list[list[bool]]
|
||||
|
||||
StackTraces = List[Optional[str]]
|
||||
StackTraces = list[Optional[str]]
|
||||
|
||||
|
||||
class CUDAWarmupNode:
|
||||
@ -600,8 +596,8 @@ class CUDAWarmupNode:
|
||||
self.wrapped_function = wrapped_function
|
||||
self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
|
||||
self.cuda_graphs_pool = cuda_graphs_pool
|
||||
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
|
||||
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
|
||||
self.outputs_weakrefs: list[Optional[StorageWeakRefWrapper]] = []
|
||||
self.tensor_weakrefs: list[Optional[TensorWeakRef]] = []
|
||||
self.existing_cuda_graph = existing_cuda_graph
|
||||
self.has_run = False
|
||||
self.device_index = device_index
|
||||
@ -619,7 +615,7 @@ class CUDAWarmupNode:
|
||||
[t.data_ptr() for t in self.path_live_weakrefs() if t()]
|
||||
)
|
||||
|
||||
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
|
||||
def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]:
|
||||
non_cudagraph_inps = [
|
||||
weakref.ref(t.untyped_storage())
|
||||
for t in itertools.chain(new_inputs, self.wrapped_function.constants)
|
||||
@ -707,9 +703,9 @@ class CUDAWarmupNode:
|
||||
|
||||
|
||||
# Aliases for List that say what the indices denote
|
||||
InputList = List # input indexes
|
||||
OutputList = List # output indexes
|
||||
LevelList = List # levels (distance from root of tree)
|
||||
InputList = list # input indexes
|
||||
OutputList = list # output indexes
|
||||
LevelList = list # levels (distance from root of tree)
|
||||
|
||||
|
||||
class OutputAliasInfo:
|
||||
@ -772,7 +768,7 @@ class CUDAGraphNode:
|
||||
wrapped_function: WrappedFunction,
|
||||
id: GraphID,
|
||||
parent: Optional[CUDAGraphNode],
|
||||
inputs: List[InputType],
|
||||
inputs: list[InputType],
|
||||
cuda_graphs_pool: tuple[int, int],
|
||||
device_index: int,
|
||||
stack_traces: Optional[StackTraces],
|
||||
@ -800,7 +796,7 @@ class CUDAGraphNode:
|
||||
|
||||
# A single wrapped function may be recorded multiple times if memory patterns or
|
||||
# invariants change from one execution to the next
|
||||
self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
|
||||
self.children: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
|
||||
|
||||
# StorageWeakRef maintains whether the Storage C++ object remains allocated,
|
||||
# not whether the corresponding memory has been deallocated. In order
|
||||
@ -825,7 +821,7 @@ class CUDAGraphNode:
|
||||
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
|
||||
|
||||
# tensors which are outputs of previous graphs in the tree
|
||||
self.cudagraph_managed_idxs: List[int] = [
|
||||
self.cudagraph_managed_idxs: list[int] = [
|
||||
idx
|
||||
for idx, t in enumerate(inputs)
|
||||
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
|
||||
@ -845,7 +841,7 @@ class CUDAGraphNode:
|
||||
# and also aliases an output of the current CUDAGraphNode
|
||||
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
|
||||
|
||||
self.static_input_idxs: List[int] = list(
|
||||
self.static_input_idxs: list[int] = list(
|
||||
OrderedSet(wrapped_function.static_input_idxs)
|
||||
| OrderedSet(self.cudagraph_managed_idxs)
|
||||
)
|
||||
@ -866,8 +862,8 @@ class CUDAGraphNode:
|
||||
|
||||
def maybe_get_static_data_ptr(
|
||||
idx: int,
|
||||
inputs: List[InputType],
|
||||
static_input_idxs: List[int],
|
||||
inputs: list[InputType],
|
||||
static_input_idxs: list[int],
|
||||
) -> Optional[int]:
|
||||
inp = inputs[idx]
|
||||
if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
|
||||
@ -888,7 +884,7 @@ class CUDAGraphNode:
|
||||
# fresh allocations.
|
||||
|
||||
# precompute expanded dims to avoid computing in the hot path
|
||||
self.expanded_dims: List[List[int]] = [
|
||||
self.expanded_dims: list[list[int]] = [
|
||||
get_expanded_dims(x)
|
||||
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
|
||||
else []
|
||||
@ -903,11 +899,11 @@ class CUDAGraphNode:
|
||||
# List of Tuples of (depth, output_index) that index into node at depth
|
||||
# number of nodes from root and output_index of outputs. Will index into
|
||||
# path_weakrefs.
|
||||
self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
|
||||
self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
|
||||
self.expected_dead_indices_before_graph: list[PathOutputIndex] = []
|
||||
self.expected_dead_indices_after_graph: list[PathOutputIndex] = []
|
||||
|
||||
# all live indices after graph recording
|
||||
self.live_indices_after_graph: List[PathOutputIndex] = []
|
||||
self.live_indices_after_graph: list[PathOutputIndex] = []
|
||||
|
||||
if self.parent is not None:
|
||||
previous_liveness = self.parent.recorded_liveness_after_graph
|
||||
@ -934,7 +930,7 @@ class CUDAGraphNode:
|
||||
# we reconstruct tensors at the correct data pointers of our inputs which are
|
||||
# non owning and do not prevent deallocation. On subsequent executions, input values
|
||||
# will be copied over to these tensors.
|
||||
self.reconstructed_inputs: List[InputType] = [
|
||||
self.reconstructed_inputs: list[InputType] = [
|
||||
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
|
||||
if isinstance(x, torch.Tensor)
|
||||
else x
|
||||
@ -983,7 +979,7 @@ class CUDAGraphNode:
|
||||
self.recording_outputs: Optional[OutputType] = self._record(
|
||||
wrapped_function.model, recording_inputs
|
||||
)
|
||||
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
|
||||
self.outputs_metadata: OutputList[Union[dict[str, Any], int, None]] = []
|
||||
|
||||
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent
|
||||
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
|
||||
@ -1001,7 +997,7 @@ class CUDAGraphNode:
|
||||
self.graph.replay()
|
||||
|
||||
def _copy_inputs_and_remove_from_src(
|
||||
self, dsts: List[InputType], srcs: List[InputType]
|
||||
self, dsts: list[InputType], srcs: list[InputType]
|
||||
) -> None:
|
||||
dst_tensors = []
|
||||
src_tensors = []
|
||||
@ -1016,7 +1012,7 @@ class CUDAGraphNode:
|
||||
if dst_tensors:
|
||||
torch._foreach_copy_(dst_tensors, src_tensors)
|
||||
|
||||
def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None:
|
||||
def check_static_inputs_are_stable(self, new_inputs: list[InputType]) -> None:
|
||||
# avoid checking managed tensor static points since we already checked those in check_invariants
|
||||
if (
|
||||
not self.rerecord_if_static_inputs_change
|
||||
@ -1036,7 +1032,7 @@ class CUDAGraphNode:
|
||||
)
|
||||
torch._check(False, lambda: error_msg)
|
||||
|
||||
def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType:
|
||||
def run_first_inputs(self, new_inputs: list[InputType]) -> OutputType:
|
||||
if config.triton.fast_path_cudagraph_asserts:
|
||||
self.debug_check_invariants_before_invocation()
|
||||
|
||||
@ -1048,7 +1044,7 @@ class CUDAGraphNode:
|
||||
assert outputs is not None
|
||||
return outputs
|
||||
|
||||
def run(self, new_inputs: List[InputType]) -> OutputType:
|
||||
def run(self, new_inputs: list[InputType]) -> OutputType:
|
||||
self.check_static_inputs_are_stable(new_inputs)
|
||||
|
||||
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
|
||||
@ -1130,7 +1126,7 @@ class CUDAGraphNode:
|
||||
def prepare_alias_info_for_tensor_construction(
|
||||
self,
|
||||
out_alias_info: Optional[OutputAliasInfo],
|
||||
metadata: Union[Dict[str, Any], int, None],
|
||||
metadata: Union[dict[str, Any], int, None],
|
||||
) -> Union[UntypedStorage, None, int]:
|
||||
if (
|
||||
isinstance(metadata, (int, type(None)))
|
||||
@ -1149,7 +1145,7 @@ class CUDAGraphNode:
|
||||
|
||||
def prepare_storages_for_construction(
|
||||
self,
|
||||
) -> List[Union[UntypedStorage, None, int]]:
|
||||
) -> list[Union[UntypedStorage, None, int]]:
|
||||
output_storages = []
|
||||
for output_storage_alias, metadata in zip(
|
||||
self.output_storage_alias, self.outputs_metadata
|
||||
@ -1173,7 +1169,7 @@ class CUDAGraphNode:
|
||||
return False
|
||||
return True
|
||||
|
||||
def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType:
|
||||
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
|
||||
"Record the model"
|
||||
|
||||
def static_input_iter() -> Generator[torch.Tensor, None, None]:
|
||||
@ -1185,7 +1181,7 @@ class CUDAGraphNode:
|
||||
yield _inp
|
||||
|
||||
# see: output_is_alias_of_persistent_static_inputs above
|
||||
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
|
||||
static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper] = {
|
||||
inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
|
||||
for inp in itertools.chain(
|
||||
static_input_iter(), self.wrapped_function.constants
|
||||
@ -1229,7 +1225,7 @@ class CUDAGraphNode:
|
||||
def _add_first_outputs(
|
||||
self,
|
||||
outputs: OutputType,
|
||||
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
|
||||
static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper],
|
||||
) -> None:
|
||||
"Add the outputs from the first invocation of the node and set up metadata"
|
||||
|
||||
@ -1243,7 +1239,7 @@ class CUDAGraphNode:
|
||||
|
||||
assert len(self.outputs_weakrefs) == 0
|
||||
# index from data pointer to index in outputs
|
||||
output_new_storages_index: Dict[StorageDataPtr, int] = {}
|
||||
output_new_storages_index: dict[StorageDataPtr, int] = {}
|
||||
|
||||
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
|
||||
self.static_output_tensors = [None for _ in range(len(outputs))]
|
||||
@ -1431,8 +1427,8 @@ class CUDAGraphNode:
|
||||
|
||||
@staticmethod
|
||||
def _check_liveness(
|
||||
indices: List[PathOutputIndex],
|
||||
output_refs: List[List[Optional[StorageWeakRefWrapper]]],
|
||||
indices: list[PathOutputIndex],
|
||||
output_refs: list[list[Optional[StorageWeakRefWrapper]]],
|
||||
) -> bool:
|
||||
"Check that all of the indices specified are dead references"
|
||||
for depth, output_index in indices:
|
||||
@ -1448,8 +1444,8 @@ class CUDAGraphNode:
|
||||
|
||||
@staticmethod
|
||||
def _get_different_indices(
|
||||
prev: List[List[bool]], curr: List[List[bool]]
|
||||
) -> List[PathOutputIndex]:
|
||||
prev: list[list[bool]], curr: list[list[bool]]
|
||||
) -> list[PathOutputIndex]:
|
||||
"Find indices where the two lists differ."
|
||||
dead_indices = []
|
||||
assert len(prev) <= len(curr)
|
||||
@ -1463,8 +1459,8 @@ class CUDAGraphNode:
|
||||
|
||||
@staticmethod
|
||||
def _get_liveness(
|
||||
weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
|
||||
) -> List[List[bool]]:
|
||||
weakrefs: list[list[Optional[StorageWeakRefWrapper]]],
|
||||
) -> list[list[bool]]:
|
||||
"Maps weakrefs to true if the reference is alive and false otherwise"
|
||||
if len(weakrefs) == 0:
|
||||
return []
|
||||
@ -1472,7 +1468,7 @@ class CUDAGraphNode:
|
||||
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
|
||||
|
||||
def debug_assert_invariants(
|
||||
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
|
||||
self, expected_liveness: list[list[bool]], newly_dead: list[PathOutputIndex]
|
||||
) -> None:
|
||||
if not config.triton.fast_path_cudagraph_asserts:
|
||||
return
|
||||
@ -1520,7 +1516,7 @@ class CUDAGraphNode:
|
||||
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
|
||||
)
|
||||
|
||||
def data_ptrs_dead_since_invocation(self) -> List[int]:
|
||||
def data_ptrs_dead_since_invocation(self) -> list[int]:
|
||||
"""
|
||||
Since this node was invoked, return data ptrs of all tensor outputs that have died
|
||||
in the current executing tree path.
|
||||
@ -1568,7 +1564,7 @@ class CUDAGraphNode:
|
||||
@staticmethod
|
||||
def _tensor_metadata(
|
||||
x: torch.Tensor, ignore_storage_offset: bool = True
|
||||
) -> Dict[str, Any]:
|
||||
) -> dict[str, Any]:
|
||||
assert isinstance(x, torch.Tensor)
|
||||
# We ignore the storage offset for inputs, but not for outputs
|
||||
# TODO: - should we make the storage resizable ?
|
||||
@ -1583,19 +1579,19 @@ class CUDAGraphNode:
|
||||
}
|
||||
|
||||
def _reconstruct_from_tensor_metadata(
|
||||
self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None
|
||||
self, metadata: dict[str, Any], storage: Optional[UntypedStorage] = None
|
||||
) -> Tensor:
|
||||
s = self.create_storage(metadata) if storage is None else storage
|
||||
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type]
|
||||
|
||||
def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage:
|
||||
def create_storage(self, metadata: dict[str, Any]) -> torch.types.Storage:
|
||||
return torch._C._construct_storage_from_data_pointer(
|
||||
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
|
||||
)
|
||||
|
||||
def _allocate_and_copy_recording_inputs(
|
||||
self, inputs: List[InputType]
|
||||
) -> List[InputType]:
|
||||
self, inputs: list[InputType]
|
||||
) -> list[InputType]:
|
||||
"""
|
||||
Allocate inputs for non static, non cudagraph managed tensors in the memory pool
|
||||
and copy over the tensor values.
|
||||
@ -1603,7 +1599,7 @@ class CUDAGraphNode:
|
||||
|
||||
torch.cuda.synchronize()
|
||||
self.stream.wait_stream(torch.cuda.current_stream())
|
||||
recording_inputs: List[InputType] = []
|
||||
recording_inputs: list[InputType] = []
|
||||
|
||||
with warnings.catch_warnings(record=True), torch.cuda.device(
|
||||
self.device
|
||||
@ -1627,7 +1623,7 @@ class CUDAGraphNode:
|
||||
return recording_inputs
|
||||
|
||||
def check_invariants(
|
||||
self, inputs: List[InputType]
|
||||
self, inputs: list[InputType]
|
||||
) -> tuple[CheckInvariantStatus, Callable[..., str]]:
|
||||
"""
|
||||
Checks if this node can be run. The same pattern of tensor liveness, static inputs,
|
||||
@ -1714,7 +1710,7 @@ def get_cudagraph_segments(pool_id: tuple[int, int]) -> Any:
|
||||
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
|
||||
|
||||
|
||||
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[int]:
|
||||
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> list[int]:
|
||||
blocks = []
|
||||
|
||||
for segment in get_cudagraph_segments(pool_id):
|
||||
@ -1728,7 +1724,7 @@ def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[in
|
||||
return blocks
|
||||
|
||||
|
||||
def format_tb(frames: List[Any]) -> str:
|
||||
def format_tb(frames: list[Any]) -> str:
|
||||
formatted_traceback = [
|
||||
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
|
||||
for entry in frames
|
||||
@ -1740,7 +1736,7 @@ def format_tb(frames: List[Any]) -> str:
|
||||
def check_memory_pool(
|
||||
device: int,
|
||||
pool_id: tuple[int, int],
|
||||
live_storages_ptrs: List[StorageWeakRefWrapper],
|
||||
live_storages_ptrs: list[StorageWeakRefWrapper],
|
||||
) -> None:
|
||||
assert all(
|
||||
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
|
||||
@ -1839,12 +1835,12 @@ class CUDAGraphTreeManager:
|
||||
# when they are first invoked, none of their inputs are outputs are outputs
|
||||
# of another node, nor are there any live outputs of another node whose
|
||||
# liveness would create a dependency.
|
||||
self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
|
||||
self.roots: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
|
||||
|
||||
# mapping from function id to wrapped function
|
||||
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
|
||||
self.ids_to_funcs: dict[FunctionID, WrappedFunction] = {}
|
||||
|
||||
self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {}
|
||||
self.ids_to_stack_traces: dict[FunctionID, Optional[StackTraces]] = {}
|
||||
|
||||
self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet()
|
||||
# if we fail to increment generation, and are stuck warming up,
|
||||
@ -1883,14 +1879,14 @@ class CUDAGraphTreeManager:
|
||||
|
||||
# mapping from graph_id to (function id to mutation type hint) since we are
|
||||
# specializing on a particular combination of Parent Node -> Function ID.
|
||||
self.non_cudagraph_managed_mutation_hint: Dict[
|
||||
Optional[GraphID], Dict[FunctionID, bool]
|
||||
self.non_cudagraph_managed_mutation_hint: dict[
|
||||
Optional[GraphID], dict[FunctionID, bool]
|
||||
] = defaultdict(dict)
|
||||
self.warmup_node_counter = itertools.count(start=-1, step=-1)
|
||||
|
||||
# mapping from graph_id to (function id to re-record count). We fall back to
|
||||
# eager function if a function is re-recorded frequently on a node.
|
||||
self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict(
|
||||
self.num_rerecord: dict[Optional[GraphID], dict[FunctionID, int]] = defaultdict(
|
||||
lambda: defaultdict(lambda: 0)
|
||||
)
|
||||
|
||||
@ -1916,7 +1912,7 @@ class CUDAGraphTreeManager:
|
||||
# number of instances we had to checkpoint the function
|
||||
self.debug_checkpointing_counter = 0
|
||||
|
||||
self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
|
||||
self.id_to_mode: dict[FunctionID, CompilationMode] = {}
|
||||
|
||||
# Note: [Backward Generation Handling]
|
||||
# We generally perform a sequence of forward executions followed by backward executions.
|
||||
@ -1943,7 +1939,7 @@ class CUDAGraphTreeManager:
|
||||
)
|
||||
)
|
||||
|
||||
def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
|
||||
def run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
|
||||
assert self.graph is not None, "Running CUDAGraph after shutdown"
|
||||
self.mode = self.id_to_mode[function_id]
|
||||
out = self._run(new_inputs, function_id)
|
||||
@ -1971,7 +1967,7 @@ class CUDAGraphTreeManager:
|
||||
return GraphID(next(self.warmup_node_counter))
|
||||
|
||||
def _update_non_cudagraph_managed_mutation(
|
||||
self, function_id: FunctionID, inputs: List[InputType]
|
||||
self, function_id: FunctionID, inputs: list[InputType]
|
||||
) -> None:
|
||||
node_id = self._get_node_id()
|
||||
if maybe_mutation_str := check_for_mutation(
|
||||
@ -2007,7 +2003,7 @@ class CUDAGraphTreeManager:
|
||||
> torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
|
||||
)
|
||||
|
||||
def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
|
||||
def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
|
||||
# we will try to end the current execution lazily, since
|
||||
# we dont want to do unnecessary checking of the existing outputs
|
||||
# on the hot path, but both recording and warmup only happen once
|
||||
@ -2152,7 +2148,7 @@ class CUDAGraphTreeManager:
|
||||
self.current_node = None
|
||||
|
||||
def record_function(
|
||||
self, new_inputs: List[InputType], function_id: FunctionID
|
||||
self, new_inputs: list[InputType], function_id: FunctionID
|
||||
) -> OutputType:
|
||||
assert not isinstance(self.current_node, CUDAWarmupNode)
|
||||
graph_id = self.new_graph_id()
|
||||
@ -2183,7 +2179,7 @@ class CUDAGraphTreeManager:
|
||||
return node.run_first_inputs(new_inputs)
|
||||
|
||||
def execute_node(
|
||||
self, node: CUDAGraphNode, new_inputs: List[InputType]
|
||||
self, node: CUDAGraphNode, new_inputs: list[InputType]
|
||||
) -> OutputType:
|
||||
self.current_node = node
|
||||
self.path_state = ExecutionState.EXECUTION
|
||||
@ -2191,7 +2187,7 @@ class CUDAGraphTreeManager:
|
||||
return node.run(new_inputs)
|
||||
|
||||
def run_eager(
|
||||
self, new_inputs: List[InputType], function_id: FunctionID
|
||||
self, new_inputs: list[InputType], function_id: FunctionID
|
||||
) -> OutputType:
|
||||
# this is only stored on current node, because when we start a new path,
|
||||
# we will deallocate it
|
||||
@ -2229,7 +2225,7 @@ class CUDAGraphTreeManager:
|
||||
def add_function(
|
||||
self,
|
||||
model: ModelType,
|
||||
inputs: List[InputType],
|
||||
inputs: list[InputType],
|
||||
static_input_idxs: Sequence[int],
|
||||
stack_traces: Optional[StackTraces],
|
||||
mode: CompilationMode,
|
||||
@ -2409,7 +2405,7 @@ class CUDAGraphTreeManager:
|
||||
# TODO: we could also allow the these weak refs to continue to be allocated,
|
||||
# but that adds some complications.
|
||||
|
||||
stor_stack_trace: Dict[int, Optional[str]] = {}
|
||||
stor_stack_trace: dict[int, Optional[str]] = {}
|
||||
for node in self.current_node._path_from_root:
|
||||
assert node.stack_traces is not None
|
||||
assert len(node.tensor_weakrefs) == len(node.stack_traces)
|
||||
@ -2475,7 +2471,7 @@ class CUDAGraphTreeManager:
|
||||
assert state is not None and device is not None
|
||||
|
||||
# currently we deallocate on instead of allowing stale recordings
|
||||
stale_storages: List[int] = []
|
||||
stale_storages: list[int] = []
|
||||
|
||||
# remove cached tensors, otherwise they would prevent memory from being
|
||||
# reclaimed in subsequent recordings
|
||||
@ -2506,7 +2502,7 @@ class CUDAGraphTreeManager:
|
||||
|
||||
def live_cudagraph_pool_storages_in_curr_execution(
|
||||
self,
|
||||
) -> List[StorageWeakRefPointer]:
|
||||
) -> list[StorageWeakRefPointer]:
|
||||
if self.current_node is None:
|
||||
return []
|
||||
# explicitly ignoring previous recorded outputs from past path
|
||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
@ -11,14 +11,18 @@ from torch._inductor.utils import InputType
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
static_inputs_log = torch._logging.getArtifactLogger(
|
||||
__name__, "cudagraph_static_inputs"
|
||||
)
|
||||
|
||||
|
||||
OutputType = List[Optional[Union[int, torch.Tensor]]]
|
||||
ModelType = Callable[[List[InputType]], OutputType]
|
||||
OutputType = list[Optional[Union[int, torch.Tensor]]]
|
||||
ModelType = Callable[[list[InputType]], OutputType]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -38,7 +42,7 @@ class PlaceholderInfo:
|
||||
name: str
|
||||
stack_trace: Optional[str]
|
||||
# This field is recursive, but never cyclic (since a node never uses itself)
|
||||
users: List[PlaceholderInfo]
|
||||
users: list[PlaceholderInfo]
|
||||
mutating_use_stack_trace: Optional[str]
|
||||
|
||||
|
||||
@ -92,7 +96,7 @@ def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
|
||||
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
|
||||
|
||||
|
||||
def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
|
||||
def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]:
|
||||
return [
|
||||
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
|
||||
]
|
||||
@ -123,7 +127,7 @@ def get_mutation_stack_trace(
|
||||
|
||||
def check_for_mutation(
|
||||
func: WrappedFunction,
|
||||
inputs: List[InputType],
|
||||
inputs: list[InputType],
|
||||
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
|
||||
) -> Optional[str]:
|
||||
# doesnt work for non-trees because the warmup run would apply mutation twice
|
||||
@ -160,7 +164,7 @@ def _get_use_stack_trace(node) -> Optional[str]:
|
||||
|
||||
|
||||
def check_multiple_devices_or_any_cpu_nodes(
|
||||
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node]
|
||||
) -> Optional[str]:
|
||||
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
||||
msg = f"cpu device ({cpu_node.name})"
|
||||
@ -180,7 +184,7 @@ def check_multiple_devices_or_any_cpu_nodes(
|
||||
|
||||
|
||||
def check_lowering_disable_cudagraph(
|
||||
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node]
|
||||
):
|
||||
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
||||
|
||||
@ -262,7 +266,7 @@ class CheckInvariantStatus(Enum):
|
||||
|
||||
def log_data_ptr_mismatch(
|
||||
placeholders: Sequence[PlaceholderInfo],
|
||||
inputs: List[InputType],
|
||||
inputs: list[InputType],
|
||||
recorded_data_ptr: Sequence[Optional[int]],
|
||||
target_idxs: Sequence[int],
|
||||
mismatch: CheckInvariantStatus,
|
||||
@ -292,7 +296,7 @@ def log_data_ptr_mismatch(
|
||||
|
||||
|
||||
def maybe_warning_due_to_dynamic_shape(
|
||||
fn_cache: Dict[tuple[int, ...], Callable[..., Any]],
|
||||
fn_cache: dict[tuple[int, ...], Callable[..., Any]],
|
||||
new_int_key: Any,
|
||||
) -> bool:
|
||||
num_cudagraphs = len(fn_cache.keys()) + 1
|
||||
@ -327,5 +331,5 @@ class CudagraphCachedInfo:
|
||||
"""
|
||||
|
||||
placeholders: Sequence[PlaceholderInfo]
|
||||
stack_traces: List[Optional[str]]
|
||||
cudagraph_fail_reasons: List[str]
|
||||
stack_traces: list[Optional[str]]
|
||||
cudagraph_fail_reasons: list[str]
|
||||
|
@ -12,7 +12,8 @@ import pickle
|
||||
import pstats
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union
|
||||
from collections.abc import Iterator
|
||||
from typing import Any, Callable, IO, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
@ -39,7 +40,7 @@ from .virtualized import V
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
SchedulerNodeList = List[Any]
|
||||
SchedulerNodeList = list[Any]
|
||||
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
|
||||
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
|
||||
|
||||
@ -54,7 +55,7 @@ def has_dot() -> bool:
|
||||
|
||||
|
||||
def draw_buffers(
|
||||
nodes: List[BaseSchedulerNode],
|
||||
nodes: list[BaseSchedulerNode],
|
||||
print_graph: bool = False,
|
||||
fname: Optional[str] = None,
|
||||
) -> None:
|
||||
@ -99,7 +100,7 @@ def draw_buffers(
|
||||
)
|
||||
|
||||
|
||||
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
||||
def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph:
|
||||
"""
|
||||
Creates a FX Graph from a list of SchedulerNode objects.
|
||||
"""
|
||||
@ -199,7 +200,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
||||
|
||||
def update_orig_fx_node_name_to_buf_name(
|
||||
nodes: Optional[SchedulerNodeList],
|
||||
node_name_to_buf_name: Dict[str, str],
|
||||
node_name_to_buf_name: dict[str, str],
|
||||
parent_buf_name: Optional[str] = None,
|
||||
n_origins: int = 0,
|
||||
) -> None:
|
||||
@ -233,8 +234,8 @@ def update_orig_fx_node_name_to_buf_name(
|
||||
|
||||
|
||||
def get_node_name_to_buf_meta(
|
||||
node_name_to_buf_name: Dict[str, str]
|
||||
) -> Dict[str, BufMeta]:
|
||||
node_name_to_buf_name: dict[str, str]
|
||||
) -> dict[str, BufMeta]:
|
||||
buf_name_to_n_node = {}
|
||||
for node_name, buf_name in node_name_to_buf_name.items():
|
||||
if buf_name not in buf_name_to_n_node:
|
||||
@ -256,7 +257,7 @@ def annotate_orig_fx_with_snodes(
|
||||
"""
|
||||
Creates a FX Graph from a list of SchedulerNode objects.
|
||||
"""
|
||||
node_name_to_buf_name: Dict[str, str] = {}
|
||||
node_name_to_buf_name: dict[str, str] = {}
|
||||
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
|
||||
if node_name_to_buf_name is None:
|
||||
return
|
||||
@ -309,7 +310,7 @@ def enable_aot_logging() -> Iterator[None]:
|
||||
|
||||
class DebugContext:
|
||||
_counter = itertools.count()
|
||||
_inductor_triton_kernel_to_post_grad_node_info: Dict[str, List[str]] = {}
|
||||
_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
|
||||
|
||||
@staticmethod
|
||||
def create_debug_dir(folder_name: str) -> Optional[str]:
|
||||
@ -425,7 +426,7 @@ class DebugContext:
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: Optional[Type[BaseException]],
|
||||
exc_type: Optional[type[BaseException]],
|
||||
exc_val: Optional[BaseException],
|
||||
exc_tb: Optional[Any],
|
||||
) -> None:
|
||||
@ -474,7 +475,7 @@ class DebugFormatter:
|
||||
def fx_graph(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
inputs: List[torch.Tensor],
|
||||
inputs: list[torch.Tensor],
|
||||
) -> None:
|
||||
with self.fopen("fx_graph_runnable.py") as fd:
|
||||
save_dir = None
|
||||
@ -504,7 +505,7 @@ class DebugFormatter:
|
||||
def fx_graph_transformed(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
inputs: List[torch.Tensor],
|
||||
inputs: list[torch.Tensor],
|
||||
) -> None:
|
||||
with self.fopen("fx_graph_transformed.py") as fd:
|
||||
fd.write(gm.print_readable(print_output=False))
|
||||
@ -557,14 +558,14 @@ class DebugFormatter:
|
||||
def log_autotuning_results(
|
||||
self,
|
||||
name: str,
|
||||
input_nodes: List[ir.IRNode],
|
||||
timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
|
||||
input_nodes: list[ir.IRNode],
|
||||
timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
|
||||
elapse: float,
|
||||
precompile_elapse: float,
|
||||
) -> None:
|
||||
from .ir import FixedLayout
|
||||
|
||||
def build_node_info(node: ir.IRNode) -> Dict[str, str]:
|
||||
def build_node_info(node: ir.IRNode) -> dict[str, str]:
|
||||
if hasattr(node, "name"):
|
||||
node_name = node.name
|
||||
else:
|
||||
@ -725,7 +726,7 @@ def aot_inductor_minifier_wrapper(
|
||||
func: Callable[..., str],
|
||||
exported_program: torch.export.ExportedProgram,
|
||||
*,
|
||||
inductor_configs: Dict[str, Any],
|
||||
inductor_configs: dict[str, Any],
|
||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||
) -> str:
|
||||
from torch._inductor import config
|
||||
|
@ -4,7 +4,7 @@ import logging
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
|
||||
|
||||
|
||||
def register_decomposition(
|
||||
ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
||||
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
||||
if op in decompositions:
|
||||
@ -170,7 +170,7 @@ def clamp(
|
||||
|
||||
@register_decomposition([aten.full])
|
||||
def full(
|
||||
size: List[Union[int, torch.SymInt]],
|
||||
size: list[Union[int, torch.SymInt]],
|
||||
fill_value: torch.types.Number,
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
@ -205,8 +205,8 @@ def index_add(
|
||||
# cool with strides and everything goes to empty_strided)
|
||||
@register_decomposition([aten.empty_permuted.default])
|
||||
def empty_permuted(
|
||||
size: List[Union[int, torch.SymInt]],
|
||||
physical_layout: List[int],
|
||||
size: list[Union[int, torch.SymInt]],
|
||||
physical_layout: list[int],
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
perm = [0] * len(size)
|
||||
@ -220,14 +220,14 @@ def convolution_backward(
|
||||
grad_output: torch.Tensor,
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
bias_sizes: List[int],
|
||||
stride: Union[int, List[int]],
|
||||
padding: Union[int, List[int]],
|
||||
dilation: Union[int, List[int]],
|
||||
bias_sizes: list[int],
|
||||
stride: Union[int, list[int]],
|
||||
padding: Union[int, list[int]],
|
||||
dilation: Union[int, list[int]],
|
||||
transposed: bool,
|
||||
output_padding: List[int],
|
||||
output_padding: list[int],
|
||||
groups: int,
|
||||
output_mask: List[bool],
|
||||
output_mask: list[bool],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
if not output_mask[2] or not is_gpu(grad_output.device.type):
|
||||
return NotImplemented
|
||||
@ -345,7 +345,7 @@ def mm(
|
||||
# don't remove ALL empty tensors, only the naughty ones)
|
||||
@register_decomposition([aten.cat.default])
|
||||
def cat(
|
||||
tensors: List[torch.Tensor],
|
||||
tensors: list[torch.Tensor],
|
||||
dim: int = 0,
|
||||
) -> torch.Tensor:
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
@ -515,7 +515,7 @@ def narrow_copy(
|
||||
@register_decomposition([aten.view_copy.default])
|
||||
def view_copy_default(
|
||||
self: torch.Tensor,
|
||||
size: List[Union[int, torch.SymInt]],
|
||||
size: list[Union[int, torch.SymInt]],
|
||||
) -> torch.Tensor:
|
||||
return aten.view(self, size).clone()
|
||||
|
||||
@ -639,7 +639,7 @@ def randint_like_low(
|
||||
@register_decomposition(aten.randint.default)
|
||||
def randint(
|
||||
high: int,
|
||||
size: List[Union[int, torch.SymInt]],
|
||||
size: list[Union[int, torch.SymInt]],
|
||||
**kwargs: Any,
|
||||
) -> torch.Tensor:
|
||||
return aten.randint.low(0, high, size, **kwargs)
|
||||
@ -731,11 +731,11 @@ def grid_sampler_2d(
|
||||
|
||||
@register_decomposition(aten._foreach_addcmul.Scalar)
|
||||
def _foreach_addcmul_scalar(
|
||||
self: List[torch.Tensor],
|
||||
left_tensors: List[torch.Tensor],
|
||||
right_tensors: List[torch.Tensor],
|
||||
self: list[torch.Tensor],
|
||||
left_tensors: list[torch.Tensor],
|
||||
right_tensors: list[torch.Tensor],
|
||||
scalar: float = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
return aten._foreach_add.List(
|
||||
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
|
||||
)
|
||||
@ -743,11 +743,11 @@ def _foreach_addcmul_scalar(
|
||||
|
||||
@register_decomposition(aten._foreach_addcdiv.Scalar)
|
||||
def _foreach_addcdiv_scalar(
|
||||
self: List[torch.Tensor],
|
||||
left_tensors: List[torch.Tensor],
|
||||
right_tensors: List[torch.Tensor],
|
||||
self: list[torch.Tensor],
|
||||
left_tensors: list[torch.Tensor],
|
||||
right_tensors: list[torch.Tensor],
|
||||
scalar: float = 1,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
return aten._foreach_add.List(
|
||||
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
|
||||
)
|
||||
@ -755,10 +755,10 @@ def _foreach_addcdiv_scalar(
|
||||
|
||||
@register_decomposition(aten._foreach_lerp.Scalar)
|
||||
def _foreach_lerp_scalar(
|
||||
start_tensors: List[torch.Tensor],
|
||||
end_tensors: List[torch.Tensor],
|
||||
start_tensors: list[torch.Tensor],
|
||||
end_tensors: list[torch.Tensor],
|
||||
weight: torch.types.Number,
|
||||
) -> List[torch.Tensor]:
|
||||
) -> list[torch.Tensor]:
|
||||
return aten._foreach_add.List(
|
||||
start_tensors,
|
||||
aten._foreach_mul.Scalar(
|
||||
@ -769,10 +769,10 @@ def _foreach_lerp_scalar(
|
||||
|
||||
@register_decomposition(aten._foreach_lerp.ScalarList)
|
||||
def _foreach_lerp_scalarlist(
|
||||
start_tensors: List[torch.Tensor],
|
||||
end_tensors: List[torch.Tensor],
|
||||
scalars: List[torch.types.Number],
|
||||
) -> List[torch.Tensor]:
|
||||
start_tensors: list[torch.Tensor],
|
||||
end_tensors: list[torch.Tensor],
|
||||
scalars: list[torch.types.Number],
|
||||
) -> list[torch.Tensor]:
|
||||
return aten._foreach_add.List(
|
||||
start_tensors,
|
||||
aten._foreach_mul.ScalarList(
|
||||
@ -814,13 +814,13 @@ def miopen_batch_norm(
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
|
||||
def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
|
||||
return {**decompositions, **extra_random_decomps}
|
||||
|
||||
|
||||
# TODO(aakhundov): replace this (and the above) Any by more
|
||||
# specific type and fix all the cascading mypy errors
|
||||
def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
|
||||
def select_decomp_table() -> dict[Any, Callable[..., Any]]:
|
||||
"""decomps can change based on config"""
|
||||
if config.fallback_random:
|
||||
return decompositions
|
||||
@ -965,10 +965,10 @@ def index_reduce(
|
||||
@register_decomposition(aten.max_pool2d_with_indices)
|
||||
def max_pool2d_with_indices(
|
||||
x: torch.Tensor,
|
||||
kernel_size: List[int],
|
||||
stride: Optional[Union[int, List[int]]] = None,
|
||||
padding: Union[int, List[int]] = 0,
|
||||
dilation: Union[int, List[int]] = 1,
|
||||
kernel_size: list[int],
|
||||
stride: Optional[Union[int, list[int]]] = None,
|
||||
padding: Union[int, list[int]] = 0,
|
||||
dilation: Union[int, list[int]] = 1,
|
||||
ceil_mode: bool = False,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if dilation == 1:
|
||||
@ -1015,7 +1015,7 @@ def max_pool2d_with_indices(
|
||||
|
||||
@register_decomposition(aten.adaptive_max_pool2d)
|
||||
def adaptive_max_pool2d(
|
||||
x: torch.Tensor, output_size: List[int]
|
||||
x: torch.Tensor, output_size: list[int]
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
*batch, h_in, w_in = x.shape
|
||||
h_out, w_out = output_size
|
||||
|
@ -4,8 +4,8 @@ import dataclasses
|
||||
import itertools
|
||||
import logging
|
||||
import re
|
||||
import typing
|
||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
@ -38,7 +38,7 @@ class Dep(abc.ABC):
|
||||
index: sympy.Expr
|
||||
|
||||
@abc.abstractmethod
|
||||
def rename(self, renames: Dict[str, str]) -> "Dep":
|
||||
def rename(self, renames: dict[str, str]) -> "Dep":
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -197,7 +197,7 @@ class MemoryDep(Dep):
|
||||
return out
|
||||
|
||||
@property
|
||||
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
|
||||
def ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
|
||||
"""{c0: 128, c1: 512, ...}"""
|
||||
return dict(zip(self.var_names, self.size))
|
||||
|
||||
@ -221,7 +221,7 @@ class MemoryDep(Dep):
|
||||
numel = numel * size
|
||||
return numel # type: ignore[return-value]
|
||||
|
||||
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
|
||||
def rename(self, renames: dict[str, str]) -> "MemoryDep":
|
||||
if self.name in renames:
|
||||
return MemoryDep(
|
||||
renames[self.name],
|
||||
@ -299,7 +299,7 @@ class StarDep(Dep):
|
||||
def get_numel(self) -> sympy.Expr:
|
||||
return V.graph.get_numel(self.name) # type: ignore[return-value]
|
||||
|
||||
def rename(self, renames: Dict[str, str]) -> "StarDep":
|
||||
def rename(self, renames: dict[str, str]) -> "StarDep":
|
||||
if self.name in renames:
|
||||
return StarDep(renames[self.name], self.mode)
|
||||
return self
|
||||
@ -347,7 +347,7 @@ class WeakDep(Dep):
|
||||
def get_numel(self) -> sympy.Expr:
|
||||
return sympy.S.One
|
||||
|
||||
def rename(self, renames: Dict[str, str]) -> "WeakDep":
|
||||
def rename(self, renames: dict[str, str]) -> "WeakDep":
|
||||
if self.name in renames:
|
||||
return WeakDep(renames[self.name], self.mutating_buf)
|
||||
return self
|
||||
@ -374,10 +374,10 @@ class ReadWrites:
|
||||
reads: OrderedSet[Dep]
|
||||
writes: OrderedSet[Dep]
|
||||
index_exprs: OrderedSet[IndexExprDep]
|
||||
range_vars: Optional[List[sympy.Expr]] = None
|
||||
range_vars: Optional[list[sympy.Expr]] = None
|
||||
var_ranges: Optional[VarRanges] = None
|
||||
|
||||
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
|
||||
def rename(self, renames: dict[str, str]) -> "ReadWrites":
|
||||
return ReadWrites(
|
||||
OrderedSet(dep.rename(renames) for dep in self.reads),
|
||||
OrderedSet(dep.rename(renames) for dep in self.writes),
|
||||
@ -405,7 +405,7 @@ class ReadWrites:
|
||||
return ReadWrites(reads - writes, writes, index_exprs)
|
||||
|
||||
@staticmethod
|
||||
def merge_list(read_writes: List["ReadWrites"]):
|
||||
def merge_list(read_writes: list["ReadWrites"]):
|
||||
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
|
||||
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
|
||||
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
|
||||
@ -564,7 +564,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
|
||||
|
||||
def index_vars_no_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str):
|
||||
var_ranges, add_var = var_builder(prefix)
|
||||
args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
|
||||
args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
|
||||
return args, var_ranges
|
||||
|
||||
|
||||
@ -572,8 +572,8 @@ def index_vars_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str = "d"):
|
||||
from .ir import SqueezeView
|
||||
|
||||
var_ranges, add_var = var_builder(prefix)
|
||||
args: List[List[sympy.Expr]] = []
|
||||
new_sizes: List[List[sympy.Expr]] = []
|
||||
args: list[list[sympy.Expr]] = []
|
||||
new_sizes: list[list[sympy.Expr]] = []
|
||||
for size in argsizes:
|
||||
new_size, reindex = SqueezeView.squeezer(size)
|
||||
new_sizes.append(new_size)
|
||||
@ -653,7 +653,7 @@ def extract_loop_body_with_args(fn, args, var_ranges, normalize=False):
|
||||
|
||||
def extract_input_node_reduction_ranges(
|
||||
input_node: "torch._inductor.ir.IRNode",
|
||||
) -> tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
|
||||
) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]:
|
||||
"""
|
||||
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
|
||||
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
|
||||
@ -663,8 +663,8 @@ def extract_input_node_reduction_ranges(
|
||||
|
||||
from .ir import ComputedBuffer, ExternKernel, Loops
|
||||
|
||||
size: Optional[List[sympy.Expr]]
|
||||
reduction_size: Optional[List[sympy.Expr]]
|
||||
size: Optional[list[sympy.Expr]]
|
||||
reduction_size: Optional[list[sympy.Expr]]
|
||||
|
||||
if isinstance(input_node.get_defining_op(), ComputedBuffer):
|
||||
# Input node has already been realized. Return its size and reduction_size.
|
||||
@ -683,11 +683,11 @@ def extract_input_node_reduction_ranges(
|
||||
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
|
||||
# Is there a way to check whether there are permutations inbetween?
|
||||
reads = input_node.get_reads()
|
||||
reduction_size: Optional[List[sympy.Expr]] = None
|
||||
size: Optional[List[sympy.Expr]] = None
|
||||
reduction_size: Optional[list[sympy.Expr]] = None
|
||||
size: Optional[list[sympy.Expr]] = None
|
||||
while reduction_size is None and len(reads) > 0:
|
||||
seen: OrderedSet[str] = OrderedSet()
|
||||
new_reads: List[Dep] = []
|
||||
new_reads: list[Dep] = []
|
||||
for read in reads:
|
||||
if not isinstance(read, MemoryDep):
|
||||
continue
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Callable, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, Union
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import sympy
|
||||
|
||||
|
@ -4,7 +4,7 @@ import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
from functools import lru_cache
|
||||
from typing import Any, List, Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
|
||||
from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback
|
||||
|
||||
@ -29,7 +29,7 @@ else:
|
||||
|
||||
class OperatorIssue(RuntimeError):
|
||||
@staticmethod
|
||||
def operator_str(target: Any, args: List[Any], kwargs: dict[str, Any]) -> str:
|
||||
def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str:
|
||||
lines = [f"target: {target}"] + [
|
||||
f"args[{i}]: {arg}" for i, arg in enumerate(args)
|
||||
]
|
||||
@ -39,13 +39,13 @@ class OperatorIssue(RuntimeError):
|
||||
|
||||
|
||||
class MissingOperatorWithoutDecomp(OperatorIssue):
|
||||
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None:
|
||||
def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
|
||||
_record_missing_op(target)
|
||||
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
|
||||
|
||||
|
||||
class MissingOperatorWithDecomp(OperatorIssue):
|
||||
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None:
|
||||
def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
|
||||
_record_missing_op(target)
|
||||
super().__init__(
|
||||
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
|
||||
@ -62,7 +62,7 @@ class MissingOperatorWithDecomp(OperatorIssue):
|
||||
|
||||
class LoweringException(OperatorIssue):
|
||||
def __init__(
|
||||
self, exc: Exception, target: Any, args: List[Any], kwargs: dict[str, Any]
|
||||
self, exc: Exception, target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> None:
|
||||
super().__init__(
|
||||
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
|
||||
|
@ -1,5 +1,4 @@
|
||||
import json
|
||||
from typing import List
|
||||
|
||||
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
|
||||
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
|
||||
@ -17,7 +16,7 @@ def serialize_extern_kernel_node(
|
||||
|
||||
|
||||
def extern_node_json_serializer(
|
||||
extern_kernel_nodes: List[inductor_ExternKernelNode],
|
||||
extern_kernel_nodes: list[inductor_ExternKernelNode],
|
||||
) -> str:
|
||||
serialized_nodes = ExternKernelNodes(
|
||||
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
|
||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
import itertools
|
||||
import logging
|
||||
import weakref
|
||||
from typing import Any, List, Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
@ -28,7 +28,7 @@ def replace_params_with_constants(
|
||||
gm: torch.fx.GraphModule,
|
||||
flat_params: list[Any],
|
||||
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
|
||||
) -> List[int]:
|
||||
) -> list[int]:
|
||||
"""
|
||||
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
|
||||
Returns a list of indices representing the input parameters that were not converted to constants.
|
||||
@ -66,8 +66,8 @@ def replace_params_with_constants(
|
||||
def freeze(
|
||||
dynamo_gm: torch.fx.GraphModule,
|
||||
aot_autograd_gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch._subclasses.FakeTensor],
|
||||
) -> tuple[torch.fx.GraphModule, List[int]]:
|
||||
example_inputs: list[torch._subclasses.FakeTensor],
|
||||
) -> tuple[torch.fx.GraphModule, list[int]]:
|
||||
"""
|
||||
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
|
||||
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
|
||||
|
@ -6,21 +6,17 @@ import signal
|
||||
import string
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import KeysView
|
||||
from enum import Enum
|
||||
from functools import partial, wraps
|
||||
from types import FrameType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
get_args,
|
||||
get_origin,
|
||||
KeysView,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
@ -92,14 +88,14 @@ class TypeExemplars:
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def example(t: Type[T]) -> Optional[T]:
|
||||
def example(t: type[T]) -> Optional[T]:
|
||||
"""
|
||||
Return an example of a class.
|
||||
"""
|
||||
return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None)
|
||||
|
||||
@staticmethod
|
||||
def contains(t: Type[T]) -> bool:
|
||||
def contains(t: type[T]) -> bool:
|
||||
return t.__name__ in TypeExemplars.TYPE_EXEMPLARS
|
||||
|
||||
|
||||
@ -136,7 +132,7 @@ class Status(Enum):
|
||||
|
||||
# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be
|
||||
# manually specified here:
|
||||
TYPE_OVERRIDES: Dict[str, List[Any]] = {
|
||||
TYPE_OVERRIDES: dict[str, list[Any]] = {
|
||||
"post_grad_fusion_options": [
|
||||
{
|
||||
"batch_linear_post_grad": {
|
||||
@ -160,7 +156,7 @@ TYPE_OVERRIDES: Dict[str, List[Any]] = {
|
||||
"autoheuristic_collect": ["pad_mm", "mixed_mm"],
|
||||
"autoheuristic_use": ["pad_mm", "mixed_mm"],
|
||||
}
|
||||
SamplingType = Callable[[str, Type[Any], Any], Any]
|
||||
SamplingType = Callable[[str, type[Any], Any], Any]
|
||||
|
||||
|
||||
class SamplingMethod(Enum):
|
||||
@ -178,7 +174,7 @@ class SamplingMethod(Enum):
|
||||
|
||||
@staticmethod
|
||||
def _generate_value_for_type(
|
||||
random_sample: bool, field_name: str, type_hint: Type[Any], default: Any
|
||||
random_sample: bool, field_name: str, type_hint: type[Any], default: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Generates a value of a type based on the setting.
|
||||
@ -304,9 +300,11 @@ class SamplingMethod(Enum):
|
||||
if random_sample:
|
||||
return random.choice(type_hint.__args__)
|
||||
else:
|
||||
return random.choice(
|
||||
[t for t in type_hint.__args__ if t != default]
|
||||
)
|
||||
choices = [t for t in type_hint.__args__ if t != default]
|
||||
if choices:
|
||||
return random.choice(choices)
|
||||
else:
|
||||
return default
|
||||
except AttributeError as err:
|
||||
raise ValueError("Literal type with no args") from err
|
||||
elif is_optional_type(type_hint):
|
||||
@ -374,7 +372,7 @@ class Default:
|
||||
DEFAULT = Default()
|
||||
|
||||
# The combination of config settings being set (based on their strings)
|
||||
ComboType = Tuple[str, ...]
|
||||
ComboType = tuple[str, ...]
|
||||
|
||||
|
||||
class ResultType:
|
||||
@ -382,7 +380,7 @@ class ResultType:
|
||||
The mapping of the combo strings to the result status after running the config fuzzer.
|
||||
"""
|
||||
|
||||
_vals: Dict[ComboType, Status]
|
||||
_vals: dict[ComboType, Status]
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"ResultType[{self._vals}]"
|
||||
@ -416,7 +414,7 @@ class ResultType:
|
||||
|
||||
|
||||
# Type that maps config strings to their default value
|
||||
ConfigType = Dict[str, Any]
|
||||
ConfigType = dict[str, Any]
|
||||
# Callable that returns a bool
|
||||
FactoryOutputType = Callable[[], bool]
|
||||
# input function factory
|
||||
@ -504,10 +502,10 @@ class ConfigFuzzer:
|
||||
return
|
||||
self.seed = seed
|
||||
self.test_timeout = test_timeout
|
||||
self.detailed_results: Dict[ComboType, Dict[str, Any]] = {}
|
||||
self.detailed_results: dict[ComboType, dict[str, Any]] = {}
|
||||
self.config_module = config_module
|
||||
self.test_model_fn_factory = test_model_fn_factory
|
||||
self.fields: Dict[str, _ConfigEntry] = self.config_module._config
|
||||
self.fields: dict[str, _ConfigEntry] = self.config_module._config
|
||||
self.sample = SamplingMethod.dispatch(sm)
|
||||
|
||||
if default is None:
|
||||
@ -587,7 +585,7 @@ class ConfigFuzzer:
|
||||
}
|
||||
return ret
|
||||
|
||||
def reproduce(self, configs: List[ConfigType]) -> ResultType:
|
||||
def reproduce(self, configs: list[ConfigType]) -> ResultType:
|
||||
"""entrypoint to reproduce any failure"""
|
||||
results = ResultType()
|
||||
for conf in configs:
|
||||
@ -675,7 +673,7 @@ class ConfigFuzzer:
|
||||
for field, value in config.items():
|
||||
print(f"{field} = {value}")
|
||||
|
||||
def get_error_info(exc: Exception) -> Dict[str, Any]:
|
||||
def get_error_info(exc: Exception) -> dict[str, Any]:
|
||||
return {
|
||||
"exception": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
@ -741,7 +739,7 @@ class ConfigFuzzer:
|
||||
else:
|
||||
return handle_return("Function succeeded", Status.PASSED, False, None)
|
||||
|
||||
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> List[ConfigType]:
|
||||
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> list[ConfigType]:
|
||||
"""
|
||||
Test configs and bisect to minimal failing configuration.
|
||||
"""
|
||||
@ -749,7 +747,7 @@ class ConfigFuzzer:
|
||||
random.seed(self.seed)
|
||||
self._reset_configs()
|
||||
results = ResultType()
|
||||
ret: List[ConfigType] = []
|
||||
ret: list[ConfigType] = []
|
||||
|
||||
for attempt in range(num_attempts):
|
||||
print(f"Random attempt {attempt + 1}/{num_attempts}")
|
||||
@ -783,7 +781,7 @@ class ConfigFuzzer:
|
||||
return self._bisect_failing_config_helper(results, list(failing_config.items()))
|
||||
|
||||
def _bisect_failing_config_helper(
|
||||
self, results: ResultType, failing_config: List[Tuple[str, Any]]
|
||||
self, results: ResultType, failing_config: list[tuple[str, Any]]
|
||||
) -> Optional[ConfigType]:
|
||||
"""
|
||||
Bisect a failing configuration to find minimal set of configs that cause failure.
|
||||
@ -795,7 +793,7 @@ class ConfigFuzzer:
|
||||
if not failing_config:
|
||||
return None
|
||||
|
||||
def test(x: List[Tuple[str, Any]]) -> Status:
|
||||
def test(x: list[tuple[str, Any]]) -> Status:
|
||||
d = dict(x)
|
||||
result = self.test_config(results, d)
|
||||
return result
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import operator
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, DefaultDict, Dict, Optional, Type
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import sympy
|
||||
|
||||
@ -24,9 +24,9 @@ from .virtualized import V
|
||||
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
||||
# Works for length 2 patterns with 1 module and 1 function/method.
|
||||
def matches_module_function_pattern(
|
||||
pattern: tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
||||
pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]],
|
||||
node: torch.fx.node.Node,
|
||||
modules: Dict[str, torch.nn.modules.Module],
|
||||
modules: dict[str, torch.nn.modules.Module],
|
||||
) -> bool:
|
||||
if len(node.args) == 0:
|
||||
return False
|
||||
@ -86,7 +86,7 @@ class FakeTensorUpdater:
|
||||
return (node, node.target, id(node.args), id(node.kwargs))
|
||||
|
||||
def incremental_update(self):
|
||||
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
||||
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
|
||||
for node in self.graph.nodes:
|
||||
existing_storages[get_node_storage(node)] += 1
|
||||
|
||||
@ -208,7 +208,7 @@ def get_fake(x):
|
||||
return x
|
||||
|
||||
|
||||
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], Dict[str, Any]]:
|
||||
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]:
|
||||
"""
|
||||
First value returns a boolean if any of the input nodes don't have a faketensor.
|
||||
"""
|
||||
|
@ -8,22 +8,10 @@ import re
|
||||
import sys
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from contextlib import contextmanager
|
||||
from types import ModuleType
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
DefaultDict,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
NoReturn,
|
||||
Optional,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
from sympy import Expr
|
||||
@ -198,8 +186,8 @@ def getattr_recursive(
|
||||
return attr_itr
|
||||
|
||||
|
||||
def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
|
||||
ret: Dict[Node, tuple[int, ...]] = {}
|
||||
def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
|
||||
ret: dict[Node, tuple[int, ...]] = {}
|
||||
output_node = g.find_nodes(op="output")[0]
|
||||
|
||||
if "user_visible_output_idxs" not in output_node.meta:
|
||||
@ -212,7 +200,7 @@ def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
|
||||
|
||||
|
||||
def mark_nodes_dislike_padding(
|
||||
g: Graph, user_visible_output_strides: Dict[Node, tuple[int, ...]]
|
||||
g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]]
|
||||
) -> None:
|
||||
"""
|
||||
Nodes like convolution/convolution_backward want its input to be dense.
|
||||
@ -282,7 +270,7 @@ def mark_nodes_dislike_padding(
|
||||
|
||||
|
||||
class GraphLowering(torch.fx.Interpreter):
|
||||
graph_outputs: List[ir.IRNode]
|
||||
graph_outputs: list[ir.IRNode]
|
||||
|
||||
def symbolic_sizes_strides(
|
||||
self, ex: torch.Tensor
|
||||
@ -323,7 +311,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
|
||||
def static_sizes_strides(
|
||||
self, ex: torch.Tensor
|
||||
) -> tuple[List[sympy.Expr], List[sympy.Expr]]:
|
||||
) -> tuple[list[sympy.Expr], list[sympy.Expr]]:
|
||||
"""
|
||||
Primarily used to weights
|
||||
"""
|
||||
@ -341,12 +329,12 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
aot_mode: bool = False,
|
||||
layout_opt: Optional[bool] = None,
|
||||
extern_node_serializer: Optional[
|
||||
Callable[[List[ir.ExternKernelNode]], Any]
|
||||
Callable[[list[ir.ExternKernelNode]], Any]
|
||||
] = None,
|
||||
is_inference: bool = False,
|
||||
is_backward: bool = False,
|
||||
is_const_graph: bool = False,
|
||||
const_output_index: Optional[Dict[str, int]] = None,
|
||||
const_output_index: Optional[dict[str, int]] = None,
|
||||
const_code: Optional[str] = None,
|
||||
const_module: Optional["GraphLowering"] = None,
|
||||
name: Optional[str] = None,
|
||||
@ -379,14 +367,14 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# you don't start adding new ones in the lowering process
|
||||
shape_env.freeze_runtime_asserts()
|
||||
# We're going to mutate ras_by_symbol as we finish generating them
|
||||
self.ras_by_symbol: Dict[
|
||||
Optional[sympy.Symbol], List[RuntimeAssert]
|
||||
self.ras_by_symbol: dict[
|
||||
Optional[sympy.Symbol], list[RuntimeAssert]
|
||||
] = shape_env.deferred_runtime_asserts.copy()
|
||||
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
||||
self.sizevars = SizeVarAllocator(shape_env)
|
||||
self.graph_input_names: List[str] = []
|
||||
self.graph_inputs: Dict[str, TensorBox] = {}
|
||||
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
||||
self.graph_input_names: list[str] = []
|
||||
self.graph_inputs: dict[str, TensorBox] = {}
|
||||
self.graph_inputs_original: dict[str, InputBuffer] = {}
|
||||
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
|
||||
self.device_types: OrderedSet[str] = (
|
||||
const_module.device_types if const_module else OrderedSet()
|
||||
@ -395,9 +383,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
const_module.device_idxs if const_module else OrderedSet()
|
||||
)
|
||||
self.device_type = "cpu"
|
||||
self.buffers: List[ir.Buffer] = []
|
||||
self.operations: List[ir.Operation] = []
|
||||
self.const_output_index: Dict[str, int] = (
|
||||
self.buffers: list[ir.Buffer] = []
|
||||
self.operations: list[ir.Operation] = []
|
||||
self.const_output_index: dict[str, int] = (
|
||||
const_output_index if const_output_index else {}
|
||||
)
|
||||
self.folded_constants: OrderedSet[str] = (
|
||||
@ -405,12 +393,12 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if const_output_index
|
||||
else OrderedSet()
|
||||
)
|
||||
self.constants: Dict[str, torch.Tensor] = (
|
||||
self.constants: dict[str, torch.Tensor] = (
|
||||
const_module.constants if const_module else {}
|
||||
)
|
||||
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
|
||||
self.seen_subgraphs: Dict[str, ir.Subgraph] = {}
|
||||
self.constant_reprs: Dict[str, str] = {}
|
||||
self.torchbind_constants: dict[str, torch._C.ScriptObject] = {}
|
||||
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
|
||||
self.constant_reprs: dict[str, str] = {}
|
||||
self.removed_operations = OrderedSet[str]()
|
||||
self.removed_buffers = OrderedSet[str]()
|
||||
self.removed_inplace_buffers = OrderedSet[str]()
|
||||
@ -420,23 +408,23 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
|
||||
self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
|
||||
|
||||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||
|
||||
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
|
||||
self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
|
||||
extern_node_serializer
|
||||
if config.is_fbcode() and extern_node_serializer
|
||||
else extern_node_json_serializer
|
||||
)
|
||||
|
||||
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
||||
self.lists: Dict[str, List[str]] = {}
|
||||
self.lists: dict[str, list[str]] = {}
|
||||
self.mutated_inputs = OrderedSet[str]()
|
||||
self.mutated_input_idxs: List[int] = []
|
||||
self.name_to_buffer: Dict[str, ir.Buffer] = {}
|
||||
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
|
||||
self.name_to_op: Dict[str, ir.Operation] = {}
|
||||
self.mutated_input_idxs: list[int] = []
|
||||
self.name_to_buffer: dict[str, ir.Buffer] = {}
|
||||
self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
|
||||
self.name_to_op: dict[str, ir.Operation] = {}
|
||||
self.creation_time = time.time()
|
||||
self.name = name # type: ignore[assignment]
|
||||
self.cpp_wrapper = cpp_wrapper
|
||||
@ -445,7 +433,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
||||
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
|
||||
self.record_multi_kernel_choice = cpp_wrapper
|
||||
self.multi_kernel_to_choice: Dict[str, str] = {}
|
||||
self.multi_kernel_to_choice: dict[str, str] = {}
|
||||
|
||||
self.aot_mode = aot_mode
|
||||
self.graph_id = graph_id
|
||||
@ -464,7 +452,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
|
||||
self.cache_key: str = "" # This is the cache key for the compiled artifact
|
||||
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
||||
self.cache_linemap: List[
|
||||
self.cache_linemap: list[
|
||||
tuple[int, str]
|
||||
] = (
|
||||
[]
|
||||
@ -473,18 +461,18 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.disable_cudagraphs_reason: Optional[str] = None
|
||||
|
||||
# only keeping one node per device for stack trace purposes
|
||||
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
|
||||
self.device_node_mapping: dict[torch.device, torch.fx.Node] = {}
|
||||
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
|
||||
self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
|
||||
"dynamo_flat_name_to_original_fqn", {}
|
||||
)
|
||||
self.allocated_constant_name: Dict[str, str] = (
|
||||
self.allocated_constant_name: dict[str, str] = (
|
||||
const_module.allocated_constant_name if const_module is not None else {}
|
||||
)
|
||||
init_backend_registration()
|
||||
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
|
||||
|
||||
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
|
||||
self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
|
||||
self.aligned_inputs = OrderedSet[str]()
|
||||
self.no_fuse_buffer_names = OrderedSet[str]()
|
||||
|
||||
@ -599,7 +587,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
if is_inference:
|
||||
from torch.utils.flop_counter import FlopCounterMode
|
||||
|
||||
flop_counts: Dict[str, float] = defaultdict(float)
|
||||
flop_counts: dict[str, float] = defaultdict(float)
|
||||
for node in conv_nodes:
|
||||
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
|
||||
node
|
||||
@ -702,7 +690,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
def make_subgraph(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: List[torch.Tensor],
|
||||
example_inputs: list[torch.Tensor],
|
||||
subgraph_name: str,
|
||||
) -> "SubgraphLowering":
|
||||
"""
|
||||
@ -886,7 +874,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
buffer.name = name
|
||||
return name
|
||||
|
||||
def register_operation_list(self, operation_names: List[str]) -> str:
|
||||
def register_operation_list(self, operation_names: list[str]) -> str:
|
||||
name = self.qualify_name("list_" + "_".join(operation_names))
|
||||
self.lists[name] = operation_names
|
||||
return name
|
||||
@ -995,7 +983,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
)
|
||||
|
||||
def placeholder(
|
||||
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
||||
) -> Union[Expr, TensorBox, None]:
|
||||
self.placeholder_idx += 1
|
||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
@ -1072,7 +1060,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.aligned_inputs.add(target)
|
||||
return tensor
|
||||
|
||||
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
||||
def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
||||
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
||||
return super().call_function(target, args, kwargs)
|
||||
|
||||
@ -1155,7 +1143,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
return len(t.shape) == 1 and t.shape[0] <= 8
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: tuple[()], kwargs: Dict[str, object] # type: ignore[override]
|
||||
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
|
||||
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
||||
# this is a constant
|
||||
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
||||
@ -1203,7 +1191,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
raise AssertionError
|
||||
|
||||
def output(
|
||||
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
||||
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
||||
) -> None:
|
||||
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||
if not isinstance(result, (tuple, list)):
|
||||
@ -1306,9 +1294,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self,
|
||||
fx_node: torch.fx.Node,
|
||||
old_args: tuple[Any],
|
||||
old_kwargs: Dict[str, Any],
|
||||
old_kwargs: dict[str, Any],
|
||||
new_args: tuple[Any],
|
||||
new_kwargs: Dict[str, Any],
|
||||
new_kwargs: dict[str, Any],
|
||||
) -> None:
|
||||
"""Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
|
||||
|
||||
@ -1803,7 +1791,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
self.const_module.wrapper_code.src_to_kernel
|
||||
)
|
||||
|
||||
def codegen_with_cpp_wrapper(self) -> tuple[str, List[tuple[int, Node]]]:
|
||||
def codegen_with_cpp_wrapper(self) -> tuple[str, list[tuple[int, Node]]]:
|
||||
"""
|
||||
For GPU, Triton kernels are autotuned and stored as cubin files
|
||||
"""
|
||||
@ -1902,7 +1890,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
# cpu
|
||||
return self.codegen()
|
||||
|
||||
def codegen(self) -> tuple[str, List[tuple[int, Node]]]:
|
||||
def codegen(self) -> tuple[str, list[tuple[int, Node]]]:
|
||||
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
|
||||
from .scheduler import Scheduler
|
||||
|
||||
@ -1949,7 +1937,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
def count_bytes(
|
||||
self,
|
||||
) -> tuple[
|
||||
int, List[tuple[BaseSchedulerNode, int]], List[tuple[BaseSchedulerNode, float]]
|
||||
int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]]
|
||||
]:
|
||||
total_bytes = 0
|
||||
node_counts = []
|
||||
@ -2041,7 +2029,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
||||
return mod
|
||||
|
||||
def get_output_names(self) -> List[str]:
|
||||
def get_output_names(self) -> list[str]:
|
||||
names = []
|
||||
shape_counter = itertools.count(0)
|
||||
none_counter = itertools.count(0)
|
||||
|
@ -1,13 +1,13 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
from typing import Callable, List, TYPE_CHECKING
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
# Executed in the order they're registered
|
||||
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
|
||||
INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = []
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
|
@ -22,7 +22,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
||||
"""
|
||||
import itertools
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, Literal, Optional, overload, Union
|
||||
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import sympy
|
||||
@ -196,8 +196,8 @@ class IndexPropagation:
|
||||
def __init__(
|
||||
self,
|
||||
inner: Any,
|
||||
iter_ranges: Dict[sympy.Symbol, sympy.Expr],
|
||||
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
|
||||
iter_ranges: dict[sympy.Symbol, sympy.Expr],
|
||||
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr],
|
||||
) -> None:
|
||||
self._inner = inner
|
||||
self.shape_env = V.graph.sizevars.shape_env
|
||||
@ -248,18 +248,18 @@ class IndexPropagation:
|
||||
self,
|
||||
name: Literal["indirect_indexing"],
|
||||
args: tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> IndexPropVar:
|
||||
...
|
||||
|
||||
@overload
|
||||
def fallback(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
...
|
||||
|
||||
def fallback(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Fallback to the wrapped handler
|
||||
new_args = [self.unwrap(a) for a in args]
|
||||
@ -267,7 +267,7 @@ class IndexPropagation:
|
||||
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
||||
|
||||
def propagate_sympy(
|
||||
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
||||
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
# Build a new SymPy expression from this ops call
|
||||
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
||||
|
@ -2,12 +2,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional, Sequence
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import _prims, Tensor
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -8,6 +8,7 @@ import logging
|
||||
import textwrap
|
||||
import traceback
|
||||
import typing
|
||||
from collections.abc import Generator, Iterable, Sequence
|
||||
from contextlib import nullcontext
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
@ -16,14 +17,9 @@ from typing import (
|
||||
Callable,
|
||||
ClassVar,
|
||||
ContextManager,
|
||||
Dict,
|
||||
Generator,
|
||||
Iterable,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
overload,
|
||||
Sequence,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
@ -165,11 +161,11 @@ e.g. it may be a graph input or compile time constant.
|
||||
_NodeOrNodes: TypeAlias = Union[
|
||||
int,
|
||||
"TensorBox",
|
||||
Dict[str, "TensorBox"],
|
||||
dict[str, "TensorBox"],
|
||||
"Symbol",
|
||||
"IRNode",
|
||||
Sequence[
|
||||
Optional[Union[int, Dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
|
||||
Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
|
||||
],
|
||||
]
|
||||
|
||||
@ -426,7 +422,7 @@ class IRNode:
|
||||
|
||||
# NB: These are kinda weird,
|
||||
origins: OrderedSet[Any] = dataclasses.field(init=False)
|
||||
traceback: Optional[List[str]] = dataclasses.field(init=False)
|
||||
traceback: Optional[list[str]] = dataclasses.field(init=False)
|
||||
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
|
||||
|
||||
@staticmethod
|
||||
@ -455,7 +451,7 @@ class IRNode:
|
||||
def get_read_names(self) -> OrderedSet[str]:
|
||||
return OrderedSet(dep.name for dep in self.get_reads())
|
||||
|
||||
def get_traceback(self) -> Optional[List[str]]:
|
||||
def get_traceback(self) -> Optional[list[str]]:
|
||||
return self.traceback
|
||||
|
||||
def get_origin_node(self) -> Optional[torch.fx.Node]:
|
||||
@ -604,18 +600,18 @@ class IRNode:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
def freeze_layout_with_stride_order(
|
||||
self, order: List[int], allow_padding: bool = False
|
||||
self, order: list[int], allow_padding: bool = False
|
||||
) -> None:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
def freeze_layout_with_fill_order(self, order: List[int]) -> None:
|
||||
def freeze_layout_with_fill_order(self, order: list[int]) -> None:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None:
|
||||
def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
def freeze_layout_with_exact_strides(
|
||||
self, exact_strides: List[_IntLike], allow_padding: bool = False
|
||||
self, exact_strides: list[_IntLike], allow_padding: bool = False
|
||||
) -> None:
|
||||
raise NotImplementedError(type(self).__name__)
|
||||
|
||||
@ -703,7 +699,7 @@ class Operation:
|
||||
def get_reads(self) -> OrderedSet[Dep]:
|
||||
return self.get_read_writes().reads
|
||||
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
def get_outputs(self) -> list[Buffer]:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
@ -936,7 +932,7 @@ class Scatter(Pointwise):
|
||||
)
|
||||
|
||||
|
||||
REDUCTION_COMBINE_FN: Dict[str, Callable[..., OpsValue]] = {
|
||||
REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
|
||||
"any": ops_wrapper("logical_or"),
|
||||
"max": ops_wrapper("maximum"),
|
||||
"min": ops_wrapper("minimum"),
|
||||
@ -1575,8 +1571,8 @@ class Reduction(Loops):
|
||||
wrapper_fn: Callable[..., Any],
|
||||
original_ranges: Sequence[Expr],
|
||||
original_reduction_ranges: Sequence[Expr],
|
||||
new_ranges: List[Expr],
|
||||
new_reduction_ranges: List[Integer],
|
||||
new_ranges: list[Expr],
|
||||
new_reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
@ -1678,8 +1674,8 @@ class Reduction(Loops):
|
||||
inner_fn: Callable[..., Any],
|
||||
original_ranges: Sequence[Expr],
|
||||
original_reduction_ranges: Sequence[Expr],
|
||||
new_ranges: List[Integer],
|
||||
new_reduction_ranges: List[Integer],
|
||||
new_ranges: list[Integer],
|
||||
new_reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_hint: ReductionHint,
|
||||
) -> TensorBox:
|
||||
@ -1767,8 +1763,8 @@ class WelfordReduction(Reduction):
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
inner_fns: Sequence[Callable[..., Any]],
|
||||
ranges: List[Integer],
|
||||
reduction_ranges: List[Integer],
|
||||
ranges: list[Integer],
|
||||
reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
) -> Sequence[TensorBox]:
|
||||
@ -1893,8 +1889,8 @@ class WelfordReduction(Reduction):
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
inner_fns: Sequence[Callable[..., Any]],
|
||||
ranges: List[Integer],
|
||||
reduction_ranges: List[Integer],
|
||||
ranges: list[Integer],
|
||||
reduction_ranges: list[Integer],
|
||||
reduction_type: str,
|
||||
split: _IntLike,
|
||||
reduction_hint: ReductionHint,
|
||||
@ -1983,8 +1979,8 @@ class WelfordReduction(Reduction):
|
||||
|
||||
@ir_dataclass
|
||||
class Scan(Loops):
|
||||
scan_ranges: List[Integer]
|
||||
size: List[Integer]
|
||||
scan_ranges: list[Integer]
|
||||
size: list[Integer]
|
||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
|
||||
reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
|
||||
reduction_hint: ReductionHint
|
||||
@ -2055,7 +2051,7 @@ class Scan(Loops):
|
||||
device: torch.device,
|
||||
dtypes: tuple[torch.dtype, ...],
|
||||
inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
|
||||
size: List[Integer],
|
||||
size: list[Integer],
|
||||
axis: int,
|
||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||
@ -2114,7 +2110,7 @@ class Scan(Loops):
|
||||
else:
|
||||
scan_type = SplitScan
|
||||
|
||||
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> List[Expr]:
|
||||
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
|
||||
assert len(scan_index) == len(scan_ranges)
|
||||
assert len(index) == len(pointwise_ranges)
|
||||
return [*index[:axis], *scan_index, *index[axis:]]
|
||||
@ -2152,8 +2148,8 @@ class Scan(Loops):
|
||||
dtype: torch.dtype,
|
||||
inner_fn: Callable[[Sequence[Expr]], OpsValue],
|
||||
axis: int,
|
||||
pointwise_ranges: List[Integer],
|
||||
scan_ranges: List[Integer],
|
||||
pointwise_ranges: list[Integer],
|
||||
scan_ranges: list[Integer],
|
||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
||||
scan_numel: Expr,
|
||||
) -> tuple[ReductionHint, _IntLike]:
|
||||
@ -2182,8 +2178,8 @@ class SplitScan(Scan):
|
||||
@ir_dataclass
|
||||
class Sort(Loops):
|
||||
# Sorts a tuple of key, value pairs
|
||||
sort_ranges: List[Integer]
|
||||
size: List[Integer]
|
||||
sort_ranges: list[Integer]
|
||||
size: list[Integer]
|
||||
reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
|
||||
reduction_hint: ReductionHint
|
||||
output_index: int
|
||||
@ -2251,8 +2247,8 @@ class Sort(Loops):
|
||||
cls,
|
||||
device: torch.device,
|
||||
dtypes: tuple[torch.dtype, ...],
|
||||
inner_fns: tuple[Callable[[List[Expr]], Any], ...],
|
||||
size: List[Integer],
|
||||
inner_fns: tuple[Callable[[list[Expr]], Any], ...],
|
||||
size: list[Integer],
|
||||
axis: int,
|
||||
stable: bool,
|
||||
descending: bool,
|
||||
@ -2293,7 +2289,7 @@ class Sort(Loops):
|
||||
for output_index in range(len(dtypes))
|
||||
]
|
||||
|
||||
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> List[Expr]:
|
||||
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
|
||||
assert len(sort_index) == len(sort_ranges)
|
||||
assert len(index) == len(pointwise_ranges)
|
||||
return [*index[:axis], *sort_index, *index[axis:]]
|
||||
@ -2509,7 +2505,7 @@ class BaseView(IRNode):
|
||||
|
||||
@ir_dataclass
|
||||
class ExpandView(BaseView):
|
||||
size: List[Expr]
|
||||
size: list[Expr]
|
||||
|
||||
@staticmethod
|
||||
def _normalize_size(x, new_size): # type: ignore[no-untyped-def]
|
||||
@ -2588,7 +2584,7 @@ class ExpandView(BaseView):
|
||||
|
||||
@ir_dataclass
|
||||
class PermuteView(BaseView):
|
||||
dims: List[Expr]
|
||||
dims: list[Expr]
|
||||
|
||||
@classmethod
|
||||
def create(cls, x, dims): # type: ignore[no-untyped-def]
|
||||
@ -2676,7 +2672,7 @@ class SqueezeView(BaseView):
|
||||
not_one = [i for i, s in enumerate(size) if s != 1]
|
||||
length = len(size)
|
||||
|
||||
def reindex(index: List[sympy.Expr]) -> tuple[sympy.Expr, ...]:
|
||||
def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]:
|
||||
assert len(index) == len(not_one), f"{index} {not_one}"
|
||||
new_index = [sympy.S.Zero] * length
|
||||
for idx, s in zip(not_one, index):
|
||||
@ -2691,7 +2687,7 @@ class SqueezeView(BaseView):
|
||||
|
||||
@ir_dataclass
|
||||
class GenericView(BaseView):
|
||||
size: List[Expr]
|
||||
size: list[Expr]
|
||||
reindex: Callable[..., Any]
|
||||
|
||||
def make_reindexer(self): # type: ignore[no-untyped-def]
|
||||
@ -3159,8 +3155,8 @@ class Layout(OutputSpec):
|
||||
self,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
size: List[Expr],
|
||||
stride: Optional[List[Expr]] = None,
|
||||
size: list[Expr],
|
||||
stride: Optional[list[Expr]] = None,
|
||||
offset: Expr = Integer(0),
|
||||
) -> None:
|
||||
if stride is None:
|
||||
@ -3169,8 +3165,8 @@ class Layout(OutputSpec):
|
||||
self.dtype = dtype
|
||||
assert len(size) == len(stride), f"size={size}, stride={stride}"
|
||||
assert all(isinstance(s, (Expr, int)) for s in size)
|
||||
self.size: List[Expr] = size
|
||||
self.stride: List[Expr] = stride
|
||||
self.size: list[Expr] = size
|
||||
self.stride: list[Expr] = stride
|
||||
self.offset: Expr = offset
|
||||
|
||||
def __str__(self) -> str:
|
||||
@ -3594,8 +3590,8 @@ class NoneLayout(OutputSpec):
|
||||
# dependencies manually in scheduler
|
||||
|
||||
device: Optional[torch.device]
|
||||
size: List[int] = dataclasses.field(default_factory=lambda: [0])
|
||||
stride: List[int] = dataclasses.field(default_factory=lambda: [0])
|
||||
size: list[int] = dataclasses.field(default_factory=lambda: [0])
|
||||
stride: list[int] = dataclasses.field(default_factory=lambda: [0])
|
||||
|
||||
def storage_size(self) -> int:
|
||||
return 0
|
||||
@ -3620,7 +3616,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
||||
V.graph.mark_buffer_mutated(name)
|
||||
|
||||
@property
|
||||
def stride(self) -> List[Expr]:
|
||||
def stride(self) -> list[Expr]:
|
||||
return self.real_layout().stride
|
||||
|
||||
@stride.setter
|
||||
@ -3725,7 +3721,7 @@ class Buffer(IRNode):
|
||||
def get_size(self) -> Sequence[Expr]:
|
||||
return [*self.get_layout().size]
|
||||
|
||||
def get_stride(self) -> List[Expr]:
|
||||
def get_stride(self) -> list[Expr]:
|
||||
return [*self.get_layout().stride]
|
||||
|
||||
def get_offset(self) -> Expr:
|
||||
@ -3816,7 +3812,7 @@ class Buffer(IRNode):
|
||||
@ir_dataclass(frozen=False)
|
||||
class OperationBuffer(Buffer, Operation):
|
||||
# An operation that produces a single output buffer
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
def get_outputs(self) -> list[Buffer]:
|
||||
return [self]
|
||||
|
||||
def get_defining_op(self) -> Operation:
|
||||
@ -3977,7 +3973,7 @@ class ComputedBuffer(OperationBuffer):
|
||||
assert isinstance(self.data, Pointwise)
|
||||
return partial(self.data.store_output, self.name, indexer)
|
||||
|
||||
def get_fill_order(self) -> Optional[List[int]]:
|
||||
def get_fill_order(self) -> Optional[list[int]]:
|
||||
"""
|
||||
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
|
||||
|
||||
@ -4028,9 +4024,9 @@ class ComputedBuffer(OperationBuffer):
|
||||
def get_default_sizes_body(
|
||||
self,
|
||||
) -> tuple[
|
||||
tuple[List[sympy.Expr], List[sympy.Expr]],
|
||||
tuple[list[sympy.Expr], list[sympy.Expr]],
|
||||
LoopBody,
|
||||
tuple[List[sympy.Expr], List[sympy.Expr]],
|
||||
tuple[list[sympy.Expr], list[sympy.Expr]],
|
||||
]:
|
||||
args, var_ranges = dependencies.index_vars_squeeze(
|
||||
self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
|
||||
@ -4043,7 +4039,7 @@ class ComputedBuffer(OperationBuffer):
|
||||
*args,
|
||||
)
|
||||
index_vars = []
|
||||
reduce_vars: List[Any] = []
|
||||
reduce_vars: list[Any] = []
|
||||
index_size = []
|
||||
reduce_size = []
|
||||
for v, s in var_ranges.items():
|
||||
@ -4059,9 +4055,9 @@ class ComputedBuffer(OperationBuffer):
|
||||
|
||||
def simplify_and_reorder(
|
||||
self,
|
||||
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
|
||||
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
|
||||
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
||||
) -> tuple[tuple[List[sympy.Expr], List[sympy.Expr]], LoopBody]:
|
||||
) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]:
|
||||
"""
|
||||
This is a main place where we do loop transformations in a
|
||||
backend-agnostic way.
|
||||
@ -4282,7 +4278,7 @@ class TemplateBuffer(OperationBuffer):
|
||||
|
||||
def simplify_and_reorder( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
|
||||
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
|
||||
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
||||
):
|
||||
return (
|
||||
@ -4314,7 +4310,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
"""
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
self.mutated_inputs = mutated_inputs
|
||||
self.outputs: List[Buffer] = [self]
|
||||
self.outputs: list[Buffer] = [self]
|
||||
if mutated_inputs is not None:
|
||||
# Ensure that the mutated inputs are only allowed for certain nodes
|
||||
allowed_set = (
|
||||
@ -4335,7 +4331,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
|
||||
)
|
||||
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
def get_outputs(self) -> list[Buffer]:
|
||||
return self.outputs
|
||||
|
||||
def get_allowed_prologue_inps(self) -> OrderedSet[str]:
|
||||
@ -4346,7 +4342,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||
return out
|
||||
|
||||
|
||||
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
|
||||
PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
|
||||
|
||||
|
||||
class ChoiceCaller:
|
||||
@ -4361,7 +4357,7 @@ class ChoiceCaller:
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
input_nodes: List[Buffer],
|
||||
input_nodes: list[Buffer],
|
||||
layout: Layout,
|
||||
description: str,
|
||||
) -> None:
|
||||
@ -4389,7 +4385,7 @@ class ChoiceCaller:
|
||||
def output_node(self) -> TensorBox:
|
||||
raise NotImplementedError
|
||||
|
||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
||||
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||
return {}
|
||||
|
||||
@ -4414,9 +4410,9 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
def __init__(
|
||||
self,
|
||||
layout: Layout,
|
||||
inputs: List[IRNode],
|
||||
choice_timings: Callable[[], Dict[ChoiceCaller, float]],
|
||||
unfiltered_choices: List[ChoiceCaller],
|
||||
inputs: list[IRNode],
|
||||
choice_timings: Callable[[], dict[ChoiceCaller, float]],
|
||||
unfiltered_choices: list[ChoiceCaller],
|
||||
allowed_prologue_inps: OrderedSet[str],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
@ -4426,7 +4422,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
allowed_prologue_inps=allowed_prologue_inps,
|
||||
)
|
||||
self._choice_timings_fn = choice_timings
|
||||
self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None
|
||||
self._choice_timings: Optional[dict[ChoiceCaller, float]] = None
|
||||
self.original_inputs = inputs
|
||||
self._output_plannable = all(
|
||||
isinstance(choice, TritonTemplateCallerBase)
|
||||
@ -4445,7 +4441,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
||||
return self._output_plannable
|
||||
|
||||
@property
|
||||
def choice_timings(self) -> Dict[ChoiceCaller, float]:
|
||||
def choice_timings(self) -> dict[ChoiceCaller, float]:
|
||||
if self._choice_timings is None:
|
||||
self._choice_timings = self._choice_timings_fn()
|
||||
return self._choice_timings
|
||||
@ -4496,7 +4492,7 @@ class CppTemplateBuffer(TemplateBuffer):
|
||||
super().__init__(layout, inputs, make_kernel_render)
|
||||
self.template = template
|
||||
self.choice = choice
|
||||
self.outputs: Optional[List[Buffer]] = None
|
||||
self.outputs: Optional[list[Buffer]] = None
|
||||
|
||||
def get_layout(self) -> Layout:
|
||||
if isinstance(self.layout, MultiOutputLayout):
|
||||
@ -4512,7 +4508,7 @@ class CppTemplateBuffer(TemplateBuffer):
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class InputsKernel(OperationBuffer):
|
||||
inputs: List[Buffer]
|
||||
inputs: list[Buffer]
|
||||
|
||||
def get_read_writes(self) -> dependencies.ReadWrites:
|
||||
reads = OrderedSet[dependencies.Dep]()
|
||||
@ -4752,7 +4748,7 @@ class ConcatKernel(NopKernel):
|
||||
@ir_dataclass(frozen=False)
|
||||
class ExternKernel(InputsKernel):
|
||||
constant_args: tuple[Any, ...] = ()
|
||||
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||
output_view: Optional[ReinterpretView] = None
|
||||
python_kernel_name: Optional[str] = None
|
||||
cpp_kernel_name: Optional[str] = None
|
||||
@ -4764,12 +4760,12 @@ class ExternKernel(InputsKernel):
|
||||
op_overload: Optional[
|
||||
Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
|
||||
] = None
|
||||
arg_properties: Optional[List[Dict[str, Any]]] = None
|
||||
kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None
|
||||
unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
|
||||
arg_properties: Optional[list[dict[str, Any]]] = None
|
||||
kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
|
||||
unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
|
||||
default_factory=dict
|
||||
)
|
||||
mutation_outputs: List[MutationOutput] = dataclasses.field(default_factory=list)
|
||||
mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
|
||||
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
@ -4801,7 +4797,7 @@ class ExternKernel(InputsKernel):
|
||||
self.mutation_outputs = []
|
||||
self.fx_node = V.graph.current_node
|
||||
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
def get_outputs(self) -> list[Buffer]:
|
||||
return [self, *self.mutation_outputs]
|
||||
|
||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
@ -4919,10 +4915,10 @@ class ExternKernel(InputsKernel):
|
||||
cls, kernel, *args, **kwargs
|
||||
) -> tuple[
|
||||
Any,
|
||||
List[Any],
|
||||
List[Any],
|
||||
list[Any],
|
||||
list[Any],
|
||||
Callable[[Any, Any], Any],
|
||||
Optional[Dict[sympy.Symbol, pytree.KeyPath]],
|
||||
Optional[dict[sympy.Symbol, pytree.KeyPath]],
|
||||
]:
|
||||
binded_args = {"args": args, "kwargs": kwargs}
|
||||
|
||||
@ -4930,7 +4926,7 @@ class ExternKernel(InputsKernel):
|
||||
|
||||
is_arg_tensor = []
|
||||
tensor_args = []
|
||||
non_tensor_args: List[Any] = []
|
||||
non_tensor_args: list[Any] = []
|
||||
for arg in args_flat:
|
||||
is_arg_tensor.append(isinstance(arg, IRNode))
|
||||
if is_arg_tensor[-1]:
|
||||
@ -4963,7 +4959,7 @@ class ExternKernel(InputsKernel):
|
||||
# Rerun fake tensor propagation, because Inductor may have changed the
|
||||
# strides of inputs and we need to determine accurately what the
|
||||
# output stride will be.
|
||||
example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = []
|
||||
example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
|
||||
|
||||
# We need to retain the constant values of fake tensors that we originally
|
||||
# propagated the graph with, because for some operators running without a
|
||||
@ -4984,7 +4980,7 @@ class ExternKernel(InputsKernel):
|
||||
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
|
||||
example_output = kernel(*new_args, **new_kwargs)
|
||||
|
||||
unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None
|
||||
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
|
||||
if shape_env := V.fake_mode.shape_env:
|
||||
rebind_unbacked(shape_env, V.current_node, example_output)
|
||||
unbacked_bindings = compute_unbacked_bindings(
|
||||
@ -5309,7 +5305,7 @@ class ExternKernel(InputsKernel):
|
||||
)
|
||||
return args
|
||||
|
||||
def codegen_const_args(self, names: Optional[List[str]] = None): # type: ignore[no-untyped-def]
|
||||
def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def]
|
||||
if V.graph.cpp_wrapper:
|
||||
result = []
|
||||
# Aten ops follow the convention that tensor args are before non-tensor args,
|
||||
@ -5635,14 +5631,14 @@ class TMADescriptor(ExternKernel):
|
||||
|
||||
# as TMA descriptors are immutable,
|
||||
# we can dedup them by the input args
|
||||
_CACHE: Dict[Any, TMADescriptor] = {}
|
||||
_CACHE: dict[Any, TMADescriptor] = {}
|
||||
|
||||
@classmethod
|
||||
def create( # type: ignore[no-untyped-def]
|
||||
cls,
|
||||
tensor: IRNode,
|
||||
dims: List[Union[int, torch.SymInt]],
|
||||
block_dims: List[Union[int, torch.SymInt]],
|
||||
dims: list[Union[int, torch.SymInt]],
|
||||
block_dims: list[Union[int, torch.SymInt]],
|
||||
element_size: Optional[int] = None,
|
||||
):
|
||||
key = (id(tensor), dims, block_dims, element_size)
|
||||
@ -5653,8 +5649,8 @@ class TMADescriptor(ExternKernel):
|
||||
def __init__(
|
||||
self,
|
||||
tensor: IRNode,
|
||||
dims: List[Union[int, torch.SymInt]],
|
||||
block_dims: List[Union[int, torch.SymInt]],
|
||||
dims: list[Union[int, torch.SymInt]],
|
||||
block_dims: list[Union[int, torch.SymInt]],
|
||||
element_size: Optional[int] = None,
|
||||
) -> None:
|
||||
assert len(dims) in (1, 2)
|
||||
@ -5707,8 +5703,8 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
|
||||
kernel = kernel_side_table.get_kernel(self.kernel_idx)
|
||||
configs = []
|
||||
restore_value_args: List[str] = []
|
||||
reset_to_zero_args: List[str] = []
|
||||
restore_value_args: list[str] = []
|
||||
reset_to_zero_args: list[str] = []
|
||||
if isinstance(kernel, Autotuner):
|
||||
# https://github.com/triton-lang/triton/pull/5083
|
||||
# changes kernel.restore_idx to kernel.restore_value
|
||||
@ -5871,7 +5867,7 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||
]
|
||||
V.graph.register_operation(self)
|
||||
|
||||
def get_outputs(self) -> List[Buffer]:
|
||||
def get_outputs(self) -> list[Buffer]:
|
||||
return list(self.mutation_outputs)
|
||||
|
||||
def get_device(self) -> Optional[torch.device]:
|
||||
@ -6333,9 +6329,9 @@ class FallbackKernel(ExternKernelAlloc):
|
||||
V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type]
|
||||
|
||||
# args that are aliased
|
||||
self.alias_names: List[str] = []
|
||||
self.alias_names: list[str] = []
|
||||
# args that are mutated AND returned from the op
|
||||
self.mutation_names: List[str] = []
|
||||
self.mutation_names: list[str] = []
|
||||
|
||||
if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
|
||||
# We assume here that HOPs with FallbackKernel are functional.
|
||||
@ -6834,7 +6830,7 @@ class MultiOutput(ExternKernel):
|
||||
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
|
||||
)
|
||||
|
||||
def __init__(self, layout: OutputSpec, input, indices: List[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
||||
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
||||
super().__init__(None, layout, [input], ())
|
||||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
@ -6903,18 +6899,18 @@ class MutableBox(IRNode):
|
||||
return self.data.freeze_layout()
|
||||
|
||||
def freeze_layout_with_stride_order(
|
||||
self, order: List[int], allow_padding: bool = False
|
||||
self, order: list[int], allow_padding: bool = False
|
||||
) -> None:
|
||||
return self.data.freeze_layout_with_stride_order(order, allow_padding)
|
||||
|
||||
def freeze_layout_with_fill_order(self, order: List[int]) -> None:
|
||||
def freeze_layout_with_fill_order(self, order: list[int]) -> None:
|
||||
return self.data.freeze_layout_with_fill_order(order)
|
||||
|
||||
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None:
|
||||
def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
|
||||
return self.data.freeze_layout_with_same_order(stride)
|
||||
|
||||
def freeze_layout_with_exact_strides(
|
||||
self, exact_strides: List[_IntLike], allow_padding: bool = False
|
||||
self, exact_strides: list[_IntLike], allow_padding: bool = False
|
||||
) -> None:
|
||||
return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
|
||||
|
||||
@ -7119,11 +7115,11 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
|
||||
@ir_dataclass(frozen=False)
|
||||
class InvokeSubgraph(ExternKernel):
|
||||
subgraph: Optional[Subgraph] = None
|
||||
operands: Optional[List[TensorBox]] = None
|
||||
outputs: Optional[List[MultiOutput]] = None
|
||||
operands: Optional[list[TensorBox]] = None
|
||||
outputs: Optional[list[MultiOutput]] = None
|
||||
|
||||
def __init__(
|
||||
self, subgraph: Subgraph, operands: List[TensorBox], layout: MultiOutputLayout
|
||||
self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout
|
||||
) -> None:
|
||||
super().__init__(
|
||||
name=None,
|
||||
@ -7212,15 +7208,15 @@ class InvokeSubgraph(ExternKernel):
|
||||
@ir_dataclass(frozen=False)
|
||||
class Conditional(ExternKernel):
|
||||
predicate: Optional[IRNode] = None
|
||||
operands: Optional[List[TensorBox]] = None
|
||||
operands: Optional[list[TensorBox]] = None
|
||||
true_subgraph: Optional[Subgraph] = None
|
||||
false_subgraph: Optional[Subgraph] = None
|
||||
outputs: Optional[List[MultiOutput]] = None
|
||||
outputs: Optional[list[MultiOutput]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
predicate: IRNode,
|
||||
operands: List[TensorBox],
|
||||
operands: list[TensorBox],
|
||||
true_subgraph: Subgraph,
|
||||
false_subgraph: Subgraph,
|
||||
layout: MultiOutputLayout,
|
||||
@ -7250,7 +7246,7 @@ class Conditional(ExternKernel):
|
||||
predicate: TensorBox,
|
||||
true_fn: Subgraph,
|
||||
false_fn: Subgraph,
|
||||
operands: List[TensorBox],
|
||||
operands: list[TensorBox],
|
||||
):
|
||||
predicate = cls.realize_input(predicate)
|
||||
operands = [cls.realize_input(x) for x in operands]
|
||||
@ -7332,16 +7328,16 @@ class Conditional(ExternKernel):
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
class WhileLoop(ExternKernel):
|
||||
carried_inputs: Optional[List[TensorBox]] = None
|
||||
additional_inputs: Optional[List[TensorBox]] = None
|
||||
carried_inputs: Optional[list[TensorBox]] = None
|
||||
additional_inputs: Optional[list[TensorBox]] = None
|
||||
cond_subgraph: Optional[Subgraph] = None
|
||||
body_subgraph: Optional[Subgraph] = None
|
||||
outputs: Optional[List[MultiOutput]] = None
|
||||
outputs: Optional[list[MultiOutput]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
carried_inputs: List[TensorBox],
|
||||
additional_inputs: List[TensorBox],
|
||||
carried_inputs: list[TensorBox],
|
||||
additional_inputs: list[TensorBox],
|
||||
cond_subgraph: Subgraph,
|
||||
body_subgraph: Subgraph,
|
||||
layout: MultiOutputLayout,
|
||||
@ -7365,8 +7361,8 @@ class WhileLoop(ExternKernel):
|
||||
cls,
|
||||
cond_fn: Subgraph,
|
||||
body_fn: Subgraph,
|
||||
carried_inputs: List[TensorBox],
|
||||
additional_inputs: List[TensorBox],
|
||||
carried_inputs: list[TensorBox],
|
||||
additional_inputs: list[TensorBox],
|
||||
):
|
||||
carried_inputs = [cls.realize_input(x) for x in carried_inputs]
|
||||
additional_inputs = [cls.realize_input(x) for x in additional_inputs]
|
||||
@ -7411,8 +7407,8 @@ class WhileLoop(ExternKernel):
|
||||
for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)):
|
||||
|
||||
def _guard_list_equals(
|
||||
lhs_exprs: List[Union[int, sympy.expr]],
|
||||
rhs_exprs: List[Union[int, sympy.expr]],
|
||||
lhs_exprs: list[Union[int, sympy.expr]],
|
||||
rhs_exprs: list[Union[int, sympy.expr]],
|
||||
) -> None:
|
||||
for lhs, rhs in zip(lhs_exprs, rhs_exprs):
|
||||
V.graph.sizevars.guard_equals(lhs, rhs)
|
||||
@ -7549,7 +7545,7 @@ class _CollectiveKernel(FallbackKernel):
|
||||
# mutation of the input buffers.
|
||||
@classmethod
|
||||
def create_inplace( # type: ignore[no-untyped-def]
|
||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
||||
cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
|
||||
) -> None:
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
@ -7610,7 +7606,7 @@ class _CollectiveKernel(FallbackKernel):
|
||||
# usage in the user program.
|
||||
@classmethod
|
||||
def create_out_of_place( # type: ignore[no-untyped-def]
|
||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
||||
cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
|
||||
):
|
||||
with V.graph.fake_mode:
|
||||
(
|
||||
|
Reference in New Issue
Block a user