mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
PEP585 update - torch/_inductor/[_-i]* (#145137)
See #145101 for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/145137 Approved by: https://github.com/bobrenjc93
This commit is contained in:
committed by
PyTorch MergeBot
parent
cede43e06b
commit
893ca1dfe1
@ -28,8 +28,8 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
def compile(
|
def compile(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
example_inputs: List[InputType],
|
example_inputs: list[InputType],
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Compile a given FX graph with TorchInductor. This allows compiling
|
Compile a given FX graph with TorchInductor. This allows compiling
|
||||||
@ -54,7 +54,7 @@ def aoti_compile_and_package(
|
|||||||
_deprecated_unused_kwargs=None,
|
_deprecated_unused_kwargs=None,
|
||||||
*,
|
*,
|
||||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||||
inductor_configs: Optional[Dict[str, Any]] = None,
|
inductor_configs: Optional[dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Compiles the exported program with AOTInductor, and packages it into a .pt2
|
Compiles the exported program with AOTInductor, and packages it into a .pt2
|
||||||
@ -131,11 +131,11 @@ def _aoti_compile_and_package_inner(
|
|||||||
gm: torch.nn.Module,
|
gm: torch.nn.Module,
|
||||||
# flat_example_inputs: List[Any],
|
# flat_example_inputs: List[Any],
|
||||||
args: tuple[Any],
|
args: tuple[Any],
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
*,
|
*,
|
||||||
load_and_run: bool = False,
|
load_and_run: bool = False,
|
||||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||||
inductor_configs: Optional[Dict[str, Any]] = None,
|
inductor_configs: Optional[dict[str, Any]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
See docstring for aoti_compile_and_package.
|
See docstring for aoti_compile_and_package.
|
||||||
@ -199,10 +199,10 @@ def aoti_load_package(path: Union[str, io.BytesIO]) -> Any: # type: ignore[type
|
|||||||
def aot_compile(
|
def aot_compile(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
args: tuple[Any],
|
args: tuple[Any],
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
*,
|
*,
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: Optional[dict[str, Any]] = None,
|
||||||
) -> Union[str, List[str]]:
|
) -> Union[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
Ahead-of-time compile a given FX graph with TorchInductor into a shared library.
|
||||||
|
|
||||||
@ -232,7 +232,7 @@ def aot_compile(
|
|||||||
|
|
||||||
def list_mode_options(
|
def list_mode_options(
|
||||||
mode: Optional[str] = None, dynamic: Optional[bool] = None
|
mode: Optional[str] = None, dynamic: Optional[bool] = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
r"""Returns a dictionary describing the optimizations that each of the available
|
r"""Returns a dictionary describing the optimizations that each of the available
|
||||||
modes passed to `torch.compile()` performs.
|
modes passed to `torch.compile()` performs.
|
||||||
|
|
||||||
@ -245,7 +245,7 @@ def list_mode_options(
|
|||||||
>>> torch._inductor.list_mode_options()
|
>>> torch._inductor.list_mode_options()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
mode_options: Dict[str, Dict[str, bool]] = {
|
mode_options: dict[str, dict[str, bool]] = {
|
||||||
"default": {},
|
"default": {},
|
||||||
# enable cudagraphs
|
# enable cudagraphs
|
||||||
"reduce-overhead": {
|
"reduce-overhead": {
|
||||||
@ -267,7 +267,7 @@ def list_mode_options(
|
|||||||
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
|
return mode_options[mode] if mode else mode_options # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
def list_options() -> List[str]:
|
def list_options() -> list[str]:
|
||||||
r"""Returns a dictionary describing the optimizations and debug configurations
|
r"""Returns a dictionary describing the optimizations and debug configurations
|
||||||
that are available to `torch.compile()`.
|
that are available to `torch.compile()`.
|
||||||
|
|
||||||
@ -280,7 +280,7 @@ def list_options() -> List[str]:
|
|||||||
|
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
|
|
||||||
current_config: Dict[str, Any] = config.get_config_copy()
|
current_config: dict[str, Any] = config.get_config_copy()
|
||||||
|
|
||||||
return list(current_config.keys())
|
return list(current_config.keys())
|
||||||
|
|
||||||
|
@ -2,7 +2,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -31,7 +31,7 @@ def aoti_eager_op_conf_lock(op_func_name_with_overload: str) -> Any:
|
|||||||
|
|
||||||
def load_aoti_eager_cache(
|
def load_aoti_eager_cache(
|
||||||
ns: str, op_func_name_with_overload: str, device_type: str
|
ns: str, op_func_name_with_overload: str, device_type: str
|
||||||
) -> List[Optional[Dict[str, Any]]]:
|
) -> list[Optional[dict[str, Any]]]:
|
||||||
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
|
device_kernel_cache = aoti_eager_cache_dir(ns, device_type)
|
||||||
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
|
op_conf = device_kernel_cache / f"{op_func_name_with_overload}.json"
|
||||||
if not op_conf.exists():
|
if not op_conf.exists():
|
||||||
@ -81,7 +81,7 @@ def load_aoti_eager_cache(
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def supported_builtin_dtype_torch_dtype() -> Dict[type, torch.dtype]:
|
def supported_builtin_dtype_torch_dtype() -> dict[type, torch.dtype]:
|
||||||
return {int: torch.int32, float: torch.float, bool: torch.bool}
|
return {int: torch.int32, float: torch.float, bool: torch.bool}
|
||||||
|
|
||||||
|
|
||||||
@ -90,8 +90,8 @@ def supported_scalar_types() -> tuple[type, ...]:
|
|||||||
return tuple(type_to_torch_dtype.keys())
|
return tuple(type_to_torch_dtype.keys())
|
||||||
|
|
||||||
|
|
||||||
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any]:
|
def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> dict[str, Any]:
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["is_dynamic"] = dynamic
|
metadata["is_dynamic"] = dynamic
|
||||||
|
|
||||||
assert isinstance(input, torch.Tensor)
|
assert isinstance(input, torch.Tensor)
|
||||||
@ -110,21 +110,21 @@ def extract_tensor_metadata(dynamic: bool, input: torch.Tensor) -> Dict[str, Any
|
|||||||
|
|
||||||
def extract_tensor_list_metadata(
|
def extract_tensor_list_metadata(
|
||||||
dynamic: bool,
|
dynamic: bool,
|
||||||
input: List[torch.Tensor],
|
input: list[torch.Tensor],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
metadata_list = []
|
metadata_list = []
|
||||||
for item in input:
|
for item in input:
|
||||||
assert isinstance(item, torch.Tensor)
|
assert isinstance(item, torch.Tensor)
|
||||||
metadata_list.append(extract_tensor_metadata(dynamic, item))
|
metadata_list.append(extract_tensor_metadata(dynamic, item))
|
||||||
|
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["tensor_list"] = metadata_list
|
metadata["tensor_list"] = metadata_list
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
|
def extract_scalar_metadata(device_type: str, input: Any) -> dict[str, Any]:
|
||||||
assert isinstance(input, supported_scalar_types())
|
assert isinstance(input, supported_scalar_types())
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["is_dynamic"] = False
|
metadata["is_dynamic"] = False
|
||||||
# Scalar tensor
|
# Scalar tensor
|
||||||
metadata["device_type"] = device_type
|
metadata["device_type"] = device_type
|
||||||
@ -135,31 +135,31 @@ def extract_scalar_metadata(device_type: str, input: Any) -> Dict[str, Any]:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def extract_string_metadata(input: str) -> Dict[str, Any]:
|
def extract_string_metadata(input: str) -> dict[str, Any]:
|
||||||
assert isinstance(input, str)
|
assert isinstance(input, str)
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["string_value"] = input
|
metadata["string_value"] = input
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def extract_dtype_metadata(input: torch.dtype) -> Dict[str, Any]:
|
def extract_dtype_metadata(input: torch.dtype) -> dict[str, Any]:
|
||||||
assert isinstance(input, torch.dtype)
|
assert isinstance(input, torch.dtype)
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["dtype_value"] = f"{input}"
|
metadata["dtype_value"] = f"{input}"
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def extract_device_metadata(input: torch.device) -> Dict[str, Any]:
|
def extract_device_metadata(input: torch.device) -> dict[str, Any]:
|
||||||
assert isinstance(input, torch.device)
|
assert isinstance(input, torch.device)
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["device_type_value"] = f"{input.type}"
|
metadata["device_type_value"] = f"{input.type}"
|
||||||
metadata["device_index_value"] = input.index
|
metadata["device_index_value"] = input.index
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
def extract_layout_metadata(input: torch.layout) -> Dict[str, Any]:
|
def extract_layout_metadata(input: torch.layout) -> dict[str, Any]:
|
||||||
assert isinstance(input, torch.layout)
|
assert isinstance(input, torch.layout)
|
||||||
metadata: Dict[str, Any] = {}
|
metadata: dict[str, Any] = {}
|
||||||
metadata["layout_value"] = f"{input}"
|
metadata["layout_value"] = f"{input}"
|
||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
@ -171,10 +171,10 @@ def aoti_compile_with_persistent_cache(
|
|||||||
dynamic: bool,
|
dynamic: bool,
|
||||||
f: Callable[..., Any],
|
f: Callable[..., Any],
|
||||||
args: tuple[Any],
|
args: tuple[Any],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
*,
|
*,
|
||||||
dynamic_shapes: Optional[Dict[str, Any]] = None,
|
dynamic_shapes: Optional[dict[str, Any]] = None,
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: Optional[dict[str, Any]] = None,
|
||||||
remove_runtime_assertions: bool = False,
|
remove_runtime_assertions: bool = False,
|
||||||
disable_constraint_solver: bool = False,
|
disable_constraint_solver: bool = False,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -261,7 +261,7 @@ def aoti_compile_with_persistent_cache(
|
|||||||
metadata["arg_order"] = idx
|
metadata["arg_order"] = idx
|
||||||
kernel_metadata_items.append(metadata)
|
kernel_metadata_items.append(metadata)
|
||||||
|
|
||||||
kernel_meta_info: Dict[str, Any] = {}
|
kernel_meta_info: dict[str, Any] = {}
|
||||||
kernel_meta_info["meta_info"] = kernel_metadata_items
|
kernel_meta_info["meta_info"] = kernel_metadata_items
|
||||||
kernel_meta_info["kernel_path"] = (
|
kernel_meta_info["kernel_path"] = (
|
||||||
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
|
Path(kernel_lib_path).relative_to(persistent_cache).as_posix()
|
||||||
|
@ -11,7 +11,7 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|||||||
from concurrent.futures.process import BrokenProcessPool
|
from concurrent.futures.process import BrokenProcessPool
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._dynamo.device_interface import get_registered_device_interfaces
|
from torch._dynamo.device_interface import get_registered_device_interfaces
|
||||||
@ -250,7 +250,7 @@ class AsyncCompile:
|
|||||||
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
|
get_result = CppCodeCache.load_async(source_code, submit_fn=self.submit)
|
||||||
return LambdaFuture(lambda: get_result().kernel)
|
return LambdaFuture(lambda: get_result().kernel)
|
||||||
|
|
||||||
def cpp_pybinding(self, argtypes: List[str], source_code: str):
|
def cpp_pybinding(self, argtypes: list[str], source_code: str):
|
||||||
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
|
kernel_code_log.info("CPP+Bindings Kernel:\n%s", source_code)
|
||||||
if get_compile_threads() <= 1:
|
if get_compile_threads() <= 1:
|
||||||
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
|
return CppPythonBindingsCodeCache.load_pybinding(argtypes, source_code)
|
||||||
@ -299,7 +299,7 @@ class AsyncCompile:
|
|||||||
)
|
)
|
||||||
return LambdaFuture(get_result)
|
return LambdaFuture(get_result)
|
||||||
|
|
||||||
def wait(self, scope: Dict[str, Any]) -> None:
|
def wait(self, scope: dict[str, Any]) -> None:
|
||||||
with dynamo_timed(
|
with dynamo_timed(
|
||||||
"async_compile.wait",
|
"async_compile.wait",
|
||||||
log_pt2_compile_event=True,
|
log_pt2_compile_event=True,
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||||
@ -50,16 +50,16 @@ class AutoHeuristic:
|
|||||||
a heuristic (see torchgen/autoheuristic/).
|
a heuristic (see torchgen/autoheuristic/).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
collected_feedback: Dict[Choice, Feedback]
|
collected_feedback: dict[Choice, Feedback]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fallback: Callable[[], Choice],
|
fallback: Callable[[], Choice],
|
||||||
choices: List[Choice],
|
choices: list[Choice],
|
||||||
feedback: Optional[LocalFeedback],
|
feedback: Optional[LocalFeedback],
|
||||||
context: AHContext,
|
context: AHContext,
|
||||||
name: str,
|
name: str,
|
||||||
augment_context: Optional[List[AHOperation]] = None,
|
augment_context: Optional[list[AHOperation]] = None,
|
||||||
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -135,8 +135,8 @@ class AutoHeuristic:
|
|||||||
return self.fallback()
|
return self.fallback()
|
||||||
|
|
||||||
def get_top_k_choices(
|
def get_top_k_choices(
|
||||||
self, top_k: int, always_included: Optional[List[str]] = None
|
self, top_k: int, always_included: Optional[list[str]] = None
|
||||||
) -> Optional[List[Choice]]:
|
) -> Optional[list[Choice]]:
|
||||||
if not self.satisfies_precondition():
|
if not self.satisfies_precondition():
|
||||||
return None
|
return None
|
||||||
if torch._inductor.config.use_autoheuristic(self.name):
|
if torch._inductor.config.use_autoheuristic(self.name):
|
||||||
@ -223,11 +223,11 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
fallback: Callable[[], Optional[ChoiceCaller]],
|
fallback: Callable[[], Optional[ChoiceCaller]],
|
||||||
choices: List[ChoiceCaller],
|
choices: list[ChoiceCaller],
|
||||||
input_nodes: List[Any],
|
input_nodes: list[Any],
|
||||||
context: AHContext,
|
context: AHContext,
|
||||||
name: str,
|
name: str,
|
||||||
augment_context: Optional[List[AHOperation]] = None,
|
augment_context: Optional[list[AHOperation]] = None,
|
||||||
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
precondition: Optional[Callable[[AHMetadata, AHContext], bool]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
@ -237,7 +237,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
|||||||
have to be used here.
|
have to be used here.
|
||||||
"""
|
"""
|
||||||
self.input_nodes = input_nodes
|
self.input_nodes = input_nodes
|
||||||
self.choicestr2choice: Dict[str, ChoiceCaller] = {}
|
self.choicestr2choice: dict[str, ChoiceCaller] = {}
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
self.choicestr2choice[choice.autoheuristic_id()] = choice
|
self.choicestr2choice[choice.autoheuristic_id()] = choice
|
||||||
choices_str = list(self.choicestr2choice.keys())
|
choices_str = list(self.choicestr2choice.keys())
|
||||||
@ -266,7 +266,7 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
|||||||
self.register_global_feedback(input_nodes, choices)
|
self.register_global_feedback(input_nodes, choices)
|
||||||
|
|
||||||
def register_global_feedback(
|
def register_global_feedback(
|
||||||
self, input_nodes: List[Any], choices: List[ChoiceCaller]
|
self, input_nodes: list[Any], choices: list[ChoiceCaller]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Registers a callback in select_algorithm, which is called with the timing of each choice.
|
Registers a callback in select_algorithm, which is called with the timing of each choice.
|
||||||
@ -281,10 +281,10 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
|||||||
def store_global_feedback(
|
def store_global_feedback(
|
||||||
ah_inputs_key: str,
|
ah_inputs_key: str,
|
||||||
ah_precompile_key: str,
|
ah_precompile_key: str,
|
||||||
timings: Dict[ChoiceCaller, float],
|
timings: dict[ChoiceCaller, float],
|
||||||
name: str,
|
name: str,
|
||||||
input_nodes: List[Any],
|
input_nodes: list[Any],
|
||||||
choices: List[ChoiceCaller],
|
choices: list[ChoiceCaller],
|
||||||
) -> None:
|
) -> None:
|
||||||
current_inputs_key = create_inputs_key(input_nodes)
|
current_inputs_key = create_inputs_key(input_nodes)
|
||||||
if current_inputs_key != ah_inputs_key:
|
if current_inputs_key != ah_inputs_key:
|
||||||
@ -307,8 +307,8 @@ class AutoHeuristicSelectAlgorithm(AutoHeuristic):
|
|||||||
return self.choicestr2choice.get(choice, None)
|
return self.choicestr2choice.get(choice, None)
|
||||||
|
|
||||||
def get_top_k_choices_caller(
|
def get_top_k_choices_caller(
|
||||||
self, top_k: int, always_included: Optional[List[str]] = None
|
self, top_k: int, always_included: Optional[list[str]] = None
|
||||||
) -> Optional[List[ChoiceCaller]]:
|
) -> Optional[list[ChoiceCaller]]:
|
||||||
choices = self.get_top_k_choices(top_k, always_included)
|
choices = self.get_top_k_choices(top_k, always_included)
|
||||||
if choices is None:
|
if choices is None:
|
||||||
return None
|
return None
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import functools
|
import functools
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
@ -51,8 +51,8 @@ class AHContext:
|
|||||||
information that will help to learn a heuristic.
|
information that will help to learn a heuristic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
features: List[AHFeature]
|
features: list[AHFeature]
|
||||||
context_dict: Dict[str, Value]
|
context_dict: dict[str, Value]
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.features = []
|
self.features = []
|
||||||
@ -64,7 +64,7 @@ class AHContext:
|
|||||||
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
|
self.features.append(AHFeature(name, value, is_categorical=is_categorical))
|
||||||
self.context_dict[name] = value
|
self.context_dict[name] = value
|
||||||
|
|
||||||
def get_numerical_and_categorical_features(self) -> tuple[List[str], List[str]]:
|
def get_numerical_and_categorical_features(self) -> tuple[list[str], list[str]]:
|
||||||
numerical_features = []
|
numerical_features = []
|
||||||
categorical_features = []
|
categorical_features = []
|
||||||
for feature in self.features:
|
for feature in self.features:
|
||||||
@ -84,7 +84,7 @@ class AHContext:
|
|||||||
def get_value(self, name: str) -> Value:
|
def get_value(self, name: str) -> Value:
|
||||||
return self.context_dict[name]
|
return self.context_dict[name]
|
||||||
|
|
||||||
def apply_operations(self, operations: List[AHOperation]) -> None:
|
def apply_operations(self, operations: list[AHOperation]) -> None:
|
||||||
for op in operations:
|
for op in operations:
|
||||||
op.apply_operation(self.context_dict)
|
op.apply_operation(self.context_dict)
|
||||||
|
|
||||||
@ -94,7 +94,7 @@ class AHMetadata:
|
|||||||
self,
|
self,
|
||||||
shared_memory: Any,
|
shared_memory: Any,
|
||||||
device_capa: tuple[int, int],
|
device_capa: tuple[int, int],
|
||||||
choices: List[Choice],
|
choices: list[Choice],
|
||||||
name: str,
|
name: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
# use amount of shared_memory and device_capability to identify GPU
|
# use amount of shared_memory and device_capability to identify GPU
|
||||||
@ -104,7 +104,7 @@ class AHMetadata:
|
|||||||
self.choices = choices
|
self.choices = choices
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Value]:
|
def to_dict(self) -> dict[str, Value]:
|
||||||
return {
|
return {
|
||||||
"shared_memory": self.shared_memory,
|
"shared_memory": self.shared_memory,
|
||||||
"device_capa": self.device_capa,
|
"device_capa": self.device_capa,
|
||||||
@ -147,7 +147,7 @@ def get_mixedmm_precondition(metadata: AHMetadata, context: AHContext) -> bool:
|
|||||||
return mat1_iscontig and not mat2_iscontig
|
return mat1_iscontig and not mat2_iscontig
|
||||||
|
|
||||||
|
|
||||||
def get_mult_dims_ops() -> List[AHOperation]:
|
def get_mult_dims_ops() -> list[AHOperation]:
|
||||||
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
|
m_times_k_op = AHOperation("m*k", lambda data: data["m"] * data["k"])
|
||||||
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
|
m_times_n_op = AHOperation("m*n", lambda data: data["m"] * data["n"])
|
||||||
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
|
k_times_n_op = AHOperation("k*n", lambda data: data["k"] * data["n"])
|
||||||
@ -163,7 +163,7 @@ def get_arith_intensity(data: Any) -> float:
|
|||||||
return m * k * n / (m * k + k * n + m * n)
|
return m * k * n / (m * k + k * n + m * n)
|
||||||
|
|
||||||
|
|
||||||
def pad_mm_operations() -> List[AHOperation]:
|
def pad_mm_operations() -> list[AHOperation]:
|
||||||
mult_dims_ops = get_mult_dims_ops()
|
mult_dims_ops = get_mult_dims_ops()
|
||||||
k_div_m_times_n_op = AHOperation(
|
k_div_m_times_n_op = AHOperation(
|
||||||
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
|
"k/(m*n)", lambda data: data["k"] / (data["m"] * data["n"])
|
||||||
@ -200,7 +200,7 @@ def between_op(data: Any, dim: str, lower: int, upper: int) -> bool:
|
|||||||
return data[dim] >= lower and data[dim] <= upper
|
return data[dim] >= lower and data[dim] <= upper
|
||||||
|
|
||||||
|
|
||||||
def between_ops() -> List[AHOperation]:
|
def between_ops() -> list[AHOperation]:
|
||||||
dims = ["m", "k", "n"]
|
dims = ["m", "k", "n"]
|
||||||
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
|
limits = [(1, 16), (17, 32), (33, 64), (65, 128), (129, 256)]
|
||||||
ah_operations = []
|
ah_operations = []
|
||||||
@ -221,13 +221,13 @@ def pow2_op(data: Any, dim: str, exponent: int) -> bool:
|
|||||||
return data[dim] == 2**exponent
|
return data[dim] == 2**exponent
|
||||||
|
|
||||||
|
|
||||||
def mm_operations() -> List[AHOperation]:
|
def mm_operations() -> list[AHOperation]:
|
||||||
mult_dims_ops = get_mult_dims_ops()
|
mult_dims_ops = get_mult_dims_ops()
|
||||||
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
arith_intensity_op = AHOperation("arith_intensity", get_arith_intensity)
|
||||||
return mult_dims_ops + [arith_intensity_op]
|
return mult_dims_ops + [arith_intensity_op]
|
||||||
|
|
||||||
|
|
||||||
def mixed_mm_operations() -> List[AHOperation]:
|
def mixed_mm_operations() -> list[AHOperation]:
|
||||||
return mm_operations() + between_ops()
|
return mm_operations() + between_ops()
|
||||||
|
|
||||||
|
|
||||||
@ -235,7 +235,7 @@ def is_multiple(data: Any, dim: str, mult: int) -> bool:
|
|||||||
return data[dim] % mult == 0
|
return data[dim] % mult == 0
|
||||||
|
|
||||||
|
|
||||||
def get_dims_multiple_ops() -> List[AHOperation]:
|
def get_dims_multiple_ops() -> list[AHOperation]:
|
||||||
multiples = [2, 4, 8, 16, 32]
|
multiples = [2, 4, 8, 16, 32]
|
||||||
dims = ["m", "k", "n"]
|
dims = ["m", "k", "n"]
|
||||||
dims_multiple_ops = []
|
dims_multiple_ops = []
|
||||||
@ -249,7 +249,7 @@ def get_dims_multiple_ops() -> List[AHOperation]:
|
|||||||
return dims_multiple_ops
|
return dims_multiple_ops
|
||||||
|
|
||||||
|
|
||||||
def get_dims_need_padding_ops() -> List[AHOperation]:
|
def get_dims_need_padding_ops() -> list[AHOperation]:
|
||||||
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
|
def mat1_innermost_needs_padding_fn(data: Any) -> bool:
|
||||||
mat1_stride_0 = data["mat1_stride_0"]
|
mat1_stride_0 = data["mat1_stride_0"]
|
||||||
mat1_stride_1 = data["mat1_stride_1"]
|
mat1_stride_1 = data["mat1_stride_1"]
|
||||||
@ -303,7 +303,7 @@ def get_dims_need_padding_ops() -> List[AHOperation]:
|
|||||||
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
|
return [mat1_innermost_op, mat2_innermost_op, num_dims_op]
|
||||||
|
|
||||||
|
|
||||||
def get_is_contig_ops() -> List[AHOperation]:
|
def get_is_contig_ops() -> list[AHOperation]:
|
||||||
def mat1_is_contig_fn(data: Any) -> bool:
|
def mat1_is_contig_fn(data: Any) -> bool:
|
||||||
stride_0 = data["mat1_stride_0"]
|
stride_0 = data["mat1_stride_0"]
|
||||||
stride_1 = data["mat1_stride_1"]
|
stride_1 = data["mat1_stride_1"]
|
||||||
|
@ -2,7 +2,7 @@ import importlib
|
|||||||
import inspect
|
import inspect
|
||||||
import pkgutil
|
import pkgutil
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||||
AHContext,
|
AHContext,
|
||||||
@ -14,7 +14,7 @@ from torch._inductor.autoheuristic.learnedheuristic_interface import LearnedHeur
|
|||||||
|
|
||||||
def find_and_instantiate_subclasses(
|
def find_and_instantiate_subclasses(
|
||||||
package_name: str, base_class: Any
|
package_name: str, base_class: Any
|
||||||
) -> List[LearnedHeuristic]:
|
) -> list[LearnedHeuristic]:
|
||||||
instances = []
|
instances = []
|
||||||
|
|
||||||
package = importlib.import_module(package_name)
|
package = importlib.import_module(package_name)
|
||||||
@ -49,7 +49,7 @@ class LearnedHeuristicController:
|
|||||||
a way to get the decision of a learned heuristic.
|
a way to get the decision of a learned heuristic.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
existing_heuristics: Dict[str, List[LearnedHeuristic]] = defaultdict(list)
|
existing_heuristics: dict[str, list[LearnedHeuristic]] = defaultdict(list)
|
||||||
"""
|
"""
|
||||||
A dictionary that stores all the learned heuristics for each optimization.
|
A dictionary that stores all the learned heuristics for each optimization.
|
||||||
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
|
The key is the optimization name, and the value is a list of LearnedHeuristic objects.
|
||||||
@ -69,7 +69,7 @@ class LearnedHeuristicController:
|
|||||||
self.metadata = metadata
|
self.metadata = metadata
|
||||||
self.context = context
|
self.context = context
|
||||||
|
|
||||||
def get_heuristics(self, name: str) -> List[LearnedHeuristic]:
|
def get_heuristics(self, name: str) -> list[LearnedHeuristic]:
|
||||||
"""
|
"""
|
||||||
Returns a list of learned heuristics for the given optimization name.
|
Returns a list of learned heuristics for the given optimization name.
|
||||||
"""
|
"""
|
||||||
@ -105,7 +105,7 @@ class LearnedHeuristicController:
|
|||||||
return heuristic.get_decision(self.context, self.metadata.choices)
|
return heuristic.get_decision(self.context, self.metadata.choices)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_decisions_ranked(self, top_k: int) -> Optional[List[Choice]]:
|
def get_decisions_ranked(self, top_k: int) -> Optional[list[Choice]]:
|
||||||
heuristics = self.get_heuristics(self.metadata.name)
|
heuristics = self.get_heuristics(self.metadata.name)
|
||||||
for heuristic in heuristics:
|
for heuristic in heuristics:
|
||||||
if heuristic.check_precondition(self.metadata, self.context):
|
if heuristic.check_precondition(self.metadata, self.context):
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
from torch._inductor.autoheuristic.autoheuristic_utils import (
|
||||||
AHContext,
|
AHContext,
|
||||||
@ -23,7 +23,7 @@ class LearnedHeuristic:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
def get_decision(
|
def get_decision(
|
||||||
self, context: AHContext, choices: List[Choice]
|
self, context: AHContext, choices: list[Choice]
|
||||||
) -> Optional[Choice]:
|
) -> Optional[Choice]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -33,7 +33,7 @@ class LearnedHeuristic:
|
|||||||
def get_name(self) -> str:
|
def get_name(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
|
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@ -45,7 +45,7 @@ class LearnedHeuristicRegression(LearnedHeuristic):
|
|||||||
return 1.0
|
return 1.0
|
||||||
|
|
||||||
def get_decision(
|
def get_decision(
|
||||||
self, context: AHContext, choices: List[Choice]
|
self, context: AHContext, choices: list[Choice]
|
||||||
) -> Optional[Choice]:
|
) -> Optional[Choice]:
|
||||||
choice2feedback = {}
|
choice2feedback = {}
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
@ -68,7 +68,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def get_decision(
|
def get_decision(
|
||||||
self, context: AHContext, choices: List[Choice]
|
self, context: AHContext, choices: list[Choice]
|
||||||
) -> Optional[Choice]:
|
) -> Optional[Choice]:
|
||||||
best_choices = self.get_best_choices(context)
|
best_choices = self.get_best_choices(context)
|
||||||
if not best_choices:
|
if not best_choices:
|
||||||
@ -78,7 +78,7 @@ class LearnedHeuristicDecision(LearnedHeuristic):
|
|||||||
return None
|
return None
|
||||||
return self.get_choice(best_choice_idx)
|
return self.get_choice(best_choice_idx)
|
||||||
|
|
||||||
def get_decisions_ranked(self, context: AHContext) -> Optional[List[str]]:
|
def get_decisions_ranked(self, context: AHContext) -> Optional[list[str]]:
|
||||||
feedback_idx_list = self.get_best_choices(context)
|
feedback_idx_list = self.get_best_choices(context)
|
||||||
if feedback_idx_list is None:
|
if feedback_idx_list is None:
|
||||||
return None
|
return None
|
||||||
@ -88,5 +88,5 @@ class LearnedHeuristicDecision(LearnedHeuristic):
|
|||||||
choices = [choice for choice in choices if choice is not None]
|
choices = [choice for choice in choices if choice is not None]
|
||||||
return choices
|
return choices
|
||||||
|
|
||||||
def get_best_choices(self, context: AHContext) -> Optional[List[tuple[float, int]]]:
|
def get_best_choices(self, context: AHContext) -> Optional[list[tuple[float, int]]]:
|
||||||
return []
|
return []
|
||||||
|
@ -10,19 +10,10 @@ import os
|
|||||||
import queue
|
import queue
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Iterable, Sequence
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from ctypes import byref, c_size_t, c_void_p, CDLL
|
from ctypes import byref, c_size_t, c_void_p, CDLL
|
||||||
from typing import (
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||||
@ -396,8 +387,8 @@ class TuningProcessPool:
|
|||||||
|
|
||||||
def benchmark(
|
def benchmark(
|
||||||
self,
|
self,
|
||||||
choices: List[TritonTemplateCaller],
|
choices: list[TritonTemplateCaller],
|
||||||
) -> Dict[TritonTemplateCaller, float]:
|
) -> dict[TritonTemplateCaller, float]:
|
||||||
"""
|
"""
|
||||||
Benchmark each choice in a separate process.
|
Benchmark each choice in a separate process.
|
||||||
"""
|
"""
|
||||||
@ -432,9 +423,9 @@ class TensorMeta:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def from_irnodes(
|
def from_irnodes(
|
||||||
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
cls, irnodes: Union[LayoutOrBuffer, Sequence[LayoutOrBuffer]]
|
||||||
) -> Union[TensorMeta, List[TensorMeta]]:
|
) -> Union[TensorMeta, list[TensorMeta]]:
|
||||||
if isinstance(irnodes, Sequence):
|
if isinstance(irnodes, Sequence):
|
||||||
result: List[Any] = [cls.from_irnodes(x) for x in irnodes]
|
result: list[Any] = [cls.from_irnodes(x) for x in irnodes]
|
||||||
assert all(isinstance(x, TensorMeta) for x in result)
|
assert all(isinstance(x, TensorMeta) for x in result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -488,8 +479,8 @@ class BenchmarkRequest:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kernel_name: str,
|
kernel_name: str,
|
||||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
extra_args: Iterable[Any],
|
extra_args: Iterable[Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
# the kernel name defined in the module
|
# the kernel name defined in the module
|
||||||
@ -640,12 +631,12 @@ class TritonBenchmarkRequest(BenchmarkRequest):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kernel_name: str,
|
kernel_name: str,
|
||||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
extra_args: Iterable[Any],
|
extra_args: Iterable[Any],
|
||||||
module_path: str, # the path of the module defining the triton kernel
|
module_path: str, # the path of the module defining the triton kernel
|
||||||
module_cache_key: str,
|
module_cache_key: str,
|
||||||
grid: List[int],
|
grid: list[int],
|
||||||
num_stages: int,
|
num_stages: int,
|
||||||
num_warps: int,
|
num_warps: int,
|
||||||
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
matrix_instr_nonkdim: int = 0, # only used for hip to choose the shape of mfma instruction.
|
||||||
@ -770,8 +761,8 @@ class CUDABenchmarkRequest(GPUDeviceBenchmarkMixin, BenchmarkRequest):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kernel_name: str,
|
kernel_name: str,
|
||||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
extra_args: Iterable[Any],
|
extra_args: Iterable[Any],
|
||||||
source_code: str,
|
source_code: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -889,8 +880,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
kernel_name: str,
|
kernel_name: str,
|
||||||
input_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
input_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
output_tensor_meta: Union[TensorMeta, List[TensorMeta]],
|
output_tensor_meta: Union[TensorMeta, list[TensorMeta]],
|
||||||
extra_args: Iterable[Any],
|
extra_args: Iterable[Any],
|
||||||
source_code: str,
|
source_code: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -946,8 +937,8 @@ class CppBenchmarkRequest(CPUDeviceBenchmarkMixin, BenchmarkRequest):
|
|||||||
|
|
||||||
|
|
||||||
def benchmark_in_sub_process(
|
def benchmark_in_sub_process(
|
||||||
choices: List[TritonTemplateCaller],
|
choices: list[TritonTemplateCaller],
|
||||||
) -> Dict[TritonTemplateCaller, float]:
|
) -> dict[TritonTemplateCaller, float]:
|
||||||
"""
|
"""
|
||||||
Do benchmarking in a subprocess and return the perf number (latency).
|
Do benchmarking in a subprocess and return the perf number (latency).
|
||||||
"""
|
"""
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Any, Callable, Dict, Union
|
from typing import Any, Callable, Union
|
||||||
|
|
||||||
from sympy import Expr
|
from sympy import Expr
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class BoundVars:
|
|||||||
or "masked_subblock" in node.target
|
or "masked_subblock" in node.target
|
||||||
)
|
)
|
||||||
# To access this variable call `get_bounds()`
|
# To access this variable call `get_bounds()`
|
||||||
self._bounds: Dict[torch.fx.Node, ValueRanges[Expr]] = {}
|
self._bounds: dict[torch.fx.Node, ValueRanges[Expr]] = {}
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return (
|
return (
|
||||||
@ -55,7 +55,7 @@ class BoundVars:
|
|||||||
)
|
)
|
||||||
|
|
||||||
@cache_on_self
|
@cache_on_self
|
||||||
def get_bounds(self) -> Dict[torch.fx.Node, ValueRanges[Expr]]:
|
def get_bounds(self) -> dict[torch.fx.Node, ValueRanges[Expr]]:
|
||||||
submodules = self.swap_submodules(self.loop_body.submodules)
|
submodules = self.swap_submodules(self.loop_body.submodules)
|
||||||
|
|
||||||
# Initialize the environment with the unbounded variables
|
# Initialize the environment with the unbounded variables
|
||||||
@ -74,9 +74,9 @@ class BoundVars:
|
|||||||
return self._bounds
|
return self._bounds
|
||||||
|
|
||||||
def swap_submodules(
|
def swap_submodules(
|
||||||
self, submodules: Dict[str, Callable[..., Any]]
|
self, submodules: dict[str, Callable[..., Any]]
|
||||||
) -> Dict[str, Callable[..., ValueRanges[Expr]]]:
|
) -> dict[str, Callable[..., ValueRanges[Expr]]]:
|
||||||
result: Dict[str, Callable[..., ValueRanges[Expr]]] = {}
|
result: dict[str, Callable[..., ValueRanges[Expr]]] = {}
|
||||||
for key in submodules.keys():
|
for key in submodules.keys():
|
||||||
if key == "get_index":
|
if key == "get_index":
|
||||||
result[key] = self.get_index
|
result[key] = self.get_index
|
||||||
@ -111,10 +111,10 @@ class BoundVars:
|
|||||||
def masked_subblock(
|
def masked_subblock(
|
||||||
self,
|
self,
|
||||||
subblock: LoopBodyBlock,
|
subblock: LoopBodyBlock,
|
||||||
env: Dict[torch.fx.Node, ValueRanges[Expr]],
|
env: dict[torch.fx.Node, ValueRanges[Expr]],
|
||||||
mask: Any,
|
mask: Any,
|
||||||
value: Any,
|
value: Any,
|
||||||
submodules: Dict[str, Callable[..., Any]],
|
submodules: dict[str, Callable[..., Any]],
|
||||||
) -> ValueRanges[Expr]:
|
) -> ValueRanges[Expr]:
|
||||||
interp = InterpreterShim(subblock.graph, submodules)
|
interp = InterpreterShim(subblock.graph, submodules)
|
||||||
interp.run(V.get_ops_handler(), initial_env=env)
|
interp.run(V.get_ops_handler(), initial_env=env)
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Dict, List, Type, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
@ -42,11 +42,11 @@ class InductorChoices:
|
|||||||
|
|
||||||
def triton_kernel_kwargs(
|
def triton_kernel_kwargs(
|
||||||
self,
|
self,
|
||||||
kernel_cls: Type[TritonKernel],
|
kernel_cls: type[TritonKernel],
|
||||||
features: SIMDKernelFeatures,
|
features: SIMDKernelFeatures,
|
||||||
groups: List[sympy.Expr],
|
groups: list[sympy.Expr],
|
||||||
kernel_kwargs: Dict[str, Any],
|
kernel_kwargs: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
|
"""Hook to change the kwargs passed to TritonKernel, used to apply fixed configurations"""
|
||||||
return kernel_kwargs
|
return kernel_kwargs
|
||||||
|
|
||||||
|
@ -36,13 +36,8 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
cast,
|
cast,
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
List,
|
|
||||||
NoReturn,
|
NoReturn,
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Tuple,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -127,7 +122,7 @@ else:
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from collections.abc import KeysView
|
from collections.abc import Generator, KeysView, Sequence
|
||||||
from concurrent.futures import Future
|
from concurrent.futures import Future
|
||||||
|
|
||||||
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
|
from .compile_fx import _CompileFxKwargs, CompiledFxGraph
|
||||||
@ -168,7 +163,7 @@ def get_global_cache_path_impl(global_cache_dir: str) -> Optional[Path]:
|
|||||||
class CacheBase:
|
class CacheBase:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def get_system() -> Dict[str, Any]:
|
def get_system() -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
from triton.compiler.compiler import triton_key
|
from triton.compiler.compiler import triton_key
|
||||||
|
|
||||||
@ -179,7 +174,7 @@ class CacheBase:
|
|||||||
triton_version = None
|
triton_version = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
system: Dict[str, Any] = {
|
system: dict[str, Any] = {
|
||||||
"device": {"name": None},
|
"device": {"name": None},
|
||||||
"version": {
|
"version": {
|
||||||
"triton": triton_version,
|
"triton": triton_version,
|
||||||
@ -217,7 +212,7 @@ class CacheBase:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.system = CacheBase.get_system()
|
self.system = CacheBase.get_system()
|
||||||
|
|
||||||
def get_local_cache(self) -> Dict[str, Any]:
|
def get_local_cache(self) -> dict[str, Any]:
|
||||||
local_cache_path = self.get_local_cache_path()
|
local_cache_path = self.get_local_cache_path()
|
||||||
if not local_cache_path.is_file():
|
if not local_cache_path.is_file():
|
||||||
return {}
|
return {}
|
||||||
@ -225,7 +220,7 @@ class CacheBase:
|
|||||||
local_cache = json.load(local_cache_fp)
|
local_cache = json.load(local_cache_fp)
|
||||||
return local_cache["cache"]
|
return local_cache["cache"]
|
||||||
|
|
||||||
def update_local_cache(self, local_cache: Dict[str, Any]) -> None:
|
def update_local_cache(self, local_cache: dict[str, Any]) -> None:
|
||||||
local_cache_path = self.get_local_cache_path()
|
local_cache_path = self.get_local_cache_path()
|
||||||
write_atomic(
|
write_atomic(
|
||||||
str(local_cache_path),
|
str(local_cache_path),
|
||||||
@ -235,7 +230,7 @@ class CacheBase:
|
|||||||
|
|
||||||
|
|
||||||
class LocalCache(CacheBase):
|
class LocalCache(CacheBase):
|
||||||
def lookup(self, *keys: str) -> Optional[Dict[str, Any]]:
|
def lookup(self, *keys: str) -> Optional[dict[str, Any]]:
|
||||||
cache = self.get_local_cache()
|
cache = self.get_local_cache()
|
||||||
|
|
||||||
sub_cache = cache
|
sub_cache = cache
|
||||||
@ -261,7 +256,7 @@ class LocalCache(CacheBase):
|
|||||||
|
|
||||||
class PersistentCache(CacheBase):
|
class PersistentCache(CacheBase):
|
||||||
@functools.lru_cache(None) # noqa: B019
|
@functools.lru_cache(None) # noqa: B019
|
||||||
def get_global_cache(self) -> Dict[str, Any]:
|
def get_global_cache(self) -> dict[str, Any]:
|
||||||
global_cache_path = self.get_global_cache_path()
|
global_cache_path = self.get_global_cache_path()
|
||||||
if global_cache_path is None or not global_cache_path.is_file():
|
if global_cache_path is None or not global_cache_path.is_file():
|
||||||
return {}
|
return {}
|
||||||
@ -271,11 +266,11 @@ class PersistentCache(CacheBase):
|
|||||||
|
|
||||||
def lookup(
|
def lookup(
|
||||||
self,
|
self,
|
||||||
choices: List[ChoiceCaller],
|
choices: list[ChoiceCaller],
|
||||||
op: str,
|
op: str,
|
||||||
inputs: str,
|
inputs: str,
|
||||||
benchmark: Optional[Callable[[Any], Dict[ChoiceCaller, float]]],
|
benchmark: Optional[Callable[[Any], dict[ChoiceCaller, float]]],
|
||||||
) -> Dict[ChoiceCaller, float]:
|
) -> dict[ChoiceCaller, float]:
|
||||||
"""
|
"""
|
||||||
Check to see if we have benchmarked the given choice callers. For each
|
Check to see if we have benchmarked the given choice callers. For each
|
||||||
choice caller:
|
choice caller:
|
||||||
@ -296,7 +291,7 @@ class PersistentCache(CacheBase):
|
|||||||
)
|
)
|
||||||
timings = {}
|
timings = {}
|
||||||
|
|
||||||
def check_cache(cache: Dict[str, Any], callback: Any = None) -> bool:
|
def check_cache(cache: dict[str, Any], callback: Any = None) -> bool:
|
||||||
"""Check if `cache` contains data for all the choices"""
|
"""Check if `cache` contains data for all the choices"""
|
||||||
hit = True
|
hit = True
|
||||||
for choice in choices:
|
for choice in choices:
|
||||||
@ -456,7 +451,7 @@ class TensorMetadataAndValues:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
tensor_metadata: TensorMetadata
|
tensor_metadata: TensorMetadata
|
||||||
values: List[Any]
|
values: list[Any]
|
||||||
|
|
||||||
|
|
||||||
def _ident(x: T) -> T:
|
def _ident(x: T) -> T:
|
||||||
@ -584,7 +579,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||||||
|
|
||||||
def _reduce_graph_module(
|
def _reduce_graph_module(
|
||||||
self, gm: torch.fx.GraphModule
|
self, gm: torch.fx.GraphModule
|
||||||
) -> tuple[Any, tuple[Dict[str, Any], str]]:
|
) -> tuple[Any, tuple[dict[str, Any], str]]:
|
||||||
"""
|
"""
|
||||||
Custom reducer for graph module to handle irrelevant data for user
|
Custom reducer for graph module to handle irrelevant data for user
|
||||||
defined triton kernels
|
defined triton kernels
|
||||||
@ -624,7 +619,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||||||
serialized_data = self.dumps(obj)
|
serialized_data = self.dumps(obj)
|
||||||
return sha256_hash(serialized_data)
|
return sha256_hash(serialized_data)
|
||||||
|
|
||||||
def debug_lines(self, inp: FxGraphHashDetails) -> List[str]:
|
def debug_lines(self, inp: FxGraphHashDetails) -> list[str]:
|
||||||
"""
|
"""
|
||||||
Get a printable string describing in more detail all the attributes
|
Get a printable string describing in more detail all the attributes
|
||||||
comprising an object. Useful for debugging when one graph hashes
|
comprising an object. Useful for debugging when one graph hashes
|
||||||
@ -659,7 +654,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
|||||||
|
|
||||||
|
|
||||||
def build_code_hash(
|
def build_code_hash(
|
||||||
roots: List[str] | None, prefix: str, hasher: hashlib._Hash
|
roots: list[str] | None, prefix: str, hasher: hashlib._Hash
|
||||||
) -> None:
|
) -> None:
|
||||||
for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
|
for lib in sorted(pkgutil.iter_modules(roots, prefix), key=lambda x: x.name):
|
||||||
spec = lib.module_finder.find_spec(lib.name, None)
|
spec = lib.module_finder.find_spec(lib.name, None)
|
||||||
@ -721,7 +716,7 @@ class OrderedSetHolder:
|
|||||||
of set kwargs.
|
of set kwargs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
items: List[Any]
|
items: list[Any]
|
||||||
|
|
||||||
|
|
||||||
class BypassFxGraphCache(Exception):
|
class BypassFxGraphCache(Exception):
|
||||||
@ -753,7 +748,7 @@ class FxGraphHashDetails:
|
|||||||
# Order kwargs so hashing is stable to changes in kwarg order. Although
|
# Order kwargs so hashing is stable to changes in kwarg order. Although
|
||||||
# it's technically a _CompileFxKwargs we don't actually need it typed as
|
# it's technically a _CompileFxKwargs we don't actually need it typed as
|
||||||
# such since we're just using it to generate a hash.
|
# such since we're just using it to generate a hash.
|
||||||
self.fx_kwargs: Dict[str, object] = {}
|
self.fx_kwargs: dict[str, object] = {}
|
||||||
for k, v in sorted(fx_kwargs.items()):
|
for k, v in sorted(fx_kwargs.items()):
|
||||||
if k not in self.EXCLUDED_KWARGS:
|
if k not in self.EXCLUDED_KWARGS:
|
||||||
if type(v) in (set, OrderedSet): # noqa: set_linter
|
if type(v) in (set, OrderedSet): # noqa: set_linter
|
||||||
@ -774,7 +769,7 @@ class FxGraphHashDetails:
|
|||||||
|
|
||||||
# Node meta will not be part of gm's reduce function, so lets remember
|
# Node meta will not be part of gm's reduce function, so lets remember
|
||||||
# the kernel source code separately
|
# the kernel source code separately
|
||||||
self.user_defined_triton_source: List[Any] = []
|
self.user_defined_triton_source: list[Any] = []
|
||||||
if gm is not None:
|
if gm is not None:
|
||||||
for module in gm.modules():
|
for module in gm.modules():
|
||||||
if not isinstance(module, torch.fx.GraphModule):
|
if not isinstance(module, torch.fx.GraphModule):
|
||||||
@ -856,7 +851,7 @@ def compiled_fx_graph_hash(
|
|||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
fx_kwargs: _CompileFxKwargs,
|
fx_kwargs: _CompileFxKwargs,
|
||||||
inputs_to_check: Sequence[int],
|
inputs_to_check: Sequence[int],
|
||||||
) -> tuple[str, List[str]]:
|
) -> tuple[str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Generate a unique hash of the FX graph for caching.
|
Generate a unique hash of the FX graph for caching.
|
||||||
"""
|
"""
|
||||||
@ -952,7 +947,7 @@ class FxGraphCache:
|
|||||||
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
|
return os.path.join(FxGraphCache._get_tmp_dir(), key[1:3], key)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _filter_backed_symints(inputs: Sequence[InputType]) -> List[torch.SymInt]:
|
def _filter_backed_symints(inputs: Sequence[InputType]) -> list[torch.SymInt]:
|
||||||
"""
|
"""
|
||||||
Get the backed SymInt objects from the input list. Note that we can never
|
Get the backed SymInt objects from the input list. Note that we can never
|
||||||
have guards that depend on unbacked symint.
|
have guards that depend on unbacked symint.
|
||||||
@ -976,7 +971,7 @@ class FxGraphCache:
|
|||||||
local: bool,
|
local: bool,
|
||||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||||
constants: CompiledFxGraphConstants,
|
constants: CompiledFxGraphConstants,
|
||||||
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Lookup a compiled graph in the cache by key. On a hit, return the
|
Lookup a compiled graph in the cache by key. On a hit, return the
|
||||||
deserialized CompiledFxGraph object. On a miss, return None.
|
deserialized CompiledFxGraph object. On a miss, return None.
|
||||||
@ -988,7 +983,7 @@ class FxGraphCache:
|
|||||||
hints = [hint_int(s) for s in symints]
|
hints = [hint_int(s) for s in symints]
|
||||||
|
|
||||||
def iterate_over_candidates() -> (
|
def iterate_over_candidates() -> (
|
||||||
Generator[Tuple[CompiledFxGraph, bytes], None, None]
|
Generator[tuple[CompiledFxGraph, bytes], None, None]
|
||||||
):
|
):
|
||||||
if local:
|
if local:
|
||||||
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
||||||
@ -1021,7 +1016,7 @@ class FxGraphCache:
|
|||||||
# their guards to determine whether there's a hit.
|
# their guards to determine whether there's a hit.
|
||||||
graph = None
|
graph = None
|
||||||
pickled_content = None
|
pickled_content = None
|
||||||
cache_info: Dict[str, Any] = dict()
|
cache_info: dict[str, Any] = dict()
|
||||||
|
|
||||||
for candidate, pickled_content in iterate_over_candidates():
|
for candidate, pickled_content in iterate_over_candidates():
|
||||||
if not candidate.guards_expr:
|
if not candidate.guards_expr:
|
||||||
@ -1234,7 +1229,7 @@ class FxGraphCache:
|
|||||||
fx_kwargs: _CompileFxKwargs,
|
fx_kwargs: _CompileFxKwargs,
|
||||||
inputs_to_check: Sequence[int],
|
inputs_to_check: Sequence[int],
|
||||||
remote: bool,
|
remote: bool,
|
||||||
) -> tuple[Optional[tuple[str, List[str]]], Dict[str, Any]]:
|
) -> tuple[Optional[tuple[str, list[str]]], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Checks that the inductor input is cacheable, then computes
|
Checks that the inductor input is cacheable, then computes
|
||||||
and returns the cache key for the input.
|
and returns the cache key for the input.
|
||||||
@ -1280,13 +1275,13 @@ class FxGraphCache:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def load_with_key(
|
def load_with_key(
|
||||||
key: str,
|
key: str,
|
||||||
debug_lines: List[str],
|
debug_lines: list[str],
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
local: bool,
|
local: bool,
|
||||||
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
remote_cache: Optional[RemoteCache[JsonDataTy]],
|
||||||
is_backward: bool,
|
is_backward: bool,
|
||||||
constants: CompiledFxGraphConstants,
|
constants: CompiledFxGraphConstants,
|
||||||
) -> tuple[Optional[CompiledFxGraph], Dict[str, Any]]:
|
) -> tuple[Optional[CompiledFxGraph], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Lookup the graph with the given key, and return results and metadata.
|
Lookup the graph with the given key, and return results and metadata.
|
||||||
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
|
Doesn't do any logging on its own, because AOTAutograd handles a cache miss
|
||||||
@ -1373,11 +1368,11 @@ def split_aot_inductor_output_path(path: str) -> tuple[str, str]:
|
|||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_inductor_cache
|
||||||
class CudaKernelParamCache:
|
class CudaKernelParamCache:
|
||||||
cache: Dict[str, Dict[str, str]] = {}
|
cache: dict[str, dict[str, str]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set(cls, key: str, params: Dict[str, str], cubin: str, bin_type: str) -> None:
|
def set(cls, key: str, params: dict[str, str], cubin: str, bin_type: str) -> None:
|
||||||
_, path = write(
|
_, path = write(
|
||||||
cubin,
|
cubin,
|
||||||
bin_type,
|
bin_type,
|
||||||
@ -1391,7 +1386,7 @@ class CudaKernelParamCache:
|
|||||||
cls.cache[key] = params
|
cls.cache[key] = params
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get(cls, key: str) -> Optional[Dict[str, str]]:
|
def get(cls, key: str) -> Optional[dict[str, str]]:
|
||||||
return cls.cache.get(key, None)
|
return cls.cache.get(key, None)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1407,8 +1402,8 @@ class AotCodeCompiler:
|
|||||||
source_code: str,
|
source_code: str,
|
||||||
serialized_extern_kernel_nodes: Optional[str],
|
serialized_extern_kernel_nodes: Optional[str],
|
||||||
device_type: str,
|
device_type: str,
|
||||||
additional_files: List[str],
|
additional_files: list[str],
|
||||||
) -> Union[List[str], str]:
|
) -> Union[list[str], str]:
|
||||||
"""
|
"""
|
||||||
Returns the .so path, or returns a list of files that were generated if
|
Returns the .so path, or returns a list of files that were generated if
|
||||||
config.aot_inductor.package=True.
|
config.aot_inductor.package=True.
|
||||||
@ -1842,14 +1837,14 @@ def cpp_prefix() -> str:
|
|||||||
# Given a path to an input cpp file and an output path,
|
# Given a path to an input cpp file and an output path,
|
||||||
# Attempts to compile the file, storing the output in "output_path"
|
# Attempts to compile the file, storing the output in "output_path"
|
||||||
def compile_file(
|
def compile_file(
|
||||||
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
|
input_path: Union[str, list[str]], output_path: str, cmd: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
with dynamo_timed("compile_file"):
|
with dynamo_timed("compile_file"):
|
||||||
return _compile_file(input_path, output_path, cmd)
|
return _compile_file(input_path, output_path, cmd)
|
||||||
|
|
||||||
|
|
||||||
def _compile_file(
|
def _compile_file(
|
||||||
input_path: Union[str, List[str]], output_path: str, cmd: List[str]
|
input_path: Union[str, list[str]], output_path: str, cmd: list[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
input_paths = [input_path] if isinstance(input_path, str) else input_path
|
input_paths = [input_path] if isinstance(input_path, str) else input_path
|
||||||
input_files = [
|
input_files = [
|
||||||
@ -1948,9 +1943,9 @@ def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p]:
|
|||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_inductor_cache
|
||||||
class CppCodeCache:
|
class CppCodeCache:
|
||||||
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
cpp_compile_command_flags: Dict[str, Any] = {}
|
cpp_compile_command_flags: dict[str, Any] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
|
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
|
||||||
@ -2093,7 +2088,7 @@ def _worker_compile_cpp(
|
|||||||
# Customized Python binding for cpp kernels
|
# Customized Python binding for cpp kernels
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_inductor_cache
|
||||||
class CppPythonBindingsCodeCache(CppCodeCache):
|
class CppPythonBindingsCodeCache(CppCodeCache):
|
||||||
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
cpp_compile_command_flags = {
|
cpp_compile_command_flags = {
|
||||||
# kernels have no dependency on libtorch
|
# kernels have no dependency on libtorch
|
||||||
@ -2212,7 +2207,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
|||||||
@classmethod
|
@classmethod
|
||||||
def load_pybinding_async(
|
def load_pybinding_async(
|
||||||
cls,
|
cls,
|
||||||
argtypes: List[str],
|
argtypes: list[str],
|
||||||
source_code: str,
|
source_code: str,
|
||||||
device_type: str = "cpu",
|
device_type: str = "cpu",
|
||||||
num_outputs: int = -1,
|
num_outputs: int = -1,
|
||||||
@ -2269,7 +2264,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
|||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_inductor_cache
|
||||||
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||||
cache: Dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
cpp_compile_command_flags = {
|
cpp_compile_command_flags = {
|
||||||
"include_pytorch": True,
|
"include_pytorch": True,
|
||||||
@ -2335,7 +2330,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
|||||||
|
|
||||||
@clear_on_fresh_inductor_cache
|
@clear_on_fresh_inductor_cache
|
||||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||||
cache: Dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
_standalone_runtime_path: Optional[str] = None
|
_standalone_runtime_path: Optional[str] = None
|
||||||
prefix = textwrap.dedent(
|
prefix = textwrap.dedent(
|
||||||
@ -2412,7 +2407,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> List[str]:
|
def _codegen_buffer(cls, name: str, arg: HalideInputSpec, cuda: bool) -> list[str]:
|
||||||
assert arg.shape is not None
|
assert arg.shape is not None
|
||||||
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
|
assert arg.stride is not None and len(arg.shape) == len(arg.stride)
|
||||||
assert arg.offset is not None
|
assert arg.offset is not None
|
||||||
@ -2573,7 +2568,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||||||
donefile = str(dirpath / "done")
|
donefile = str(dirpath / "done")
|
||||||
lockfile = str(dirpath / "lock")
|
lockfile = str(dirpath / "lock")
|
||||||
need_compile = not os.path.exists(donefile)
|
need_compile = not os.path.exists(donefile)
|
||||||
jobs: List[Any] = []
|
jobs: list[Any] = []
|
||||||
if need_compile:
|
if need_compile:
|
||||||
write_atomic(genfile, source_code)
|
write_atomic(genfile, source_code)
|
||||||
cmd = [
|
cmd = [
|
||||||
@ -2685,7 +2680,7 @@ class HalideCodeCache(CppPythonBindingsCodeCache):
|
|||||||
return sofile
|
return sofile
|
||||||
|
|
||||||
|
|
||||||
def _worker_task_halide(lockfile: str, jobs: List[partial[Any]]) -> None:
|
def _worker_task_halide(lockfile: str, jobs: list[partial[Any]]) -> None:
|
||||||
from torch.utils._filelock import FileLock
|
from torch.utils._filelock import FileLock
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@ -2733,8 +2728,8 @@ class PyCodeCache:
|
|||||||
# clearing the cache. Note also that we may load the same path more
|
# clearing the cache. Note also that we may load the same path more
|
||||||
# than once, but attach different attributes, i.e., due to different
|
# than once, but attach different attributes, i.e., due to different
|
||||||
# constant values.
|
# constant values.
|
||||||
modules: List[ModuleType] = []
|
modules: list[ModuleType] = []
|
||||||
linemaps: Dict[str, List[tuple[Any, ...]]] = {}
|
linemaps: dict[str, list[tuple[Any, ...]]] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def write(cls, source_code: str, extra: str = "") -> tuple[str, str]:
|
def write(cls, source_code: str, extra: str = "") -> tuple[str, str]:
|
||||||
@ -2745,8 +2740,8 @@ class PyCodeCache:
|
|||||||
cls,
|
cls,
|
||||||
source_code: str,
|
source_code: str,
|
||||||
extra: str = "",
|
extra: str = "",
|
||||||
linemap: Optional[List[tuple[int, str]]] = None,
|
linemap: Optional[list[tuple[int, str]]] = None,
|
||||||
attrs: Optional[Dict[str, Any]] = None,
|
attrs: Optional[dict[str, Any]] = None,
|
||||||
) -> ModuleType:
|
) -> ModuleType:
|
||||||
key, path = write(source_code, "py", extra=extra)
|
key, path = write(source_code, "py", extra=extra)
|
||||||
return cls.load_by_key_path(key, path, linemap, attrs)
|
return cls.load_by_key_path(key, path, linemap, attrs)
|
||||||
@ -2756,8 +2751,8 @@ class PyCodeCache:
|
|||||||
cls,
|
cls,
|
||||||
key: str,
|
key: str,
|
||||||
path: str,
|
path: str,
|
||||||
linemap: Optional[List[tuple[int, str]]] = None,
|
linemap: Optional[list[tuple[int, str]]] = None,
|
||||||
attrs: Optional[Dict[str, Any]] = None,
|
attrs: Optional[dict[str, Any]] = None,
|
||||||
) -> ModuleType:
|
) -> ModuleType:
|
||||||
if linemap is None:
|
if linemap is None:
|
||||||
linemap = []
|
linemap = []
|
||||||
@ -2798,7 +2793,7 @@ class PyCodeCache:
|
|||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def stack_frames_for_code(
|
def stack_frames_for_code(
|
||||||
cls, path: str, lineno: int
|
cls, path: str, lineno: int
|
||||||
) -> Optional[List[Dict[str, Any]]]:
|
) -> Optional[list[dict[str, Any]]]:
|
||||||
if path not in cls.linemaps:
|
if path not in cls.linemaps:
|
||||||
return None
|
return None
|
||||||
# [(starting_line, <fx node>), ...]
|
# [(starting_line, <fx node>), ...]
|
||||||
@ -2810,7 +2805,7 @@ class PyCodeCache:
|
|||||||
if not entry:
|
if not entry:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def parse_stack_trace(stack_trace: str) -> List[Dict[str, Any]]:
|
def parse_stack_trace(stack_trace: str) -> list[dict[str, Any]]:
|
||||||
# ideally fx stores stack traces as data rather than a string
|
# ideally fx stores stack traces as data rather than a string
|
||||||
# but this is not along a performance critical path
|
# but this is not along a performance critical path
|
||||||
regex = r'File "(.+)", line (\d+), in (.+)\n'
|
regex = r'File "(.+)", line (\d+), in (.+)\n'
|
||||||
@ -2841,7 +2836,7 @@ def _cuda_compiler() -> Optional[str]:
|
|||||||
return "nvcc"
|
return "nvcc"
|
||||||
|
|
||||||
|
|
||||||
def _cutlass_include_paths() -> List[str]:
|
def _cutlass_include_paths() -> list[str]:
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
from libfb.py import parutil
|
from libfb.py import parutil
|
||||||
|
|
||||||
@ -2857,14 +2852,14 @@ def _cutlass_include_paths() -> List[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _cuda_lib_options() -> List[str]:
|
def _cuda_lib_options() -> list[str]:
|
||||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||||
from torch.utils import cpp_extension
|
from torch.utils import cpp_extension
|
||||||
|
|
||||||
lpaths = cpp_extension.library_paths(device_type="cuda") + [
|
lpaths = cpp_extension.library_paths(device_type="cuda") + [
|
||||||
sysconfig.get_config_var("LIBDIR")
|
sysconfig.get_config_var("LIBDIR")
|
||||||
]
|
]
|
||||||
extra_ldflags: List[str] = []
|
extra_ldflags: list[str] = []
|
||||||
if is_linux():
|
if is_linux():
|
||||||
_transform_cuda_paths(lpaths)
|
_transform_cuda_paths(lpaths)
|
||||||
for path in lpaths:
|
for path in lpaths:
|
||||||
@ -2880,7 +2875,7 @@ def _cuda_lib_options() -> List[str]:
|
|||||||
return extra_ldflags
|
return extra_ldflags
|
||||||
|
|
||||||
|
|
||||||
def _nvcc_host_compiler_options() -> List[str]:
|
def _nvcc_host_compiler_options() -> list[str]:
|
||||||
return [
|
return [
|
||||||
"-fPIC",
|
"-fPIC",
|
||||||
"-fno-strict-aliasing",
|
"-fno-strict-aliasing",
|
||||||
@ -2889,7 +2884,7 @@ def _nvcc_host_compiler_options() -> List[str]:
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _nvcc_compiler_options() -> List[str]:
|
def _nvcc_compiler_options() -> list[str]:
|
||||||
arch = cuda_env.get_cuda_arch()
|
arch = cuda_env.get_cuda_arch()
|
||||||
if arch == "90":
|
if arch == "90":
|
||||||
# Required by cutlass compilation.
|
# Required by cutlass compilation.
|
||||||
@ -2934,10 +2929,10 @@ def _nvcc_compiler_options() -> List[str]:
|
|||||||
|
|
||||||
|
|
||||||
def cuda_compile_command(
|
def cuda_compile_command(
|
||||||
src_files: List[str],
|
src_files: list[str],
|
||||||
dst_file: str,
|
dst_file: str,
|
||||||
dst_file_ext: str,
|
dst_file_ext: str,
|
||||||
extra_args: Optional[List[str]] = None,
|
extra_args: Optional[list[str]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
if extra_args is None:
|
if extra_args is None:
|
||||||
extra_args = []
|
extra_args = []
|
||||||
@ -3052,7 +3047,7 @@ class CUDACodeCache:
|
|||||||
input_path: str
|
input_path: str
|
||||||
output_path: str
|
output_path: str
|
||||||
|
|
||||||
cache: Dict[str, CacheEntry] = {}
|
cache: dict[str, CacheEntry] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
_SOURCE_CODE_SUFFIX = "cu"
|
_SOURCE_CODE_SUFFIX = "cu"
|
||||||
|
|
||||||
@ -3073,7 +3068,7 @@ class CUDACodeCache:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compile(
|
def compile(
|
||||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
|
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
|
||||||
) -> tuple[str, str, str]:
|
) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Compiles CUDA source_code into a file with dst_file_ext extension.
|
Compiles CUDA source_code into a file with dst_file_ext extension.
|
||||||
@ -3137,7 +3132,7 @@ class ROCmCodeCache:
|
|||||||
input_path: str
|
input_path: str
|
||||||
output_path: str
|
output_path: str
|
||||||
|
|
||||||
cache: Dict[str, CacheEntry] = {}
|
cache: dict[str, CacheEntry] = {}
|
||||||
cache_clear = staticmethod(cache.clear)
|
cache_clear = staticmethod(cache.clear)
|
||||||
_SOURCE_CODE_SUFFIX = "cpp"
|
_SOURCE_CODE_SUFFIX = "cpp"
|
||||||
_logged_compiler_version = False
|
_logged_compiler_version = False
|
||||||
@ -3159,7 +3154,7 @@ class ROCmCodeCache:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def compile(
|
def compile(
|
||||||
cls, source_code: str, dst_file_ext: str, extra_args: Optional[List[str]] = None
|
cls, source_code: str, dst_file_ext: str, extra_args: Optional[list[str]] = None
|
||||||
) -> tuple[str, str, str]:
|
) -> tuple[str, str, str]:
|
||||||
"""
|
"""
|
||||||
Compiles source_code into a file with dst_file_ext extension,
|
Compiles source_code into a file with dst_file_ext extension,
|
||||||
|
@ -7,7 +7,7 @@ import logging
|
|||||||
import operator
|
import operator
|
||||||
import sys
|
import sys
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, TYPE_CHECKING
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.multiprocessing.reductions import StorageWeakRef
|
from torch.multiprocessing.reductions import StorageWeakRef
|
||||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
|||||||
from .scheduler import BaseSchedulerNode
|
from .scheduler import BaseSchedulerNode
|
||||||
|
|
||||||
|
|
||||||
def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||||
"""
|
"""
|
||||||
Greedily schedules waits as late as possible.
|
Greedily schedules waits as late as possible.
|
||||||
"""
|
"""
|
||||||
@ -42,7 +42,7 @@ def sink_waits(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
|
||||||
"""
|
"""
|
||||||
Greedily schedules comms as early as possible.
|
Greedily schedules comms as early as possible.
|
||||||
"""
|
"""
|
||||||
@ -52,8 +52,8 @@ def raise_comms(snodes: List[BaseSchedulerNode]) -> List[BaseSchedulerNode]:
|
|||||||
|
|
||||||
|
|
||||||
def reorder_compute_for_overlap(
|
def reorder_compute_for_overlap(
|
||||||
snodes: List[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> List[BaseSchedulerNode]:
|
) -> list[BaseSchedulerNode]:
|
||||||
"""
|
"""
|
||||||
This achieves the following overall scheduling procedure:
|
This achieves the following overall scheduling procedure:
|
||||||
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
|
||||||
@ -71,11 +71,11 @@ def reorder_compute_for_overlap(
|
|||||||
|
|
||||||
|
|
||||||
def _schedule_for_comm(
|
def _schedule_for_comm(
|
||||||
snodes: List[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
raise_comms: bool,
|
raise_comms: bool,
|
||||||
sink_waits: bool,
|
sink_waits: bool,
|
||||||
reorder_for_overlap: bool,
|
reorder_for_overlap: bool,
|
||||||
) -> List[BaseSchedulerNode]:
|
) -> list[BaseSchedulerNode]:
|
||||||
"""
|
"""
|
||||||
Schedule `snodes` for various comm optimization objectives.
|
Schedule `snodes` for various comm optimization objectives.
|
||||||
|
|
||||||
@ -149,13 +149,13 @@ def _schedule_for_comm(
|
|||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return self.score < other.score
|
return self.score < other.score
|
||||||
|
|
||||||
unmet_deps: Dict[BaseSchedulerNode, OrderedSet[str]] = {
|
unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
|
||||||
snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
|
snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
|
||||||
for snode in snodes
|
for snode in snodes
|
||||||
}
|
}
|
||||||
|
|
||||||
ready: List[Runnable] = []
|
ready: list[Runnable] = []
|
||||||
buffer_users: Dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
|
buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
|
||||||
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
|
snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
|
||||||
|
|
||||||
for snode, deps in unmet_deps.items():
|
for snode, deps in unmet_deps.items():
|
||||||
@ -226,8 +226,8 @@ def _schedule_for_comm(
|
|||||||
|
|
||||||
|
|
||||||
def decide_global_ordering_of_comms(
|
def decide_global_ordering_of_comms(
|
||||||
nodes: List[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
|
||||||
) -> List[BaseSchedulerNode]:
|
) -> list[BaseSchedulerNode]:
|
||||||
"""
|
"""
|
||||||
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
Decide global ordering of comms, by just enforcing the ordering that's in the input graph
|
||||||
(might not be the same ordering as the eager mode program).
|
(might not be the same ordering as the eager mode program).
|
||||||
@ -303,8 +303,8 @@ def visualize_overlap(order):
|
|||||||
|
|
||||||
|
|
||||||
def reorder_compute_and_comm_for_overlap(
|
def reorder_compute_and_comm_for_overlap(
|
||||||
snodes: List[BaseSchedulerNode],
|
snodes: list[BaseSchedulerNode],
|
||||||
) -> List[BaseSchedulerNode]:
|
) -> list[BaseSchedulerNode]:
|
||||||
order = snodes
|
order = snodes
|
||||||
|
|
||||||
for p in config.reorder_for_compute_comm_overlap_passes:
|
for p in config.reorder_for_compute_comm_overlap_passes:
|
||||||
@ -653,10 +653,10 @@ def get_op_idx(snode):
|
|||||||
|
|
||||||
|
|
||||||
def enforce_comm_ordering_for_fsdp(
|
def enforce_comm_ordering_for_fsdp(
|
||||||
snodes: List[torch._inductor.scheduler.BaseSchedulerNode],
|
snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
|
||||||
name_to_buf: Dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
|
||||||
name_to_fused_node: Dict[str, BaseSchedulerNode],
|
name_to_fused_node: dict[str, BaseSchedulerNode],
|
||||||
) -> List[torch._inductor.scheduler.BaseSchedulerNode]:
|
) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
|
||||||
from . import scheduler
|
from . import scheduler
|
||||||
|
|
||||||
new_order: list[BaseSchedulerNode] = []
|
new_order: list[BaseSchedulerNode] = []
|
||||||
|
@ -16,11 +16,7 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -121,6 +117,8 @@ from .virtualized import V
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator, Sequence
|
||||||
|
|
||||||
from torch._inductor.output_code import _StrideExprStr
|
from torch._inductor.output_code import _StrideExprStr
|
||||||
from torch._ops import OpOverload
|
from torch._ops import OpOverload
|
||||||
|
|
||||||
@ -157,7 +155,7 @@ static_inputs_log = torch._logging.getArtifactLogger(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_static_input_idxs(num_fixed: int) -> List[int]:
|
def get_static_input_idxs(num_fixed: int) -> list[int]:
|
||||||
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
|
# If we are inlining NNModules, we treat all torch.nn.Parameters as static for the purposes
|
||||||
# of cudagraphs. Rather than copying these into cudagraph-owned memory
|
# of cudagraphs. Rather than copying these into cudagraph-owned memory
|
||||||
# like we do for normal inputs on each run, we will re-record a cudagraph if these
|
# like we do for normal inputs on each run, we will re-record a cudagraph if these
|
||||||
@ -208,7 +206,7 @@ def _unlift_graph(
|
|||||||
) -> GraphModule:
|
) -> GraphModule:
|
||||||
from torch.export.unflatten import _assign_attr, _AttrKind
|
from torch.export.unflatten import _assign_attr, _AttrKind
|
||||||
|
|
||||||
state_dict: Dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
|
state_dict: dict[str, Union[torch.nn.parameter.Parameter, torch.Tensor]] = {}
|
||||||
for name, param in mod.named_parameters(remove_duplicate=False):
|
for name, param in mod.named_parameters(remove_duplicate=False):
|
||||||
state_dict[name] = param
|
state_dict[name] = param
|
||||||
_assign_attr(
|
_assign_attr(
|
||||||
@ -227,7 +225,7 @@ def _unlift_graph(
|
|||||||
)
|
)
|
||||||
|
|
||||||
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
|
placeholder_nodes = gm.graph.find_nodes(op="placeholder")
|
||||||
lifted_inputs: List[Optional[FQN]] = []
|
lifted_inputs: list[Optional[FQN]] = []
|
||||||
|
|
||||||
# In AOTI, module parameters and buffers are not lifted as graph inputs.
|
# In AOTI, module parameters and buffers are not lifted as graph inputs.
|
||||||
# As a result, mutation to buffers has side effect which makes their initial
|
# As a result, mutation to buffers has side effect which makes their initial
|
||||||
@ -343,9 +341,9 @@ def _recursive_post_grad_passes(gm: GraphModule, is_inference: bool = False) ->
|
|||||||
def split_const_gm(
|
def split_const_gm(
|
||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
skip_constructor: bool = True,
|
skip_constructor: bool = True,
|
||||||
lifted_constant_names: Optional[List[str]] = None,
|
lifted_constant_names: Optional[list[str]] = None,
|
||||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||||
) -> tuple[GraphModule, Dict[str, int]]:
|
) -> tuple[GraphModule, dict[str, int]]:
|
||||||
"""
|
"""
|
||||||
This function takes an GraphModule input "gm".
|
This function takes an GraphModule input "gm".
|
||||||
The gm will be split into 2 components,
|
The gm will be split into 2 components,
|
||||||
@ -488,8 +486,8 @@ def fake_tensor_prop(
|
|||||||
|
|
||||||
# pass config dict back to user
|
# pass config dict back to user
|
||||||
def get_patched_config_dict(
|
def get_patched_config_dict(
|
||||||
config_patches: Optional[Union[str, Dict[str, Any]]] = None
|
config_patches: Optional[Union[str, dict[str, Any]]] = None
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
with config.patch(config_patches):
|
with config.patch(config_patches):
|
||||||
return config.get_config_copy()
|
return config.get_config_copy()
|
||||||
|
|
||||||
@ -515,7 +513,7 @@ class _CompileFxKwargs(TypedDict, total=False):
|
|||||||
aot_mode: bool
|
aot_mode: bool
|
||||||
is_inference: bool
|
is_inference: bool
|
||||||
layout_opt: Optional[bool]
|
layout_opt: Optional[bool]
|
||||||
extern_node_serializer: Optional[Callable[[List[ExternKernelNode]], Any]]
|
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]]
|
||||||
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
boxed_forward_device_index: Optional[BoxedDeviceIndex]
|
||||||
|
|
||||||
|
|
||||||
@ -822,7 +820,7 @@ class _InProcessFxCompile(FxCompile):
|
|||||||
aot_mode: bool = V.aot_compilation
|
aot_mode: bool = V.aot_compilation
|
||||||
is_inference: bool = graph_kwargs.get("is_inference", False)
|
is_inference: bool = graph_kwargs.get("is_inference", False)
|
||||||
extern_node_serializer: Optional[
|
extern_node_serializer: Optional[
|
||||||
Callable[[List[ExternKernelNode]], Any]
|
Callable[[list[ExternKernelNode]], Any]
|
||||||
] = graph_kwargs.get("extern_node_serializer", None)
|
] = graph_kwargs.get("extern_node_serializer", None)
|
||||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
|
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
|
||||||
"boxed_forward_device_index", None
|
"boxed_forward_device_index", None
|
||||||
@ -997,7 +995,7 @@ class _InProcessFxCompile(FxCompile):
|
|||||||
metrics_helper = metrics.CachedMetricsHelper()
|
metrics_helper = metrics.CachedMetricsHelper()
|
||||||
with V.set_graph_handler(graph):
|
with V.set_graph_handler(graph):
|
||||||
graph.run(*example_inputs)
|
graph.run(*example_inputs)
|
||||||
output_strides: List[Optional[tuple[_StrideExprStr, ...]]] = []
|
output_strides: list[Optional[tuple[_StrideExprStr, ...]]] = []
|
||||||
if graph.graph_outputs is not None:
|
if graph.graph_outputs is not None:
|
||||||
# We'll put the output strides in the compiled graph so we
|
# We'll put the output strides in the compiled graph so we
|
||||||
# can later return them to the caller via TracingContext
|
# can later return them to the caller via TracingContext
|
||||||
@ -1189,7 +1187,7 @@ def cudagraphify(
|
|||||||
static_input_idxs: Sequence[int] = (),
|
static_input_idxs: Sequence[int] = (),
|
||||||
*,
|
*,
|
||||||
device_index: int,
|
device_index: int,
|
||||||
stack_traces: List[Optional[str]],
|
stack_traces: list[Optional[str]],
|
||||||
is_backward: bool,
|
is_backward: bool,
|
||||||
is_inference: bool,
|
is_inference: bool,
|
||||||
constants: tuple[torch.Tensor, ...] = (),
|
constants: tuple[torch.Tensor, ...] = (),
|
||||||
@ -1240,7 +1238,7 @@ def static_input(x: torch.Tensor) -> torch.Tensor:
|
|||||||
def index_expanded_dims_and_copy_(
|
def index_expanded_dims_and_copy_(
|
||||||
dst: torch.Tensor,
|
dst: torch.Tensor,
|
||||||
src: torch.Tensor,
|
src: torch.Tensor,
|
||||||
expanded_dims: List[int],
|
expanded_dims: list[int],
|
||||||
) -> None:
|
) -> None:
|
||||||
"Index into expanded dimensions of both dst and src then copy_"
|
"Index into expanded dimensions of both dst and src then copy_"
|
||||||
dst = index_expanded_dims(dst, expanded_dims)
|
dst = index_expanded_dims(dst, expanded_dims)
|
||||||
@ -1250,9 +1248,9 @@ def index_expanded_dims_and_copy_(
|
|||||||
|
|
||||||
def cudagraphify_impl(
|
def cudagraphify_impl(
|
||||||
model: Callable[..., Any],
|
model: Callable[..., Any],
|
||||||
inputs: List[torch.Tensor],
|
inputs: list[torch.Tensor],
|
||||||
static_input_idxs: Sequence[int] = (),
|
static_input_idxs: Sequence[int] = (),
|
||||||
) -> Callable[[List[InputType]], Any]:
|
) -> Callable[[list[InputType]], Any]:
|
||||||
"""
|
"""
|
||||||
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
||||||
"""
|
"""
|
||||||
@ -1304,7 +1302,7 @@ def cudagraphify_impl(
|
|||||||
|
|
||||||
if config.size_asserts:
|
if config.size_asserts:
|
||||||
|
|
||||||
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
|
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
|
||||||
assert len(static_inputs) == len(new_inputs)
|
assert len(static_inputs) == len(new_inputs)
|
||||||
for idx, (dst, src, expanded_dims) in enumerate(
|
for idx, (dst, src, expanded_dims) in enumerate(
|
||||||
zip(static_inputs, new_inputs, inps_expanded_dims)
|
zip(static_inputs, new_inputs, inps_expanded_dims)
|
||||||
@ -1328,7 +1326,7 @@ def cudagraphify_impl(
|
|||||||
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
idx for idx in range(len(static_inputs)) if idx not in static_input_idxs
|
||||||
]
|
]
|
||||||
|
|
||||||
def run(new_inputs: List[InputType]) -> Callable[[List[InputType]], Any]:
|
def run(new_inputs: list[InputType]) -> Callable[[list[InputType]], Any]:
|
||||||
for idx in copy_indices:
|
for idx in copy_indices:
|
||||||
expanded_dims = inps_expanded_dims[idx]
|
expanded_dims = inps_expanded_dims[idx]
|
||||||
src = new_inputs[idx]
|
src = new_inputs[idx]
|
||||||
@ -1343,16 +1341,16 @@ def cudagraphify_impl(
|
|||||||
|
|
||||||
def compile_fx_aot(
|
def compile_fx_aot(
|
||||||
model_: GraphModule,
|
model_: GraphModule,
|
||||||
example_inputs_: List[InputType],
|
example_inputs_: list[InputType],
|
||||||
inner_compile: _CompileFxCallable = compile_fx_inner,
|
inner_compile: _CompileFxCallable = compile_fx_inner,
|
||||||
config_patches: Optional[Dict[str, str]] = None,
|
config_patches: Optional[dict[str, str]] = None,
|
||||||
) -> Union[List[str], str]:
|
) -> Union[list[str], str]:
|
||||||
assert isinstance(model_, GraphModule), model_
|
assert isinstance(model_, GraphModule), model_
|
||||||
|
|
||||||
# [See NOTE] Unwrapping subclasses AOT
|
# [See NOTE] Unwrapping subclasses AOT
|
||||||
unwrap_tensor_subclass_parameters(model_)
|
unwrap_tensor_subclass_parameters(model_)
|
||||||
|
|
||||||
config_patches: Dict[str, Any] = (
|
config_patches: dict[str, Any] = (
|
||||||
{"cpp_wrapper": True}
|
{"cpp_wrapper": True}
|
||||||
if config_patches is None
|
if config_patches is None
|
||||||
else {**config_patches, "cpp_wrapper": True}
|
else {**config_patches, "cpp_wrapper": True}
|
||||||
@ -1409,7 +1407,7 @@ def fw_compiler_freezing(
|
|||||||
cudagraphs: BoxedBool,
|
cudagraphs: BoxedBool,
|
||||||
graph_id: int,
|
graph_id: int,
|
||||||
forward_device: BoxedDeviceIndex,
|
forward_device: BoxedDeviceIndex,
|
||||||
) -> Callable[[List[object]], Sequence[torch.Tensor]]:
|
) -> Callable[[list[object]], Sequence[torch.Tensor]]:
|
||||||
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
|
from torch._inductor.freezing import convert_conv_weights_to_channels_last, freeze
|
||||||
|
|
||||||
# partition_fn won't be called
|
# partition_fn won't be called
|
||||||
@ -1492,7 +1490,7 @@ def fw_compiler_freezing(
|
|||||||
if V.aot_compilation:
|
if V.aot_compilation:
|
||||||
return optimized_function
|
return optimized_function
|
||||||
|
|
||||||
def wrapper(args: List[object]) -> Sequence[torch.Tensor]:
|
def wrapper(args: list[object]) -> Sequence[torch.Tensor]:
|
||||||
args_new = [
|
args_new = [
|
||||||
args[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
|
args[i - unwrapped_args_offsets[min(i, max_offset_idx)]]
|
||||||
for i in preserved_arg_indices
|
for i in preserved_arg_indices
|
||||||
@ -1505,7 +1503,7 @@ def fw_compiler_freezing(
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def get_cpp_wrapper_config() -> Dict[str, object]:
|
def get_cpp_wrapper_config() -> dict[str, object]:
|
||||||
return {
|
return {
|
||||||
# Set autotune_at_compile_time to True as default if the option is not explicitly set
|
# Set autotune_at_compile_time to True as default if the option is not explicitly set
|
||||||
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
|
"triton.autotune_at_compile_time": config.triton.autotune_at_compile_time
|
||||||
@ -1551,9 +1549,9 @@ def compile_fx(
|
|||||||
model_: GraphModule,
|
model_: GraphModule,
|
||||||
example_inputs_: Sequence[InputType],
|
example_inputs_: Sequence[InputType],
|
||||||
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||||
config_patches: Optional[Dict[str, Any]] = None,
|
config_patches: Optional[dict[str, Any]] = None,
|
||||||
decompositions: Optional[Dict[OpOverload, Callable[..., Any]]] = None,
|
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
|
||||||
) -> Union[Callable[[List[object]], Sequence[torch.Tensor]], str, List[str]]:
|
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
|
||||||
"""
|
"""
|
||||||
Main entry point for compiling given FX graph. Despite the fact that this
|
Main entry point for compiling given FX graph. Despite the fact that this
|
||||||
lives in :mod:`torch._inductor`, this function is responsible for calling
|
lives in :mod:`torch._inductor`, this function is responsible for calling
|
||||||
@ -2017,11 +2015,11 @@ def _check_triton_bf16_support(graph: GraphLowering) -> None:
|
|||||||
|
|
||||||
def _aoti_flatten_inputs(
|
def _aoti_flatten_inputs(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
args: Union[List[Any], tuple[Any, ...]],
|
args: Union[list[Any], tuple[Any, ...]],
|
||||||
kwargs: Optional[Dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
*,
|
*,
|
||||||
options: Optional[Dict[str, Any]] = None,
|
options: Optional[dict[str, Any]] = None,
|
||||||
) -> tuple[List[Any], Dict[str, Any]]:
|
) -> tuple[list[Any], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
Flatten the inputs to the graph module and return the flat inputs and options.
|
Flatten the inputs to the graph module and return the flat inputs and options.
|
||||||
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
|
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
|
||||||
|
@ -5,7 +5,7 @@ import importlib
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Type, TypeVar
|
from typing import TypeVar
|
||||||
|
|
||||||
from torch._inductor.async_compile import pre_fork_setup
|
from torch._inductor.async_compile import pre_fork_setup
|
||||||
from torch._inductor.compile_worker.subproc_pool import (
|
from torch._inductor.compile_worker.subproc_pool import (
|
||||||
@ -32,7 +32,7 @@ except ImportError:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def _lookup_and_create_type(base: Type[_T], qname: str) -> _T:
|
def _lookup_and_create_type(base: type[_T], qname: str) -> _T:
|
||||||
"""
|
"""
|
||||||
Given a base type and qualified name: import & lookup that name, check
|
Given a base type and qualified name: import & lookup that name, check
|
||||||
that it's of the given type and then instantiate it.
|
that it's of the given type and then instantiate it.
|
||||||
|
@ -13,7 +13,7 @@ import typing
|
|||||||
from concurrent.futures import Future, ProcessPoolExecutor
|
from concurrent.futures import Future, ProcessPoolExecutor
|
||||||
from concurrent.futures.process import BrokenProcessPool
|
from concurrent.futures.process import BrokenProcessPool
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, BinaryIO, Callable, Dict, Optional, TypeVar
|
from typing import Any, BinaryIO, Callable, Optional, TypeVar
|
||||||
from typing_extensions import Never, ParamSpec
|
from typing_extensions import Never, ParamSpec
|
||||||
|
|
||||||
# _thread_safe_fork is needed because the subprocesses in the pool can read
|
# _thread_safe_fork is needed because the subprocesses in the pool can read
|
||||||
@ -158,7 +158,7 @@ class SubprocPool:
|
|||||||
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
|
self.read_thread = threading.Thread(target=self._read_thread, daemon=True)
|
||||||
|
|
||||||
self.futures_lock = threading.Lock()
|
self.futures_lock = threading.Lock()
|
||||||
self.pending_futures: Dict[int, Future[Any]] = {}
|
self.pending_futures: dict[int, Future[Any]] = {}
|
||||||
self.job_id_count = itertools.count()
|
self.job_id_count = itertools.count()
|
||||||
|
|
||||||
self.running = True
|
self.running = True
|
||||||
|
@ -7,7 +7,7 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Callable, Dict, List, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||||
|
|
||||||
@ -43,7 +43,7 @@ class ConfigChange(BinarySubsystem):
|
|||||||
|
|
||||||
|
|
||||||
# Dictionary of backend -> subsystems
|
# Dictionary of backend -> subsystems
|
||||||
BACKENDS: Dict[str, List[Subsystem]] = {
|
BACKENDS: dict[str, list[Subsystem]] = {
|
||||||
# run dynamo without aot_autograd
|
# run dynamo without aot_autograd
|
||||||
"eager": [],
|
"eager": [],
|
||||||
# run dynamo with aot_autograd, but no partitioner or decomps
|
# run dynamo with aot_autograd, but no partitioner or decomps
|
||||||
@ -68,8 +68,8 @@ BACKENDS: Dict[str, List[Subsystem]] = {
|
|||||||
], # TODO - add more - fusions ?
|
], # TODO - add more - fusions ?
|
||||||
}
|
}
|
||||||
|
|
||||||
subsystem_call_counter: Dict[str, int] = collections.Counter()
|
subsystem_call_counter: dict[str, int] = collections.Counter()
|
||||||
call_counter_debug_info: Dict[int, str] = {}
|
call_counter_debug_info: dict[int, str] = {}
|
||||||
|
|
||||||
|
|
||||||
def reset_counters() -> None:
|
def reset_counters() -> None:
|
||||||
@ -123,13 +123,13 @@ class CompilerBisector:
|
|||||||
return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}"
|
return f"{cache_dir() if not cls.in_process_cache else cls.in_process_cache}/{SUBDIR_NAME}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def write_lines_to_file(cls, file_path: str, lines: List[str]) -> None:
|
def write_lines_to_file(cls, file_path: str, lines: list[str]) -> None:
|
||||||
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
os.makedirs(os.path.dirname(file_path), exist_ok=True)
|
||||||
with open(file_path, "w") as file:
|
with open(file_path, "w") as file:
|
||||||
file.writelines(lines)
|
file.writelines(lines)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def read_lines_from_file(cls, file_path: str) -> List[str]:
|
def read_lines_from_file(cls, file_path: str) -> list[str]:
|
||||||
if os.path.exists(file_path):
|
if os.path.exists(file_path):
|
||||||
with open(file_path) as file:
|
with open(file_path) as file:
|
||||||
return file.readlines()
|
return file.readlines()
|
||||||
@ -154,7 +154,7 @@ class CompilerBisector:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_config_values(
|
def set_config_values(
|
||||||
cls, backend: str, subsystem: str, config_data: Dict[str, object]
|
cls, backend: str, subsystem: str, config_data: dict[str, object]
|
||||||
) -> None:
|
) -> None:
|
||||||
file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt")
|
file_path = os.path.join(cls.get_dir(), backend, f"{subsystem}_config.txt")
|
||||||
lines = [f"{k}={v}\n" for k, v in config_data.items()]
|
lines = [f"{k}={v}\n" for k, v in config_data.items()]
|
||||||
@ -267,7 +267,7 @@ class CompilerBisector:
|
|||||||
cls.write_lines_to_file(file_path, lines)
|
cls.write_lines_to_file(file_path, lines)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_config_change(cls, config_name: str) -> Optional[Dict[str, object]]:
|
def get_config_change(cls, config_name: str) -> Optional[dict[str, object]]:
|
||||||
backend = cls.get_backend()
|
backend = cls.get_backend()
|
||||||
subsystem = cls.get_subsystem()
|
subsystem = cls.get_subsystem()
|
||||||
|
|
||||||
|
@ -1,16 +1,6 @@
|
|||||||
import os # noqa: C101
|
import os # noqa: C101
|
||||||
import sys
|
import sys
|
||||||
from typing import (
|
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
Dict,
|
|
||||||
List,
|
|
||||||
Literal,
|
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._inductor.custom_graph_pass
|
import torch._inductor.custom_graph_pass
|
||||||
@ -193,8 +183,8 @@ pre_grad_custom_pass: Optional[Callable[[torch.fx.graph.Graph], None]] = None
|
|||||||
# hence custom IR passes built on top of it might break in the future.
|
# hence custom IR passes built on top of it might break in the future.
|
||||||
_pre_fusion_custom_pass: Optional[
|
_pre_fusion_custom_pass: Optional[
|
||||||
Callable[
|
Callable[
|
||||||
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
|
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||||
List["torch._inductor.scheduler.BaseSchedulerNode"],
|
list["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||||
]
|
]
|
||||||
] = None
|
] = None
|
||||||
|
|
||||||
@ -231,11 +221,11 @@ batch_fusion = True
|
|||||||
# merge_splits_pass
|
# merge_splits_pass
|
||||||
# mutate_cat_pass
|
# mutate_cat_pass
|
||||||
# split_cat_pass
|
# split_cat_pass
|
||||||
pre_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
|
pre_grad_fusion_options: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
# Post grad fusion and options, set to empty dict to disable fusion.
|
# Post grad fusion and options, set to empty dict to disable fusion.
|
||||||
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
|
# Call `torch._inductor.fx_passes.group_batch_fusion.list_group_batch_fusions(False)` to see available fusions.
|
||||||
post_grad_fusion_options: Dict[str, Dict[str, Any]] = {}
|
post_grad_fusion_options: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
# enable reordering pass for improving memory locality
|
# enable reordering pass for improving memory locality
|
||||||
reorder_for_locality = True
|
reorder_for_locality = True
|
||||||
@ -257,7 +247,7 @@ use_mixed_mm = True
|
|||||||
# floating point numbers,about 16 decimal digits for double precision floating point numbers)
|
# floating point numbers,about 16 decimal digits for double precision floating point numbers)
|
||||||
# according to PyTorch documentation.
|
# according to PyTorch documentation.
|
||||||
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
|
# https://pytorch.org/docs/stable/notes/numerical_accuracy.html#batched-computations-or-slice-computations
|
||||||
fx_passes_numeric_check: Dict[str, Any] = {
|
fx_passes_numeric_check: dict[str, Any] = {
|
||||||
"pre_grad": False,
|
"pre_grad": False,
|
||||||
"precision": 1e-4,
|
"precision": 1e-4,
|
||||||
"num_iterations": 1,
|
"num_iterations": 1,
|
||||||
@ -287,12 +277,12 @@ reorder_for_compute_comm_overlap = False
|
|||||||
# for built-in passes, use string name; for user-defined passes, pass in the function handle
|
# for built-in passes, use string name; for user-defined passes, pass in the function handle
|
||||||
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
|
# WARNING: Inductor scheduler IR is at prototype stage and subject to change,
|
||||||
# hence custom IR passes built on top of it might break in the future.
|
# hence custom IR passes built on top of it might break in the future.
|
||||||
reorder_for_compute_comm_overlap_passes: List[
|
reorder_for_compute_comm_overlap_passes: list[
|
||||||
Union[
|
Union[
|
||||||
str,
|
str,
|
||||||
Callable[
|
Callable[
|
||||||
[List["torch._inductor.scheduler.BaseSchedulerNode"]],
|
[list["torch._inductor.scheduler.BaseSchedulerNode"]],
|
||||||
List["torch._inductor.scheduler.BaseSchedulerNode"],
|
list["torch._inductor.scheduler.BaseSchedulerNode"],
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = [
|
] = [
|
||||||
@ -618,7 +608,7 @@ _fuse_ddp_bucket_size = 25
|
|||||||
# overlapping. At this moment, this pass performs better than
|
# overlapping. At this moment, this pass performs better than
|
||||||
# reorder_for_compute_comm_overlap_passes but we will add the logic of
|
# reorder_for_compute_comm_overlap_passes but we will add the logic of
|
||||||
# "schedule_comm_wait" in the future and remove the one here.
|
# "schedule_comm_wait" in the future and remove the one here.
|
||||||
_fuse_ddp_communication_passes: List[Union[Callable[..., None], str]] = [
|
_fuse_ddp_communication_passes: list[Union[Callable[..., None], str]] = [
|
||||||
"fuse_ddp_with_concat_op",
|
"fuse_ddp_with_concat_op",
|
||||||
"schedule_comm_wait",
|
"schedule_comm_wait",
|
||||||
]
|
]
|
||||||
@ -852,7 +842,7 @@ class cpp:
|
|||||||
simdlen: Optional[int] = None
|
simdlen: Optional[int] = None
|
||||||
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
|
min_chunk_size = int(os.environ.get("TORCHINDUCTOR_CPP_MIN_CHUNK_SIZE", "4096"))
|
||||||
|
|
||||||
cxx: Tuple[None, str] = (
|
cxx: tuple[Literal[None], str] = (
|
||||||
None, # download gcc12 from conda-forge if conda is installed
|
None, # download gcc12 from conda-forge if conda is installed
|
||||||
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
|
os.environ.get("CXX", "clang++" if sys.platform == "darwin" else "g++"),
|
||||||
) # type: ignore[assignment]
|
) # type: ignore[assignment]
|
||||||
@ -1157,7 +1147,7 @@ class aot_inductor:
|
|||||||
|
|
||||||
# Dictionary of metadata users might want to save to pass to the runtime.
|
# Dictionary of metadata users might want to save to pass to the runtime.
|
||||||
# TODO: Move this somewhere else, since it's no longer really a config
|
# TODO: Move this somewhere else, since it's no longer really a config
|
||||||
metadata: Dict[str, str] = {}
|
metadata: dict[str, str] = {}
|
||||||
|
|
||||||
# fbcode only. Whether to raise error if C++ codegen is too big to optimize
|
# fbcode only. Whether to raise error if C++ codegen is too big to optimize
|
||||||
raise_error_on_ignored_optimization: bool = (
|
raise_error_on_ignored_optimization: bool = (
|
||||||
@ -1168,7 +1158,7 @@ class aot_inductor:
|
|||||||
dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
|
dump_aoti_minifier: bool = os.environ.get("DUMP_AOTI_MINIFIER", "0") == "1"
|
||||||
|
|
||||||
# Dictionary of presets that can be passed in
|
# Dictionary of presets that can be passed in
|
||||||
presets: Dict[str, Any] = {}
|
presets: dict[str, Any] = {}
|
||||||
|
|
||||||
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
|
# Kill switch for allowing temporary tensors to be allocated as stack arrays. Tests
|
||||||
# should be run with this flag both on and off to make sure we have coverage.
|
# should be run with this flag both on and off to make sure we have coverage.
|
||||||
@ -1265,11 +1255,11 @@ class cuda:
|
|||||||
class rocm:
|
class rocm:
|
||||||
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
|
# Offload arch list for device code compilation, e.g. ["gfx941", "gfx942"].
|
||||||
# If empty, the `native` arch is used
|
# If empty, the `native` arch is used
|
||||||
arch: List[str] = []
|
arch: list[str] = []
|
||||||
|
|
||||||
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
|
# Enable the CK backend for CDNA2 and CDNA3 only (for now)
|
||||||
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
# Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors
|
||||||
ck_supported_arch: List[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
ck_supported_arch: list[str] = ["gfx90a", "gfx940", "gfx941", "gfx942"]
|
||||||
|
|
||||||
# Optimization level, use to balance compilation speed and runtime performance.
|
# Optimization level, use to balance compilation speed and runtime performance.
|
||||||
# The type will not necessarily be comprehensive and won't be enforced at runtime.
|
# The type will not necessarily be comprehensive and won't be enforced at runtime.
|
||||||
@ -1415,7 +1405,7 @@ class trace:
|
|||||||
log_inductor_triton_kernel_to_post_grad_node_info: bool = True
|
log_inductor_triton_kernel_to_post_grad_node_info: bool = True
|
||||||
|
|
||||||
|
|
||||||
_save_config_ignore: List[str] = [
|
_save_config_ignore: list[str] = [
|
||||||
# workaround: "Can't pickle <function ...>"
|
# workaround: "Can't pickle <function ...>"
|
||||||
"trace.upload_tar",
|
"trace.upload_tar",
|
||||||
"joint_custom_pre_pass",
|
"joint_custom_pre_pass",
|
||||||
@ -1423,7 +1413,7 @@ _save_config_ignore: List[str] = [
|
|||||||
"pre_grad_custom_pass",
|
"pre_grad_custom_pass",
|
||||||
]
|
]
|
||||||
|
|
||||||
_cache_config_ignore_prefix: List[str] = [
|
_cache_config_ignore_prefix: list[str] = [
|
||||||
# trace functions are not relevant to config caching
|
# trace functions are not relevant to config caching
|
||||||
"trace",
|
"trace",
|
||||||
# uses absolute path
|
# uses absolute path
|
||||||
@ -1439,7 +1429,7 @@ _cache_config_ignore_prefix: List[str] = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
# External callable for matmul tuning candidates
|
# External callable for matmul tuning candidates
|
||||||
external_matmul: List[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
|
external_matmul: list[Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None]] = []
|
||||||
|
|
||||||
|
|
||||||
class test_configs:
|
class test_configs:
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
import collections
|
import collections
|
||||||
from typing import Any, Callable, Dict, List, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -57,7 +57,7 @@ def replace_node_with_constant(
|
|||||||
|
|
||||||
|
|
||||||
def is_const_source(
|
def is_const_source(
|
||||||
node: torch.fx.Node, lifted_constant_names: Optional[List[str]]
|
node: torch.fx.Node, lifted_constant_names: Optional[list[str]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
return node.op == "get_attr" or node.name in (lifted_constant_names or ())
|
return node.op == "get_attr" or node.name in (lifted_constant_names or ())
|
||||||
|
|
||||||
@ -67,12 +67,12 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||||||
self,
|
self,
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
skip_constructors: bool = False,
|
skip_constructors: bool = False,
|
||||||
lifted_constant_names: Optional[List[str]] = None,
|
lifted_constant_names: Optional[list[str]] = None,
|
||||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(gm)
|
super().__init__(gm)
|
||||||
self.node_replacements: Dict[torch.fx.Node, Any] = {}
|
self.node_replacements: dict[torch.fx.Node, Any] = {}
|
||||||
self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter()
|
self.replaced_uses: dict[torch.fx.Node, int] = collections.Counter()
|
||||||
self.unknown_value = object()
|
self.unknown_value = object()
|
||||||
self.skip_constructors: bool = skip_constructors
|
self.skip_constructors: bool = skip_constructors
|
||||||
|
|
||||||
@ -141,7 +141,7 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def node_to_last_non_output_use(self) -> Dict[torch.fx.Node, List[torch.fx.Node]]:
|
def node_to_last_non_output_use(self) -> dict[torch.fx.Node, list[torch.fx.Node]]:
|
||||||
last_non_output_use = collections.defaultdict(list)
|
last_non_output_use = collections.defaultdict(list)
|
||||||
seen_uses = OrderedSet[torch.fx.Node]()
|
seen_uses = OrderedSet[torch.fx.Node]()
|
||||||
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
|
output_node = next(iter(reversed(self.module.graph.nodes))) # type: ignore[arg-type, union-attr]
|
||||||
@ -264,11 +264,11 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||||||
self.node_replacements[node] = tensor
|
self.node_replacements[node] = tensor
|
||||||
|
|
||||||
def run(self) -> Any: # type: ignore[override]
|
def run(self) -> Any: # type: ignore[override]
|
||||||
env: Dict[torch.fx.Node, Any] = {}
|
env: dict[torch.fx.Node, Any] = {}
|
||||||
self.insert_placerholder_values(env)
|
self.insert_placerholder_values(env)
|
||||||
return super().run(initial_env=env)
|
return super().run(initial_env=env)
|
||||||
|
|
||||||
def insert_placerholder_values(self, env: Dict[torch.fx.Node, Any]) -> None:
|
def insert_placerholder_values(self, env: dict[torch.fx.Node, Any]) -> None:
|
||||||
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
for n in self.module.graph.find_nodes(op="placeholder"): # type: ignore[operator, union-attr]
|
||||||
env[n] = self.unknown_value # type: ignore[assignment]
|
env[n] = self.unknown_value # type: ignore[assignment]
|
||||||
if self.lifted_constant_names is None:
|
if self.lifted_constant_names is None:
|
||||||
@ -309,7 +309,7 @@ def constant_fold(
|
|||||||
def constant_graph_tag(
|
def constant_graph_tag(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
skip_constructors: bool = True,
|
skip_constructors: bool = True,
|
||||||
lifted_constant_names: Optional[List[str]] = None,
|
lifted_constant_names: Optional[list[str]] = None,
|
||||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
with torch.utils._python_dispatch._disable_current_modes():
|
with torch.utils._python_dispatch._disable_current_modes():
|
||||||
@ -337,7 +337,7 @@ def constant_graph_tag(
|
|||||||
def run_and_get_constant_graph(
|
def run_and_get_constant_graph(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
skip_constructors: bool = True,
|
skip_constructors: bool = True,
|
||||||
lifted_constant_names: Optional[List[str]] = None,
|
lifted_constant_names: Optional[list[str]] = None,
|
||||||
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None,
|
||||||
) -> torch.fx.GraphModule:
|
) -> torch.fx.GraphModule:
|
||||||
"""
|
"""
|
||||||
@ -367,7 +367,7 @@ def run_and_get_constant_graph(
|
|||||||
|
|
||||||
new_graph = torch.fx.Graph()
|
new_graph = torch.fx.Graph()
|
||||||
|
|
||||||
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {}
|
node_remapping: dict[torch.fx.Node, torch.fx.Node] = {}
|
||||||
output_nodes = []
|
output_nodes = []
|
||||||
for node in gm.graph.nodes:
|
for node in gm.graph.nodes:
|
||||||
if node.meta[META_TAG] == MODULE_TAG:
|
if node.meta[META_TAG] == MODULE_TAG:
|
||||||
|
@ -16,10 +16,11 @@ import sys
|
|||||||
import sysconfig
|
import sysconfig
|
||||||
import textwrap
|
import textwrap
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Sequence
|
||||||
from ctypes import cdll
|
from ctypes import cdll
|
||||||
from ctypes.util import find_library
|
from ctypes.util import find_library
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, List, Optional, Sequence, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._dynamo.utils import dynamo_timed
|
from torch._dynamo.utils import dynamo_timed
|
||||||
@ -285,12 +286,12 @@ def get_compiler_version_info(compiler: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
# =============================== cpp builder ===============================
|
# =============================== cpp builder ===============================
|
||||||
def _append_list(dest_list: List[str], src_list: List[str]) -> None:
|
def _append_list(dest_list: list[str], src_list: list[str]) -> None:
|
||||||
dest_list.extend(copy.deepcopy(item) for item in src_list)
|
dest_list.extend(copy.deepcopy(item) for item in src_list)
|
||||||
|
|
||||||
|
|
||||||
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
|
def _remove_duplication_in_list(orig_list: list[str]) -> list[str]:
|
||||||
new_list: List[str] = []
|
new_list: list[str] = []
|
||||||
for item in orig_list:
|
for item in orig_list:
|
||||||
if item not in new_list:
|
if item not in new_list:
|
||||||
new_list.append(item)
|
new_list.append(item)
|
||||||
@ -362,26 +363,26 @@ class BuildOptionsBase:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
compiler: str = "",
|
compiler: str = "",
|
||||||
definitions: Optional[List[str]] = None,
|
definitions: Optional[list[str]] = None,
|
||||||
include_dirs: Optional[List[str]] = None,
|
include_dirs: Optional[list[str]] = None,
|
||||||
cflags: Optional[List[str]] = None,
|
cflags: Optional[list[str]] = None,
|
||||||
ldflags: Optional[List[str]] = None,
|
ldflags: Optional[list[str]] = None,
|
||||||
libraries_dirs: Optional[List[str]] = None,
|
libraries_dirs: Optional[list[str]] = None,
|
||||||
libraries: Optional[List[str]] = None,
|
libraries: Optional[list[str]] = None,
|
||||||
passthrough_args: Optional[List[str]] = None,
|
passthrough_args: Optional[list[str]] = None,
|
||||||
aot_mode: bool = False,
|
aot_mode: bool = False,
|
||||||
use_absolute_path: bool = False,
|
use_absolute_path: bool = False,
|
||||||
compile_only: bool = False,
|
compile_only: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
self._compiler = compiler
|
self._compiler = compiler
|
||||||
self._definations: List[str] = definitions or []
|
self._definations: list[str] = definitions or []
|
||||||
self._include_dirs: List[str] = include_dirs or []
|
self._include_dirs: list[str] = include_dirs or []
|
||||||
self._cflags: List[str] = cflags or []
|
self._cflags: list[str] = cflags or []
|
||||||
self._ldflags: List[str] = ldflags or []
|
self._ldflags: list[str] = ldflags or []
|
||||||
self._libraries_dirs: List[str] = libraries_dirs or []
|
self._libraries_dirs: list[str] = libraries_dirs or []
|
||||||
self._libraries: List[str] = libraries or []
|
self._libraries: list[str] = libraries or []
|
||||||
# Some args is hard to abstract to OS compatable, passthrough it directly.
|
# Some args is hard to abstract to OS compatable, passthrough it directly.
|
||||||
self._passthrough_args: List[str] = passthrough_args or []
|
self._passthrough_args: list[str] = passthrough_args or []
|
||||||
|
|
||||||
self._aot_mode: bool = aot_mode
|
self._aot_mode: bool = aot_mode
|
||||||
self._use_absolute_path: bool = use_absolute_path
|
self._use_absolute_path: bool = use_absolute_path
|
||||||
@ -408,25 +409,25 @@ class BuildOptionsBase:
|
|||||||
def get_compiler(self) -> str:
|
def get_compiler(self) -> str:
|
||||||
return self._compiler
|
return self._compiler
|
||||||
|
|
||||||
def get_definations(self) -> List[str]:
|
def get_definations(self) -> list[str]:
|
||||||
return self._definations
|
return self._definations
|
||||||
|
|
||||||
def get_include_dirs(self) -> List[str]:
|
def get_include_dirs(self) -> list[str]:
|
||||||
return self._include_dirs
|
return self._include_dirs
|
||||||
|
|
||||||
def get_cflags(self) -> List[str]:
|
def get_cflags(self) -> list[str]:
|
||||||
return self._cflags
|
return self._cflags
|
||||||
|
|
||||||
def get_ldflags(self) -> List[str]:
|
def get_ldflags(self) -> list[str]:
|
||||||
return self._ldflags
|
return self._ldflags
|
||||||
|
|
||||||
def get_libraries_dirs(self) -> List[str]:
|
def get_libraries_dirs(self) -> list[str]:
|
||||||
return self._libraries_dirs
|
return self._libraries_dirs
|
||||||
|
|
||||||
def get_libraries(self) -> List[str]:
|
def get_libraries(self) -> list[str]:
|
||||||
return self._libraries
|
return self._libraries
|
||||||
|
|
||||||
def get_passthrough_args(self) -> List[str]:
|
def get_passthrough_args(self) -> list[str]:
|
||||||
return self._passthrough_args
|
return self._passthrough_args
|
||||||
|
|
||||||
def get_aot_mode(self) -> bool:
|
def get_aot_mode(self) -> bool:
|
||||||
@ -457,14 +458,14 @@ class BuildOptionsBase:
|
|||||||
json.dump(attrs, f)
|
json.dump(attrs, f)
|
||||||
|
|
||||||
|
|
||||||
def _get_warning_all_cflag(warning_all: bool = True) -> List[str]:
|
def _get_warning_all_cflag(warning_all: bool = True) -> list[str]:
|
||||||
if not _IS_WINDOWS:
|
if not _IS_WINDOWS:
|
||||||
return ["Wall"] if warning_all else []
|
return ["Wall"] if warning_all else []
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
|
def _get_cpp_std_cflag(std_num: str = "c++17") -> list[str]:
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
"""
|
"""
|
||||||
On Windows, only c++20 can support `std::enable_if_t`.
|
On Windows, only c++20 can support `std::enable_if_t`.
|
||||||
@ -479,7 +480,7 @@ def _get_cpp_std_cflag(std_num: str = "c++17") -> List[str]:
|
|||||||
return [f"std={std_num}"]
|
return [f"std={std_num}"]
|
||||||
|
|
||||||
|
|
||||||
def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
|
def _get_os_related_cpp_cflags(cpp_compiler: str) -> list[str]:
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
cflags = [
|
cflags = [
|
||||||
"wd4819",
|
"wd4819",
|
||||||
@ -506,7 +507,7 @@ def _get_os_related_cpp_cflags(cpp_compiler: str) -> List[str]:
|
|||||||
return cflags
|
return cflags
|
||||||
|
|
||||||
|
|
||||||
def _get_ffast_math_flags() -> List[str]:
|
def _get_ffast_math_flags() -> list[str]:
|
||||||
# ffast-math is equivalent to these flags as in
|
# ffast-math is equivalent to these flags as in
|
||||||
# https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468
|
# https://github.com/gcc-mirror/gcc/blob/4700ad1c78ccd7767f846802fca148b2ea9a1852/gcc/opts.cc#L3458-L3468
|
||||||
# however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have
|
# however gcc<13 sets the FTZ/DAZ flags for runtime on x86 even if we have
|
||||||
@ -527,7 +528,7 @@ def _get_ffast_math_flags() -> List[str]:
|
|||||||
return flags
|
return flags
|
||||||
|
|
||||||
|
|
||||||
def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
|
def _get_optimization_cflags(cpp_compiler: str) -> list[str]:
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
return ["O2"]
|
return ["O2"]
|
||||||
else:
|
else:
|
||||||
@ -554,7 +555,7 @@ def _get_optimization_cflags(cpp_compiler: str) -> List[str]:
|
|||||||
return cflags
|
return cflags
|
||||||
|
|
||||||
|
|
||||||
def _get_shared_cflag(compile_only: bool) -> List[str]:
|
def _get_shared_cflag(compile_only: bool) -> list[str]:
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
"""
|
"""
|
||||||
MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
|
MSVC `/MD` using python `ucrtbase.dll` lib as runtime.
|
||||||
@ -578,14 +579,14 @@ def get_cpp_options(
|
|||||||
compile_only: bool,
|
compile_only: bool,
|
||||||
warning_all: bool = True,
|
warning_all: bool = True,
|
||||||
extra_flags: Sequence[str] = (),
|
extra_flags: Sequence[str] = (),
|
||||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
|
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||||
definations: List[str] = []
|
definations: list[str] = []
|
||||||
include_dirs: List[str] = []
|
include_dirs: list[str] = []
|
||||||
cflags: List[str] = []
|
cflags: list[str] = []
|
||||||
ldflags: List[str] = []
|
ldflags: list[str] = []
|
||||||
libraries_dirs: List[str] = []
|
libraries_dirs: list[str] = []
|
||||||
libraries: List[str] = []
|
libraries: list[str] = []
|
||||||
passthrough_args: List[str] = []
|
passthrough_args: list[str] = []
|
||||||
|
|
||||||
cflags = (
|
cflags = (
|
||||||
_get_shared_cflag(compile_only)
|
_get_shared_cflag(compile_only)
|
||||||
@ -657,22 +658,22 @@ class CppOptions(BuildOptionsBase):
|
|||||||
self._finalize_options()
|
self._finalize_options()
|
||||||
|
|
||||||
|
|
||||||
def _get_glibcxx_abi_build_flags() -> List[str]:
|
def _get_glibcxx_abi_build_flags() -> list[str]:
|
||||||
if not _IS_WINDOWS:
|
if not _IS_WINDOWS:
|
||||||
return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
|
return ["-D_GLIBCXX_USE_CXX11_ABI=" + str(int(torch._C._GLIBCXX_USE_CXX11_ABI))]
|
||||||
else:
|
else:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _get_torch_cpp_wrapper_defination() -> List[str]:
|
def _get_torch_cpp_wrapper_defination() -> list[str]:
|
||||||
return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"]
|
return ["TORCH_INDUCTOR_CPP_WRAPPER", "STANDALONE_TORCH_HEADER"]
|
||||||
|
|
||||||
|
|
||||||
def _use_custom_generated_macros() -> List[str]:
|
def _use_custom_generated_macros() -> list[str]:
|
||||||
return [" C10_USING_CUSTOM_GENERATED_MACROS"]
|
return [" C10_USING_CUSTOM_GENERATED_MACROS"]
|
||||||
|
|
||||||
|
|
||||||
def _use_fb_internal_macros() -> List[str]:
|
def _use_fb_internal_macros() -> list[str]:
|
||||||
if not _IS_WINDOWS:
|
if not _IS_WINDOWS:
|
||||||
if config.is_fbcode():
|
if config.is_fbcode():
|
||||||
fb_internal_macros = [
|
fb_internal_macros = [
|
||||||
@ -697,12 +698,12 @@ def _setup_standard_sys_libs(
|
|||||||
cpp_compiler: str,
|
cpp_compiler: str,
|
||||||
aot_mode: bool,
|
aot_mode: bool,
|
||||||
use_absolute_path: bool,
|
use_absolute_path: bool,
|
||||||
) -> tuple[List[str], List[str], List[str]]:
|
) -> tuple[list[str], list[str], list[str]]:
|
||||||
from torch._inductor.codecache import _LINKER_SCRIPT
|
from torch._inductor.codecache import _LINKER_SCRIPT
|
||||||
|
|
||||||
cflags: List[str] = []
|
cflags: list[str] = []
|
||||||
include_dirs: List[str] = []
|
include_dirs: list[str] = []
|
||||||
passthrough_args: List[str] = []
|
passthrough_args: list[str] = []
|
||||||
if _IS_WINDOWS:
|
if _IS_WINDOWS:
|
||||||
return cflags, include_dirs, passthrough_args
|
return cflags, include_dirs, passthrough_args
|
||||||
|
|
||||||
@ -737,9 +738,9 @@ def _setup_standard_sys_libs(
|
|||||||
return cflags, include_dirs, passthrough_args
|
return cflags, include_dirs, passthrough_args
|
||||||
|
|
||||||
|
|
||||||
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]]:
|
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[list[str], list[str]]:
|
||||||
macros: List[str] = []
|
macros: list[str] = []
|
||||||
build_flags: List[str] = []
|
build_flags: list[str] = []
|
||||||
if vec_isa != invalid_vec_isa:
|
if vec_isa != invalid_vec_isa:
|
||||||
# Add Windows support later.
|
# Add Windows support later.
|
||||||
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
|
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
|
||||||
@ -759,7 +760,7 @@ def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> tuple[List[str], List[str]
|
|||||||
|
|
||||||
def _get_torch_related_args(
|
def _get_torch_related_args(
|
||||||
include_pytorch: bool, aot_mode: bool
|
include_pytorch: bool, aot_mode: bool
|
||||||
) -> tuple[List[str], List[str], List[str]]:
|
) -> tuple[list[str], list[str], list[str]]:
|
||||||
from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH
|
from torch.utils.cpp_extension import _TORCH_PATH, TORCH_LIB_PATH
|
||||||
|
|
||||||
include_dirs = [
|
include_dirs = [
|
||||||
@ -783,7 +784,7 @@ def _get_torch_related_args(
|
|||||||
return include_dirs, libraries_dirs, libraries
|
return include_dirs, libraries_dirs, libraries
|
||||||
|
|
||||||
|
|
||||||
def _get_python_include_dirs() -> List[str]:
|
def _get_python_include_dirs() -> list[str]:
|
||||||
include_dir = Path(sysconfig.get_path("include"))
|
include_dir = Path(sysconfig.get_path("include"))
|
||||||
# On Darwin Python executable from a framework can return
|
# On Darwin Python executable from a framework can return
|
||||||
# non-existing /Library/Python/... include path, in which case
|
# non-existing /Library/Python/... include path, in which case
|
||||||
@ -796,7 +797,7 @@ def _get_python_include_dirs() -> List[str]:
|
|||||||
return [str(include_dir)]
|
return [str(include_dir)]
|
||||||
|
|
||||||
|
|
||||||
def _get_python_related_args() -> tuple[List[str], List[str]]:
|
def _get_python_related_args() -> tuple[list[str], list[str]]:
|
||||||
python_include_dirs = _get_python_include_dirs()
|
python_include_dirs = _get_python_include_dirs()
|
||||||
python_include_path = sysconfig.get_path(
|
python_include_path = sysconfig.get_path(
|
||||||
"include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
|
"include", scheme="nt" if _IS_WINDOWS else "posix_prefix"
|
||||||
@ -893,13 +894,13 @@ def perload_icx_libomp_win(cpp_compiler: str) -> None:
|
|||||||
|
|
||||||
def _get_openmp_args(
|
def _get_openmp_args(
|
||||||
cpp_compiler: str,
|
cpp_compiler: str,
|
||||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str]]:
|
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||||
cflags: List[str] = []
|
cflags: list[str] = []
|
||||||
ldflags: List[str] = []
|
ldflags: list[str] = []
|
||||||
include_dir_paths: List[str] = []
|
include_dir_paths: list[str] = []
|
||||||
lib_dir_paths: List[str] = []
|
lib_dir_paths: list[str] = []
|
||||||
libs: List[str] = []
|
libs: list[str] = []
|
||||||
passthrough_args: List[str] = []
|
passthrough_args: list[str] = []
|
||||||
if _IS_MACOS:
|
if _IS_MACOS:
|
||||||
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
|
# Per https://mac.r-project.org/openmp/ right way to pass `openmp` flags to MacOS is via `-Xclang`
|
||||||
cflags.append("Xclang")
|
cflags.append("Xclang")
|
||||||
@ -998,7 +999,7 @@ def _get_openmp_args(
|
|||||||
return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args
|
return cflags, ldflags, include_dir_paths, lib_dir_paths, libs, passthrough_args
|
||||||
|
|
||||||
|
|
||||||
def get_mmap_self_macro(use_mmap_weights: bool) -> List[str]:
|
def get_mmap_self_macro(use_mmap_weights: bool) -> list[str]:
|
||||||
macros = []
|
macros = []
|
||||||
if use_mmap_weights:
|
if use_mmap_weights:
|
||||||
macros.append(" USE_MMAP_SELF")
|
macros.append(" USE_MMAP_SELF")
|
||||||
@ -1013,14 +1014,14 @@ def get_cpp_torch_options(
|
|||||||
compile_only: bool,
|
compile_only: bool,
|
||||||
use_absolute_path: bool,
|
use_absolute_path: bool,
|
||||||
use_mmap_weights: bool,
|
use_mmap_weights: bool,
|
||||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
|
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||||
definations: List[str] = []
|
definations: list[str] = []
|
||||||
include_dirs: List[str] = []
|
include_dirs: list[str] = []
|
||||||
cflags: List[str] = []
|
cflags: list[str] = []
|
||||||
ldflags: List[str] = []
|
ldflags: list[str] = []
|
||||||
libraries_dirs: List[str] = []
|
libraries_dirs: list[str] = []
|
||||||
libraries: List[str] = []
|
libraries: list[str] = []
|
||||||
passthrough_args: List[str] = []
|
passthrough_args: list[str] = []
|
||||||
|
|
||||||
torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination()
|
torch_cpp_wrapper_definations = _get_torch_cpp_wrapper_defination()
|
||||||
use_custom_generated_macros_definations = _use_custom_generated_macros()
|
use_custom_generated_macros_definations = _use_custom_generated_macros()
|
||||||
@ -1163,7 +1164,7 @@ def _set_gpu_runtime_env() -> None:
|
|||||||
os.environ["CUDA_HOME"] = build_paths.sdk_home
|
os.environ["CUDA_HOME"] = build_paths.sdk_home
|
||||||
|
|
||||||
|
|
||||||
def _transform_cuda_paths(lpaths: List[str]) -> None:
|
def _transform_cuda_paths(lpaths: list[str]) -> None:
|
||||||
# This handles two cases:
|
# This handles two cases:
|
||||||
# 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs
|
# 1. Cases where libs are in (e.g.) lib/cuda-12 and lib/cuda-12/stubs
|
||||||
# 2. Linux machines may have CUDA installed under either lib64/ or lib/
|
# 2. Linux machines may have CUDA installed under either lib64/ or lib/
|
||||||
@ -1186,14 +1187,14 @@ def get_cpp_torch_device_options(
|
|||||||
device_type: str,
|
device_type: str,
|
||||||
aot_mode: bool = False,
|
aot_mode: bool = False,
|
||||||
compile_only: bool = False,
|
compile_only: bool = False,
|
||||||
) -> tuple[List[str], List[str], List[str], List[str], List[str], List[str], List[str]]:
|
) -> tuple[list[str], list[str], list[str], list[str], list[str], list[str], list[str]]:
|
||||||
definations: List[str] = []
|
definations: list[str] = []
|
||||||
include_dirs: List[str] = []
|
include_dirs: list[str] = []
|
||||||
cflags: List[str] = []
|
cflags: list[str] = []
|
||||||
ldflags: List[str] = []
|
ldflags: list[str] = []
|
||||||
libraries_dirs: List[str] = []
|
libraries_dirs: list[str] = []
|
||||||
libraries: List[str] = []
|
libraries: list[str] = []
|
||||||
passthrough_args: List[str] = []
|
passthrough_args: list[str] = []
|
||||||
if (
|
if (
|
||||||
config.is_fbcode()
|
config.is_fbcode()
|
||||||
and "CUDA_HOME" not in os.environ
|
and "CUDA_HOME" not in os.environ
|
||||||
@ -1287,13 +1288,13 @@ class CppTorchDeviceOptions(CppTorchOptions):
|
|||||||
extra_flags=extra_flags,
|
extra_flags=extra_flags,
|
||||||
)
|
)
|
||||||
|
|
||||||
device_definations: List[str] = []
|
device_definations: list[str] = []
|
||||||
device_include_dirs: List[str] = []
|
device_include_dirs: list[str] = []
|
||||||
device_cflags: List[str] = []
|
device_cflags: list[str] = []
|
||||||
device_ldflags: List[str] = []
|
device_ldflags: list[str] = []
|
||||||
device_libraries_dirs: List[str] = []
|
device_libraries_dirs: list[str] = []
|
||||||
device_libraries: List[str] = []
|
device_libraries: list[str] = []
|
||||||
device_passthrough_args: List[str] = []
|
device_passthrough_args: list[str] = []
|
||||||
|
|
||||||
(
|
(
|
||||||
device_definations,
|
device_definations,
|
||||||
@ -1379,7 +1380,7 @@ class CppBuilder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
sources: Union[str, List[str]],
|
sources: Union[str, list[str]],
|
||||||
BuildOption: BuildOptionsBase,
|
BuildOption: BuildOptionsBase,
|
||||||
output_dir: str = "",
|
output_dir: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -7,7 +7,7 @@ import re
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, Callable, Dict, List, Union
|
from typing import Any, Callable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
@ -33,9 +33,9 @@ def _get_isa_dry_compile_fingerprint(isa_flags: str) -> str:
|
|||||||
|
|
||||||
class VecISA:
|
class VecISA:
|
||||||
_bit_width: int
|
_bit_width: int
|
||||||
_macro: List[str]
|
_macro: list[str]
|
||||||
_arch_flags: str
|
_arch_flags: str
|
||||||
_dtype_nelements: Dict[torch.dtype, int]
|
_dtype_nelements: dict[torch.dtype, int]
|
||||||
|
|
||||||
# Note [Checking for Vectorized Support in Inductor]
|
# Note [Checking for Vectorized Support in Inductor]
|
||||||
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
# TorchInductor CPU vectorization reuses PyTorch vectorization utility functions
|
||||||
@ -79,7 +79,7 @@ cdll.LoadLibrary("__lib_path__")
|
|||||||
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
def nelements(self, dtype: torch.dtype = torch.float) -> int:
|
||||||
return self._dtype_nelements[dtype]
|
return self._dtype_nelements[dtype]
|
||||||
|
|
||||||
def build_macro(self) -> List[str]:
|
def build_macro(self) -> list[str]:
|
||||||
return self._macro
|
return self._macro
|
||||||
|
|
||||||
def build_arch_flags(self) -> str:
|
def build_arch_flags(self) -> str:
|
||||||
@ -300,11 +300,11 @@ class InvalidVecISA(VecISA):
|
|||||||
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
__hash__: Callable[[VecISA], Any] = VecISA.__hash__
|
||||||
|
|
||||||
|
|
||||||
def x86_isa_checker() -> List[str]:
|
def x86_isa_checker() -> list[str]:
|
||||||
supported_isa: List[str] = []
|
supported_isa: list[str] = []
|
||||||
|
|
||||||
def _check_and_append_supported_isa(
|
def _check_and_append_supported_isa(
|
||||||
dest: List[str], isa_supported: bool, isa_name: str
|
dest: list[str], isa_supported: bool, isa_name: str
|
||||||
) -> None:
|
) -> None:
|
||||||
if isa_supported:
|
if isa_supported:
|
||||||
dest.append(isa_name)
|
dest.append(isa_name)
|
||||||
@ -333,7 +333,7 @@ supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON(), VecSVE()]
|
|||||||
|
|
||||||
def get_isa_from_cpu_capability(
|
def get_isa_from_cpu_capability(
|
||||||
capability: Union[str, None],
|
capability: Union[str, None],
|
||||||
vec_isa_list: List[VecISA],
|
vec_isa_list: list[VecISA],
|
||||||
invalid_vec_isa: InvalidVecISA,
|
invalid_vec_isa: InvalidVecISA,
|
||||||
):
|
):
|
||||||
# AMX setting is not supported in eager
|
# AMX setting is not supported in eager
|
||||||
@ -364,8 +364,8 @@ def get_isa_from_cpu_capability(
|
|||||||
# might have too much redundant content that is useless for ISA check. Hence,
|
# might have too much redundant content that is useless for ISA check. Hence,
|
||||||
# we only cache some key isa information.
|
# we only cache some key isa information.
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def valid_vec_isa_list() -> List[VecISA]:
|
def valid_vec_isa_list() -> list[VecISA]:
|
||||||
isa_list: List[VecISA] = []
|
isa_list: list[VecISA] = []
|
||||||
if sys.platform == "darwin" and platform.processor() == "arm":
|
if sys.platform == "darwin" and platform.processor() == "arm":
|
||||||
isa_list.append(VecNEON())
|
isa_list.append(VecNEON())
|
||||||
|
|
||||||
@ -411,7 +411,7 @@ def pick_vec_isa() -> VecISA:
|
|||||||
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
|
if config.is_fbcode() and (platform.machine() in ["x86_64", "AMD64"]):
|
||||||
return VecAVX2()
|
return VecAVX2()
|
||||||
|
|
||||||
_valid_vec_isa_list: List[VecISA] = valid_vec_isa_list()
|
_valid_vec_isa_list: list[VecISA] = valid_vec_isa_list()
|
||||||
if not _valid_vec_isa_list:
|
if not _valid_vec_isa_list:
|
||||||
return invalid_vec_isa
|
return invalid_vec_isa
|
||||||
|
|
||||||
|
@ -54,13 +54,7 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
cast,
|
cast,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
|
||||||
Type,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -99,6 +93,8 @@ from torch.utils.weak import TensorWeakRef
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator, Iterator, Sequence
|
||||||
|
|
||||||
from torch._inductor.utils import InputType
|
from torch._inductor.utils import InputType
|
||||||
from torch.types import _bool
|
from torch.types import _bool
|
||||||
|
|
||||||
@ -357,12 +353,12 @@ def get_manager(
|
|||||||
|
|
||||||
def cudagraphify_impl(
|
def cudagraphify_impl(
|
||||||
model: ModelType,
|
model: ModelType,
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
static_input_idxs: Sequence[int],
|
static_input_idxs: Sequence[int],
|
||||||
*args: Any,
|
*args: Any,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ModelType:
|
) -> ModelType:
|
||||||
fn_cache: Dict[tuple[int, ...], Callable[..., Any]] = {}
|
fn_cache: dict[tuple[int, ...], Callable[..., Any]] = {}
|
||||||
|
|
||||||
# Detect int inputs: we need to index on these
|
# Detect int inputs: we need to index on these
|
||||||
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
|
int_key = [i for i, v in enumerate(inputs) if isinstance(v, int)]
|
||||||
@ -372,7 +368,7 @@ def cudagraphify_impl(
|
|||||||
|
|
||||||
del inputs
|
del inputs
|
||||||
|
|
||||||
def deferred_cudagraphify(inputs: List[InputType]) -> OutputType:
|
def deferred_cudagraphify(inputs: list[InputType]) -> OutputType:
|
||||||
nonlocal has_warn
|
nonlocal has_warn
|
||||||
|
|
||||||
int_key = get_ints(inputs)
|
int_key = get_ints(inputs)
|
||||||
@ -405,7 +401,7 @@ def cudagraphify_impl(
|
|||||||
|
|
||||||
def cudagraphify(
|
def cudagraphify(
|
||||||
model: ModelType,
|
model: ModelType,
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
static_input_idxs: Sequence[int] = (),
|
static_input_idxs: Sequence[int] = (),
|
||||||
*,
|
*,
|
||||||
device_index: int,
|
device_index: int,
|
||||||
@ -466,7 +462,7 @@ class StorageWeakRefWrapper:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_weakref_and_data_ptr(
|
def from_weakref_and_data_ptr(
|
||||||
cls: Type[S],
|
cls: type[S],
|
||||||
cdata: Any,
|
cdata: Any,
|
||||||
data_ptr: int,
|
data_ptr: int,
|
||||||
extra_ref_check: Optional[Callable[[], bool]] = None,
|
extra_ref_check: Optional[Callable[[], bool]] = None,
|
||||||
@ -561,9 +557,9 @@ def map_to_ref(t: Optional[Tensor]) -> Optional[StorageWeakRefWrapper]:
|
|||||||
PathOutputIndex = tuple[int, int]
|
PathOutputIndex = tuple[int, int]
|
||||||
|
|
||||||
# For each node in the path, for each output, is the output alive
|
# For each node in the path, for each output, is the output alive
|
||||||
PathLiveness = List[List[bool]]
|
PathLiveness = list[list[bool]]
|
||||||
|
|
||||||
StackTraces = List[Optional[str]]
|
StackTraces = list[Optional[str]]
|
||||||
|
|
||||||
|
|
||||||
class CUDAWarmupNode:
|
class CUDAWarmupNode:
|
||||||
@ -600,8 +596,8 @@ class CUDAWarmupNode:
|
|||||||
self.wrapped_function = wrapped_function
|
self.wrapped_function = wrapped_function
|
||||||
self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
|
self.parent: Optional[Union[CUDAGraphNode, CUDAWarmupNode]] = parent
|
||||||
self.cuda_graphs_pool = cuda_graphs_pool
|
self.cuda_graphs_pool = cuda_graphs_pool
|
||||||
self.outputs_weakrefs: List[Optional[StorageWeakRefWrapper]] = []
|
self.outputs_weakrefs: list[Optional[StorageWeakRefWrapper]] = []
|
||||||
self.tensor_weakrefs: List[Optional[TensorWeakRef]] = []
|
self.tensor_weakrefs: list[Optional[TensorWeakRef]] = []
|
||||||
self.existing_cuda_graph = existing_cuda_graph
|
self.existing_cuda_graph = existing_cuda_graph
|
||||||
self.has_run = False
|
self.has_run = False
|
||||||
self.device_index = device_index
|
self.device_index = device_index
|
||||||
@ -619,7 +615,7 @@ class CUDAWarmupNode:
|
|||||||
[t.data_ptr() for t in self.path_live_weakrefs() if t()]
|
[t.data_ptr() for t in self.path_live_weakrefs() if t()]
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
|
def get_non_cudagraph_inps() -> list[weakref.ReferenceType[UntypedStorage]]:
|
||||||
non_cudagraph_inps = [
|
non_cudagraph_inps = [
|
||||||
weakref.ref(t.untyped_storage())
|
weakref.ref(t.untyped_storage())
|
||||||
for t in itertools.chain(new_inputs, self.wrapped_function.constants)
|
for t in itertools.chain(new_inputs, self.wrapped_function.constants)
|
||||||
@ -707,9 +703,9 @@ class CUDAWarmupNode:
|
|||||||
|
|
||||||
|
|
||||||
# Aliases for List that say what the indices denote
|
# Aliases for List that say what the indices denote
|
||||||
InputList = List # input indexes
|
InputList = list # input indexes
|
||||||
OutputList = List # output indexes
|
OutputList = list # output indexes
|
||||||
LevelList = List # levels (distance from root of tree)
|
LevelList = list # levels (distance from root of tree)
|
||||||
|
|
||||||
|
|
||||||
class OutputAliasInfo:
|
class OutputAliasInfo:
|
||||||
@ -772,7 +768,7 @@ class CUDAGraphNode:
|
|||||||
wrapped_function: WrappedFunction,
|
wrapped_function: WrappedFunction,
|
||||||
id: GraphID,
|
id: GraphID,
|
||||||
parent: Optional[CUDAGraphNode],
|
parent: Optional[CUDAGraphNode],
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
cuda_graphs_pool: tuple[int, int],
|
cuda_graphs_pool: tuple[int, int],
|
||||||
device_index: int,
|
device_index: int,
|
||||||
stack_traces: Optional[StackTraces],
|
stack_traces: Optional[StackTraces],
|
||||||
@ -800,7 +796,7 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
# A single wrapped function may be recorded multiple times if memory patterns or
|
# A single wrapped function may be recorded multiple times if memory patterns or
|
||||||
# invariants change from one execution to the next
|
# invariants change from one execution to the next
|
||||||
self.children: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
|
self.children: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
|
||||||
|
|
||||||
# StorageWeakRef maintains whether the Storage C++ object remains allocated,
|
# StorageWeakRef maintains whether the Storage C++ object remains allocated,
|
||||||
# not whether the corresponding memory has been deallocated. In order
|
# not whether the corresponding memory has been deallocated. In order
|
||||||
@ -825,7 +821,7 @@ class CUDAGraphNode:
|
|||||||
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
|
self.tensor_weakrefs: OutputList[Optional[TensorWeakRef]] = []
|
||||||
|
|
||||||
# tensors which are outputs of previous graphs in the tree
|
# tensors which are outputs of previous graphs in the tree
|
||||||
self.cudagraph_managed_idxs: List[int] = [
|
self.cudagraph_managed_idxs: list[int] = [
|
||||||
idx
|
idx
|
||||||
for idx, t in enumerate(inputs)
|
for idx, t in enumerate(inputs)
|
||||||
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
|
if isinstance(t, torch.Tensor) and self._is_cuda_graph_recorded_tensor(t)
|
||||||
@ -845,7 +841,7 @@ class CUDAGraphNode:
|
|||||||
# and also aliases an output of the current CUDAGraphNode
|
# and also aliases an output of the current CUDAGraphNode
|
||||||
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
|
self.preserved_aliased_inputs: InputList[bool] = [False] * len(inputs)
|
||||||
|
|
||||||
self.static_input_idxs: List[int] = list(
|
self.static_input_idxs: list[int] = list(
|
||||||
OrderedSet(wrapped_function.static_input_idxs)
|
OrderedSet(wrapped_function.static_input_idxs)
|
||||||
| OrderedSet(self.cudagraph_managed_idxs)
|
| OrderedSet(self.cudagraph_managed_idxs)
|
||||||
)
|
)
|
||||||
@ -866,8 +862,8 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
def maybe_get_static_data_ptr(
|
def maybe_get_static_data_ptr(
|
||||||
idx: int,
|
idx: int,
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
static_input_idxs: List[int],
|
static_input_idxs: list[int],
|
||||||
) -> Optional[int]:
|
) -> Optional[int]:
|
||||||
inp = inputs[idx]
|
inp = inputs[idx]
|
||||||
if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
|
if isinstance(inp, torch.Tensor) and idx in static_input_idxs:
|
||||||
@ -888,7 +884,7 @@ class CUDAGraphNode:
|
|||||||
# fresh allocations.
|
# fresh allocations.
|
||||||
|
|
||||||
# precompute expanded dims to avoid computing in the hot path
|
# precompute expanded dims to avoid computing in the hot path
|
||||||
self.expanded_dims: List[List[int]] = [
|
self.expanded_dims: list[list[int]] = [
|
||||||
get_expanded_dims(x)
|
get_expanded_dims(x)
|
||||||
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
|
if isinstance(x, torch.Tensor) and idx not in self.static_input_idxs
|
||||||
else []
|
else []
|
||||||
@ -903,11 +899,11 @@ class CUDAGraphNode:
|
|||||||
# List of Tuples of (depth, output_index) that index into node at depth
|
# List of Tuples of (depth, output_index) that index into node at depth
|
||||||
# number of nodes from root and output_index of outputs. Will index into
|
# number of nodes from root and output_index of outputs. Will index into
|
||||||
# path_weakrefs.
|
# path_weakrefs.
|
||||||
self.expected_dead_indices_before_graph: List[PathOutputIndex] = []
|
self.expected_dead_indices_before_graph: list[PathOutputIndex] = []
|
||||||
self.expected_dead_indices_after_graph: List[PathOutputIndex] = []
|
self.expected_dead_indices_after_graph: list[PathOutputIndex] = []
|
||||||
|
|
||||||
# all live indices after graph recording
|
# all live indices after graph recording
|
||||||
self.live_indices_after_graph: List[PathOutputIndex] = []
|
self.live_indices_after_graph: list[PathOutputIndex] = []
|
||||||
|
|
||||||
if self.parent is not None:
|
if self.parent is not None:
|
||||||
previous_liveness = self.parent.recorded_liveness_after_graph
|
previous_liveness = self.parent.recorded_liveness_after_graph
|
||||||
@ -934,7 +930,7 @@ class CUDAGraphNode:
|
|||||||
# we reconstruct tensors at the correct data pointers of our inputs which are
|
# we reconstruct tensors at the correct data pointers of our inputs which are
|
||||||
# non owning and do not prevent deallocation. On subsequent executions, input values
|
# non owning and do not prevent deallocation. On subsequent executions, input values
|
||||||
# will be copied over to these tensors.
|
# will be copied over to these tensors.
|
||||||
self.reconstructed_inputs: List[InputType] = [
|
self.reconstructed_inputs: list[InputType] = [
|
||||||
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
|
self._reconstruct_from_tensor_metadata(self._tensor_metadata(x))
|
||||||
if isinstance(x, torch.Tensor)
|
if isinstance(x, torch.Tensor)
|
||||||
else x
|
else x
|
||||||
@ -983,7 +979,7 @@ class CUDAGraphNode:
|
|||||||
self.recording_outputs: Optional[OutputType] = self._record(
|
self.recording_outputs: Optional[OutputType] = self._record(
|
||||||
wrapped_function.model, recording_inputs
|
wrapped_function.model, recording_inputs
|
||||||
)
|
)
|
||||||
self.outputs_metadata: OutputList[Union[Dict[str, Any], int, None]] = []
|
self.outputs_metadata: OutputList[Union[dict[str, Any], int, None]] = []
|
||||||
|
|
||||||
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent
|
# As with inputs, we do not want to keep the outputs permanently alive because that would prevent
|
||||||
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
|
# their memory being reclaimed in subsequent cuda graph recordings. We record the tensor metadata
|
||||||
@ -1001,7 +997,7 @@ class CUDAGraphNode:
|
|||||||
self.graph.replay()
|
self.graph.replay()
|
||||||
|
|
||||||
def _copy_inputs_and_remove_from_src(
|
def _copy_inputs_and_remove_from_src(
|
||||||
self, dsts: List[InputType], srcs: List[InputType]
|
self, dsts: list[InputType], srcs: list[InputType]
|
||||||
) -> None:
|
) -> None:
|
||||||
dst_tensors = []
|
dst_tensors = []
|
||||||
src_tensors = []
|
src_tensors = []
|
||||||
@ -1016,7 +1012,7 @@ class CUDAGraphNode:
|
|||||||
if dst_tensors:
|
if dst_tensors:
|
||||||
torch._foreach_copy_(dst_tensors, src_tensors)
|
torch._foreach_copy_(dst_tensors, src_tensors)
|
||||||
|
|
||||||
def check_static_inputs_are_stable(self, new_inputs: List[InputType]) -> None:
|
def check_static_inputs_are_stable(self, new_inputs: list[InputType]) -> None:
|
||||||
# avoid checking managed tensor static points since we already checked those in check_invariants
|
# avoid checking managed tensor static points since we already checked those in check_invariants
|
||||||
if (
|
if (
|
||||||
not self.rerecord_if_static_inputs_change
|
not self.rerecord_if_static_inputs_change
|
||||||
@ -1036,7 +1032,7 @@ class CUDAGraphNode:
|
|||||||
)
|
)
|
||||||
torch._check(False, lambda: error_msg)
|
torch._check(False, lambda: error_msg)
|
||||||
|
|
||||||
def run_first_inputs(self, new_inputs: List[InputType]) -> OutputType:
|
def run_first_inputs(self, new_inputs: list[InputType]) -> OutputType:
|
||||||
if config.triton.fast_path_cudagraph_asserts:
|
if config.triton.fast_path_cudagraph_asserts:
|
||||||
self.debug_check_invariants_before_invocation()
|
self.debug_check_invariants_before_invocation()
|
||||||
|
|
||||||
@ -1048,7 +1044,7 @@ class CUDAGraphNode:
|
|||||||
assert outputs is not None
|
assert outputs is not None
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def run(self, new_inputs: List[InputType]) -> OutputType:
|
def run(self, new_inputs: list[InputType]) -> OutputType:
|
||||||
self.check_static_inputs_are_stable(new_inputs)
|
self.check_static_inputs_are_stable(new_inputs)
|
||||||
|
|
||||||
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
|
self._copy_inputs_and_remove_from_src(self.reconstructed_inputs, new_inputs)
|
||||||
@ -1130,7 +1126,7 @@ class CUDAGraphNode:
|
|||||||
def prepare_alias_info_for_tensor_construction(
|
def prepare_alias_info_for_tensor_construction(
|
||||||
self,
|
self,
|
||||||
out_alias_info: Optional[OutputAliasInfo],
|
out_alias_info: Optional[OutputAliasInfo],
|
||||||
metadata: Union[Dict[str, Any], int, None],
|
metadata: Union[dict[str, Any], int, None],
|
||||||
) -> Union[UntypedStorage, None, int]:
|
) -> Union[UntypedStorage, None, int]:
|
||||||
if (
|
if (
|
||||||
isinstance(metadata, (int, type(None)))
|
isinstance(metadata, (int, type(None)))
|
||||||
@ -1149,7 +1145,7 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
def prepare_storages_for_construction(
|
def prepare_storages_for_construction(
|
||||||
self,
|
self,
|
||||||
) -> List[Union[UntypedStorage, None, int]]:
|
) -> list[Union[UntypedStorage, None, int]]:
|
||||||
output_storages = []
|
output_storages = []
|
||||||
for output_storage_alias, metadata in zip(
|
for output_storage_alias, metadata in zip(
|
||||||
self.output_storage_alias, self.outputs_metadata
|
self.output_storage_alias, self.outputs_metadata
|
||||||
@ -1173,7 +1169,7 @@ class CUDAGraphNode:
|
|||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _record(self, model: ModelType, inputs: List[InputType]) -> OutputType:
|
def _record(self, model: ModelType, inputs: list[InputType]) -> OutputType:
|
||||||
"Record the model"
|
"Record the model"
|
||||||
|
|
||||||
def static_input_iter() -> Generator[torch.Tensor, None, None]:
|
def static_input_iter() -> Generator[torch.Tensor, None, None]:
|
||||||
@ -1185,7 +1181,7 @@ class CUDAGraphNode:
|
|||||||
yield _inp
|
yield _inp
|
||||||
|
|
||||||
# see: output_is_alias_of_persistent_static_inputs above
|
# see: output_is_alias_of_persistent_static_inputs above
|
||||||
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper] = {
|
static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper] = {
|
||||||
inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
|
inp.untyped_storage().data_ptr(): StorageWeakRefWrapper(inp)
|
||||||
for inp in itertools.chain(
|
for inp in itertools.chain(
|
||||||
static_input_iter(), self.wrapped_function.constants
|
static_input_iter(), self.wrapped_function.constants
|
||||||
@ -1229,7 +1225,7 @@ class CUDAGraphNode:
|
|||||||
def _add_first_outputs(
|
def _add_first_outputs(
|
||||||
self,
|
self,
|
||||||
outputs: OutputType,
|
outputs: OutputType,
|
||||||
static_input_persistent_storage_ptrs: Dict[int, StorageWeakRefWrapper],
|
static_input_persistent_storage_ptrs: dict[int, StorageWeakRefWrapper],
|
||||||
) -> None:
|
) -> None:
|
||||||
"Add the outputs from the first invocation of the node and set up metadata"
|
"Add the outputs from the first invocation of the node and set up metadata"
|
||||||
|
|
||||||
@ -1243,7 +1239,7 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
assert len(self.outputs_weakrefs) == 0
|
assert len(self.outputs_weakrefs) == 0
|
||||||
# index from data pointer to index in outputs
|
# index from data pointer to index in outputs
|
||||||
output_new_storages_index: Dict[StorageDataPtr, int] = {}
|
output_new_storages_index: dict[StorageDataPtr, int] = {}
|
||||||
|
|
||||||
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
|
self.unaliased_in_all_paths = [False for _ in range(len(outputs))]
|
||||||
self.static_output_tensors = [None for _ in range(len(outputs))]
|
self.static_output_tensors = [None for _ in range(len(outputs))]
|
||||||
@ -1431,8 +1427,8 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_liveness(
|
def _check_liveness(
|
||||||
indices: List[PathOutputIndex],
|
indices: list[PathOutputIndex],
|
||||||
output_refs: List[List[Optional[StorageWeakRefWrapper]]],
|
output_refs: list[list[Optional[StorageWeakRefWrapper]]],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"Check that all of the indices specified are dead references"
|
"Check that all of the indices specified are dead references"
|
||||||
for depth, output_index in indices:
|
for depth, output_index in indices:
|
||||||
@ -1448,8 +1444,8 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_different_indices(
|
def _get_different_indices(
|
||||||
prev: List[List[bool]], curr: List[List[bool]]
|
prev: list[list[bool]], curr: list[list[bool]]
|
||||||
) -> List[PathOutputIndex]:
|
) -> list[PathOutputIndex]:
|
||||||
"Find indices where the two lists differ."
|
"Find indices where the two lists differ."
|
||||||
dead_indices = []
|
dead_indices = []
|
||||||
assert len(prev) <= len(curr)
|
assert len(prev) <= len(curr)
|
||||||
@ -1463,8 +1459,8 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_liveness(
|
def _get_liveness(
|
||||||
weakrefs: List[List[Optional[StorageWeakRefWrapper]]],
|
weakrefs: list[list[Optional[StorageWeakRefWrapper]]],
|
||||||
) -> List[List[bool]]:
|
) -> list[list[bool]]:
|
||||||
"Maps weakrefs to true if the reference is alive and false otherwise"
|
"Maps weakrefs to true if the reference is alive and false otherwise"
|
||||||
if len(weakrefs) == 0:
|
if len(weakrefs) == 0:
|
||||||
return []
|
return []
|
||||||
@ -1472,7 +1468,7 @@ class CUDAGraphNode:
|
|||||||
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
|
return [pytree.tree_map(is_live, outputs) for outputs in weakrefs]
|
||||||
|
|
||||||
def debug_assert_invariants(
|
def debug_assert_invariants(
|
||||||
self, expected_liveness: List[List[bool]], newly_dead: List[PathOutputIndex]
|
self, expected_liveness: list[list[bool]], newly_dead: list[PathOutputIndex]
|
||||||
) -> None:
|
) -> None:
|
||||||
if not config.triton.fast_path_cudagraph_asserts:
|
if not config.triton.fast_path_cudagraph_asserts:
|
||||||
return
|
return
|
||||||
@ -1520,7 +1516,7 @@ class CUDAGraphNode:
|
|||||||
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
|
self.recorded_liveness_before_graph, self.expected_dead_indices_after_graph
|
||||||
)
|
)
|
||||||
|
|
||||||
def data_ptrs_dead_since_invocation(self) -> List[int]:
|
def data_ptrs_dead_since_invocation(self) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Since this node was invoked, return data ptrs of all tensor outputs that have died
|
Since this node was invoked, return data ptrs of all tensor outputs that have died
|
||||||
in the current executing tree path.
|
in the current executing tree path.
|
||||||
@ -1568,7 +1564,7 @@ class CUDAGraphNode:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _tensor_metadata(
|
def _tensor_metadata(
|
||||||
x: torch.Tensor, ignore_storage_offset: bool = True
|
x: torch.Tensor, ignore_storage_offset: bool = True
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
assert isinstance(x, torch.Tensor)
|
assert isinstance(x, torch.Tensor)
|
||||||
# We ignore the storage offset for inputs, but not for outputs
|
# We ignore the storage offset for inputs, but not for outputs
|
||||||
# TODO: - should we make the storage resizable ?
|
# TODO: - should we make the storage resizable ?
|
||||||
@ -1583,19 +1579,19 @@ class CUDAGraphNode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _reconstruct_from_tensor_metadata(
|
def _reconstruct_from_tensor_metadata(
|
||||||
self, metadata: Dict[str, Any], storage: Optional[UntypedStorage] = None
|
self, metadata: dict[str, Any], storage: Optional[UntypedStorage] = None
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
s = self.create_storage(metadata) if storage is None else storage
|
s = self.create_storage(metadata) if storage is None else storage
|
||||||
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type]
|
return torch._C._construct_CUDA_Tensor_From_Storage_And_Metadata(metadata, s) # type: ignore[arg-type]
|
||||||
|
|
||||||
def create_storage(self, metadata: Dict[str, Any]) -> torch.types.Storage:
|
def create_storage(self, metadata: dict[str, Any]) -> torch.types.Storage:
|
||||||
return torch._C._construct_storage_from_data_pointer(
|
return torch._C._construct_storage_from_data_pointer(
|
||||||
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
|
metadata["data_ptr"], metadata["device"], metadata["nbytes"]
|
||||||
)
|
)
|
||||||
|
|
||||||
def _allocate_and_copy_recording_inputs(
|
def _allocate_and_copy_recording_inputs(
|
||||||
self, inputs: List[InputType]
|
self, inputs: list[InputType]
|
||||||
) -> List[InputType]:
|
) -> list[InputType]:
|
||||||
"""
|
"""
|
||||||
Allocate inputs for non static, non cudagraph managed tensors in the memory pool
|
Allocate inputs for non static, non cudagraph managed tensors in the memory pool
|
||||||
and copy over the tensor values.
|
and copy over the tensor values.
|
||||||
@ -1603,7 +1599,7 @@ class CUDAGraphNode:
|
|||||||
|
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
self.stream.wait_stream(torch.cuda.current_stream())
|
self.stream.wait_stream(torch.cuda.current_stream())
|
||||||
recording_inputs: List[InputType] = []
|
recording_inputs: list[InputType] = []
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True), torch.cuda.device(
|
with warnings.catch_warnings(record=True), torch.cuda.device(
|
||||||
self.device
|
self.device
|
||||||
@ -1627,7 +1623,7 @@ class CUDAGraphNode:
|
|||||||
return recording_inputs
|
return recording_inputs
|
||||||
|
|
||||||
def check_invariants(
|
def check_invariants(
|
||||||
self, inputs: List[InputType]
|
self, inputs: list[InputType]
|
||||||
) -> tuple[CheckInvariantStatus, Callable[..., str]]:
|
) -> tuple[CheckInvariantStatus, Callable[..., str]]:
|
||||||
"""
|
"""
|
||||||
Checks if this node can be run. The same pattern of tensor liveness, static inputs,
|
Checks if this node can be run. The same pattern of tensor liveness, static inputs,
|
||||||
@ -1714,7 +1710,7 @@ def get_cudagraph_segments(pool_id: tuple[int, int]) -> Any:
|
|||||||
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
|
return [segment for segment in segments if segment["segment_pool_id"] == pool_id]
|
||||||
|
|
||||||
|
|
||||||
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[int]:
|
def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> list[int]:
|
||||||
blocks = []
|
blocks = []
|
||||||
|
|
||||||
for segment in get_cudagraph_segments(pool_id):
|
for segment in get_cudagraph_segments(pool_id):
|
||||||
@ -1728,7 +1724,7 @@ def get_block_addrs(pool_id: tuple[int, int], live_only: bool = True) -> List[in
|
|||||||
return blocks
|
return blocks
|
||||||
|
|
||||||
|
|
||||||
def format_tb(frames: List[Any]) -> str:
|
def format_tb(frames: list[Any]) -> str:
|
||||||
formatted_traceback = [
|
formatted_traceback = [
|
||||||
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
|
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
|
||||||
for entry in frames
|
for entry in frames
|
||||||
@ -1740,7 +1736,7 @@ def format_tb(frames: List[Any]) -> str:
|
|||||||
def check_memory_pool(
|
def check_memory_pool(
|
||||||
device: int,
|
device: int,
|
||||||
pool_id: tuple[int, int],
|
pool_id: tuple[int, int],
|
||||||
live_storages_ptrs: List[StorageWeakRefWrapper],
|
live_storages_ptrs: list[StorageWeakRefWrapper],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert all(
|
assert all(
|
||||||
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
|
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
|
||||||
@ -1839,12 +1835,12 @@ class CUDAGraphTreeManager:
|
|||||||
# when they are first invoked, none of their inputs are outputs are outputs
|
# when they are first invoked, none of their inputs are outputs are outputs
|
||||||
# of another node, nor are there any live outputs of another node whose
|
# of another node, nor are there any live outputs of another node whose
|
||||||
# liveness would create a dependency.
|
# liveness would create a dependency.
|
||||||
self.roots: Dict[FunctionID, List[CUDAGraphNode]] = defaultdict(list)
|
self.roots: dict[FunctionID, list[CUDAGraphNode]] = defaultdict(list)
|
||||||
|
|
||||||
# mapping from function id to wrapped function
|
# mapping from function id to wrapped function
|
||||||
self.ids_to_funcs: Dict[FunctionID, WrappedFunction] = {}
|
self.ids_to_funcs: dict[FunctionID, WrappedFunction] = {}
|
||||||
|
|
||||||
self.ids_to_stack_traces: Dict[FunctionID, Optional[StackTraces]] = {}
|
self.ids_to_stack_traces: dict[FunctionID, Optional[StackTraces]] = {}
|
||||||
|
|
||||||
self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet()
|
self.warmed_up_functions: OrderedSet[FunctionID] = OrderedSet()
|
||||||
# if we fail to increment generation, and are stuck warming up,
|
# if we fail to increment generation, and are stuck warming up,
|
||||||
@ -1883,14 +1879,14 @@ class CUDAGraphTreeManager:
|
|||||||
|
|
||||||
# mapping from graph_id to (function id to mutation type hint) since we are
|
# mapping from graph_id to (function id to mutation type hint) since we are
|
||||||
# specializing on a particular combination of Parent Node -> Function ID.
|
# specializing on a particular combination of Parent Node -> Function ID.
|
||||||
self.non_cudagraph_managed_mutation_hint: Dict[
|
self.non_cudagraph_managed_mutation_hint: dict[
|
||||||
Optional[GraphID], Dict[FunctionID, bool]
|
Optional[GraphID], dict[FunctionID, bool]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
self.warmup_node_counter = itertools.count(start=-1, step=-1)
|
self.warmup_node_counter = itertools.count(start=-1, step=-1)
|
||||||
|
|
||||||
# mapping from graph_id to (function id to re-record count). We fall back to
|
# mapping from graph_id to (function id to re-record count). We fall back to
|
||||||
# eager function if a function is re-recorded frequently on a node.
|
# eager function if a function is re-recorded frequently on a node.
|
||||||
self.num_rerecord: Dict[Optional[GraphID], Dict[FunctionID, int]] = defaultdict(
|
self.num_rerecord: dict[Optional[GraphID], dict[FunctionID, int]] = defaultdict(
|
||||||
lambda: defaultdict(lambda: 0)
|
lambda: defaultdict(lambda: 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1916,7 +1912,7 @@ class CUDAGraphTreeManager:
|
|||||||
# number of instances we had to checkpoint the function
|
# number of instances we had to checkpoint the function
|
||||||
self.debug_checkpointing_counter = 0
|
self.debug_checkpointing_counter = 0
|
||||||
|
|
||||||
self.id_to_mode: Dict[FunctionID, CompilationMode] = {}
|
self.id_to_mode: dict[FunctionID, CompilationMode] = {}
|
||||||
|
|
||||||
# Note: [Backward Generation Handling]
|
# Note: [Backward Generation Handling]
|
||||||
# We generally perform a sequence of forward executions followed by backward executions.
|
# We generally perform a sequence of forward executions followed by backward executions.
|
||||||
@ -1943,7 +1939,7 @@ class CUDAGraphTreeManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
|
def run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
|
||||||
assert self.graph is not None, "Running CUDAGraph after shutdown"
|
assert self.graph is not None, "Running CUDAGraph after shutdown"
|
||||||
self.mode = self.id_to_mode[function_id]
|
self.mode = self.id_to_mode[function_id]
|
||||||
out = self._run(new_inputs, function_id)
|
out = self._run(new_inputs, function_id)
|
||||||
@ -1971,7 +1967,7 @@ class CUDAGraphTreeManager:
|
|||||||
return GraphID(next(self.warmup_node_counter))
|
return GraphID(next(self.warmup_node_counter))
|
||||||
|
|
||||||
def _update_non_cudagraph_managed_mutation(
|
def _update_non_cudagraph_managed_mutation(
|
||||||
self, function_id: FunctionID, inputs: List[InputType]
|
self, function_id: FunctionID, inputs: list[InputType]
|
||||||
) -> None:
|
) -> None:
|
||||||
node_id = self._get_node_id()
|
node_id = self._get_node_id()
|
||||||
if maybe_mutation_str := check_for_mutation(
|
if maybe_mutation_str := check_for_mutation(
|
||||||
@ -2007,7 +2003,7 @@ class CUDAGraphTreeManager:
|
|||||||
> torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
|
> torch._inductor.config.triton.cudagraph_unexpected_rerecord_limit
|
||||||
)
|
)
|
||||||
|
|
||||||
def _run(self, new_inputs: List[InputType], function_id: FunctionID) -> OutputType:
|
def _run(self, new_inputs: list[InputType], function_id: FunctionID) -> OutputType:
|
||||||
# we will try to end the current execution lazily, since
|
# we will try to end the current execution lazily, since
|
||||||
# we dont want to do unnecessary checking of the existing outputs
|
# we dont want to do unnecessary checking of the existing outputs
|
||||||
# on the hot path, but both recording and warmup only happen once
|
# on the hot path, but both recording and warmup only happen once
|
||||||
@ -2152,7 +2148,7 @@ class CUDAGraphTreeManager:
|
|||||||
self.current_node = None
|
self.current_node = None
|
||||||
|
|
||||||
def record_function(
|
def record_function(
|
||||||
self, new_inputs: List[InputType], function_id: FunctionID
|
self, new_inputs: list[InputType], function_id: FunctionID
|
||||||
) -> OutputType:
|
) -> OutputType:
|
||||||
assert not isinstance(self.current_node, CUDAWarmupNode)
|
assert not isinstance(self.current_node, CUDAWarmupNode)
|
||||||
graph_id = self.new_graph_id()
|
graph_id = self.new_graph_id()
|
||||||
@ -2183,7 +2179,7 @@ class CUDAGraphTreeManager:
|
|||||||
return node.run_first_inputs(new_inputs)
|
return node.run_first_inputs(new_inputs)
|
||||||
|
|
||||||
def execute_node(
|
def execute_node(
|
||||||
self, node: CUDAGraphNode, new_inputs: List[InputType]
|
self, node: CUDAGraphNode, new_inputs: list[InputType]
|
||||||
) -> OutputType:
|
) -> OutputType:
|
||||||
self.current_node = node
|
self.current_node = node
|
||||||
self.path_state = ExecutionState.EXECUTION
|
self.path_state = ExecutionState.EXECUTION
|
||||||
@ -2191,7 +2187,7 @@ class CUDAGraphTreeManager:
|
|||||||
return node.run(new_inputs)
|
return node.run(new_inputs)
|
||||||
|
|
||||||
def run_eager(
|
def run_eager(
|
||||||
self, new_inputs: List[InputType], function_id: FunctionID
|
self, new_inputs: list[InputType], function_id: FunctionID
|
||||||
) -> OutputType:
|
) -> OutputType:
|
||||||
# this is only stored on current node, because when we start a new path,
|
# this is only stored on current node, because when we start a new path,
|
||||||
# we will deallocate it
|
# we will deallocate it
|
||||||
@ -2229,7 +2225,7 @@ class CUDAGraphTreeManager:
|
|||||||
def add_function(
|
def add_function(
|
||||||
self,
|
self,
|
||||||
model: ModelType,
|
model: ModelType,
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
static_input_idxs: Sequence[int],
|
static_input_idxs: Sequence[int],
|
||||||
stack_traces: Optional[StackTraces],
|
stack_traces: Optional[StackTraces],
|
||||||
mode: CompilationMode,
|
mode: CompilationMode,
|
||||||
@ -2409,7 +2405,7 @@ class CUDAGraphTreeManager:
|
|||||||
# TODO: we could also allow the these weak refs to continue to be allocated,
|
# TODO: we could also allow the these weak refs to continue to be allocated,
|
||||||
# but that adds some complications.
|
# but that adds some complications.
|
||||||
|
|
||||||
stor_stack_trace: Dict[int, Optional[str]] = {}
|
stor_stack_trace: dict[int, Optional[str]] = {}
|
||||||
for node in self.current_node._path_from_root:
|
for node in self.current_node._path_from_root:
|
||||||
assert node.stack_traces is not None
|
assert node.stack_traces is not None
|
||||||
assert len(node.tensor_weakrefs) == len(node.stack_traces)
|
assert len(node.tensor_weakrefs) == len(node.stack_traces)
|
||||||
@ -2475,7 +2471,7 @@ class CUDAGraphTreeManager:
|
|||||||
assert state is not None and device is not None
|
assert state is not None and device is not None
|
||||||
|
|
||||||
# currently we deallocate on instead of allowing stale recordings
|
# currently we deallocate on instead of allowing stale recordings
|
||||||
stale_storages: List[int] = []
|
stale_storages: list[int] = []
|
||||||
|
|
||||||
# remove cached tensors, otherwise they would prevent memory from being
|
# remove cached tensors, otherwise they would prevent memory from being
|
||||||
# reclaimed in subsequent recordings
|
# reclaimed in subsequent recordings
|
||||||
@ -2506,7 +2502,7 @@ class CUDAGraphTreeManager:
|
|||||||
|
|
||||||
def live_cudagraph_pool_storages_in_curr_execution(
|
def live_cudagraph_pool_storages_in_curr_execution(
|
||||||
self,
|
self,
|
||||||
) -> List[StorageWeakRefPointer]:
|
) -> list[StorageWeakRefPointer]:
|
||||||
if self.current_node is None:
|
if self.current_node is None:
|
||||||
return []
|
return []
|
||||||
# explicitly ignoring previous recorded outputs from past path
|
# explicitly ignoring previous recorded outputs from past path
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._dynamo.utils import counters
|
from torch._dynamo.utils import counters
|
||||||
@ -11,14 +11,18 @@ from torch._inductor.utils import InputType
|
|||||||
from torch.utils._ordered_set import OrderedSet
|
from torch.utils._ordered_set import OrderedSet
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||||
static_inputs_log = torch._logging.getArtifactLogger(
|
static_inputs_log = torch._logging.getArtifactLogger(
|
||||||
__name__, "cudagraph_static_inputs"
|
__name__, "cudagraph_static_inputs"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
OutputType = List[Optional[Union[int, torch.Tensor]]]
|
OutputType = list[Optional[Union[int, torch.Tensor]]]
|
||||||
ModelType = Callable[[List[InputType]], OutputType]
|
ModelType = Callable[[list[InputType]], OutputType]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
@ -38,7 +42,7 @@ class PlaceholderInfo:
|
|||||||
name: str
|
name: str
|
||||||
stack_trace: Optional[str]
|
stack_trace: Optional[str]
|
||||||
# This field is recursive, but never cyclic (since a node never uses itself)
|
# This field is recursive, but never cyclic (since a node never uses itself)
|
||||||
users: List[PlaceholderInfo]
|
users: list[PlaceholderInfo]
|
||||||
mutating_use_stack_trace: Optional[str]
|
mutating_use_stack_trace: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
@ -92,7 +96,7 @@ def to_placeholder_info(placeholder_node: torch.fx.Node) -> PlaceholderInfo:
|
|||||||
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
|
return PlaceholderInfo(name, stack_trace, users, mutating_use_stack_trace)
|
||||||
|
|
||||||
|
|
||||||
def get_placeholder_info(graph: torch.fx.Graph) -> List[PlaceholderInfo]:
|
def get_placeholder_info(graph: torch.fx.Graph) -> list[PlaceholderInfo]:
|
||||||
return [
|
return [
|
||||||
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
|
to_placeholder_info(node) for node in graph.nodes if node.op == "placeholder"
|
||||||
]
|
]
|
||||||
@ -123,7 +127,7 @@ def get_mutation_stack_trace(
|
|||||||
|
|
||||||
def check_for_mutation(
|
def check_for_mutation(
|
||||||
func: WrappedFunction,
|
func: WrappedFunction,
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
|
is_cuda_graph_recorded_tensor: Callable[[torch.Tensor], bool],
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
# doesnt work for non-trees because the warmup run would apply mutation twice
|
# doesnt work for non-trees because the warmup run would apply mutation twice
|
||||||
@ -160,7 +164,7 @@ def _get_use_stack_trace(node) -> Optional[str]:
|
|||||||
|
|
||||||
|
|
||||||
def check_multiple_devices_or_any_cpu_nodes(
|
def check_multiple_devices_or_any_cpu_nodes(
|
||||||
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
device_node_mapping: dict[torch.device, torch.fx.Node]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
||||||
msg = f"cpu device ({cpu_node.name})"
|
msg = f"cpu device ({cpu_node.name})"
|
||||||
@ -180,7 +184,7 @@ def check_multiple_devices_or_any_cpu_nodes(
|
|||||||
|
|
||||||
|
|
||||||
def check_lowering_disable_cudagraph(
|
def check_lowering_disable_cudagraph(
|
||||||
device_node_mapping: Dict[torch.device, torch.fx.Node]
|
device_node_mapping: dict[torch.device, torch.fx.Node]
|
||||||
):
|
):
|
||||||
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
||||||
|
|
||||||
@ -262,7 +266,7 @@ class CheckInvariantStatus(Enum):
|
|||||||
|
|
||||||
def log_data_ptr_mismatch(
|
def log_data_ptr_mismatch(
|
||||||
placeholders: Sequence[PlaceholderInfo],
|
placeholders: Sequence[PlaceholderInfo],
|
||||||
inputs: List[InputType],
|
inputs: list[InputType],
|
||||||
recorded_data_ptr: Sequence[Optional[int]],
|
recorded_data_ptr: Sequence[Optional[int]],
|
||||||
target_idxs: Sequence[int],
|
target_idxs: Sequence[int],
|
||||||
mismatch: CheckInvariantStatus,
|
mismatch: CheckInvariantStatus,
|
||||||
@ -292,7 +296,7 @@ def log_data_ptr_mismatch(
|
|||||||
|
|
||||||
|
|
||||||
def maybe_warning_due_to_dynamic_shape(
|
def maybe_warning_due_to_dynamic_shape(
|
||||||
fn_cache: Dict[tuple[int, ...], Callable[..., Any]],
|
fn_cache: dict[tuple[int, ...], Callable[..., Any]],
|
||||||
new_int_key: Any,
|
new_int_key: Any,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
num_cudagraphs = len(fn_cache.keys()) + 1
|
num_cudagraphs = len(fn_cache.keys()) + 1
|
||||||
@ -327,5 +331,5 @@ class CudagraphCachedInfo:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
placeholders: Sequence[PlaceholderInfo]
|
placeholders: Sequence[PlaceholderInfo]
|
||||||
stack_traces: List[Optional[str]]
|
stack_traces: list[Optional[str]]
|
||||||
cudagraph_fail_reasons: List[str]
|
cudagraph_fail_reasons: list[str]
|
||||||
|
@ -12,7 +12,8 @@ import pickle
|
|||||||
import pstats
|
import pstats
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
from typing import Any, Callable, Dict, IO, Iterator, List, Optional, Type, Union
|
from collections.abc import Iterator
|
||||||
|
from typing import Any, Callable, IO, Optional, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -39,7 +40,7 @@ from .virtualized import V
|
|||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
SchedulerNodeList = List[Any]
|
SchedulerNodeList = list[Any]
|
||||||
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
|
BufMeta = collections.namedtuple("BufMeta", ["name", "n_origin"])
|
||||||
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
|
GRAPHVIZ_COMMAND_SCALABLE = ["dot", "-Gnslimit=2", "-Gnslimit1=2", "-Gmaxiter=5000"]
|
||||||
|
|
||||||
@ -54,7 +55,7 @@ def has_dot() -> bool:
|
|||||||
|
|
||||||
|
|
||||||
def draw_buffers(
|
def draw_buffers(
|
||||||
nodes: List[BaseSchedulerNode],
|
nodes: list[BaseSchedulerNode],
|
||||||
print_graph: bool = False,
|
print_graph: bool = False,
|
||||||
fname: Optional[str] = None,
|
fname: Optional[str] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -99,7 +100,7 @@ def draw_buffers(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
def create_fx_from_snodes(snodes: list[BaseSchedulerNode]) -> fx.Graph:
|
||||||
"""
|
"""
|
||||||
Creates a FX Graph from a list of SchedulerNode objects.
|
Creates a FX Graph from a list of SchedulerNode objects.
|
||||||
"""
|
"""
|
||||||
@ -199,7 +200,7 @@ def create_fx_from_snodes(snodes: List[BaseSchedulerNode]) -> fx.Graph:
|
|||||||
|
|
||||||
def update_orig_fx_node_name_to_buf_name(
|
def update_orig_fx_node_name_to_buf_name(
|
||||||
nodes: Optional[SchedulerNodeList],
|
nodes: Optional[SchedulerNodeList],
|
||||||
node_name_to_buf_name: Dict[str, str],
|
node_name_to_buf_name: dict[str, str],
|
||||||
parent_buf_name: Optional[str] = None,
|
parent_buf_name: Optional[str] = None,
|
||||||
n_origins: int = 0,
|
n_origins: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -233,8 +234,8 @@ def update_orig_fx_node_name_to_buf_name(
|
|||||||
|
|
||||||
|
|
||||||
def get_node_name_to_buf_meta(
|
def get_node_name_to_buf_meta(
|
||||||
node_name_to_buf_name: Dict[str, str]
|
node_name_to_buf_name: dict[str, str]
|
||||||
) -> Dict[str, BufMeta]:
|
) -> dict[str, BufMeta]:
|
||||||
buf_name_to_n_node = {}
|
buf_name_to_n_node = {}
|
||||||
for node_name, buf_name in node_name_to_buf_name.items():
|
for node_name, buf_name in node_name_to_buf_name.items():
|
||||||
if buf_name not in buf_name_to_n_node:
|
if buf_name not in buf_name_to_n_node:
|
||||||
@ -256,7 +257,7 @@ def annotate_orig_fx_with_snodes(
|
|||||||
"""
|
"""
|
||||||
Creates a FX Graph from a list of SchedulerNode objects.
|
Creates a FX Graph from a list of SchedulerNode objects.
|
||||||
"""
|
"""
|
||||||
node_name_to_buf_name: Dict[str, str] = {}
|
node_name_to_buf_name: dict[str, str] = {}
|
||||||
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
|
update_orig_fx_node_name_to_buf_name(snodes, node_name_to_buf_name)
|
||||||
if node_name_to_buf_name is None:
|
if node_name_to_buf_name is None:
|
||||||
return
|
return
|
||||||
@ -309,7 +310,7 @@ def enable_aot_logging() -> Iterator[None]:
|
|||||||
|
|
||||||
class DebugContext:
|
class DebugContext:
|
||||||
_counter = itertools.count()
|
_counter = itertools.count()
|
||||||
_inductor_triton_kernel_to_post_grad_node_info: Dict[str, List[str]] = {}
|
_inductor_triton_kernel_to_post_grad_node_info: dict[str, list[str]] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create_debug_dir(folder_name: str) -> Optional[str]:
|
def create_debug_dir(folder_name: str) -> Optional[str]:
|
||||||
@ -425,7 +426,7 @@ class DebugContext:
|
|||||||
|
|
||||||
def __exit__(
|
def __exit__(
|
||||||
self,
|
self,
|
||||||
exc_type: Optional[Type[BaseException]],
|
exc_type: Optional[type[BaseException]],
|
||||||
exc_val: Optional[BaseException],
|
exc_val: Optional[BaseException],
|
||||||
exc_tb: Optional[Any],
|
exc_tb: Optional[Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -474,7 +475,7 @@ class DebugFormatter:
|
|||||||
def fx_graph(
|
def fx_graph(
|
||||||
self,
|
self,
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
inputs: List[torch.Tensor],
|
inputs: list[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
with self.fopen("fx_graph_runnable.py") as fd:
|
with self.fopen("fx_graph_runnable.py") as fd:
|
||||||
save_dir = None
|
save_dir = None
|
||||||
@ -504,7 +505,7 @@ class DebugFormatter:
|
|||||||
def fx_graph_transformed(
|
def fx_graph_transformed(
|
||||||
self,
|
self,
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
inputs: List[torch.Tensor],
|
inputs: list[torch.Tensor],
|
||||||
) -> None:
|
) -> None:
|
||||||
with self.fopen("fx_graph_transformed.py") as fd:
|
with self.fopen("fx_graph_transformed.py") as fd:
|
||||||
fd.write(gm.print_readable(print_output=False))
|
fd.write(gm.print_readable(print_output=False))
|
||||||
@ -557,14 +558,14 @@ class DebugFormatter:
|
|||||||
def log_autotuning_results(
|
def log_autotuning_results(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
input_nodes: List[ir.IRNode],
|
input_nodes: list[ir.IRNode],
|
||||||
timings: Dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
|
timings: dict["ChoiceCaller", float], # type: ignore[name-defined] # noqa: F821
|
||||||
elapse: float,
|
elapse: float,
|
||||||
precompile_elapse: float,
|
precompile_elapse: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
from .ir import FixedLayout
|
from .ir import FixedLayout
|
||||||
|
|
||||||
def build_node_info(node: ir.IRNode) -> Dict[str, str]:
|
def build_node_info(node: ir.IRNode) -> dict[str, str]:
|
||||||
if hasattr(node, "name"):
|
if hasattr(node, "name"):
|
||||||
node_name = node.name
|
node_name = node.name
|
||||||
else:
|
else:
|
||||||
@ -725,7 +726,7 @@ def aot_inductor_minifier_wrapper(
|
|||||||
func: Callable[..., str],
|
func: Callable[..., str],
|
||||||
exported_program: torch.export.ExportedProgram,
|
exported_program: torch.export.ExportedProgram,
|
||||||
*,
|
*,
|
||||||
inductor_configs: Dict[str, Any],
|
inductor_configs: dict[str, Any],
|
||||||
package_path: Optional[Union[str, io.BytesIO]] = None,
|
package_path: Optional[Union[str, io.BytesIO]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
|
@ -4,7 +4,7 @@ import logging
|
|||||||
import math
|
import math
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
|
|||||||
|
|
||||||
|
|
||||||
def register_decomposition(
|
def register_decomposition(
|
||||||
ops: List[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
||||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||||
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
||||||
if op in decompositions:
|
if op in decompositions:
|
||||||
@ -170,7 +170,7 @@ def clamp(
|
|||||||
|
|
||||||
@register_decomposition([aten.full])
|
@register_decomposition([aten.full])
|
||||||
def full(
|
def full(
|
||||||
size: List[Union[int, torch.SymInt]],
|
size: list[Union[int, torch.SymInt]],
|
||||||
fill_value: torch.types.Number,
|
fill_value: torch.types.Number,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
@ -205,8 +205,8 @@ def index_add(
|
|||||||
# cool with strides and everything goes to empty_strided)
|
# cool with strides and everything goes to empty_strided)
|
||||||
@register_decomposition([aten.empty_permuted.default])
|
@register_decomposition([aten.empty_permuted.default])
|
||||||
def empty_permuted(
|
def empty_permuted(
|
||||||
size: List[Union[int, torch.SymInt]],
|
size: list[Union[int, torch.SymInt]],
|
||||||
physical_layout: List[int],
|
physical_layout: list[int],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
perm = [0] * len(size)
|
perm = [0] * len(size)
|
||||||
@ -220,14 +220,14 @@ def convolution_backward(
|
|||||||
grad_output: torch.Tensor,
|
grad_output: torch.Tensor,
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
weight: torch.Tensor,
|
weight: torch.Tensor,
|
||||||
bias_sizes: List[int],
|
bias_sizes: list[int],
|
||||||
stride: Union[int, List[int]],
|
stride: Union[int, list[int]],
|
||||||
padding: Union[int, List[int]],
|
padding: Union[int, list[int]],
|
||||||
dilation: Union[int, List[int]],
|
dilation: Union[int, list[int]],
|
||||||
transposed: bool,
|
transposed: bool,
|
||||||
output_padding: List[int],
|
output_padding: list[int],
|
||||||
groups: int,
|
groups: int,
|
||||||
output_mask: List[bool],
|
output_mask: list[bool],
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
if not output_mask[2] or not is_gpu(grad_output.device.type):
|
if not output_mask[2] or not is_gpu(grad_output.device.type):
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
@ -345,7 +345,7 @@ def mm(
|
|||||||
# don't remove ALL empty tensors, only the naughty ones)
|
# don't remove ALL empty tensors, only the naughty ones)
|
||||||
@register_decomposition([aten.cat.default])
|
@register_decomposition([aten.cat.default])
|
||||||
def cat(
|
def cat(
|
||||||
tensors: List[torch.Tensor],
|
tensors: list[torch.Tensor],
|
||||||
dim: int = 0,
|
dim: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||||
@ -515,7 +515,7 @@ def narrow_copy(
|
|||||||
@register_decomposition([aten.view_copy.default])
|
@register_decomposition([aten.view_copy.default])
|
||||||
def view_copy_default(
|
def view_copy_default(
|
||||||
self: torch.Tensor,
|
self: torch.Tensor,
|
||||||
size: List[Union[int, torch.SymInt]],
|
size: list[Union[int, torch.SymInt]],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return aten.view(self, size).clone()
|
return aten.view(self, size).clone()
|
||||||
|
|
||||||
@ -639,7 +639,7 @@ def randint_like_low(
|
|||||||
@register_decomposition(aten.randint.default)
|
@register_decomposition(aten.randint.default)
|
||||||
def randint(
|
def randint(
|
||||||
high: int,
|
high: int,
|
||||||
size: List[Union[int, torch.SymInt]],
|
size: list[Union[int, torch.SymInt]],
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return aten.randint.low(0, high, size, **kwargs)
|
return aten.randint.low(0, high, size, **kwargs)
|
||||||
@ -731,11 +731,11 @@ def grid_sampler_2d(
|
|||||||
|
|
||||||
@register_decomposition(aten._foreach_addcmul.Scalar)
|
@register_decomposition(aten._foreach_addcmul.Scalar)
|
||||||
def _foreach_addcmul_scalar(
|
def _foreach_addcmul_scalar(
|
||||||
self: List[torch.Tensor],
|
self: list[torch.Tensor],
|
||||||
left_tensors: List[torch.Tensor],
|
left_tensors: list[torch.Tensor],
|
||||||
right_tensors: List[torch.Tensor],
|
right_tensors: list[torch.Tensor],
|
||||||
scalar: float = 1,
|
scalar: float = 1,
|
||||||
) -> List[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
return aten._foreach_add.List(
|
return aten._foreach_add.List(
|
||||||
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
|
self, aten._foreach_mul.List(left_tensors, right_tensors), alpha=scalar
|
||||||
)
|
)
|
||||||
@ -743,11 +743,11 @@ def _foreach_addcmul_scalar(
|
|||||||
|
|
||||||
@register_decomposition(aten._foreach_addcdiv.Scalar)
|
@register_decomposition(aten._foreach_addcdiv.Scalar)
|
||||||
def _foreach_addcdiv_scalar(
|
def _foreach_addcdiv_scalar(
|
||||||
self: List[torch.Tensor],
|
self: list[torch.Tensor],
|
||||||
left_tensors: List[torch.Tensor],
|
left_tensors: list[torch.Tensor],
|
||||||
right_tensors: List[torch.Tensor],
|
right_tensors: list[torch.Tensor],
|
||||||
scalar: float = 1,
|
scalar: float = 1,
|
||||||
) -> List[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
return aten._foreach_add.List(
|
return aten._foreach_add.List(
|
||||||
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
|
self, aten._foreach_div.List(left_tensors, right_tensors), alpha=scalar
|
||||||
)
|
)
|
||||||
@ -755,10 +755,10 @@ def _foreach_addcdiv_scalar(
|
|||||||
|
|
||||||
@register_decomposition(aten._foreach_lerp.Scalar)
|
@register_decomposition(aten._foreach_lerp.Scalar)
|
||||||
def _foreach_lerp_scalar(
|
def _foreach_lerp_scalar(
|
||||||
start_tensors: List[torch.Tensor],
|
start_tensors: list[torch.Tensor],
|
||||||
end_tensors: List[torch.Tensor],
|
end_tensors: list[torch.Tensor],
|
||||||
weight: torch.types.Number,
|
weight: torch.types.Number,
|
||||||
) -> List[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
return aten._foreach_add.List(
|
return aten._foreach_add.List(
|
||||||
start_tensors,
|
start_tensors,
|
||||||
aten._foreach_mul.Scalar(
|
aten._foreach_mul.Scalar(
|
||||||
@ -769,10 +769,10 @@ def _foreach_lerp_scalar(
|
|||||||
|
|
||||||
@register_decomposition(aten._foreach_lerp.ScalarList)
|
@register_decomposition(aten._foreach_lerp.ScalarList)
|
||||||
def _foreach_lerp_scalarlist(
|
def _foreach_lerp_scalarlist(
|
||||||
start_tensors: List[torch.Tensor],
|
start_tensors: list[torch.Tensor],
|
||||||
end_tensors: List[torch.Tensor],
|
end_tensors: list[torch.Tensor],
|
||||||
scalars: List[torch.types.Number],
|
scalars: list[torch.types.Number],
|
||||||
) -> List[torch.Tensor]:
|
) -> list[torch.Tensor]:
|
||||||
return aten._foreach_add.List(
|
return aten._foreach_add.List(
|
||||||
start_tensors,
|
start_tensors,
|
||||||
aten._foreach_mul.ScalarList(
|
aten._foreach_mul.ScalarList(
|
||||||
@ -814,13 +814,13 @@ def miopen_batch_norm(
|
|||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def fast_random_decomps() -> Dict[Any, Callable[..., Any]]:
|
def fast_random_decomps() -> dict[Any, Callable[..., Any]]:
|
||||||
return {**decompositions, **extra_random_decomps}
|
return {**decompositions, **extra_random_decomps}
|
||||||
|
|
||||||
|
|
||||||
# TODO(aakhundov): replace this (and the above) Any by more
|
# TODO(aakhundov): replace this (and the above) Any by more
|
||||||
# specific type and fix all the cascading mypy errors
|
# specific type and fix all the cascading mypy errors
|
||||||
def select_decomp_table() -> Dict[Any, Callable[..., Any]]:
|
def select_decomp_table() -> dict[Any, Callable[..., Any]]:
|
||||||
"""decomps can change based on config"""
|
"""decomps can change based on config"""
|
||||||
if config.fallback_random:
|
if config.fallback_random:
|
||||||
return decompositions
|
return decompositions
|
||||||
@ -965,10 +965,10 @@ def index_reduce(
|
|||||||
@register_decomposition(aten.max_pool2d_with_indices)
|
@register_decomposition(aten.max_pool2d_with_indices)
|
||||||
def max_pool2d_with_indices(
|
def max_pool2d_with_indices(
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
kernel_size: List[int],
|
kernel_size: list[int],
|
||||||
stride: Optional[Union[int, List[int]]] = None,
|
stride: Optional[Union[int, list[int]]] = None,
|
||||||
padding: Union[int, List[int]] = 0,
|
padding: Union[int, list[int]] = 0,
|
||||||
dilation: Union[int, List[int]] = 1,
|
dilation: Union[int, list[int]] = 1,
|
||||||
ceil_mode: bool = False,
|
ceil_mode: bool = False,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
if dilation == 1:
|
if dilation == 1:
|
||||||
@ -1015,7 +1015,7 @@ def max_pool2d_with_indices(
|
|||||||
|
|
||||||
@register_decomposition(aten.adaptive_max_pool2d)
|
@register_decomposition(aten.adaptive_max_pool2d)
|
||||||
def adaptive_max_pool2d(
|
def adaptive_max_pool2d(
|
||||||
x: torch.Tensor, output_size: List[int]
|
x: torch.Tensor, output_size: list[int]
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
*batch, h_in, w_in = x.shape
|
*batch, h_in, w_in = x.shape
|
||||||
h_out, w_out = output_size
|
h_out, w_out = output_size
|
||||||
|
@ -4,8 +4,8 @@ import dataclasses
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
import typing
|
from collections.abc import Sequence
|
||||||
from typing import Any, Callable, Dict, List, Optional, Sequence, TypeVar, Union
|
from typing import Any, Callable, Optional, TypeVar, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
@ -38,7 +38,7 @@ class Dep(abc.ABC):
|
|||||||
index: sympy.Expr
|
index: sympy.Expr
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def rename(self, renames: Dict[str, str]) -> "Dep":
|
def rename(self, renames: dict[str, str]) -> "Dep":
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@ -197,7 +197,7 @@ class MemoryDep(Dep):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def ranges(self) -> Dict[sympy.Symbol, sympy.Expr]:
|
def ranges(self) -> dict[sympy.Symbol, sympy.Expr]:
|
||||||
"""{c0: 128, c1: 512, ...}"""
|
"""{c0: 128, c1: 512, ...}"""
|
||||||
return dict(zip(self.var_names, self.size))
|
return dict(zip(self.var_names, self.size))
|
||||||
|
|
||||||
@ -221,7 +221,7 @@ class MemoryDep(Dep):
|
|||||||
numel = numel * size
|
numel = numel * size
|
||||||
return numel # type: ignore[return-value]
|
return numel # type: ignore[return-value]
|
||||||
|
|
||||||
def rename(self, renames: Dict[str, str]) -> "MemoryDep":
|
def rename(self, renames: dict[str, str]) -> "MemoryDep":
|
||||||
if self.name in renames:
|
if self.name in renames:
|
||||||
return MemoryDep(
|
return MemoryDep(
|
||||||
renames[self.name],
|
renames[self.name],
|
||||||
@ -299,7 +299,7 @@ class StarDep(Dep):
|
|||||||
def get_numel(self) -> sympy.Expr:
|
def get_numel(self) -> sympy.Expr:
|
||||||
return V.graph.get_numel(self.name) # type: ignore[return-value]
|
return V.graph.get_numel(self.name) # type: ignore[return-value]
|
||||||
|
|
||||||
def rename(self, renames: Dict[str, str]) -> "StarDep":
|
def rename(self, renames: dict[str, str]) -> "StarDep":
|
||||||
if self.name in renames:
|
if self.name in renames:
|
||||||
return StarDep(renames[self.name], self.mode)
|
return StarDep(renames[self.name], self.mode)
|
||||||
return self
|
return self
|
||||||
@ -347,7 +347,7 @@ class WeakDep(Dep):
|
|||||||
def get_numel(self) -> sympy.Expr:
|
def get_numel(self) -> sympy.Expr:
|
||||||
return sympy.S.One
|
return sympy.S.One
|
||||||
|
|
||||||
def rename(self, renames: Dict[str, str]) -> "WeakDep":
|
def rename(self, renames: dict[str, str]) -> "WeakDep":
|
||||||
if self.name in renames:
|
if self.name in renames:
|
||||||
return WeakDep(renames[self.name], self.mutating_buf)
|
return WeakDep(renames[self.name], self.mutating_buf)
|
||||||
return self
|
return self
|
||||||
@ -374,10 +374,10 @@ class ReadWrites:
|
|||||||
reads: OrderedSet[Dep]
|
reads: OrderedSet[Dep]
|
||||||
writes: OrderedSet[Dep]
|
writes: OrderedSet[Dep]
|
||||||
index_exprs: OrderedSet[IndexExprDep]
|
index_exprs: OrderedSet[IndexExprDep]
|
||||||
range_vars: Optional[List[sympy.Expr]] = None
|
range_vars: Optional[list[sympy.Expr]] = None
|
||||||
var_ranges: Optional[VarRanges] = None
|
var_ranges: Optional[VarRanges] = None
|
||||||
|
|
||||||
def rename(self, renames: typing.Dict[str, str]) -> "ReadWrites":
|
def rename(self, renames: dict[str, str]) -> "ReadWrites":
|
||||||
return ReadWrites(
|
return ReadWrites(
|
||||||
OrderedSet(dep.rename(renames) for dep in self.reads),
|
OrderedSet(dep.rename(renames) for dep in self.reads),
|
||||||
OrderedSet(dep.rename(renames) for dep in self.writes),
|
OrderedSet(dep.rename(renames) for dep in self.writes),
|
||||||
@ -405,7 +405,7 @@ class ReadWrites:
|
|||||||
return ReadWrites(reads - writes, writes, index_exprs)
|
return ReadWrites(reads - writes, writes, index_exprs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def merge_list(read_writes: List["ReadWrites"]):
|
def merge_list(read_writes: list["ReadWrites"]):
|
||||||
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
|
all_writes = OrderedSet.union(*[rw.writes for rw in read_writes])
|
||||||
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
|
all_reads = OrderedSet.union(*[rw.reads for rw in read_writes]) - all_writes
|
||||||
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
|
all_index_exprs = OrderedSet.union(*[rw.index_exprs for rw in read_writes])
|
||||||
@ -564,7 +564,7 @@ def var_builder(prefix: str) -> tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
|
|||||||
|
|
||||||
def index_vars_no_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str):
|
def index_vars_no_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str):
|
||||||
var_ranges, add_var = var_builder(prefix)
|
var_ranges, add_var = var_builder(prefix)
|
||||||
args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
|
args: list[list[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
|
||||||
return args, var_ranges
|
return args, var_ranges
|
||||||
|
|
||||||
|
|
||||||
@ -572,8 +572,8 @@ def index_vars_squeeze(*argsizes: Sequence[sympy.Expr], prefix: str = "d"):
|
|||||||
from .ir import SqueezeView
|
from .ir import SqueezeView
|
||||||
|
|
||||||
var_ranges, add_var = var_builder(prefix)
|
var_ranges, add_var = var_builder(prefix)
|
||||||
args: List[List[sympy.Expr]] = []
|
args: list[list[sympy.Expr]] = []
|
||||||
new_sizes: List[List[sympy.Expr]] = []
|
new_sizes: list[list[sympy.Expr]] = []
|
||||||
for size in argsizes:
|
for size in argsizes:
|
||||||
new_size, reindex = SqueezeView.squeezer(size)
|
new_size, reindex = SqueezeView.squeezer(size)
|
||||||
new_sizes.append(new_size)
|
new_sizes.append(new_size)
|
||||||
@ -653,7 +653,7 @@ def extract_loop_body_with_args(fn, args, var_ranges, normalize=False):
|
|||||||
|
|
||||||
def extract_input_node_reduction_ranges(
|
def extract_input_node_reduction_ranges(
|
||||||
input_node: "torch._inductor.ir.IRNode",
|
input_node: "torch._inductor.ir.IRNode",
|
||||||
) -> tuple[Optional[List[sympy.Expr]], Optional[List[sympy.Expr]]]:
|
) -> tuple[Optional[list[sympy.Expr]], Optional[list[sympy.Expr]]]:
|
||||||
"""
|
"""
|
||||||
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
|
Returns the size and reduction size of all inputs, if the sizes and reduction_sizes (if exist) are all the same.
|
||||||
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
|
It's possible that a node has multiple inputs, some are Reduction nodes and others are Pointwise nodes.
|
||||||
@ -663,8 +663,8 @@ def extract_input_node_reduction_ranges(
|
|||||||
|
|
||||||
from .ir import ComputedBuffer, ExternKernel, Loops
|
from .ir import ComputedBuffer, ExternKernel, Loops
|
||||||
|
|
||||||
size: Optional[List[sympy.Expr]]
|
size: Optional[list[sympy.Expr]]
|
||||||
reduction_size: Optional[List[sympy.Expr]]
|
reduction_size: Optional[list[sympy.Expr]]
|
||||||
|
|
||||||
if isinstance(input_node.get_defining_op(), ComputedBuffer):
|
if isinstance(input_node.get_defining_op(), ComputedBuffer):
|
||||||
# Input node has already been realized. Return its size and reduction_size.
|
# Input node has already been realized. Return its size and reduction_size.
|
||||||
@ -683,11 +683,11 @@ def extract_input_node_reduction_ranges(
|
|||||||
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
|
# The current method still uses reduction ranges from the dependent realized node, which is not ideal.
|
||||||
# Is there a way to check whether there are permutations inbetween?
|
# Is there a way to check whether there are permutations inbetween?
|
||||||
reads = input_node.get_reads()
|
reads = input_node.get_reads()
|
||||||
reduction_size: Optional[List[sympy.Expr]] = None
|
reduction_size: Optional[list[sympy.Expr]] = None
|
||||||
size: Optional[List[sympy.Expr]] = None
|
size: Optional[list[sympy.Expr]] = None
|
||||||
while reduction_size is None and len(reads) > 0:
|
while reduction_size is None and len(reads) > 0:
|
||||||
seen: OrderedSet[str] = OrderedSet()
|
seen: OrderedSet[str] = OrderedSet()
|
||||||
new_reads: List[Dep] = []
|
new_reads: list[Dep] = []
|
||||||
for read in reads:
|
for read in reads:
|
||||||
if not isinstance(read, MemoryDep):
|
if not isinstance(read, MemoryDep):
|
||||||
continue
|
continue
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import functools
|
import functools
|
||||||
from typing import Callable, Optional, Protocol, Sequence, TYPE_CHECKING, TypeVar, Union
|
from collections.abc import Sequence
|
||||||
|
from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
|
@ -4,7 +4,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import textwrap
|
import textwrap
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Any, List, Optional, TYPE_CHECKING
|
from typing import Any, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback
|
from torch._dynamo.exc import BackendCompilerFailed, ShortenTraceback
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ else:
|
|||||||
|
|
||||||
class OperatorIssue(RuntimeError):
|
class OperatorIssue(RuntimeError):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def operator_str(target: Any, args: List[Any], kwargs: dict[str, Any]) -> str:
|
def operator_str(target: Any, args: list[Any], kwargs: dict[str, Any]) -> str:
|
||||||
lines = [f"target: {target}"] + [
|
lines = [f"target: {target}"] + [
|
||||||
f"args[{i}]: {arg}" for i, arg in enumerate(args)
|
f"args[{i}]: {arg}" for i, arg in enumerate(args)
|
||||||
]
|
]
|
||||||
@ -39,13 +39,13 @@ class OperatorIssue(RuntimeError):
|
|||||||
|
|
||||||
|
|
||||||
class MissingOperatorWithoutDecomp(OperatorIssue):
|
class MissingOperatorWithoutDecomp(OperatorIssue):
|
||||||
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None:
|
def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
|
||||||
_record_missing_op(target)
|
_record_missing_op(target)
|
||||||
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
|
super().__init__(f"missing lowering\n{self.operator_str(target, args, kwargs)}")
|
||||||
|
|
||||||
|
|
||||||
class MissingOperatorWithDecomp(OperatorIssue):
|
class MissingOperatorWithDecomp(OperatorIssue):
|
||||||
def __init__(self, target: Any, args: List[Any], kwargs: dict[str, Any]) -> None:
|
def __init__(self, target: Any, args: list[Any], kwargs: dict[str, Any]) -> None:
|
||||||
_record_missing_op(target)
|
_record_missing_op(target)
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
|
f"missing decomposition\n{self.operator_str(target, args, kwargs)}"
|
||||||
@ -62,7 +62,7 @@ class MissingOperatorWithDecomp(OperatorIssue):
|
|||||||
|
|
||||||
class LoweringException(OperatorIssue):
|
class LoweringException(OperatorIssue):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, exc: Exception, target: Any, args: List[Any], kwargs: dict[str, Any]
|
self, exc: Exception, target: Any, args: list[Any], kwargs: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
|
f"{type(exc).__name__}: {exc}\n{self.operator_str(target, args, kwargs)}"
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
from typing import List
|
|
||||||
|
|
||||||
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
|
from torch._export.serde.aoti_schema import ExternKernelNode, ExternKernelNodes, Node
|
||||||
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
|
from torch._export.serde.serialize import _dataclass_to_dict, EnumEncoder
|
||||||
@ -17,7 +16,7 @@ def serialize_extern_kernel_node(
|
|||||||
|
|
||||||
|
|
||||||
def extern_node_json_serializer(
|
def extern_node_json_serializer(
|
||||||
extern_kernel_nodes: List[inductor_ExternKernelNode],
|
extern_kernel_nodes: list[inductor_ExternKernelNode],
|
||||||
) -> str:
|
) -> str:
|
||||||
serialized_nodes = ExternKernelNodes(
|
serialized_nodes = ExternKernelNodes(
|
||||||
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
|
nodes=[serialize_extern_kernel_node(node) for node in extern_kernel_nodes]
|
||||||
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
@ -28,7 +28,7 @@ def replace_params_with_constants(
|
|||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
flat_params: list[Any],
|
flat_params: list[Any],
|
||||||
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
|
fw_metadata: torch._functorch.aot_autograd.ViewAndMutationMeta,
|
||||||
) -> List[int]:
|
) -> list[int]:
|
||||||
"""
|
"""
|
||||||
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
|
Replaces the parameters of a PyTorch GraphModule with constants wherever possible.
|
||||||
Returns a list of indices representing the input parameters that were not converted to constants.
|
Returns a list of indices representing the input parameters that were not converted to constants.
|
||||||
@ -66,8 +66,8 @@ def replace_params_with_constants(
|
|||||||
def freeze(
|
def freeze(
|
||||||
dynamo_gm: torch.fx.GraphModule,
|
dynamo_gm: torch.fx.GraphModule,
|
||||||
aot_autograd_gm: torch.fx.GraphModule,
|
aot_autograd_gm: torch.fx.GraphModule,
|
||||||
example_inputs: List[torch._subclasses.FakeTensor],
|
example_inputs: list[torch._subclasses.FakeTensor],
|
||||||
) -> tuple[torch.fx.GraphModule, List[int]]:
|
) -> tuple[torch.fx.GraphModule, list[int]]:
|
||||||
"""
|
"""
|
||||||
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
|
Inlines parameters that are not mutated into constants and optimizes the graph through constant propagation
|
||||||
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
|
and other techniques. If enabled, the function also discards the original parameters of the module for memory efficiency.
|
||||||
|
@ -6,21 +6,17 @@ import signal
|
|||||||
import string
|
import string
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
from collections.abc import KeysView
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from types import FrameType
|
from types import FrameType
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
|
||||||
get_args,
|
get_args,
|
||||||
get_origin,
|
get_origin,
|
||||||
KeysView,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Tuple,
|
|
||||||
Type,
|
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
@ -92,14 +88,14 @@ class TypeExemplars:
|
|||||||
}
|
}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def example(t: Type[T]) -> Optional[T]:
|
def example(t: type[T]) -> Optional[T]:
|
||||||
"""
|
"""
|
||||||
Return an example of a class.
|
Return an example of a class.
|
||||||
"""
|
"""
|
||||||
return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None)
|
return TypeExemplars.TYPE_EXEMPLARS.get(t.__name__, None)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def contains(t: Type[T]) -> bool:
|
def contains(t: type[T]) -> bool:
|
||||||
return t.__name__ in TypeExemplars.TYPE_EXEMPLARS
|
return t.__name__ in TypeExemplars.TYPE_EXEMPLARS
|
||||||
|
|
||||||
|
|
||||||
@ -136,7 +132,7 @@ class Status(Enum):
|
|||||||
|
|
||||||
# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be
|
# Sometime the types of configs aren't expressive enough to be captured by python type system, so the options can be
|
||||||
# manually specified here:
|
# manually specified here:
|
||||||
TYPE_OVERRIDES: Dict[str, List[Any]] = {
|
TYPE_OVERRIDES: dict[str, list[Any]] = {
|
||||||
"post_grad_fusion_options": [
|
"post_grad_fusion_options": [
|
||||||
{
|
{
|
||||||
"batch_linear_post_grad": {
|
"batch_linear_post_grad": {
|
||||||
@ -160,7 +156,7 @@ TYPE_OVERRIDES: Dict[str, List[Any]] = {
|
|||||||
"autoheuristic_collect": ["pad_mm", "mixed_mm"],
|
"autoheuristic_collect": ["pad_mm", "mixed_mm"],
|
||||||
"autoheuristic_use": ["pad_mm", "mixed_mm"],
|
"autoheuristic_use": ["pad_mm", "mixed_mm"],
|
||||||
}
|
}
|
||||||
SamplingType = Callable[[str, Type[Any], Any], Any]
|
SamplingType = Callable[[str, type[Any], Any], Any]
|
||||||
|
|
||||||
|
|
||||||
class SamplingMethod(Enum):
|
class SamplingMethod(Enum):
|
||||||
@ -178,7 +174,7 @@ class SamplingMethod(Enum):
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _generate_value_for_type(
|
def _generate_value_for_type(
|
||||||
random_sample: bool, field_name: str, type_hint: Type[Any], default: Any
|
random_sample: bool, field_name: str, type_hint: type[Any], default: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
Generates a value of a type based on the setting.
|
Generates a value of a type based on the setting.
|
||||||
@ -304,9 +300,11 @@ class SamplingMethod(Enum):
|
|||||||
if random_sample:
|
if random_sample:
|
||||||
return random.choice(type_hint.__args__)
|
return random.choice(type_hint.__args__)
|
||||||
else:
|
else:
|
||||||
return random.choice(
|
choices = [t for t in type_hint.__args__ if t != default]
|
||||||
[t for t in type_hint.__args__ if t != default]
|
if choices:
|
||||||
)
|
return random.choice(choices)
|
||||||
|
else:
|
||||||
|
return default
|
||||||
except AttributeError as err:
|
except AttributeError as err:
|
||||||
raise ValueError("Literal type with no args") from err
|
raise ValueError("Literal type with no args") from err
|
||||||
elif is_optional_type(type_hint):
|
elif is_optional_type(type_hint):
|
||||||
@ -374,7 +372,7 @@ class Default:
|
|||||||
DEFAULT = Default()
|
DEFAULT = Default()
|
||||||
|
|
||||||
# The combination of config settings being set (based on their strings)
|
# The combination of config settings being set (based on their strings)
|
||||||
ComboType = Tuple[str, ...]
|
ComboType = tuple[str, ...]
|
||||||
|
|
||||||
|
|
||||||
class ResultType:
|
class ResultType:
|
||||||
@ -382,7 +380,7 @@ class ResultType:
|
|||||||
The mapping of the combo strings to the result status after running the config fuzzer.
|
The mapping of the combo strings to the result status after running the config fuzzer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_vals: Dict[ComboType, Status]
|
_vals: dict[ComboType, Status]
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return f"ResultType[{self._vals}]"
|
return f"ResultType[{self._vals}]"
|
||||||
@ -416,7 +414,7 @@ class ResultType:
|
|||||||
|
|
||||||
|
|
||||||
# Type that maps config strings to their default value
|
# Type that maps config strings to their default value
|
||||||
ConfigType = Dict[str, Any]
|
ConfigType = dict[str, Any]
|
||||||
# Callable that returns a bool
|
# Callable that returns a bool
|
||||||
FactoryOutputType = Callable[[], bool]
|
FactoryOutputType = Callable[[], bool]
|
||||||
# input function factory
|
# input function factory
|
||||||
@ -504,10 +502,10 @@ class ConfigFuzzer:
|
|||||||
return
|
return
|
||||||
self.seed = seed
|
self.seed = seed
|
||||||
self.test_timeout = test_timeout
|
self.test_timeout = test_timeout
|
||||||
self.detailed_results: Dict[ComboType, Dict[str, Any]] = {}
|
self.detailed_results: dict[ComboType, dict[str, Any]] = {}
|
||||||
self.config_module = config_module
|
self.config_module = config_module
|
||||||
self.test_model_fn_factory = test_model_fn_factory
|
self.test_model_fn_factory = test_model_fn_factory
|
||||||
self.fields: Dict[str, _ConfigEntry] = self.config_module._config
|
self.fields: dict[str, _ConfigEntry] = self.config_module._config
|
||||||
self.sample = SamplingMethod.dispatch(sm)
|
self.sample = SamplingMethod.dispatch(sm)
|
||||||
|
|
||||||
if default is None:
|
if default is None:
|
||||||
@ -587,7 +585,7 @@ class ConfigFuzzer:
|
|||||||
}
|
}
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def reproduce(self, configs: List[ConfigType]) -> ResultType:
|
def reproduce(self, configs: list[ConfigType]) -> ResultType:
|
||||||
"""entrypoint to reproduce any failure"""
|
"""entrypoint to reproduce any failure"""
|
||||||
results = ResultType()
|
results = ResultType()
|
||||||
for conf in configs:
|
for conf in configs:
|
||||||
@ -675,7 +673,7 @@ class ConfigFuzzer:
|
|||||||
for field, value in config.items():
|
for field, value in config.items():
|
||||||
print(f"{field} = {value}")
|
print(f"{field} = {value}")
|
||||||
|
|
||||||
def get_error_info(exc: Exception) -> Dict[str, Any]:
|
def get_error_info(exc: Exception) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"exception": str(exc),
|
"exception": str(exc),
|
||||||
"traceback": traceback.format_exc(),
|
"traceback": traceback.format_exc(),
|
||||||
@ -741,7 +739,7 @@ class ConfigFuzzer:
|
|||||||
else:
|
else:
|
||||||
return handle_return("Function succeeded", Status.PASSED, False, None)
|
return handle_return("Function succeeded", Status.PASSED, False, None)
|
||||||
|
|
||||||
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> List[ConfigType]:
|
def bisect(self, num_attempts: int = 100, p: float = 0.5) -> list[ConfigType]:
|
||||||
"""
|
"""
|
||||||
Test configs and bisect to minimal failing configuration.
|
Test configs and bisect to minimal failing configuration.
|
||||||
"""
|
"""
|
||||||
@ -749,7 +747,7 @@ class ConfigFuzzer:
|
|||||||
random.seed(self.seed)
|
random.seed(self.seed)
|
||||||
self._reset_configs()
|
self._reset_configs()
|
||||||
results = ResultType()
|
results = ResultType()
|
||||||
ret: List[ConfigType] = []
|
ret: list[ConfigType] = []
|
||||||
|
|
||||||
for attempt in range(num_attempts):
|
for attempt in range(num_attempts):
|
||||||
print(f"Random attempt {attempt + 1}/{num_attempts}")
|
print(f"Random attempt {attempt + 1}/{num_attempts}")
|
||||||
@ -783,7 +781,7 @@ class ConfigFuzzer:
|
|||||||
return self._bisect_failing_config_helper(results, list(failing_config.items()))
|
return self._bisect_failing_config_helper(results, list(failing_config.items()))
|
||||||
|
|
||||||
def _bisect_failing_config_helper(
|
def _bisect_failing_config_helper(
|
||||||
self, results: ResultType, failing_config: List[Tuple[str, Any]]
|
self, results: ResultType, failing_config: list[tuple[str, Any]]
|
||||||
) -> Optional[ConfigType]:
|
) -> Optional[ConfigType]:
|
||||||
"""
|
"""
|
||||||
Bisect a failing configuration to find minimal set of configs that cause failure.
|
Bisect a failing configuration to find minimal set of configs that cause failure.
|
||||||
@ -795,7 +793,7 @@ class ConfigFuzzer:
|
|||||||
if not failing_config:
|
if not failing_config:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def test(x: List[Tuple[str, Any]]) -> Status:
|
def test(x: list[tuple[str, Any]]) -> Status:
|
||||||
d = dict(x)
|
d = dict(x)
|
||||||
result = self.test_config(results, d)
|
result = self.test_config(results, d)
|
||||||
return result
|
return result
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import operator
|
import operator
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Callable, DefaultDict, Dict, Optional, Type
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
@ -24,9 +24,9 @@ from .virtualized import V
|
|||||||
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
# Check the pattern: (nn.module, F.function/torch.Tensor.method) matched.
|
||||||
# Works for length 2 patterns with 1 module and 1 function/method.
|
# Works for length 2 patterns with 1 module and 1 function/method.
|
||||||
def matches_module_function_pattern(
|
def matches_module_function_pattern(
|
||||||
pattern: tuple[Type[torch.nn.modules.Module], Callable[..., Any]],
|
pattern: tuple[type[torch.nn.modules.Module], Callable[..., Any]],
|
||||||
node: torch.fx.node.Node,
|
node: torch.fx.node.Node,
|
||||||
modules: Dict[str, torch.nn.modules.Module],
|
modules: dict[str, torch.nn.modules.Module],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
if len(node.args) == 0:
|
if len(node.args) == 0:
|
||||||
return False
|
return False
|
||||||
@ -86,7 +86,7 @@ class FakeTensorUpdater:
|
|||||||
return (node, node.target, id(node.args), id(node.kwargs))
|
return (node, node.target, id(node.args), id(node.kwargs))
|
||||||
|
|
||||||
def incremental_update(self):
|
def incremental_update(self):
|
||||||
existing_storages: DefaultDict[Optional[int], int] = defaultdict(int)
|
existing_storages: defaultdict[Optional[int], int] = defaultdict(int)
|
||||||
for node in self.graph.nodes:
|
for node in self.graph.nodes:
|
||||||
existing_storages[get_node_storage(node)] += 1
|
existing_storages[get_node_storage(node)] += 1
|
||||||
|
|
||||||
@ -208,7 +208,7 @@ def get_fake(x):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], Dict[str, Any]]:
|
def get_fake_args_kwargs(x: torch.fx.Node) -> tuple[bool, tuple[Any], dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
First value returns a boolean if any of the input nodes don't have a faketensor.
|
First value returns a boolean if any of the input nodes don't have a faketensor.
|
||||||
"""
|
"""
|
||||||
|
@ -8,22 +8,10 @@ import re
|
|||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
from collections.abc import Iterable, Iterator, Sequence
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import (
|
from typing import Any, Callable, NoReturn, Optional, TYPE_CHECKING, Union
|
||||||
Any,
|
|
||||||
Callable,
|
|
||||||
DefaultDict,
|
|
||||||
Dict,
|
|
||||||
Iterable,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
NoReturn,
|
|
||||||
Optional,
|
|
||||||
Sequence,
|
|
||||||
TYPE_CHECKING,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
from sympy import Expr
|
from sympy import Expr
|
||||||
@ -198,8 +186,8 @@ def getattr_recursive(
|
|||||||
return attr_itr
|
return attr_itr
|
||||||
|
|
||||||
|
|
||||||
def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
|
def get_user_visible_output_strides(g: Graph) -> dict[Node, tuple[int, ...]]:
|
||||||
ret: Dict[Node, tuple[int, ...]] = {}
|
ret: dict[Node, tuple[int, ...]] = {}
|
||||||
output_node = g.find_nodes(op="output")[0]
|
output_node = g.find_nodes(op="output")[0]
|
||||||
|
|
||||||
if "user_visible_output_idxs" not in output_node.meta:
|
if "user_visible_output_idxs" not in output_node.meta:
|
||||||
@ -212,7 +200,7 @@ def get_user_visible_output_strides(g: Graph) -> Dict[Node, tuple[int, ...]]:
|
|||||||
|
|
||||||
|
|
||||||
def mark_nodes_dislike_padding(
|
def mark_nodes_dislike_padding(
|
||||||
g: Graph, user_visible_output_strides: Dict[Node, tuple[int, ...]]
|
g: Graph, user_visible_output_strides: dict[Node, tuple[int, ...]]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Nodes like convolution/convolution_backward want its input to be dense.
|
Nodes like convolution/convolution_backward want its input to be dense.
|
||||||
@ -282,7 +270,7 @@ def mark_nodes_dislike_padding(
|
|||||||
|
|
||||||
|
|
||||||
class GraphLowering(torch.fx.Interpreter):
|
class GraphLowering(torch.fx.Interpreter):
|
||||||
graph_outputs: List[ir.IRNode]
|
graph_outputs: list[ir.IRNode]
|
||||||
|
|
||||||
def symbolic_sizes_strides(
|
def symbolic_sizes_strides(
|
||||||
self, ex: torch.Tensor
|
self, ex: torch.Tensor
|
||||||
@ -323,7 +311,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
|
|
||||||
def static_sizes_strides(
|
def static_sizes_strides(
|
||||||
self, ex: torch.Tensor
|
self, ex: torch.Tensor
|
||||||
) -> tuple[List[sympy.Expr], List[sympy.Expr]]:
|
) -> tuple[list[sympy.Expr], list[sympy.Expr]]:
|
||||||
"""
|
"""
|
||||||
Primarily used to weights
|
Primarily used to weights
|
||||||
"""
|
"""
|
||||||
@ -341,12 +329,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
aot_mode: bool = False,
|
aot_mode: bool = False,
|
||||||
layout_opt: Optional[bool] = None,
|
layout_opt: Optional[bool] = None,
|
||||||
extern_node_serializer: Optional[
|
extern_node_serializer: Optional[
|
||||||
Callable[[List[ir.ExternKernelNode]], Any]
|
Callable[[list[ir.ExternKernelNode]], Any]
|
||||||
] = None,
|
] = None,
|
||||||
is_inference: bool = False,
|
is_inference: bool = False,
|
||||||
is_backward: bool = False,
|
is_backward: bool = False,
|
||||||
is_const_graph: bool = False,
|
is_const_graph: bool = False,
|
||||||
const_output_index: Optional[Dict[str, int]] = None,
|
const_output_index: Optional[dict[str, int]] = None,
|
||||||
const_code: Optional[str] = None,
|
const_code: Optional[str] = None,
|
||||||
const_module: Optional["GraphLowering"] = None,
|
const_module: Optional["GraphLowering"] = None,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
@ -379,14 +367,14 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
# you don't start adding new ones in the lowering process
|
# you don't start adding new ones in the lowering process
|
||||||
shape_env.freeze_runtime_asserts()
|
shape_env.freeze_runtime_asserts()
|
||||||
# We're going to mutate ras_by_symbol as we finish generating them
|
# We're going to mutate ras_by_symbol as we finish generating them
|
||||||
self.ras_by_symbol: Dict[
|
self.ras_by_symbol: dict[
|
||||||
Optional[sympy.Symbol], List[RuntimeAssert]
|
Optional[sympy.Symbol], list[RuntimeAssert]
|
||||||
] = shape_env.deferred_runtime_asserts.copy()
|
] = shape_env.deferred_runtime_asserts.copy()
|
||||||
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
||||||
self.sizevars = SizeVarAllocator(shape_env)
|
self.sizevars = SizeVarAllocator(shape_env)
|
||||||
self.graph_input_names: List[str] = []
|
self.graph_input_names: list[str] = []
|
||||||
self.graph_inputs: Dict[str, TensorBox] = {}
|
self.graph_inputs: dict[str, TensorBox] = {}
|
||||||
self.graph_inputs_original: Dict[str, InputBuffer] = {}
|
self.graph_inputs_original: dict[str, InputBuffer] = {}
|
||||||
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
|
self.zero_dim_cpu_tensor_list = OrderedSet[str]()
|
||||||
self.device_types: OrderedSet[str] = (
|
self.device_types: OrderedSet[str] = (
|
||||||
const_module.device_types if const_module else OrderedSet()
|
const_module.device_types if const_module else OrderedSet()
|
||||||
@ -395,9 +383,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
const_module.device_idxs if const_module else OrderedSet()
|
const_module.device_idxs if const_module else OrderedSet()
|
||||||
)
|
)
|
||||||
self.device_type = "cpu"
|
self.device_type = "cpu"
|
||||||
self.buffers: List[ir.Buffer] = []
|
self.buffers: list[ir.Buffer] = []
|
||||||
self.operations: List[ir.Operation] = []
|
self.operations: list[ir.Operation] = []
|
||||||
self.const_output_index: Dict[str, int] = (
|
self.const_output_index: dict[str, int] = (
|
||||||
const_output_index if const_output_index else {}
|
const_output_index if const_output_index else {}
|
||||||
)
|
)
|
||||||
self.folded_constants: OrderedSet[str] = (
|
self.folded_constants: OrderedSet[str] = (
|
||||||
@ -405,12 +393,12 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
if const_output_index
|
if const_output_index
|
||||||
else OrderedSet()
|
else OrderedSet()
|
||||||
)
|
)
|
||||||
self.constants: Dict[str, torch.Tensor] = (
|
self.constants: dict[str, torch.Tensor] = (
|
||||||
const_module.constants if const_module else {}
|
const_module.constants if const_module else {}
|
||||||
)
|
)
|
||||||
self.torchbind_constants: Dict[str, torch._C.ScriptObject] = {}
|
self.torchbind_constants: dict[str, torch._C.ScriptObject] = {}
|
||||||
self.seen_subgraphs: Dict[str, ir.Subgraph] = {}
|
self.seen_subgraphs: dict[str, ir.Subgraph] = {}
|
||||||
self.constant_reprs: Dict[str, str] = {}
|
self.constant_reprs: dict[str, str] = {}
|
||||||
self.removed_operations = OrderedSet[str]()
|
self.removed_operations = OrderedSet[str]()
|
||||||
self.removed_buffers = OrderedSet[str]()
|
self.removed_buffers = OrderedSet[str]()
|
||||||
self.removed_inplace_buffers = OrderedSet[str]()
|
self.removed_inplace_buffers = OrderedSet[str]()
|
||||||
@ -420,23 +408,23 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
self.device_ops: DeviceOpOverrides = None # type: ignore[assignment]
|
||||||
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
self.wrapper_code: PythonWrapperCodegen = None # type: ignore[assignment]
|
||||||
# See `ProxyExecutor Design Note` in ir.py for more details
|
# See `ProxyExecutor Design Note` in ir.py for more details
|
||||||
self.extern_kernel_nodes: List[ir.ExternKernelNode] = []
|
self.extern_kernel_nodes: list[ir.ExternKernelNode] = []
|
||||||
|
|
||||||
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
from torch._inductor.extern_node_serializer import extern_node_json_serializer
|
||||||
|
|
||||||
self.extern_node_serializer: Callable[[List[ir.ExternKernelNode]], Any] = (
|
self.extern_node_serializer: Callable[[list[ir.ExternKernelNode]], Any] = (
|
||||||
extern_node_serializer
|
extern_node_serializer
|
||||||
if config.is_fbcode() and extern_node_serializer
|
if config.is_fbcode() and extern_node_serializer
|
||||||
else extern_node_json_serializer
|
else extern_node_json_serializer
|
||||||
)
|
)
|
||||||
|
|
||||||
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
self.current_node: torch.fx.Node = None # type: ignore[assignment]
|
||||||
self.lists: Dict[str, List[str]] = {}
|
self.lists: dict[str, list[str]] = {}
|
||||||
self.mutated_inputs = OrderedSet[str]()
|
self.mutated_inputs = OrderedSet[str]()
|
||||||
self.mutated_input_idxs: List[int] = []
|
self.mutated_input_idxs: list[int] = []
|
||||||
self.name_to_buffer: Dict[str, ir.Buffer] = {}
|
self.name_to_buffer: dict[str, ir.Buffer] = {}
|
||||||
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
|
self.name_to_users: defaultdict[str, list[ir.IRNode]] = defaultdict(list)
|
||||||
self.name_to_op: Dict[str, ir.Operation] = {}
|
self.name_to_op: dict[str, ir.Operation] = {}
|
||||||
self.creation_time = time.time()
|
self.creation_time = time.time()
|
||||||
self.name = name # type: ignore[assignment]
|
self.name = name # type: ignore[assignment]
|
||||||
self.cpp_wrapper = cpp_wrapper
|
self.cpp_wrapper = cpp_wrapper
|
||||||
@ -445,7 +433,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
# which sub-kernel is picked. Copy cpp_wrapper to another variable
|
||||||
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
|
# since cpp_wrapper flag is OrderedSet to false for the first pass of codegen.
|
||||||
self.record_multi_kernel_choice = cpp_wrapper
|
self.record_multi_kernel_choice = cpp_wrapper
|
||||||
self.multi_kernel_to_choice: Dict[str, str] = {}
|
self.multi_kernel_to_choice: dict[str, str] = {}
|
||||||
|
|
||||||
self.aot_mode = aot_mode
|
self.aot_mode = aot_mode
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
@ -464,7 +452,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
|
mark_nodes_dislike_padding(gm.graph, self.user_visible_output_strides)
|
||||||
self.cache_key: str = "" # This is the cache key for the compiled artifact
|
self.cache_key: str = "" # This is the cache key for the compiled artifact
|
||||||
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
||||||
self.cache_linemap: List[
|
self.cache_linemap: list[
|
||||||
tuple[int, str]
|
tuple[int, str]
|
||||||
] = (
|
] = (
|
||||||
[]
|
[]
|
||||||
@ -473,18 +461,18 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
self.disable_cudagraphs_reason: Optional[str] = None
|
self.disable_cudagraphs_reason: Optional[str] = None
|
||||||
|
|
||||||
# only keeping one node per device for stack trace purposes
|
# only keeping one node per device for stack trace purposes
|
||||||
self.device_node_mapping: Dict[torch.device, torch.fx.Node] = {}
|
self.device_node_mapping: dict[torch.device, torch.fx.Node] = {}
|
||||||
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
|
self.orig_gm: torch.fx.GraphModule = gm.__copy__()
|
||||||
self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
|
self.dynamo_flat_name_to_original_fqn = self.module.meta.get( # type: ignore[operator, union-attr]
|
||||||
"dynamo_flat_name_to_original_fqn", {}
|
"dynamo_flat_name_to_original_fqn", {}
|
||||||
)
|
)
|
||||||
self.allocated_constant_name: Dict[str, str] = (
|
self.allocated_constant_name: dict[str, str] = (
|
||||||
const_module.allocated_constant_name if const_module is not None else {}
|
const_module.allocated_constant_name if const_module is not None else {}
|
||||||
)
|
)
|
||||||
init_backend_registration()
|
init_backend_registration()
|
||||||
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
|
self.get_backend_features = functools.lru_cache(None)(get_backend_features)
|
||||||
|
|
||||||
self.effectful_ops: Dict[_EffectType, ir.Buffer] = {}
|
self.effectful_ops: dict[_EffectType, ir.Buffer] = {}
|
||||||
self.aligned_inputs = OrderedSet[str]()
|
self.aligned_inputs = OrderedSet[str]()
|
||||||
self.no_fuse_buffer_names = OrderedSet[str]()
|
self.no_fuse_buffer_names = OrderedSet[str]()
|
||||||
|
|
||||||
@ -599,7 +587,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
if is_inference:
|
if is_inference:
|
||||||
from torch.utils.flop_counter import FlopCounterMode
|
from torch.utils.flop_counter import FlopCounterMode
|
||||||
|
|
||||||
flop_counts: Dict[str, float] = defaultdict(float)
|
flop_counts: dict[str, float] = defaultdict(float)
|
||||||
for node in conv_nodes:
|
for node in conv_nodes:
|
||||||
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
|
success, args, kwargs = torch._inductor.fx_utils.get_fake_args_kwargs(
|
||||||
node
|
node
|
||||||
@ -702,7 +690,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
def make_subgraph(
|
def make_subgraph(
|
||||||
self,
|
self,
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
example_inputs: List[torch.Tensor],
|
example_inputs: list[torch.Tensor],
|
||||||
subgraph_name: str,
|
subgraph_name: str,
|
||||||
) -> "SubgraphLowering":
|
) -> "SubgraphLowering":
|
||||||
"""
|
"""
|
||||||
@ -886,7 +874,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
buffer.name = name
|
buffer.name = name
|
||||||
return name
|
return name
|
||||||
|
|
||||||
def register_operation_list(self, operation_names: List[str]) -> str:
|
def register_operation_list(self, operation_names: list[str]) -> str:
|
||||||
name = self.qualify_name("list_" + "_".join(operation_names))
|
name = self.qualify_name("list_" + "_".join(operation_names))
|
||||||
self.lists[name] = operation_names
|
self.lists[name] = operation_names
|
||||||
return name
|
return name
|
||||||
@ -995,7 +983,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
||||||
) -> Union[Expr, TensorBox, None]:
|
) -> Union[Expr, TensorBox, None]:
|
||||||
self.placeholder_idx += 1
|
self.placeholder_idx += 1
|
||||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||||
@ -1072,7 +1060,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
self.aligned_inputs.add(target)
|
self.aligned_inputs.add(target)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def call_function(self, target: Callable, args: Any, kwargs: Dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
def call_function(self, target: Callable, args: Any, kwargs: dict[str, Any]) -> Any: # type: ignore[type-arg, override]
|
||||||
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
if target is operator.getitem and isinstance(args[0], (list, tuple, dict)):
|
||||||
return super().call_function(target, args, kwargs)
|
return super().call_function(target, args, kwargs)
|
||||||
|
|
||||||
@ -1155,7 +1143,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
return len(t.shape) == 1 and t.shape[0] <= 8
|
return len(t.shape) == 1 and t.shape[0] <= 8
|
||||||
|
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: str, args: tuple[()], kwargs: Dict[str, object] # type: ignore[override]
|
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
|
||||||
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
||||||
# this is a constant
|
# this is a constant
|
||||||
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
||||||
@ -1203,7 +1191,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
raise AssertionError
|
raise AssertionError
|
||||||
|
|
||||||
def output(
|
def output(
|
||||||
self, target: str, args: tuple[object], kwargs: Dict[str, object] # type: ignore[override]
|
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
||||||
) -> None:
|
) -> None:
|
||||||
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||||
if not isinstance(result, (tuple, list)):
|
if not isinstance(result, (tuple, list)):
|
||||||
@ -1306,9 +1294,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
self,
|
self,
|
||||||
fx_node: torch.fx.Node,
|
fx_node: torch.fx.Node,
|
||||||
old_args: tuple[Any],
|
old_args: tuple[Any],
|
||||||
old_kwargs: Dict[str, Any],
|
old_kwargs: dict[str, Any],
|
||||||
new_args: tuple[Any],
|
new_args: tuple[Any],
|
||||||
new_kwargs: Dict[str, Any],
|
new_kwargs: dict[str, Any],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
|
"""Propagate mutations on new_args/new_kwargs back to old_args/old_kwargs.
|
||||||
|
|
||||||
@ -1803,7 +1791,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
self.const_module.wrapper_code.src_to_kernel
|
self.const_module.wrapper_code.src_to_kernel
|
||||||
)
|
)
|
||||||
|
|
||||||
def codegen_with_cpp_wrapper(self) -> tuple[str, List[tuple[int, Node]]]:
|
def codegen_with_cpp_wrapper(self) -> tuple[str, list[tuple[int, Node]]]:
|
||||||
"""
|
"""
|
||||||
For GPU, Triton kernels are autotuned and stored as cubin files
|
For GPU, Triton kernels are autotuned and stored as cubin files
|
||||||
"""
|
"""
|
||||||
@ -1902,7 +1890,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
# cpu
|
# cpu
|
||||||
return self.codegen()
|
return self.codegen()
|
||||||
|
|
||||||
def codegen(self) -> tuple[str, List[tuple[int, Node]]]:
|
def codegen(self) -> tuple[str, list[tuple[int, Node]]]:
|
||||||
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
|
with dynamo_timed("GraphLowering.codegen", log_pt2_compile_event=True):
|
||||||
from .scheduler import Scheduler
|
from .scheduler import Scheduler
|
||||||
|
|
||||||
@ -1949,7 +1937,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
def count_bytes(
|
def count_bytes(
|
||||||
self,
|
self,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
int, List[tuple[BaseSchedulerNode, int]], List[tuple[BaseSchedulerNode, float]]
|
int, list[tuple[BaseSchedulerNode, int]], list[tuple[BaseSchedulerNode, float]]
|
||||||
]:
|
]:
|
||||||
total_bytes = 0
|
total_bytes = 0
|
||||||
node_counts = []
|
node_counts = []
|
||||||
@ -2041,7 +2029,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||||||
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
V.debug.copy(os.path.splitext(mod.__file__)[0] + ".debug")
|
||||||
return mod
|
return mod
|
||||||
|
|
||||||
def get_output_names(self) -> List[str]:
|
def get_output_names(self) -> list[str]:
|
||||||
names = []
|
names = []
|
||||||
shape_counter = itertools.count(0)
|
shape_counter = itertools.count(0)
|
||||||
none_counter = itertools.count(0)
|
none_counter = itertools.count(0)
|
||||||
|
@ -1,13 +1,13 @@
|
|||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import contextlib
|
import contextlib
|
||||||
from typing import Callable, List, TYPE_CHECKING
|
from typing import Callable, TYPE_CHECKING
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
# Executed in the order they're registered
|
# Executed in the order they're registered
|
||||||
INTERMEDIATE_HOOKS: List[Callable[[str, "torch.Tensor"], None]] = []
|
INTERMEDIATE_HOOKS: list[Callable[[str, "torch.Tensor"], None]] = []
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
@ -22,7 +22,7 @@ SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
|||||||
"""
|
"""
|
||||||
import itertools
|
import itertools
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Dict, Literal, Optional, overload, Union
|
from typing import Any, Callable, Literal, Optional, overload, Union
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
@ -196,8 +196,8 @@ class IndexPropagation:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
inner: Any,
|
inner: Any,
|
||||||
iter_ranges: Dict[sympy.Symbol, sympy.Expr],
|
iter_ranges: dict[sympy.Symbol, sympy.Expr],
|
||||||
indirect_var_ranges: Dict[sympy.Symbol, sympy.Expr],
|
indirect_var_ranges: dict[sympy.Symbol, sympy.Expr],
|
||||||
) -> None:
|
) -> None:
|
||||||
self._inner = inner
|
self._inner = inner
|
||||||
self.shape_env = V.graph.sizevars.shape_env
|
self.shape_env = V.graph.sizevars.shape_env
|
||||||
@ -248,18 +248,18 @@ class IndexPropagation:
|
|||||||
self,
|
self,
|
||||||
name: Literal["indirect_indexing"],
|
name: Literal["indirect_indexing"],
|
||||||
args: tuple[Any, ...],
|
args: tuple[Any, ...],
|
||||||
kwargs: Dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
) -> IndexPropVar:
|
) -> IndexPropVar:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def fallback(
|
def fallback(
|
||||||
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
) -> IndexPropResult:
|
) -> IndexPropResult:
|
||||||
...
|
...
|
||||||
|
|
||||||
def fallback(
|
def fallback(
|
||||||
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
) -> IndexPropResult:
|
) -> IndexPropResult:
|
||||||
# Fallback to the wrapped handler
|
# Fallback to the wrapped handler
|
||||||
new_args = [self.unwrap(a) for a in args]
|
new_args = [self.unwrap(a) for a in args]
|
||||||
@ -267,7 +267,7 @@ class IndexPropagation:
|
|||||||
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
return self.wrap(getattr(self._inner, name)(*new_args, **new_kwargs))
|
||||||
|
|
||||||
def propagate_sympy(
|
def propagate_sympy(
|
||||||
self, name: str, args: tuple[Any, ...], kwargs: Dict[str, Any]
|
self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]
|
||||||
) -> IndexPropResult:
|
) -> IndexPropResult:
|
||||||
# Build a new SymPy expression from this ops call
|
# Build a new SymPy expression from this ops call
|
||||||
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
def unwrap(a: Union[Any, IndexPropVar]) -> Any:
|
||||||
|
@ -2,12 +2,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Sequence
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import _prims, Tensor
|
from torch import _prims, Tensor
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Sequence
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,6 +8,7 @@ import logging
|
|||||||
import textwrap
|
import textwrap
|
||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
|
from collections.abc import Generator, Iterable, Sequence
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import partial
|
from functools import partial
|
||||||
@ -16,14 +17,9 @@ from typing import (
|
|||||||
Callable,
|
Callable,
|
||||||
ClassVar,
|
ClassVar,
|
||||||
ContextManager,
|
ContextManager,
|
||||||
Dict,
|
|
||||||
Generator,
|
|
||||||
Iterable,
|
|
||||||
List,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
overload,
|
overload,
|
||||||
Sequence,
|
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -165,11 +161,11 @@ e.g. it may be a graph input or compile time constant.
|
|||||||
_NodeOrNodes: TypeAlias = Union[
|
_NodeOrNodes: TypeAlias = Union[
|
||||||
int,
|
int,
|
||||||
"TensorBox",
|
"TensorBox",
|
||||||
Dict[str, "TensorBox"],
|
dict[str, "TensorBox"],
|
||||||
"Symbol",
|
"Symbol",
|
||||||
"IRNode",
|
"IRNode",
|
||||||
Sequence[
|
Sequence[
|
||||||
Optional[Union[int, Dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
|
Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -426,7 +422,7 @@ class IRNode:
|
|||||||
|
|
||||||
# NB: These are kinda weird,
|
# NB: These are kinda weird,
|
||||||
origins: OrderedSet[Any] = dataclasses.field(init=False)
|
origins: OrderedSet[Any] = dataclasses.field(init=False)
|
||||||
traceback: Optional[List[str]] = dataclasses.field(init=False)
|
traceback: Optional[list[str]] = dataclasses.field(init=False)
|
||||||
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
|
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -455,7 +451,7 @@ class IRNode:
|
|||||||
def get_read_names(self) -> OrderedSet[str]:
|
def get_read_names(self) -> OrderedSet[str]:
|
||||||
return OrderedSet(dep.name for dep in self.get_reads())
|
return OrderedSet(dep.name for dep in self.get_reads())
|
||||||
|
|
||||||
def get_traceback(self) -> Optional[List[str]]:
|
def get_traceback(self) -> Optional[list[str]]:
|
||||||
return self.traceback
|
return self.traceback
|
||||||
|
|
||||||
def get_origin_node(self) -> Optional[torch.fx.Node]:
|
def get_origin_node(self) -> Optional[torch.fx.Node]:
|
||||||
@ -604,18 +600,18 @@ class IRNode:
|
|||||||
raise NotImplementedError(type(self).__name__)
|
raise NotImplementedError(type(self).__name__)
|
||||||
|
|
||||||
def freeze_layout_with_stride_order(
|
def freeze_layout_with_stride_order(
|
||||||
self, order: List[int], allow_padding: bool = False
|
self, order: list[int], allow_padding: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError(type(self).__name__)
|
raise NotImplementedError(type(self).__name__)
|
||||||
|
|
||||||
def freeze_layout_with_fill_order(self, order: List[int]) -> None:
|
def freeze_layout_with_fill_order(self, order: list[int]) -> None:
|
||||||
raise NotImplementedError(type(self).__name__)
|
raise NotImplementedError(type(self).__name__)
|
||||||
|
|
||||||
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None:
|
def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
|
||||||
raise NotImplementedError(type(self).__name__)
|
raise NotImplementedError(type(self).__name__)
|
||||||
|
|
||||||
def freeze_layout_with_exact_strides(
|
def freeze_layout_with_exact_strides(
|
||||||
self, exact_strides: List[_IntLike], allow_padding: bool = False
|
self, exact_strides: list[_IntLike], allow_padding: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
raise NotImplementedError(type(self).__name__)
|
raise NotImplementedError(type(self).__name__)
|
||||||
|
|
||||||
@ -703,7 +699,7 @@ class Operation:
|
|||||||
def get_reads(self) -> OrderedSet[Dep]:
|
def get_reads(self) -> OrderedSet[Dep]:
|
||||||
return self.get_read_writes().reads
|
return self.get_read_writes().reads
|
||||||
|
|
||||||
def get_outputs(self) -> List[Buffer]:
|
def get_outputs(self) -> list[Buffer]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
@ -936,7 +932,7 @@ class Scatter(Pointwise):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
REDUCTION_COMBINE_FN: Dict[str, Callable[..., OpsValue]] = {
|
REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
|
||||||
"any": ops_wrapper("logical_or"),
|
"any": ops_wrapper("logical_or"),
|
||||||
"max": ops_wrapper("maximum"),
|
"max": ops_wrapper("maximum"),
|
||||||
"min": ops_wrapper("minimum"),
|
"min": ops_wrapper("minimum"),
|
||||||
@ -1575,8 +1571,8 @@ class Reduction(Loops):
|
|||||||
wrapper_fn: Callable[..., Any],
|
wrapper_fn: Callable[..., Any],
|
||||||
original_ranges: Sequence[Expr],
|
original_ranges: Sequence[Expr],
|
||||||
original_reduction_ranges: Sequence[Expr],
|
original_reduction_ranges: Sequence[Expr],
|
||||||
new_ranges: List[Expr],
|
new_ranges: list[Expr],
|
||||||
new_reduction_ranges: List[Integer],
|
new_reduction_ranges: list[Integer],
|
||||||
reduction_type: str,
|
reduction_type: str,
|
||||||
split: _IntLike,
|
split: _IntLike,
|
||||||
reduction_hint: ReductionHint,
|
reduction_hint: ReductionHint,
|
||||||
@ -1678,8 +1674,8 @@ class Reduction(Loops):
|
|||||||
inner_fn: Callable[..., Any],
|
inner_fn: Callable[..., Any],
|
||||||
original_ranges: Sequence[Expr],
|
original_ranges: Sequence[Expr],
|
||||||
original_reduction_ranges: Sequence[Expr],
|
original_reduction_ranges: Sequence[Expr],
|
||||||
new_ranges: List[Integer],
|
new_ranges: list[Integer],
|
||||||
new_reduction_ranges: List[Integer],
|
new_reduction_ranges: list[Integer],
|
||||||
reduction_type: str,
|
reduction_type: str,
|
||||||
reduction_hint: ReductionHint,
|
reduction_hint: ReductionHint,
|
||||||
) -> TensorBox:
|
) -> TensorBox:
|
||||||
@ -1767,8 +1763,8 @@ class WelfordReduction(Reduction):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
inner_fns: Sequence[Callable[..., Any]],
|
inner_fns: Sequence[Callable[..., Any]],
|
||||||
ranges: List[Integer],
|
ranges: list[Integer],
|
||||||
reduction_ranges: List[Integer],
|
reduction_ranges: list[Integer],
|
||||||
reduction_type: str,
|
reduction_type: str,
|
||||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||||
) -> Sequence[TensorBox]:
|
) -> Sequence[TensorBox]:
|
||||||
@ -1893,8 +1889,8 @@ class WelfordReduction(Reduction):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
inner_fns: Sequence[Callable[..., Any]],
|
inner_fns: Sequence[Callable[..., Any]],
|
||||||
ranges: List[Integer],
|
ranges: list[Integer],
|
||||||
reduction_ranges: List[Integer],
|
reduction_ranges: list[Integer],
|
||||||
reduction_type: str,
|
reduction_type: str,
|
||||||
split: _IntLike,
|
split: _IntLike,
|
||||||
reduction_hint: ReductionHint,
|
reduction_hint: ReductionHint,
|
||||||
@ -1983,8 +1979,8 @@ class WelfordReduction(Reduction):
|
|||||||
|
|
||||||
@ir_dataclass
|
@ir_dataclass
|
||||||
class Scan(Loops):
|
class Scan(Loops):
|
||||||
scan_ranges: List[Integer]
|
scan_ranges: list[Integer]
|
||||||
size: List[Integer]
|
size: list[Integer]
|
||||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
|
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
|
||||||
reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
|
reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
|
||||||
reduction_hint: ReductionHint
|
reduction_hint: ReductionHint
|
||||||
@ -2055,7 +2051,7 @@ class Scan(Loops):
|
|||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtypes: tuple[torch.dtype, ...],
|
dtypes: tuple[torch.dtype, ...],
|
||||||
inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
|
inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
|
||||||
size: List[Integer],
|
size: list[Integer],
|
||||||
axis: int,
|
axis: int,
|
||||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
||||||
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
reduction_hint: ReductionHint = ReductionHint.DEFAULT,
|
||||||
@ -2114,7 +2110,7 @@ class Scan(Loops):
|
|||||||
else:
|
else:
|
||||||
scan_type = SplitScan
|
scan_type = SplitScan
|
||||||
|
|
||||||
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> List[Expr]:
|
def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
|
||||||
assert len(scan_index) == len(scan_ranges)
|
assert len(scan_index) == len(scan_ranges)
|
||||||
assert len(index) == len(pointwise_ranges)
|
assert len(index) == len(pointwise_ranges)
|
||||||
return [*index[:axis], *scan_index, *index[axis:]]
|
return [*index[:axis], *scan_index, *index[axis:]]
|
||||||
@ -2152,8 +2148,8 @@ class Scan(Loops):
|
|||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
inner_fn: Callable[[Sequence[Expr]], OpsValue],
|
inner_fn: Callable[[Sequence[Expr]], OpsValue],
|
||||||
axis: int,
|
axis: int,
|
||||||
pointwise_ranges: List[Integer],
|
pointwise_ranges: list[Integer],
|
||||||
scan_ranges: List[Integer],
|
scan_ranges: list[Integer],
|
||||||
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
|
||||||
scan_numel: Expr,
|
scan_numel: Expr,
|
||||||
) -> tuple[ReductionHint, _IntLike]:
|
) -> tuple[ReductionHint, _IntLike]:
|
||||||
@ -2182,8 +2178,8 @@ class SplitScan(Scan):
|
|||||||
@ir_dataclass
|
@ir_dataclass
|
||||||
class Sort(Loops):
|
class Sort(Loops):
|
||||||
# Sorts a tuple of key, value pairs
|
# Sorts a tuple of key, value pairs
|
||||||
sort_ranges: List[Integer]
|
sort_ranges: list[Integer]
|
||||||
size: List[Integer]
|
size: list[Integer]
|
||||||
reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
|
reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
|
||||||
reduction_hint: ReductionHint
|
reduction_hint: ReductionHint
|
||||||
output_index: int
|
output_index: int
|
||||||
@ -2251,8 +2247,8 @@ class Sort(Loops):
|
|||||||
cls,
|
cls,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtypes: tuple[torch.dtype, ...],
|
dtypes: tuple[torch.dtype, ...],
|
||||||
inner_fns: tuple[Callable[[List[Expr]], Any], ...],
|
inner_fns: tuple[Callable[[list[Expr]], Any], ...],
|
||||||
size: List[Integer],
|
size: list[Integer],
|
||||||
axis: int,
|
axis: int,
|
||||||
stable: bool,
|
stable: bool,
|
||||||
descending: bool,
|
descending: bool,
|
||||||
@ -2293,7 +2289,7 @@ class Sort(Loops):
|
|||||||
for output_index in range(len(dtypes))
|
for output_index in range(len(dtypes))
|
||||||
]
|
]
|
||||||
|
|
||||||
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> List[Expr]:
|
def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
|
||||||
assert len(sort_index) == len(sort_ranges)
|
assert len(sort_index) == len(sort_ranges)
|
||||||
assert len(index) == len(pointwise_ranges)
|
assert len(index) == len(pointwise_ranges)
|
||||||
return [*index[:axis], *sort_index, *index[axis:]]
|
return [*index[:axis], *sort_index, *index[axis:]]
|
||||||
@ -2509,7 +2505,7 @@ class BaseView(IRNode):
|
|||||||
|
|
||||||
@ir_dataclass
|
@ir_dataclass
|
||||||
class ExpandView(BaseView):
|
class ExpandView(BaseView):
|
||||||
size: List[Expr]
|
size: list[Expr]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _normalize_size(x, new_size): # type: ignore[no-untyped-def]
|
def _normalize_size(x, new_size): # type: ignore[no-untyped-def]
|
||||||
@ -2588,7 +2584,7 @@ class ExpandView(BaseView):
|
|||||||
|
|
||||||
@ir_dataclass
|
@ir_dataclass
|
||||||
class PermuteView(BaseView):
|
class PermuteView(BaseView):
|
||||||
dims: List[Expr]
|
dims: list[Expr]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, x, dims): # type: ignore[no-untyped-def]
|
def create(cls, x, dims): # type: ignore[no-untyped-def]
|
||||||
@ -2676,7 +2672,7 @@ class SqueezeView(BaseView):
|
|||||||
not_one = [i for i, s in enumerate(size) if s != 1]
|
not_one = [i for i, s in enumerate(size) if s != 1]
|
||||||
length = len(size)
|
length = len(size)
|
||||||
|
|
||||||
def reindex(index: List[sympy.Expr]) -> tuple[sympy.Expr, ...]:
|
def reindex(index: list[sympy.Expr]) -> tuple[sympy.Expr, ...]:
|
||||||
assert len(index) == len(not_one), f"{index} {not_one}"
|
assert len(index) == len(not_one), f"{index} {not_one}"
|
||||||
new_index = [sympy.S.Zero] * length
|
new_index = [sympy.S.Zero] * length
|
||||||
for idx, s in zip(not_one, index):
|
for idx, s in zip(not_one, index):
|
||||||
@ -2691,7 +2687,7 @@ class SqueezeView(BaseView):
|
|||||||
|
|
||||||
@ir_dataclass
|
@ir_dataclass
|
||||||
class GenericView(BaseView):
|
class GenericView(BaseView):
|
||||||
size: List[Expr]
|
size: list[Expr]
|
||||||
reindex: Callable[..., Any]
|
reindex: Callable[..., Any]
|
||||||
|
|
||||||
def make_reindexer(self): # type: ignore[no-untyped-def]
|
def make_reindexer(self): # type: ignore[no-untyped-def]
|
||||||
@ -3159,8 +3155,8 @@ class Layout(OutputSpec):
|
|||||||
self,
|
self,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
size: List[Expr],
|
size: list[Expr],
|
||||||
stride: Optional[List[Expr]] = None,
|
stride: Optional[list[Expr]] = None,
|
||||||
offset: Expr = Integer(0),
|
offset: Expr = Integer(0),
|
||||||
) -> None:
|
) -> None:
|
||||||
if stride is None:
|
if stride is None:
|
||||||
@ -3169,8 +3165,8 @@ class Layout(OutputSpec):
|
|||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
assert len(size) == len(stride), f"size={size}, stride={stride}"
|
assert len(size) == len(stride), f"size={size}, stride={stride}"
|
||||||
assert all(isinstance(s, (Expr, int)) for s in size)
|
assert all(isinstance(s, (Expr, int)) for s in size)
|
||||||
self.size: List[Expr] = size
|
self.size: list[Expr] = size
|
||||||
self.stride: List[Expr] = stride
|
self.stride: list[Expr] = stride
|
||||||
self.offset: Expr = offset
|
self.offset: Expr = offset
|
||||||
|
|
||||||
def __str__(self) -> str:
|
def __str__(self) -> str:
|
||||||
@ -3594,8 +3590,8 @@ class NoneLayout(OutputSpec):
|
|||||||
# dependencies manually in scheduler
|
# dependencies manually in scheduler
|
||||||
|
|
||||||
device: Optional[torch.device]
|
device: Optional[torch.device]
|
||||||
size: List[int] = dataclasses.field(default_factory=lambda: [0])
|
size: list[int] = dataclasses.field(default_factory=lambda: [0])
|
||||||
stride: List[int] = dataclasses.field(default_factory=lambda: [0])
|
stride: list[int] = dataclasses.field(default_factory=lambda: [0])
|
||||||
|
|
||||||
def storage_size(self) -> int:
|
def storage_size(self) -> int:
|
||||||
return 0
|
return 0
|
||||||
@ -3620,7 +3616,7 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
|||||||
V.graph.mark_buffer_mutated(name)
|
V.graph.mark_buffer_mutated(name)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def stride(self) -> List[Expr]:
|
def stride(self) -> list[Expr]:
|
||||||
return self.real_layout().stride
|
return self.real_layout().stride
|
||||||
|
|
||||||
@stride.setter
|
@stride.setter
|
||||||
@ -3725,7 +3721,7 @@ class Buffer(IRNode):
|
|||||||
def get_size(self) -> Sequence[Expr]:
|
def get_size(self) -> Sequence[Expr]:
|
||||||
return [*self.get_layout().size]
|
return [*self.get_layout().size]
|
||||||
|
|
||||||
def get_stride(self) -> List[Expr]:
|
def get_stride(self) -> list[Expr]:
|
||||||
return [*self.get_layout().stride]
|
return [*self.get_layout().stride]
|
||||||
|
|
||||||
def get_offset(self) -> Expr:
|
def get_offset(self) -> Expr:
|
||||||
@ -3816,7 +3812,7 @@ class Buffer(IRNode):
|
|||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
class OperationBuffer(Buffer, Operation):
|
class OperationBuffer(Buffer, Operation):
|
||||||
# An operation that produces a single output buffer
|
# An operation that produces a single output buffer
|
||||||
def get_outputs(self) -> List[Buffer]:
|
def get_outputs(self) -> list[Buffer]:
|
||||||
return [self]
|
return [self]
|
||||||
|
|
||||||
def get_defining_op(self) -> Operation:
|
def get_defining_op(self) -> Operation:
|
||||||
@ -3977,7 +3973,7 @@ class ComputedBuffer(OperationBuffer):
|
|||||||
assert isinstance(self.data, Pointwise)
|
assert isinstance(self.data, Pointwise)
|
||||||
return partial(self.data.store_output, self.name, indexer)
|
return partial(self.data.store_output, self.name, indexer)
|
||||||
|
|
||||||
def get_fill_order(self) -> Optional[List[int]]:
|
def get_fill_order(self) -> Optional[list[int]]:
|
||||||
"""
|
"""
|
||||||
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
|
If our layout is still flexible, try to determine the stride order based on stride orders of reads.
|
||||||
|
|
||||||
@ -4028,9 +4024,9 @@ class ComputedBuffer(OperationBuffer):
|
|||||||
def get_default_sizes_body(
|
def get_default_sizes_body(
|
||||||
self,
|
self,
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
tuple[List[sympy.Expr], List[sympy.Expr]],
|
tuple[list[sympy.Expr], list[sympy.Expr]],
|
||||||
LoopBody,
|
LoopBody,
|
||||||
tuple[List[sympy.Expr], List[sympy.Expr]],
|
tuple[list[sympy.Expr], list[sympy.Expr]],
|
||||||
]:
|
]:
|
||||||
args, var_ranges = dependencies.index_vars_squeeze(
|
args, var_ranges = dependencies.index_vars_squeeze(
|
||||||
self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
|
self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
|
||||||
@ -4043,7 +4039,7 @@ class ComputedBuffer(OperationBuffer):
|
|||||||
*args,
|
*args,
|
||||||
)
|
)
|
||||||
index_vars = []
|
index_vars = []
|
||||||
reduce_vars: List[Any] = []
|
reduce_vars: list[Any] = []
|
||||||
index_size = []
|
index_size = []
|
||||||
reduce_size = []
|
reduce_size = []
|
||||||
for v, s in var_ranges.items():
|
for v, s in var_ranges.items():
|
||||||
@ -4059,9 +4055,9 @@ class ComputedBuffer(OperationBuffer):
|
|||||||
|
|
||||||
def simplify_and_reorder(
|
def simplify_and_reorder(
|
||||||
self,
|
self,
|
||||||
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
|
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
|
||||||
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
||||||
) -> tuple[tuple[List[sympy.Expr], List[sympy.Expr]], LoopBody]:
|
) -> tuple[tuple[list[sympy.Expr], list[sympy.Expr]], LoopBody]:
|
||||||
"""
|
"""
|
||||||
This is a main place where we do loop transformations in a
|
This is a main place where we do loop transformations in a
|
||||||
backend-agnostic way.
|
backend-agnostic way.
|
||||||
@ -4282,7 +4278,7 @@ class TemplateBuffer(OperationBuffer):
|
|||||||
|
|
||||||
def simplify_and_reorder( # type: ignore[no-untyped-def]
|
def simplify_and_reorder( # type: ignore[no-untyped-def]
|
||||||
self,
|
self,
|
||||||
extra_indexing_constraints: Optional[tuple[Dict[Any, Any], List[Any]]] = None,
|
extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
|
||||||
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
|
||||||
):
|
):
|
||||||
return (
|
return (
|
||||||
@ -4314,7 +4310,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
|||||||
"""
|
"""
|
||||||
super().__init__(layout, inputs, make_kernel_render)
|
super().__init__(layout, inputs, make_kernel_render)
|
||||||
self.mutated_inputs = mutated_inputs
|
self.mutated_inputs = mutated_inputs
|
||||||
self.outputs: List[Buffer] = [self]
|
self.outputs: list[Buffer] = [self]
|
||||||
if mutated_inputs is not None:
|
if mutated_inputs is not None:
|
||||||
# Ensure that the mutated inputs are only allowed for certain nodes
|
# Ensure that the mutated inputs are only allowed for certain nodes
|
||||||
allowed_set = (
|
allowed_set = (
|
||||||
@ -4335,7 +4331,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
|||||||
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
|
allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_outputs(self) -> List[Buffer]:
|
def get_outputs(self) -> list[Buffer]:
|
||||||
return self.outputs
|
return self.outputs
|
||||||
|
|
||||||
def get_allowed_prologue_inps(self) -> OrderedSet[str]:
|
def get_allowed_prologue_inps(self) -> OrderedSet[str]:
|
||||||
@ -4346,7 +4342,7 @@ class TritonTemplateBuffer(TemplateBuffer):
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
PrimitiveInfoType = Union[int, float, bool, str, List[Union[int, str, float, bool]]]
|
PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
|
||||||
|
|
||||||
|
|
||||||
class ChoiceCaller:
|
class ChoiceCaller:
|
||||||
@ -4361,7 +4357,7 @@ class ChoiceCaller:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
name: str,
|
name: str,
|
||||||
input_nodes: List[Buffer],
|
input_nodes: list[Buffer],
|
||||||
layout: Layout,
|
layout: Layout,
|
||||||
description: str,
|
description: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -4389,7 +4385,7 @@ class ChoiceCaller:
|
|||||||
def output_node(self) -> TensorBox:
|
def output_node(self) -> TensorBox:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def info_dict(self) -> Dict[str, Union[PrimitiveInfoType, List[PrimitiveInfoType]]]:
|
def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
|
||||||
"""Information returned here is logged to the autotune log file when that is enabled."""
|
"""Information returned here is logged to the autotune log file when that is enabled."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
@ -4414,9 +4410,9 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
layout: Layout,
|
layout: Layout,
|
||||||
inputs: List[IRNode],
|
inputs: list[IRNode],
|
||||||
choice_timings: Callable[[], Dict[ChoiceCaller, float]],
|
choice_timings: Callable[[], dict[ChoiceCaller, float]],
|
||||||
unfiltered_choices: List[ChoiceCaller],
|
unfiltered_choices: list[ChoiceCaller],
|
||||||
allowed_prologue_inps: OrderedSet[str],
|
allowed_prologue_inps: OrderedSet[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
@ -4426,7 +4422,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
|||||||
allowed_prologue_inps=allowed_prologue_inps,
|
allowed_prologue_inps=allowed_prologue_inps,
|
||||||
)
|
)
|
||||||
self._choice_timings_fn = choice_timings
|
self._choice_timings_fn = choice_timings
|
||||||
self._choice_timings: Optional[Dict[ChoiceCaller, float]] = None
|
self._choice_timings: Optional[dict[ChoiceCaller, float]] = None
|
||||||
self.original_inputs = inputs
|
self.original_inputs = inputs
|
||||||
self._output_plannable = all(
|
self._output_plannable = all(
|
||||||
isinstance(choice, TritonTemplateCallerBase)
|
isinstance(choice, TritonTemplateCallerBase)
|
||||||
@ -4445,7 +4441,7 @@ class MultiTemplateBuffer(TritonTemplateBuffer):
|
|||||||
return self._output_plannable
|
return self._output_plannable
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def choice_timings(self) -> Dict[ChoiceCaller, float]:
|
def choice_timings(self) -> dict[ChoiceCaller, float]:
|
||||||
if self._choice_timings is None:
|
if self._choice_timings is None:
|
||||||
self._choice_timings = self._choice_timings_fn()
|
self._choice_timings = self._choice_timings_fn()
|
||||||
return self._choice_timings
|
return self._choice_timings
|
||||||
@ -4496,7 +4492,7 @@ class CppTemplateBuffer(TemplateBuffer):
|
|||||||
super().__init__(layout, inputs, make_kernel_render)
|
super().__init__(layout, inputs, make_kernel_render)
|
||||||
self.template = template
|
self.template = template
|
||||||
self.choice = choice
|
self.choice = choice
|
||||||
self.outputs: Optional[List[Buffer]] = None
|
self.outputs: Optional[list[Buffer]] = None
|
||||||
|
|
||||||
def get_layout(self) -> Layout:
|
def get_layout(self) -> Layout:
|
||||||
if isinstance(self.layout, MultiOutputLayout):
|
if isinstance(self.layout, MultiOutputLayout):
|
||||||
@ -4512,7 +4508,7 @@ class CppTemplateBuffer(TemplateBuffer):
|
|||||||
|
|
||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
class InputsKernel(OperationBuffer):
|
class InputsKernel(OperationBuffer):
|
||||||
inputs: List[Buffer]
|
inputs: list[Buffer]
|
||||||
|
|
||||||
def get_read_writes(self) -> dependencies.ReadWrites:
|
def get_read_writes(self) -> dependencies.ReadWrites:
|
||||||
reads = OrderedSet[dependencies.Dep]()
|
reads = OrderedSet[dependencies.Dep]()
|
||||||
@ -4752,7 +4748,7 @@ class ConcatKernel(NopKernel):
|
|||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
class ExternKernel(InputsKernel):
|
class ExternKernel(InputsKernel):
|
||||||
constant_args: tuple[Any, ...] = ()
|
constant_args: tuple[Any, ...] = ()
|
||||||
kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
|
||||||
output_view: Optional[ReinterpretView] = None
|
output_view: Optional[ReinterpretView] = None
|
||||||
python_kernel_name: Optional[str] = None
|
python_kernel_name: Optional[str] = None
|
||||||
cpp_kernel_name: Optional[str] = None
|
cpp_kernel_name: Optional[str] = None
|
||||||
@ -4764,12 +4760,12 @@ class ExternKernel(InputsKernel):
|
|||||||
op_overload: Optional[
|
op_overload: Optional[
|
||||||
Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
|
Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
|
||||||
] = None
|
] = None
|
||||||
arg_properties: Optional[List[Dict[str, Any]]] = None
|
arg_properties: Optional[list[dict[str, Any]]] = None
|
||||||
kwarg_properties: Optional[Dict[str, Dict[str, Any]]] = None
|
kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
|
||||||
unbacked_bindings: Dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
|
unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
|
||||||
default_factory=dict
|
default_factory=dict
|
||||||
)
|
)
|
||||||
mutation_outputs: List[MutationOutput] = dataclasses.field(default_factory=list)
|
mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
|
||||||
|
|
||||||
def __init__( # type: ignore[no-untyped-def]
|
def __init__( # type: ignore[no-untyped-def]
|
||||||
self,
|
self,
|
||||||
@ -4801,7 +4797,7 @@ class ExternKernel(InputsKernel):
|
|||||||
self.mutation_outputs = []
|
self.mutation_outputs = []
|
||||||
self.fx_node = V.graph.current_node
|
self.fx_node = V.graph.current_node
|
||||||
|
|
||||||
def get_outputs(self) -> List[Buffer]:
|
def get_outputs(self) -> list[Buffer]:
|
||||||
return [self, *self.mutation_outputs]
|
return [self, *self.mutation_outputs]
|
||||||
|
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
@ -4919,10 +4915,10 @@ class ExternKernel(InputsKernel):
|
|||||||
cls, kernel, *args, **kwargs
|
cls, kernel, *args, **kwargs
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
Any,
|
Any,
|
||||||
List[Any],
|
list[Any],
|
||||||
List[Any],
|
list[Any],
|
||||||
Callable[[Any, Any], Any],
|
Callable[[Any, Any], Any],
|
||||||
Optional[Dict[sympy.Symbol, pytree.KeyPath]],
|
Optional[dict[sympy.Symbol, pytree.KeyPath]],
|
||||||
]:
|
]:
|
||||||
binded_args = {"args": args, "kwargs": kwargs}
|
binded_args = {"args": args, "kwargs": kwargs}
|
||||||
|
|
||||||
@ -4930,7 +4926,7 @@ class ExternKernel(InputsKernel):
|
|||||||
|
|
||||||
is_arg_tensor = []
|
is_arg_tensor = []
|
||||||
tensor_args = []
|
tensor_args = []
|
||||||
non_tensor_args: List[Any] = []
|
non_tensor_args: list[Any] = []
|
||||||
for arg in args_flat:
|
for arg in args_flat:
|
||||||
is_arg_tensor.append(isinstance(arg, IRNode))
|
is_arg_tensor.append(isinstance(arg, IRNode))
|
||||||
if is_arg_tensor[-1]:
|
if is_arg_tensor[-1]:
|
||||||
@ -4963,7 +4959,7 @@ class ExternKernel(InputsKernel):
|
|||||||
# Rerun fake tensor propagation, because Inductor may have changed the
|
# Rerun fake tensor propagation, because Inductor may have changed the
|
||||||
# strides of inputs and we need to determine accurately what the
|
# strides of inputs and we need to determine accurately what the
|
||||||
# output stride will be.
|
# output stride will be.
|
||||||
example_args: List[Union[torch.Tensor, torch._C.ScriptObject]] = []
|
example_args: list[Union[torch.Tensor, torch._C.ScriptObject]] = []
|
||||||
|
|
||||||
# We need to retain the constant values of fake tensors that we originally
|
# We need to retain the constant values of fake tensors that we originally
|
||||||
# propagated the graph with, because for some operators running without a
|
# propagated the graph with, because for some operators running without a
|
||||||
@ -4984,7 +4980,7 @@ class ExternKernel(InputsKernel):
|
|||||||
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
|
new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
|
||||||
example_output = kernel(*new_args, **new_kwargs)
|
example_output = kernel(*new_args, **new_kwargs)
|
||||||
|
|
||||||
unbacked_bindings: Optional[Dict[sympy.Symbol, pytree.KeyPath]] = None
|
unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
|
||||||
if shape_env := V.fake_mode.shape_env:
|
if shape_env := V.fake_mode.shape_env:
|
||||||
rebind_unbacked(shape_env, V.current_node, example_output)
|
rebind_unbacked(shape_env, V.current_node, example_output)
|
||||||
unbacked_bindings = compute_unbacked_bindings(
|
unbacked_bindings = compute_unbacked_bindings(
|
||||||
@ -5309,7 +5305,7 @@ class ExternKernel(InputsKernel):
|
|||||||
)
|
)
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def codegen_const_args(self, names: Optional[List[str]] = None): # type: ignore[no-untyped-def]
|
def codegen_const_args(self, names: Optional[list[str]] = None): # type: ignore[no-untyped-def]
|
||||||
if V.graph.cpp_wrapper:
|
if V.graph.cpp_wrapper:
|
||||||
result = []
|
result = []
|
||||||
# Aten ops follow the convention that tensor args are before non-tensor args,
|
# Aten ops follow the convention that tensor args are before non-tensor args,
|
||||||
@ -5635,14 +5631,14 @@ class TMADescriptor(ExternKernel):
|
|||||||
|
|
||||||
# as TMA descriptors are immutable,
|
# as TMA descriptors are immutable,
|
||||||
# we can dedup them by the input args
|
# we can dedup them by the input args
|
||||||
_CACHE: Dict[Any, TMADescriptor] = {}
|
_CACHE: dict[Any, TMADescriptor] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create( # type: ignore[no-untyped-def]
|
def create( # type: ignore[no-untyped-def]
|
||||||
cls,
|
cls,
|
||||||
tensor: IRNode,
|
tensor: IRNode,
|
||||||
dims: List[Union[int, torch.SymInt]],
|
dims: list[Union[int, torch.SymInt]],
|
||||||
block_dims: List[Union[int, torch.SymInt]],
|
block_dims: list[Union[int, torch.SymInt]],
|
||||||
element_size: Optional[int] = None,
|
element_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
key = (id(tensor), dims, block_dims, element_size)
|
key = (id(tensor), dims, block_dims, element_size)
|
||||||
@ -5653,8 +5649,8 @@ class TMADescriptor(ExternKernel):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tensor: IRNode,
|
tensor: IRNode,
|
||||||
dims: List[Union[int, torch.SymInt]],
|
dims: list[Union[int, torch.SymInt]],
|
||||||
block_dims: List[Union[int, torch.SymInt]],
|
block_dims: list[Union[int, torch.SymInt]],
|
||||||
element_size: Optional[int] = None,
|
element_size: Optional[int] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert len(dims) in (1, 2)
|
assert len(dims) in (1, 2)
|
||||||
@ -5707,8 +5703,8 @@ class UserDefinedTritonKernel(ExternKernel):
|
|||||||
|
|
||||||
kernel = kernel_side_table.get_kernel(self.kernel_idx)
|
kernel = kernel_side_table.get_kernel(self.kernel_idx)
|
||||||
configs = []
|
configs = []
|
||||||
restore_value_args: List[str] = []
|
restore_value_args: list[str] = []
|
||||||
reset_to_zero_args: List[str] = []
|
reset_to_zero_args: list[str] = []
|
||||||
if isinstance(kernel, Autotuner):
|
if isinstance(kernel, Autotuner):
|
||||||
# https://github.com/triton-lang/triton/pull/5083
|
# https://github.com/triton-lang/triton/pull/5083
|
||||||
# changes kernel.restore_idx to kernel.restore_value
|
# changes kernel.restore_idx to kernel.restore_value
|
||||||
@ -5871,7 +5867,7 @@ class UserDefinedTritonKernel(ExternKernel):
|
|||||||
]
|
]
|
||||||
V.graph.register_operation(self)
|
V.graph.register_operation(self)
|
||||||
|
|
||||||
def get_outputs(self) -> List[Buffer]:
|
def get_outputs(self) -> list[Buffer]:
|
||||||
return list(self.mutation_outputs)
|
return list(self.mutation_outputs)
|
||||||
|
|
||||||
def get_device(self) -> Optional[torch.device]:
|
def get_device(self) -> Optional[torch.device]:
|
||||||
@ -6333,9 +6329,9 @@ class FallbackKernel(ExternKernelAlloc):
|
|||||||
V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type]
|
V.graph.warn_fallback(self.python_kernel_name) # type: ignore[arg-type]
|
||||||
|
|
||||||
# args that are aliased
|
# args that are aliased
|
||||||
self.alias_names: List[str] = []
|
self.alias_names: list[str] = []
|
||||||
# args that are mutated AND returned from the op
|
# args that are mutated AND returned from the op
|
||||||
self.mutation_names: List[str] = []
|
self.mutation_names: list[str] = []
|
||||||
|
|
||||||
if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
|
if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
|
||||||
# We assume here that HOPs with FallbackKernel are functional.
|
# We assume here that HOPs with FallbackKernel are functional.
|
||||||
@ -6834,7 +6830,7 @@ class MultiOutput(ExternKernel):
|
|||||||
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
|
self.codegen_list_tuple_access(self.inputs[0].get_name(), self.indices),
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, layout: OutputSpec, input, indices: List[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
||||||
super().__init__(None, layout, [input], ())
|
super().__init__(None, layout, [input], ())
|
||||||
self.name = V.graph.register_buffer(self)
|
self.name = V.graph.register_buffer(self)
|
||||||
V.graph.register_operation(self)
|
V.graph.register_operation(self)
|
||||||
@ -6903,18 +6899,18 @@ class MutableBox(IRNode):
|
|||||||
return self.data.freeze_layout()
|
return self.data.freeze_layout()
|
||||||
|
|
||||||
def freeze_layout_with_stride_order(
|
def freeze_layout_with_stride_order(
|
||||||
self, order: List[int], allow_padding: bool = False
|
self, order: list[int], allow_padding: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
return self.data.freeze_layout_with_stride_order(order, allow_padding)
|
return self.data.freeze_layout_with_stride_order(order, allow_padding)
|
||||||
|
|
||||||
def freeze_layout_with_fill_order(self, order: List[int]) -> None:
|
def freeze_layout_with_fill_order(self, order: list[int]) -> None:
|
||||||
return self.data.freeze_layout_with_fill_order(order)
|
return self.data.freeze_layout_with_fill_order(order)
|
||||||
|
|
||||||
def freeze_layout_with_same_order(self, stride: List[_IntLike]) -> None:
|
def freeze_layout_with_same_order(self, stride: list[_IntLike]) -> None:
|
||||||
return self.data.freeze_layout_with_same_order(stride)
|
return self.data.freeze_layout_with_same_order(stride)
|
||||||
|
|
||||||
def freeze_layout_with_exact_strides(
|
def freeze_layout_with_exact_strides(
|
||||||
self, exact_strides: List[_IntLike], allow_padding: bool = False
|
self, exact_strides: list[_IntLike], allow_padding: bool = False
|
||||||
) -> None:
|
) -> None:
|
||||||
return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
|
return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
|
||||||
|
|
||||||
@ -7119,11 +7115,11 @@ def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
|
|||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
class InvokeSubgraph(ExternKernel):
|
class InvokeSubgraph(ExternKernel):
|
||||||
subgraph: Optional[Subgraph] = None
|
subgraph: Optional[Subgraph] = None
|
||||||
operands: Optional[List[TensorBox]] = None
|
operands: Optional[list[TensorBox]] = None
|
||||||
outputs: Optional[List[MultiOutput]] = None
|
outputs: Optional[list[MultiOutput]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, subgraph: Subgraph, operands: List[TensorBox], layout: MultiOutputLayout
|
self, subgraph: Subgraph, operands: list[TensorBox], layout: MultiOutputLayout
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
name=None,
|
name=None,
|
||||||
@ -7212,15 +7208,15 @@ class InvokeSubgraph(ExternKernel):
|
|||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
class Conditional(ExternKernel):
|
class Conditional(ExternKernel):
|
||||||
predicate: Optional[IRNode] = None
|
predicate: Optional[IRNode] = None
|
||||||
operands: Optional[List[TensorBox]] = None
|
operands: Optional[list[TensorBox]] = None
|
||||||
true_subgraph: Optional[Subgraph] = None
|
true_subgraph: Optional[Subgraph] = None
|
||||||
false_subgraph: Optional[Subgraph] = None
|
false_subgraph: Optional[Subgraph] = None
|
||||||
outputs: Optional[List[MultiOutput]] = None
|
outputs: Optional[list[MultiOutput]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
predicate: IRNode,
|
predicate: IRNode,
|
||||||
operands: List[TensorBox],
|
operands: list[TensorBox],
|
||||||
true_subgraph: Subgraph,
|
true_subgraph: Subgraph,
|
||||||
false_subgraph: Subgraph,
|
false_subgraph: Subgraph,
|
||||||
layout: MultiOutputLayout,
|
layout: MultiOutputLayout,
|
||||||
@ -7250,7 +7246,7 @@ class Conditional(ExternKernel):
|
|||||||
predicate: TensorBox,
|
predicate: TensorBox,
|
||||||
true_fn: Subgraph,
|
true_fn: Subgraph,
|
||||||
false_fn: Subgraph,
|
false_fn: Subgraph,
|
||||||
operands: List[TensorBox],
|
operands: list[TensorBox],
|
||||||
):
|
):
|
||||||
predicate = cls.realize_input(predicate)
|
predicate = cls.realize_input(predicate)
|
||||||
operands = [cls.realize_input(x) for x in operands]
|
operands = [cls.realize_input(x) for x in operands]
|
||||||
@ -7332,16 +7328,16 @@ class Conditional(ExternKernel):
|
|||||||
|
|
||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
class WhileLoop(ExternKernel):
|
class WhileLoop(ExternKernel):
|
||||||
carried_inputs: Optional[List[TensorBox]] = None
|
carried_inputs: Optional[list[TensorBox]] = None
|
||||||
additional_inputs: Optional[List[TensorBox]] = None
|
additional_inputs: Optional[list[TensorBox]] = None
|
||||||
cond_subgraph: Optional[Subgraph] = None
|
cond_subgraph: Optional[Subgraph] = None
|
||||||
body_subgraph: Optional[Subgraph] = None
|
body_subgraph: Optional[Subgraph] = None
|
||||||
outputs: Optional[List[MultiOutput]] = None
|
outputs: Optional[list[MultiOutput]] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
carried_inputs: List[TensorBox],
|
carried_inputs: list[TensorBox],
|
||||||
additional_inputs: List[TensorBox],
|
additional_inputs: list[TensorBox],
|
||||||
cond_subgraph: Subgraph,
|
cond_subgraph: Subgraph,
|
||||||
body_subgraph: Subgraph,
|
body_subgraph: Subgraph,
|
||||||
layout: MultiOutputLayout,
|
layout: MultiOutputLayout,
|
||||||
@ -7365,8 +7361,8 @@ class WhileLoop(ExternKernel):
|
|||||||
cls,
|
cls,
|
||||||
cond_fn: Subgraph,
|
cond_fn: Subgraph,
|
||||||
body_fn: Subgraph,
|
body_fn: Subgraph,
|
||||||
carried_inputs: List[TensorBox],
|
carried_inputs: list[TensorBox],
|
||||||
additional_inputs: List[TensorBox],
|
additional_inputs: list[TensorBox],
|
||||||
):
|
):
|
||||||
carried_inputs = [cls.realize_input(x) for x in carried_inputs]
|
carried_inputs = [cls.realize_input(x) for x in carried_inputs]
|
||||||
additional_inputs = [cls.realize_input(x) for x in additional_inputs]
|
additional_inputs = [cls.realize_input(x) for x in additional_inputs]
|
||||||
@ -7411,8 +7407,8 @@ class WhileLoop(ExternKernel):
|
|||||||
for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)):
|
for i, (op, bo) in enumerate(zip(carried_inputs, body_outputs)):
|
||||||
|
|
||||||
def _guard_list_equals(
|
def _guard_list_equals(
|
||||||
lhs_exprs: List[Union[int, sympy.expr]],
|
lhs_exprs: list[Union[int, sympy.expr]],
|
||||||
rhs_exprs: List[Union[int, sympy.expr]],
|
rhs_exprs: list[Union[int, sympy.expr]],
|
||||||
) -> None:
|
) -> None:
|
||||||
for lhs, rhs in zip(lhs_exprs, rhs_exprs):
|
for lhs, rhs in zip(lhs_exprs, rhs_exprs):
|
||||||
V.graph.sizevars.guard_equals(lhs, rhs)
|
V.graph.sizevars.guard_equals(lhs, rhs)
|
||||||
@ -7549,7 +7545,7 @@ class _CollectiveKernel(FallbackKernel):
|
|||||||
# mutation of the input buffers.
|
# mutation of the input buffers.
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_inplace( # type: ignore[no-untyped-def]
|
def create_inplace( # type: ignore[no-untyped-def]
|
||||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
|
||||||
) -> None:
|
) -> None:
|
||||||
with V.graph.fake_mode:
|
with V.graph.fake_mode:
|
||||||
(
|
(
|
||||||
@ -7610,7 +7606,7 @@ class _CollectiveKernel(FallbackKernel):
|
|||||||
# usage in the user program.
|
# usage in the user program.
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_out_of_place( # type: ignore[no-untyped-def]
|
def create_out_of_place( # type: ignore[no-untyped-def]
|
||||||
cls, kernel, inputs: Union[TensorBox, List[TensorBox]], *args, **kwargs
|
cls, kernel, inputs: Union[TensorBox, list[TensorBox]], *args, **kwargs
|
||||||
):
|
):
|
||||||
with V.graph.fake_mode:
|
with V.graph.fake_mode:
|
||||||
(
|
(
|
||||||
|
Reference in New Issue
Block a user