Use union syntax in torch/_inductor runtime and fx_passes (#165652)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165652
Approved by: https://github.com/aorenste
This commit is contained in:
Oguz Ulgen
2025-10-16 11:04:06 -07:00
committed by PyTorch MergeBot
parent fb06e49ce8
commit 7d0f872cb3
15 changed files with 86 additions and 92 deletions

View File

@ -339,7 +339,7 @@ def sha256_hash(data: bytes) -> str:
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
def code_hash(code: str | bytes, extra: str | bytes = "") -> str:
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
if extra:
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
@ -361,9 +361,7 @@ def get_path(
return basename, subdir, path
def get_hash(
content: Union[str, bytes], extra: str = "", hash_type: str = "code"
) -> str:
def get_hash(content: str | bytes, extra: str = "", hash_type: str = "code") -> str:
if hash_type in {"amdgcn", "code", "ptx", "spv"}:
return code_hash(content, extra)
if hash_type in {"cubin", "hsaco", "spv"}:
@ -409,7 +407,7 @@ class WritableTempFile:
def write(
content: Union[str, bytes],
content: str | bytes,
extension: str,
extra: str = "",
hash_type: str = "code",
@ -436,7 +434,7 @@ def write_text(text: str) -> str:
def write_atomic(
path_: str,
content: Union[str, bytes],
content: str | bytes,
make_dirs: bool = False,
encode_utf_8: bool = False,
) -> None:
@ -547,7 +545,7 @@ class FxGraphCachePickler(pickle.Pickler):
def _reduce_tensor(
self, t: Tensor
) -> tuple[Callable[[T], T], tuple[Union[TensorMetadata, TensorMetadataAndValues]]]:
) -> tuple[Callable[[T], T], tuple[TensorMetadata | TensorMetadataAndValues]]:
"""
Custom reducer to pickle Tensors. If we see tensors, we know they're constants
stored as attributes on the GraphModule.
@ -943,7 +941,7 @@ class FxGraphHashDetails:
raise AssertionError(f"unknown config type: {str(type(custom_pass))}")
def _get_custom_pass_detail(
self, custom_pass: Union[CustomGraphPassType, CustomGraphModulePass]
self, custom_pass: CustomGraphPassType | CustomGraphModulePass
) -> Any | None:
if not custom_pass:
return None
@ -1058,7 +1056,7 @@ class GuardedCache(Generic[T]):
key: str,
local: bool,
remote_cache: RemoteCache[JsonDataTy] | None,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool],
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool],
hints: list[int],
) -> tuple[T | None, bytes | None, dict[str, str]]:
"""
@ -1292,7 +1290,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
local: bool,
remote_cache: RemoteCache[JsonDataTy] | None,
constants: CompiledFxGraphConstants,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
| None = None,
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
@ -1543,7 +1541,7 @@ class FxGraphCache(GuardedCache[CompiledFxGraph]):
remote_cache: RemoteCache[JsonDataTy] | None,
is_backward: bool,
constants: CompiledFxGraphConstants,
evaluate_guards: Callable[[str, Union[list[int], list[torch.SymInt]]], bool]
evaluate_guards: Callable[[str, list[int] | list[torch.SymInt]], bool]
| None = None,
) -> tuple[CompiledFxGraph | None, dict[str, Any]]:
"""
@ -1723,12 +1721,12 @@ class AotCodeCompiler:
*,
device_type: str,
additional_files: list[str],
) -> Union[list[Union[str, Weights]], str]:
) -> list[Union[str, Weights]] | str:
"""
Returns the .so path, or returns a list of files that were generated if
config.aot_inductor.package=True.
"""
generated_files: list[Union[str, Weights]] = additional_files # type: ignore[assignment]
generated_files: list[str | Weights] = additional_files # type: ignore[assignment]
_set_gpu_runtime_env() # cpp_extension consults the env
@ -2342,7 +2340,7 @@ end
f.write(json.dumps(qual_name_to_id))
generated_files.append(constants_config_json)
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
gpu_codecache: ROCmCodeCache | CUDACodeCache = (
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
)
gpu_kernels_o = gpu_codecache.aot_kernels_o.copy()
@ -2555,7 +2553,7 @@ end
_libgomp: CDLL | None = None
def custom_op_wrapper(op: str, *args: Any) -> Union[list[c_void_p], c_void_p, None]:
def custom_op_wrapper(op: str, *args: Any) -> list[c_void_p] | c_void_p | None:
# This function will be called from generated cpp wrapper code in the JIT mode.
# Because tensors will be passed in as AtenTensorHandle, we need to explicitly convert them.
def convert_arg(arg: Any) -> Any:
@ -2698,16 +2696,16 @@ class CppCodeCache:
"""Compiles and caches C++ libraries. Users of this class supply the source code to
be compiled, while compilation flags are set by CppBuilder."""
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags: dict[str, Any] = {}
@staticmethod
def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
def _load_library_inner(path: str, key: str) -> CDLL | ModuleType:
return cdll.LoadLibrary(path)
@classmethod
def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
def _load_library(cls, path: str, key: str) -> CDLL | ModuleType:
try:
result = cls._load_library_inner(path, key)
result.key = key # type: ignore[union-attr]
@ -2910,7 +2908,7 @@ def _worker_compile_cpp(
# Customized Python binding for cpp kernels
@clear_on_fresh_cache
class CppPythonBindingsCodeCache(CppCodeCache):
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = {
# kernels have no dependency on libtorch
@ -3092,7 +3090,7 @@ class CppPythonBindingsCodeCache(CppCodeCache):
@clear_on_fresh_cache
class CppWrapperCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[CDLL, ModuleType]]] = {}
cache: dict[str, Callable[[], CDLL | ModuleType]] = {}
cache_clear = staticmethod(cache.clear)
cpp_compile_command_flags = {
"include_pytorch": True,
@ -3161,7 +3159,7 @@ class CppWrapperCodeCache(CppPythonBindingsCodeCache):
@clear_on_fresh_cache
class HalideCodeCache(CppPythonBindingsCodeCache):
cache: dict[str, Callable[[], Union[ModuleType, CDLL]]] = {}
cache: dict[str, Callable[[], ModuleType | CDLL]] = {}
cache_clear = staticmethod(cache.clear)
_standalone_runtime_path: str | None = None
prefix = textwrap.dedent(

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import functools
from collections import deque
from typing import Union
import torch
from torch.utils._ordered_set import OrderedSet
@ -514,7 +513,7 @@ def build_subgraph_buffer(
def create_placeholder(
name: str, dtype: torch.dtype, device: torch.device
) -> Union[TensorBox, ShapeAsConstantBuffer]:
) -> TensorBox | ShapeAsConstantBuffer:
"""
Creates a placeholder input buffers for producing subgraph_output
"""

View File

@ -7,7 +7,7 @@ import operator
from collections.abc import Generator
from dataclasses import dataclass
from functools import partial
from typing import Any, Callable, cast, Union
from typing import Any, Callable, cast
import torch
import torch.fx as fx
@ -39,12 +39,12 @@ def move_block_before(block: list[fx.Node], target_node: fx.Node) -> None:
def call_function(
graph: fx.Graph,
target: Union[str, Callable[..., Any]],
target: str | Callable[..., Any],
args: tuple[fx.node.Argument, ...] | None = None,
kwargs: dict[str, fx.node.Argument] | None = None,
) -> fx.Node:
# We accept target as a str to avoid typing error as the type of
# a node.target is Union[str, Callable[..., Any]].
# a node.target is str | Callable[..., Any].
# This also allows us to avoid writing check for every call.
if isinstance(target, str):
raise RuntimeError(f"Call function should not get a str target {target=}")
@ -62,7 +62,7 @@ def call_function(
@dataclass(unsafe_hash=True)
class CommBlock:
shape: Union[torch.Size, list[torch.Size]]
shape: torch.Size | list[torch.Size]
node_list: list[fx.Node]
inputs: list[fx.Node]
wait_nodes: list[fx.Node]
@ -128,7 +128,7 @@ def get_comm_block(comm_node: fx.Node) -> CommBlock | None:
break
tensor_meta = input_nodes[0].meta["tensor_meta"]
shape: Union[torch.Size, list[torch.Size]]
shape: torch.Size | list[torch.Size]
if isinstance(tensor_meta, TensorMetadata):
shape = tensor_meta.shape
elif isinstance(tensor_meta, (list, tuple)):
@ -571,7 +571,7 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
def fuse_ddp_communication(
graph: fx.Graph, passes: list[Union[Callable[..., None], str]], bucket_size_mb: int
graph: fx.Graph, passes: list[Callable[..., None] | str], bucket_size_mb: int
) -> None:
for i, pa in enumerate(passes):
with GraphTransformObserver(

View File

@ -1,6 +1,6 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import Any, Union
from typing import Any
import torch
from torch import SymBool, SymFloat, SymInt
@ -14,7 +14,7 @@ class _SymExprHash:
Hash for a py_sym_types that will use the underlying sympy expression
"""
sym_obj: Union[SymInt, SymFloat, SymBool]
sym_obj: SymInt | SymFloat | SymBool
def __hash__(self) -> int:
return hash((type(self.sym_obj), self.sym_obj.node.expr))

View File

@ -6,7 +6,7 @@ import operator
import typing
from collections import Counter
from collections.abc import Sequence
from typing import Any, Union
from typing import Any
import torch
import torch._guards
@ -706,8 +706,8 @@ def pointless_convert(match: Match, arg, dtype1: torch.dtype, dtype2: torch.dtyp
def definitely_equal(
old_sizes: Sequence[Union[torch.SymInt, int]],
new_sizes: Sequence[Union[torch.SymInt, torch.fx.Node, int]],
old_sizes: Sequence[torch.SymInt | int],
new_sizes: Sequence[torch.SymInt | torch.fx.Node | int],
) -> bool:
"""
Leverage guard_or_true/false to compare if two lists of int/symint are equal.
@ -906,7 +906,7 @@ def mul_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
if dtype is not None:
inp = inp.to(dtype)
sign: Union[int, float, torch.Tensor]
sign: int | float | torch.Tensor
if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
sign = 1 if other >= 0 else -1
else:
@ -936,7 +936,7 @@ def div_softmax_pattern(match: Match, *, inp, other, dim, keepdim, dtype=None):
if dtype is not None:
inp = inp.to(dtype)
sign: Union[int, float, torch.Tensor]
sign: int | float | torch.Tensor
if isinstance(other, (int, float, torch.SymInt, torch.SymFloat)):
sign = 1 if other >= 0 else -1
else:

View File

@ -2,7 +2,7 @@ import itertools
import logging
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, Union
from typing import Callable
import torch
import torch.fx as fx
@ -143,7 +143,7 @@ class GraphAliasTracker:
return self.node_to_storages_last_used[node]
def _size_of_default(num_bytes: Union[int, torch.SymInt]) -> int:
def _size_of_default(num_bytes: int | torch.SymInt) -> int:
return hint_int(num_bytes, fallback=torch._inductor.config.unbacked_symint_fallback)
@ -154,7 +154,7 @@ def device_filter(device: torch.device) -> bool:
def build_memory_profile(
graph: fx.Graph,
is_releasable: Callable[[fx.Node], bool],
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
size_of: Callable[[int | torch.SymInt], int] | None = None,
) -> list[int]:
"""
Function to estimate the memory profile of an input FX graph.
@ -165,7 +165,7 @@ def build_memory_profile(
- is_releasable (Callable[[fx.Node], bool]): A function that
determines if a node's memory can be released (e.g. primal nodes
cannot be released).
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
- size_of (Callable[[int | torch.SymInt], int]): A function that converts
byte counts (possibly symbolic) to concrete integers.
Returns:
@ -216,7 +216,7 @@ def build_memory_profile(
def get_fwd_bwd_interactions(
fwd_graph: fx.Graph,
bwd_graph: fx.Graph,
size_of: Callable[[Union[int, torch.SymInt]], int] | None = None,
size_of: Callable[[int | torch.SymInt], int] | None = None,
) -> tuple[int, OrderedSet[str]]:
"""
Analyze the interactions between the forward (fwd) and backward (bwd) graphs
@ -225,7 +225,7 @@ def get_fwd_bwd_interactions(
Args:
- fwd_graph (fx.Graph): The forward graph representing the forward pass.
- bwd_graph (fx.Graph): The backward graph representing the backward pass.
- size_of (Callable[[Union[int, torch.SymInt]], int]): A function that converts
- size_of (Callable[[int | torch.SymInt], int]): A function that converts
byte counts (possibly symbolic) to concrete integers.
Returns:

View File

@ -6,7 +6,7 @@ import sys
from collections import Counter, defaultdict
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Callable, Union
from typing import Any, Callable
import torch
import torch.fx as fx
@ -82,7 +82,7 @@ def is_compute_node(n: fx.Node) -> bool:
)
def get_hint(x: Union[int, torch.SymInt]) -> int | None:
def get_hint(x: int | torch.SymInt) -> int | None:
if isinstance(x, int):
return x
assert isinstance(x, torch.SymInt)

View File

@ -3,7 +3,7 @@ import itertools
import operator
import typing
from collections.abc import Sequence
from typing import Any, Callable, Union
from typing import Any, Callable
import torch
import torch._inductor.runtime.runtime_utils
@ -118,7 +118,7 @@ def should_pad_common(mat1: Tensor, mat2: Tensor, input: Tensor | None = None) -
)
def get_padded_length(x: Union[int, torch.SymInt], alignment_size: int) -> int:
def get_padded_length(x: int | torch.SymInt, alignment_size: int) -> int:
# we don't pad x if it is symbolic
if isinstance(x, torch.SymInt) or alignment_size == 0 or x % alignment_size == 0:
return 0
@ -438,7 +438,7 @@ def _should_pad_bench(
return False
def realize_symbols(
ds: Union[torch.Size, tuple[torch.SymInt, ...]],
ds: torch.Size | tuple[torch.SymInt, ...],
) -> list[int]:
return [d if isinstance(d, int) else d.node.hint for d in ds]

View File

@ -5,7 +5,7 @@ import itertools
import logging
import operator
from collections import Counter, defaultdict
from typing import Any, Callable, TypeVar, Union
from typing import Any, Callable, TypeVar
from typing_extensions import ParamSpec
import torch
@ -437,7 +437,7 @@ def decompose_map_to_while_loop(gm: torch.fx.GraphModule):
def resolve_shape_to_proxy(
shape: list[Union[int, torch.SymInt]], bound_symbols: dict[Any, Any]
shape: list[int | torch.SymInt], bound_symbols: dict[Any, Any]
):
"""
Given a list of symints/ints, this function returns a calculated expression of bound_symbols' values.
@ -1123,8 +1123,8 @@ def remove_noop_ops(graph: torch.fx.Graph):
Removes both operations that are essentially aten.clone and operations that are essentially aten.alias from the graph.
"""
inputs = OrderedSet[torch.fx.Node]()
input_storages = OrderedSet[Union[int, None]]()
output_storages = OrderedSet[Union[int, None]]()
input_storages = OrderedSet[int | None]()
output_storages = OrderedSet[int | None]()
for node in graph.find_nodes(op="placeholder"):
inputs.add(node)

View File

@ -5,7 +5,7 @@ import operator
from collections import defaultdict
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Callable, cast, Union
from typing import Any, Callable, cast
import torch
import torch.fx.node
@ -578,7 +578,7 @@ def reinplace_inplaceable_ops_core(graph: torch.fx.Graph) -> None:
old_tensors_to_clone, kwargs, node_name, trigger
):
tensors_to_clone: list[str] = []
storage_of_reinplaced_args = OrderedSet[Union[int, None]]()
storage_of_reinplaced_args = OrderedSet[int | None]()
# Those used to count possibly_missed_reinplacing_opportunities
missed_nodes = []

View File

@ -5,7 +5,7 @@ import operator
import os
from collections import defaultdict
from collections.abc import Sequence
from typing import Any, Callable, Union
from typing import Any, Callable
from typing_extensions import TypeAlias
import torch
@ -725,14 +725,14 @@ class SplitCatSimplifier:
def get_user_input_list(
self, split_node: torch.fx.Node, next_users: list[torch.fx.Node]
) -> list[list[Union[torch.fx.Node, _Range]]]:
) -> list[list[torch.fx.Node | _Range]]:
"""
Returns list of inputs to the following user nodes, in order. The outer list represents the user node. The inner
list represents the inputs to that particular node. This list can either contain
- a tuple representing the ranges of get_items that should go into the cat (closed interval)
- torch.fx.Node representing "other" inputs (which are not coming from our split)
"""
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]] = []
user_inputs_list: list[list[torch.fx.Node | _Range]] = []
for user in next_users:
if user.target in (torch.cat, torch.stack):
user_inputs_list.append(self.get_merged_user_inputs(split_node, user))
@ -742,7 +742,7 @@ class SplitCatSimplifier:
def get_merged_user_inputs(
self, split_node: torch.fx.Node, cat_node: torch.fx.Node
) -> list[Union[torch.fx.Node, _Range]]:
) -> list[torch.fx.Node | _Range]:
user_inputs = get_arg_value(cat_node, 0, "tensors")
simplified_user_inputs = []
split_users = OrderedSet(split_node.users.keys())
@ -769,8 +769,8 @@ class SplitCatSimplifier:
return node_input
def merge_consecutive_inputs(
self, inputs: list[Union[torch.fx.Node, int]]
) -> list[Union[torch.fx.Node, _Range]]:
self, inputs: list[torch.fx.Node | int]
) -> list[torch.fx.Node | _Range]:
"""
Merge consecutive inputs going into a user node.
@ -801,7 +801,7 @@ class SplitCatSimplifier:
self,
split_sections,
next_users,
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
user_inputs_list: list[list[torch.fx.Node | _Range]],
) -> list[_Range] | None:
ranges = OrderedSet[Any]()
for user_inputs in user_inputs_list:
@ -847,7 +847,7 @@ class SplitCatSimplifier:
self,
split_node: torch.fx.Node,
next_users: list[torch.fx.Node],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
user_inputs_list: list[list[torch.fx.Node | _Range]],
) -> list[list[_TransformParam]] | None:
"""
Figure out what transforms are needed for each input to each cat node.
@ -901,7 +901,7 @@ class SplitCatSimplifier:
graph: torch.fx.Graph,
split_node: torch.fx.Node,
split_sections: list[int],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
user_inputs_list: list[list[torch.fx.Node | _Range]],
split_ranges: list[_Range],
) -> list[list[torch.fx.Node]]:
"""
@ -1177,7 +1177,7 @@ class UnbindCatRemover(SplitCatSimplifier):
self,
split_sections: list[int],
next_users: list[torch.fx.Node],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
user_inputs_list: list[list[torch.fx.Node | _Range]],
) -> list[_Range] | None:
simplified_split_ranges = super().get_simplified_split_ranges(
split_sections, next_users, user_inputs_list
@ -1190,7 +1190,7 @@ class UnbindCatRemover(SplitCatSimplifier):
self,
split_node: torch.fx.Node,
next_users: list[torch.fx.Node],
user_inputs_list: list[list[Union[torch.fx.Node, _Range]]],
user_inputs_list: list[list[torch.fx.Node | _Range]],
) -> list[list[_TransformParam]] | None:
"""
Figure out what transforms are needed for each input to each cat node.

View File

@ -4,7 +4,7 @@ import time
from functools import cached_property, wraps
from itertools import chain
from statistics import median
from typing import Any, Callable, Union
from typing import Any, Callable
from typing_extensions import Concatenate, ParamSpec, Self, TypeVar
import torch
@ -31,8 +31,8 @@ def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any
return fn
def distort(
ms: Union[list[float], tuple[float], float],
) -> Union[list[float], tuple[float], float]:
ms: list[float] | tuple[float] | float,
) -> list[float] | tuple[float] | float:
if isinstance(ms, (list, tuple)):
return type(ms)(distort(val) for val in ms) # type: ignore[misc]
@ -50,7 +50,7 @@ def may_distort_benchmarking_result(fn: Callable[..., Any]) -> Callable[..., Any
@functools.wraps(fn)
def wrapper(
*args: list[Any], **kwargs: dict[str, Any]
) -> Union[list[float], tuple[float], float]:
) -> list[float] | tuple[float] | float:
ms = fn(*args, **kwargs)
return distort(ms)
@ -276,7 +276,7 @@ class InductorBenchmarker(TritonBenchmarker): # noqa: docstring_linter
grad_to_none: list[torch.Tensor] | None = None,
is_vetted_benchmarking: bool = False,
**kwargs: Any,
) -> Union[float, list[float]]:
) -> float | list[float]:
"""Benchmark a GPU callable using a custom benchmarking implementation.
Arguments:

