mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Use union syntax in torch/_inductor runtime and fx_passes (#165652)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/165652 Approved by: https://github.com/aorenste
This commit is contained in:
committed by
PyTorch MergeBot
parent
fb06e49ce8
commit
7d0f872cb3
@ -339,7 +339,7 @@ def sha256_hash(data: bytes) -> str:
|
||||
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
|
||||
|
||||
|
||||
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
|
||||
def code_hash(code: str | bytes, extra: str | bytes = "") -> str:
|
||||
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
|
||||
if extra:
|
||||
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
|
||||
@ -361,9 +361,7 @@ def get_path(
|
||||
return basename, subdir, path
|
||||
|
||||
|
||||
def get_hash(
|
||||
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
|
||||
) -> str:
|
||||
def get_hash(content: str | bytes, extra: str = "", hash_type: str = "code") -> str:
|
||||
if hash_type in {"amdgcn", "code", "ptx", "spv"}:
|
||||
return code_hash(content, extra)
|
||||
if hash_type in {"cubin", "hsaco", "spv"}:
|
||||
@ -409,7 +407,7 @@ class WritableTempFile:
|
||||
|
||||
|
||||
def write(
|
||||
content: Union[str, bytes],
|
||||
content: str | bytes,
|
||||
extension: str,
|
||||
extra: str = "",
|
||||
hash_type: str = "code",
|
||||
@ -436,7 +434,7 @@ def write_text(text: str) -> str:
|
||||
|
||||
def write_atomic(
|
||||
path_: str,
|
||||
content: Union[str, bytes],
|
||||
content: str | bytes,
|
||||
make_dirs: bool = False,
|
||||
encode_utf_8: bool = False,
|
||||
) -> None:
|
||||
@ -547,7 +545,7 @@ class FxGraphCachePickler(pickle.Pickler):
|
||||
|
||||
def _reduce_tensor(
|
||||
self, t: Tensor
|
||||
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
|
||||
) -> tuple[Callable[[T], T], tuple[TensorMetadata | TensorMetadataAndValues]]:
|
||||
"""
|
||||
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
|
||||
stored as attributes on the GraphModule.
|
||||
@ -943,7 +941,7 @@ class FxGraphHashDetails:
|
||||
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
|
||||
|
||||
def _get_custom_pass_detail(
|
||||
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
|
||||
self, custom_pass: CustomGraphPassType | CustomGraphModulePass
|
||||
) -> Any | None:
|
||||
if not custom_pass:
|
||||
return None
|
||||
@ -1058,7 +1056,7 @@ class GuardedCache(Generic[T]):
|
||||
key: str,
|
||||
local: bool,
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
|
||||
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool],
|
||||
hints: list[int],
|
||||
) -> tuple[T | None, bytes | None, dict[str, str]]:
|
||||
"""
|
||||
@ -1292,7 +1290,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
local: bool,
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
constants: CompiledFxGraphConstants,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
|
||||
| None = None,
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
@ -1543,7 +1541,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
|
||||
remote_cache: RemoteCache[JsonDataTy] | None,
|
||||
is_backward: bool,
|
||||
constants: CompiledFxGraphConstants,
|
||||
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
|
||||
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
|
||||
| None = None,
|
||||
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
|
||||
"""
|
||||
@ -1723,12 +1721,12 @@ class AotCodeCompiler:
|
||||
*,
|
||||
device_type: str,
|
||||
additional_files: list[str],
|
||||
) -> Union[list[Union[str, Weights]], str]:
|
||||
) -> list[Union[str, Weights]] | str:
|
||||
"""
|
||||
Returns the .so path, or returns a list of files that were generated if
|
||||
config.aot_inductor.package=True.
|
||||
"""
|
||||
generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment]
|
||||
generated_files: list[str | Weights] = additional_files # type: ignore[assignment]
|
||||
|
||||
_set_gpu_runtime_env() # cpp_extension consults the env
|
||||
|
||||
@ -2342,7 +2340,7 @@ end
|
||||
f.write(json.dumps(qual_name_to_id))
|
||||
generated_files.append(constants_config_json)
|
||||
|
||||
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
|
||||
gpu_codecache: ROCmCodeCache | CUDACodeCache = (
|
||||
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
|
||||
)
|
||||
gpu_kernels_o = gpu_codecache.aot_kernels_o.copy()
|
||||
@ -2555,7 +2553,7 @@ end
|
||||
_libgomp: CDLL | None = None
|
||||
|
||||
|
||||
def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]:
|
||||
def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
|
||||
# This function will be called from generated cpp wrapper code in the JIT mode.
|
||||
# Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them.
|
||||
def convert_arg(arg: Any) -> Any:
|
||||
@ -2698,16 +2696,16 @@ class CppCodeCache:
|
||||
"""Compiles and caches C++ libraries. Users of this class supply the source code to
|
||||
be compiled, while compilation flags are set by CppBuilder."""
|
||||
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags: dict[str, Any] = {}
|
||||
|
||||
@staticmethod
|
||||
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
|
||||
def _load_library_inner(path: str, key: str) -> CDLL | ModuleType:
|
||||
return cdll.LoadLibrary(path)
|
||||
|
||||
@classmethod
|
||||
def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
|
||||
def _load_library(cls, path: str, key: str) -> CDLL | ModuleType:
|
||||
try:
|
||||
result = cls._load_library_inner(path, key)
|
||||
result.key = key # type: ignore[union-attr]
|
||||
@ -2910,7 +2908,7 @@ def _worker_compile_cpp(
|
||||
# Customized Python binding for cpp kernels
|
||||
@clear_on_fresh_cache
|
||||
class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags = {
|
||||
# kernels have no dependency on libtorch
|
||||
@ -3092,7 +3090,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
|
||||
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
cpp_compile_command_flags = {
|
||||
"include_pytorch": True,
|
||||
@ -3161,7 +3159,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
|
||||
|
||||
@clear_on_fresh_cache
|
||||
class HalideCodeCache(CppPythonBindingsCodeCache):
|
||||
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
|
||||
cache: dict[str, Callable[[], ModuleType | CDLL]] = {}
|
||||
cache_clear = staticmethod(cache.clear)
|
||||
_standalone_runtime_path: str | None = None
|
||||
prefix = textwrap.dedent(
|
||||
|
@ -1,7 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections import deque
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
@ -514,7 +513,7 @@ def build_subgraph_buffer(
|
||||
|
||||
def create_placeholder(
|
||||
name: str, dtype: torch.dtype, device: torch.device
|
||||
) -> Union[TensorBox, ShapeAsConstantBuffer]:
|
||||
) -> TensorBox | ShapeAsConstantBuffer:
|
||||
"""
|
||||
Creates a placeholder input buffers for producing subgraph_output
|
||||
"""
|
||||
|
@ -7,7 +7,7 @@ import operator
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import Any, Callable, cast, Union
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -39,12 +39,12 @@ def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None:
|
||||
|
||||
def call_function(
|
||||
graph: fx.Graph,
|
||||
target: Union[str, Callable[..., Any]],
|
||||
target: str | Callable[..., Any],
|
||||
args: tuple[fx.node.Argument, ...] | None = None,
|
||||
kwargs: dict[str, fx.node.Argument] | None = None,
|
||||
) -> fx.Node:
|
||||
# We accept target as a str to avoid typing error as the type of
|
||||
# a node.target is Union[str, Callable[..., Any]].
|
||||
# a node.target is str | Callable[..., Any].
|
||||
# This also allows us to avoid writing check for every call.
|
||||
if isinstance(target, str):
|
||||
raise RuntimeError(f"Call function should not get a str target {target=}")
|
||||
@ -62,7 +62,7 @@ def call_function(
|
||||
|
||||
@dataclass(unsafe_hash=True)
|
||||
class CommBlock:
|
||||
shape: Union[torch.Size, list[torch.Size]]
|
||||
shape: torch.Size | list[torch.Size]
|
||||
node_list: list[fx.Node]
|
||||
inputs: list[fx.Node]
|
||||
wait_nodes: list[fx.Node]
|
||||
@ -128,7 +128,7 @@ def get_comm_block(comm_node: fx.Node) -> CommBlock | None:
|
||||
break
|
||||
|
||||
tensor_meta = input_nodes[0].meta["tensor_meta"]
|
||||
shape: Union[torch.Size, list[torch.Size]]
|
||||
shape: torch.Size | list[torch.Size]
|
||||
if isinstance(tensor_meta, TensorMetadata):
|
||||
shape = tensor_meta.shape
|
||||
elif isinstance(tensor_meta, (list, tuple)):
|
||||
@ -571,7 +571,7 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
|
||||
|
||||
|
||||
def fuse_ddp_communication(
|
||||
graph: fx.Graph, passes: list[Union[Callable[..., None], str]], bucket_size_mb: int
|
||||
graph: fx.Graph, passes: list[Callable[..., None] | str], bucket_size_mb: int
|
||||
) -> None:
|
||||
for i, pa in enumerate(passes):
|
||||
with GraphTransformObserver(
|
||||
|
@ -1,6 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import SymBool, SymFloat, SymInt
|
||||
@ -14,7 +14,7 @@ class _SymExprHash:
|
||||
Hash for a py_sym_types that will use the underlying sympy expression
|
||||
"""
|
||||
|
||||
sym_obj: Union[SymInt, SymFloat, SymBool]
|
||||
sym_obj: SymInt | SymFloat | SymBool
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((type(self.sym_obj), self.sym_obj.node.expr))
|
||||
|
@ -6,7 +6,7 @@ import operator
|
||||
import typing
|
||||
from collections import Counter
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch._guards
|
||||
@ -706,8 +706,8 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp
|
||||
|
||||
|
||||
def definitely_equal(
|
||||
old_sizes: Sequence[Union[torch.SymInt, int]],
|
||||
new_sizes: Sequence[Union[torch.SymInt, torch.fx.Node, int]],
|
||||
old_sizes: Sequence[torch.SymInt | int],
|
||||
new_sizes: Sequence[torch.SymInt | torch.fx.Node | int],
|
||||
) -> bool:
|
||||
"""
|
||||
Leverage guard_or_true/false to compare if two lists of int/symint are equal.
|
||||
@ -906,7 +906,7 @@ def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
|
||||
if dtype is not None:
|
||||
inp = inp.to(dtype)
|
||||
|
||||
sign: Union[int, float, torch.Tensor]
|
||||
sign: int | float | torch.Tensor
|
||||
if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
|
||||
sign = 1 if other >= 0 else -1
|
||||
else:
|
||||
@ -936,7 +936,7 @@ def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
|
||||
if dtype is not None:
|
||||
inp = inp.to(dtype)
|
||||
|
||||
sign: Union[int, float, torch.Tensor]
|
||||
sign: int | float | torch.Tensor
|
||||
if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
|
||||
sign = 1 if other >= 0 else -1
|
||||
else:
|
||||
|
@ -2,7 +2,7 @@ import itertools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Union
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -143,7 +143,7 @@ class GraphAliasTracker:
|
||||
return self.node_to_storages_last_used[node]
|
||||
|
||||
|
||||
def _size_of_default(num_bytes: Union[int, torch.SymInt]) -> int:
|
||||
def _size_of_default(num_bytes: int | torch.SymInt) -> int:
|
||||
return hint_int(num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback)
|
||||
|
||||
|
||||
@ -154,7 +154,7 @@ def device_filter(device: torch.device) -> bool:
|
||||
def build_memory_profile(
|
||||
graph: fx.Graph,
|
||||
is_releasable: Callable[[fx.Node], bool],
|
||||
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
|
||||
size_of: Callable[[int | torch.SymInt], int] | None = None,
|
||||
) -> list[int]:
|
||||
"""
|
||||
Function to estimate the memory profile of an input FX graph.
|
||||
@ -165,7 +165,7 @@ def build_memory_profile(
|
||||
- is_releasable (Callable[[fx.Node], bool]): A function that
|
||||
determines if a node's memory can be released (e.g. primal nodes
|
||||
cannot be released).
|
||||
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
|
||||
- size_of (Callable[[int | torch.SymInt], int]): A function that converts
|
||||
byte counts (possibly symbolic) to concrete integers.
|
||||
|
||||
Returns:
|
||||
@ -216,7 +216,7 @@ def build_memory_profile(
|
||||
def get_fwd_bwd_interactions(
|
||||
fwd_graph: fx.Graph,
|
||||
bwd_graph: fx.Graph,
|
||||
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
|
||||
size_of: Callable[[int | torch.SymInt], int] | None = None,
|
||||
) -> tuple[int, OrderedSet[str]]:
|
||||
"""
|
||||
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
|
||||
@ -225,7 +225,7 @@ def get_fwd_bwd_interactions(
|
||||
Args:
|
||||
- fwd_graph (fx.Graph): The forward graph representing the forward pass.
|
||||
- bwd_graph (fx.Graph): The backward graph representing the backward pass.
|
||||
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
|
||||
- size_of (Callable[[int | torch.SymInt], int]): A function that converts
|
||||
byte counts (possibly symbolic) to concrete integers.
|
||||
|
||||
Returns:
|
||||
|
@ -6,7 +6,7 @@ import sys
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -82,7 +82,7 @@ def is_compute_node(n: fx.Node) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def get_hint(x: Union[int, torch.SymInt]) -> int | None:
|
||||
def get_hint(x: int | torch.SymInt) -> int | None:
|
||||
if isinstance(x, int):
|
||||
return x
|
||||
assert isinstance(x, torch.SymInt)
|
||||
|
@ -3,7 +3,7 @@ import itertools
|
||||
import operator
|
||||
import typing
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch._inductor.runtime.runtime_utils
|
||||
@ -118,7 +118,7 @@ def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -
|
||||
)
|
||||
|
||||
|
||||
def get_padded_length(x: Union[int, torch.SymInt], alignment_size: int) -> int:
|
||||
def get_padded_length(x: int | torch.SymInt, alignment_size: int) -> int:
|
||||
# we don't pad x if it is symbolic
|
||||
if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
|
||||
return 0
|
||||
@ -438,7 +438,7 @@ def _should_pad_bench(
|
||||
return False
|
||||
|
||||
def realize_symbols(
|
||||
ds: Union[torch.Size, tuple[torch.SymInt, ...]],
|
||||
ds: torch.Size | tuple[torch.SymInt, ...],
|
||||
) -> list[int]:
|
||||
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
||||
|
||||
|
@ -5,7 +5,7 @@ import itertools
|
||||
import logging
|
||||
import operator
|
||||
from collections import Counter, defaultdict
|
||||
from typing import Any, Callable, TypeVar, Union
|
||||
from typing import Any, Callable, TypeVar
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
import torch
|
||||
@ -437,7 +437,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
|
||||
|
||||
|
||||
def resolve_shape_to_proxy(
|
||||
shape: list[Union[int, torch.SymInt]], bound_symbols: dict[Any, Any]
|
||||
shape: list[int | torch.SymInt], bound_symbols: dict[Any, Any]
|
||||
):
|
||||
"""
|
||||
Given a list of symints/ints, this function returns a calculated expression of bound_symbols' values.
|
||||
@ -1123,8 +1123,8 @@ def remove_noop_ops(graph: torch.fx.Graph):
|
||||
Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
|
||||
"""
|
||||
inputs = OrderedSet[torch.fx.Node]()
|
||||
input_storages = OrderedSet[Union[int, None]]()
|
||||
output_storages = OrderedSet[Union[int, None]]()
|
||||
input_storages = OrderedSet[int | None]()
|
||||
output_storages = OrderedSet[int | None]()
|
||||
|
||||
for node in graph.find_nodes(op="placeholder"):
|
||||
inputs.add(node)
|
||||
|
@ -5,7 +5,7 @@ import operator
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, cast, Union
|
||||
from typing import Any, Callable, cast
|
||||
|
||||
import torch
|
||||
import torch.fx.node
|
||||
@ -578,7 +578,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
|
||||
old_tensors_to_clone, kwargs, node_name, trigger
|
||||
):
|
||||
tensors_to_clone: list[str] = []
|
||||
storage_of_reinplaced_args = OrderedSet[Union[int, None]]()
|
||||
storage_of_reinplaced_args = OrderedSet[int | None]()
|
||||
|
||||
# Those used to count possibly_missed_reinplacing_opportunities
|
||||
missed_nodes = []
|
||||
|
@ -5,7 +5,7 @@ import operator
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
@ -725,14 +725,14 @@ class SplitCatSimplifier:
|
||||
|
||||
def get_user_input_list(
|
||||
self, split_node: torch.fx.Node, next_users: list[torch.fx.Node]
|
||||
) -> list[list[Union[torch.fx.Node, _Range]]]:
|
||||
) -> list[list[torch.fx.Node | _Range]]:
|
||||
"""
|
||||
Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner
|
||||
list represents the inputs to that particular node. This list can either contain
|
||||
- a tuple representing the ranges of get_items that should go into the cat (closed interval)
|
||||
- torch.fx.Node representing "other" inputs (which are not coming from our split)
|
||||
"""
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]] = []
|
||||
user_inputs_list: list[list[torch.fx.Node | _Range]] = []
|
||||
for user in next_users:
|
||||
if user.target in (torch.cat, torch.stack):
|
||||
user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
|
||||
@ -742,7 +742,7 @@ class SplitCatSimplifier:
|
||||
|
||||
def get_merged_user_inputs(
|
||||
self, split_node: torch.fx.Node, cat_node: torch.fx.Node
|
||||
) -> list[Union[torch.fx.Node, _Range]]:
|
||||
) -> list[torch.fx.Node | _Range]:
|
||||
user_inputs = get_arg_value(cat_node, 0, "tensors")
|
||||
simplified_user_inputs = []
|
||||
split_users = OrderedSet(split_node.users.keys())
|
||||
@ -769,8 +769,8 @@ class SplitCatSimplifier:
|
||||
return node_input
|
||||
|
||||
def merge_consecutive_inputs(
|
||||
self, inputs: list[Union[torch.fx.Node, int]]
|
||||
) -> list[Union[torch.fx.Node, _Range]]:
|
||||
self, inputs: list[torch.fx.Node | int]
|
||||
) -> list[torch.fx.Node | _Range]:
|
||||
"""
|
||||
Merge consecutive inputs going into a user node.
|
||||
|
||||
@ -801,7 +801,7 @@ class SplitCatSimplifier:
|
||||
self,
|
||||
split_sections,
|
||||
next_users,
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
user_inputs_list: list[list[torch.fx.Node | _Range]],
|
||||
) -> list[_Range] | None:
|
||||
ranges = OrderedSet[Any]()
|
||||
for user_inputs in user_inputs_list:
|
||||
@ -847,7 +847,7 @@ class SplitCatSimplifier:
|
||||
self,
|
||||
split_node: torch.fx.Node,
|
||||
next_users: list[torch.fx.Node],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
user_inputs_list: list[list[torch.fx.Node | _Range]],
|
||||
) -> list[list[_TransformParam]] | None:
|
||||
"""
|
||||
Figure out what transforms are needed for each input to each cat node.
|
||||
@ -901,7 +901,7 @@ class SplitCatSimplifier:
|
||||
graph: torch.fx.Graph,
|
||||
split_node: torch.fx.Node,
|
||||
split_sections: list[int],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
user_inputs_list: list[list[torch.fx.Node | _Range]],
|
||||
split_ranges: list[_Range],
|
||||
) -> list[list[torch.fx.Node]]:
|
||||
"""
|
||||
@ -1177,7 +1177,7 @@ class UnbindCatRemover(SplitCatSimplifier):
|
||||
self,
|
||||
split_sections: list[int],
|
||||
next_users: list[torch.fx.Node],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
user_inputs_list: list[list[torch.fx.Node | _Range]],
|
||||
) -> list[_Range] | None:
|
||||
simplified_split_ranges = super().get_simplified_split_ranges(
|
||||
split_sections, next_users, user_inputs_list
|
||||
@ -1190,7 +1190,7 @@ class UnbindCatRemover(SplitCatSimplifier):
|
||||
self,
|
||||
split_node: torch.fx.Node,
|
||||
next_users: list[torch.fx.Node],
|
||||
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
|
||||
user_inputs_list: list[list[torch.fx.Node | _Range]],
|
||||
) -> list[list[_TransformParam]] | None:
|
||||
"""
|
||||
Figure out what transforms are needed for each input to each cat node.
|
||||
|
@ -4,7 +4,7 @@ import time
|
||||
from functools import cached_property, wraps
|
||||
from itertools import chain
|
||||
from statistics import median
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Callable
|
||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
|
||||
|
||||
import torch
|
||||
@ -31,8 +31,8 @@ def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any
|
||||
return fn
|
||||
|
||||
def distort(
|
||||
ms: Union[list[float], tuple[float], float],
|
||||
) -> Union[list[float], tuple[float], float]:
|
||||
ms: list[float] | tuple[float] | float,
|
||||
) -> list[float] | tuple[float] | float:
|
||||
if isinstance(ms, (list, tuple)):
|
||||
return type(ms)(distort(val) for val in ms) # type: ignore[misc]
|
||||
|
||||
@ -50,7 +50,7 @@ def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any
|
||||
@functools.wraps(fn)
|
||||
def wrapper(
|
||||
*args: list[Any], **kwargs: dict[str, Any]
|
||||
) -> Union[list[float], tuple[float], float]:
|
||||
) -> list[float] | tuple[float] | float:
|
||||
ms = fn(*args, **kwargs)
|
||||
|
||||
return distort(ms)
|
||||
@ -276,7 +276,7 @@ class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter
|
||||
grad_to_none: list[torch.Tensor] | None = None,
|
||||
is_vetted_benchmarking: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[float, list[float]]:
|
||||
) -> float | list[float]:
|
||||
"""Benchmark a GPU callable using a custom benchmarking implementation.
|
||||
|
||||
Arguments:
|
||||
|
@ -5,7 +5,6 @@ import collections
|
||||
import functools
|
||||
import typing
|
||||
from enum import auto, Enum
|
||||
from typing import Union
|
||||
|
||||
from torch.utils._triton import has_triton_package
|
||||
|
||||
@ -202,7 +201,7 @@ class HalideMeta(typing.NamedTuple):
|
||||
argtypes: list[HalideInputSpec]
|
||||
target: str
|
||||
scheduler: str | None = None
|
||||
scheduler_flags: dict[str, Union[int, str]] | None = None
|
||||
scheduler_flags: dict[str, int | str] | None = None
|
||||
cuda_device: int | None = None
|
||||
|
||||
def args(self) -> list[str]:
|
||||
|
@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from typing import Any, Union
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
@ -37,7 +37,7 @@ if triton is not None:
|
||||
|
||||
def GPUTarget(
|
||||
backend: str,
|
||||
arch: Union[int, str],
|
||||
arch: int | str,
|
||||
warp_size: int,
|
||||
) -> Any:
|
||||
if torch.version.hip:
|
||||
@ -138,7 +138,7 @@ else:
|
||||
HAS_TRITON = False
|
||||
|
||||
|
||||
def cc_warp_size(cc: Union[str, int]) -> int:
|
||||
def cc_warp_size(cc: str | int) -> int:
|
||||
if torch.version.hip:
|
||||
cc_str = str(cc)
|
||||
if "gfx10" in cc_str or "gfx11" in cc_str:
|
||||
|
@ -3447,9 +3447,9 @@ class GridExpr:
|
||||
inductor_meta: dict[str, Any]
|
||||
mode: Literal["python", "cpp"] = "python"
|
||||
prefix: list[str] = dataclasses.field(default_factory=list)
|
||||
x_grid: Union[str, int] = 1
|
||||
y_grid: Union[str, int] = 1
|
||||
z_grid: Union[str, int] = 1
|
||||
x_grid: str | int = 1
|
||||
y_grid: str | int = 1
|
||||
z_grid: str | int = 1
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.mode in ("python", "cpp")
|
||||
@ -3457,9 +3457,7 @@ class GridExpr:
|
||||
def generate(self, meta: dict[str, int]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def ceildiv(
|
||||
self, numel: Union[str, int], block: Union[None, int, str]
|
||||
) -> Union[str, int]:
|
||||
def ceildiv(self, numel: str | int, block: None | int | str) -> str | int:
|
||||
if block is None or block == 1:
|
||||
return numel
|
||||
if isinstance(numel, int) and isinstance(block, int):
|
||||
@ -3471,7 +3469,7 @@ class GridExpr:
|
||||
# For cpp code gen
|
||||
return f"(({numel} + ({block} - 1)) / ({block}))"
|
||||
|
||||
def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
||||
def maximum(self, seq: list[int | str]) -> int | str:
|
||||
"""Codegen for max function with constant folding, constants are represented as int"""
|
||||
items = self._constant_fold(max, seq)
|
||||
if len(items) <= 1:
|
||||
@ -3480,7 +3478,7 @@ class GridExpr:
|
||||
return f"max({', '.join(map(str, items))})"
|
||||
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
|
||||
|
||||
def summation(self, seq: list[Union[int, str]]) -> Union[int, str]:
|
||||
def summation(self, seq: list[int | str]) -> int | str:
|
||||
"""Codegen for sum function with constant folding, constants are represented as int"""
|
||||
items = self._constant_fold(sum, seq)
|
||||
if len(items) <= 1:
|
||||
@ -3488,16 +3486,16 @@ class GridExpr:
|
||||
return " + ".join(map(str, items))
|
||||
|
||||
def _constant_fold(
|
||||
self, fn: Callable[[list[int]], int], seq: list[Union[int, str]]
|
||||
) -> list[Union[int, str]]:
|
||||
self, fn: Callable[[list[int]], int], seq: list[int | str]
|
||||
) -> list[int | str]:
|
||||
"""Constant fold through a commutative fn where ints are constants"""
|
||||
items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)]
|
||||
items: list[int | str] = [x for x in seq if not isinstance(x, int)]
|
||||
const_items = [x for x in seq if isinstance(x, int)]
|
||||
if const_items:
|
||||
items.append(fn(const_items))
|
||||
return items
|
||||
|
||||
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
|
||||
def assign_tmp(self, name: str, expr: str | int) -> str:
|
||||
# Grid functions are one per kernel, so name collisions are fine
|
||||
if self.mode == "python":
|
||||
return f"{name} = {expr}"
|
||||
@ -3508,7 +3506,7 @@ class GridExpr:
|
||||
@staticmethod
|
||||
def from_meta(
|
||||
inductor_meta: dict[str, Any],
|
||||
cfg: Union[Config, dict[str, int]],
|
||||
cfg: Config | dict[str, int],
|
||||
mode: Literal["python", "cpp"] = "python",
|
||||
) -> GridExpr:
|
||||
grid_cls = globals()[inductor_meta["grid_type"]]
|
||||
@ -3632,20 +3630,20 @@ class ComboKernelGrid(GridExpr):
|
||||
|
||||
def combo_x_grid(
|
||||
self,
|
||||
xnumels: list[Union[int, str]],
|
||||
xnumels: list[int | str],
|
||||
no_x_dims: list[bool],
|
||||
meta: dict[str, int],
|
||||
) -> Union[str, int]:
|
||||
) -> str | int:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SequentialComboKernelGrid(ComboKernelGrid):
|
||||
def combo_x_grid(
|
||||
self,
|
||||
xnumels: list[Union[int, str]],
|
||||
xnumels: list[int | str],
|
||||
no_x_dims: list[bool],
|
||||
meta: dict[str, int],
|
||||
) -> Union[str, int]:
|
||||
) -> str | int:
|
||||
assert len(xnumels) == len(no_x_dims)
|
||||
return self.summation(
|
||||
[
|
||||
@ -3658,7 +3656,7 @@ class SequentialComboKernelGrid(ComboKernelGrid):
|
||||
class RoundRobinComboKernelGrid(ComboKernelGrid):
|
||||
def combo_x_grid(
|
||||
self,
|
||||
xnumels: list[Union[int, str]],
|
||||
xnumels: list[int | str],
|
||||
no_x_dims: list[bool],
|
||||
meta: dict[str, int],
|
||||
) -> str:
|
||||
|
Reference in New Issue
Block a user