View File

@ -5,7 +5,6 @@ import collections
import functools
import typing
from enum import auto, Enum
from typing import Union
from torch.utils._triton import has_triton_package
@ -202,7 +201,7 @@ class HalideMeta(typing.NamedTuple):
argtypes: list[HalideInputSpec]
target: str
scheduler: str | None = None
scheduler_flags: dict[str, Union[int, str]] | None = None
scheduler_flags: dict[str, int | str] | None = None
cuda_device: int | None = None
def args(self) -> list[str]:

View File

@ -1,7 +1,7 @@
from __future__ import annotations
import inspect
from typing import Any, Union
from typing import Any
import torch
@ -37,7 +37,7 @@ if triton is not None:
def GPUTarget(
backend: str,
arch: Union[int, str],
arch: int | str,
warp_size: int,
) -> Any:
if torch.version.hip:
@ -138,7 +138,7 @@ else:
HAS_TRITON = False
def cc_warp_size(cc: Union[str, int]) -> int:
def cc_warp_size(cc: str | int) -> int:
if torch.version.hip:
cc_str = str(cc)
if "gfx10" in cc_str or "gfx11" in cc_str:

View File

@ -3447,9 +3447,9 @@ class GridExpr:
inductor_meta: dict[str, Any]
mode: Literal["python", "cpp"] = "python"
prefix: list[str] = dataclasses.field(default_factory=list)
x_grid: Union[str, int] = 1
y_grid: Union[str, int] = 1
z_grid: Union[str, int] = 1
x_grid: str | int = 1
y_grid: str | int = 1
z_grid: str | int = 1
def __post_init__(self) -> None:
assert self.mode in ("python", "cpp")
@ -3457,9 +3457,7 @@ class GridExpr:
def generate(self, meta: dict[str, int]) -> None:
raise NotImplementedError
def ceildiv(
self, numel: Union[str, int], block: Union[None, int, str]
) -> Union[str, int]:
def ceildiv(self, numel: str | int, block: None | int | str) -> str | int:
if block is None or block == 1:
return numel
if isinstance(numel, int) and isinstance(block, int):
@ -3471,7 +3469,7 @@ class GridExpr:
# For cpp code gen
return f"(({numel} + ({block} - 1)) / ({block}))"
def maximum(self, seq: list[Union[int, str]]) -> Union[int, str]:
def maximum(self, seq: list[int | str]) -> int | str:
"""Codegen for max function with constant folding, constants are represented as int"""
items = self._constant_fold(max, seq)
if len(items) <= 1:
@ -3480,7 +3478,7 @@ class GridExpr:
return f"max({', '.join(map(str, items))})"
return functools.reduce(lambda x, y: f"std::max({x}, {y})", items)
def summation(self, seq: list[Union[int, str]]) -> Union[int, str]:
def summation(self, seq: list[int | str]) -> int | str:
"""Codegen for sum function with constant folding, constants are represented as int"""
items = self._constant_fold(sum, seq)
if len(items) <= 1:
@ -3488,16 +3486,16 @@ class GridExpr:
return " + ".join(map(str, items))
def _constant_fold(
self, fn: Callable[[list[int]], int], seq: list[Union[int, str]]
) -> list[Union[int, str]]:
self, fn: Callable[[list[int]], int], seq: list[int | str]
) -> list[int | str]:
"""Constant fold through a commutative fn where ints are constants"""
items: list[Union[int, str]] = [x for x in seq if not isinstance(x, int)]
items: list[int | str] = [x for x in seq if not isinstance(x, int)]
const_items = [x for x in seq if isinstance(x, int)]
if const_items:
items.append(fn(const_items))
return items
def assign_tmp(self, name: str, expr: Union[str, int]) -> str:
def assign_tmp(self, name: str, expr: str | int) -> str:
# Grid functions are one per kernel, so name collisions are fine
if self.mode == "python":
return f"{name} = {expr}"
@ -3508,7 +3506,7 @@ class GridExpr:
@staticmethod
def from_meta(
inductor_meta: dict[str, Any],
cfg: Union[Config, dict[str, int]],
cfg: Config | dict[str, int],
mode: Literal["python", "cpp"] = "python",
) -> GridExpr:
grid_cls = globals()[inductor_meta["grid_type"]]
@ -3632,20 +3630,20 @@ class ComboKernelGrid(GridExpr):
def combo_x_grid(
self,
xnumels: list[Union[int, str]],
xnumels: list[int | str],
no_x_dims: list[bool],
meta: dict[str, int],
) -> Union[str, int]:
) -> str | int:
raise NotImplementedError
class SequentialComboKernelGrid(ComboKernelGrid):
def combo_x_grid(
self,
xnumels: list[Union[int, str]],
xnumels: list[int | str],
no_x_dims: list[bool],
meta: dict[str, int],
) -> Union[str, int]:
) -> str | int:
assert len(xnumels) == len(no_x_dims)
return self.summation(
[
@ -3658,7 +3656,7 @@ class SequentialComboKernelGrid(ComboKernelGrid):
class RoundRobinComboKernelGrid(ComboKernelGrid):
def combo_x_grid(
self,
xnumels: list[Union[int, str]],
xnumels: list[int | str],
no_x_dims: list[bool],
meta: dict[str, int],
) -> str: