[inductor] Add typing to common.CSE (#145993)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145993
Approved by: https://github.com/yanboliang
ghstack dependencies: #145913, #145914, #145915, #145916
This commit is contained in:
Jason Ansel
2025-01-31 21:12:32 -08:00
committed by PyTorch MergeBot
parent 68cf36d5ab
commit 8c657ae4be
10 changed files with 192 additions and 118 deletions

View File

@ -10,6 +10,7 @@ import logging
import math
import operator
import re
import typing
from enum import auto, Enum
from itertools import chain
from typing import (
@ -17,12 +18,15 @@ from typing import (
Callable,
cast,
ClassVar,
Generic,
Iterator,
MutableMapping,
NamedTuple,
Optional,
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeVar
import sympy
@ -44,6 +48,7 @@ from ..utils import (
generate_assert,
IndentedBuffer,
ir_dataclass,
ScopedDict,
sympy_dot,
sympy_subs,
unique,
@ -52,11 +57,9 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
if TYPE_CHECKING:
from typing import Never, TypeVar
from ..ir import FixedLayout
from ..ir import FixedLayout, IRNode
from ..loop_body import LoopBody
from ..scheduler import BaseScheduling, Scheduler
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
from .wrapper import PythonWrapperCodegen
_T = TypeVar("_T")
@ -1336,6 +1339,18 @@ class KernelArgs:
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
)
def arg_name(self, name: str) -> Optional[str]:
"""
Returns inner name of a given outer name.
"""
inplaced = self.inplace_buffers.get(name, None)
if inplaced is not None and not isinstance(inplaced, RemovedArg):
return inplaced.inner_name
output_name = self.output_buffers.get(name, None)
if output_name is not None and not isinstance(output_name, RemovedArg):
return output_name
return self.input_buffers.get(name, None)
def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
return buf
@ -1477,17 +1492,18 @@ class CSEVariable:
def __init__(
self,
name,
name: str,
bounds: ValueRanges[Any],
dtype: Optional[torch.dtype] = None,
):
super().__init__()
assert isinstance(bounds, ValueRanges)
self.name = name
self.bounds = bounds
self.use_count = 1 # track how many times this expression is used
self.dtype = dtype
def __str__(self):
def __str__(self) -> str:
return self.name
def __hash__(self) -> int:
@ -1496,45 +1512,62 @@ class CSEVariable:
def __eq__(self, other) -> bool:
return type(other) == type(self) and other.name == self.name
def update_on_args(self, name, args, kwargs):
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
pass
def __repr__(self):
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.name!r})"
class CSE:
AugmentedKeyT = TypeVar("AugmentedKeyT", default=str)
CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
if TYPE_CHECKING:
ReductionCacheKey = tuple[
torch.dtype,
ReductionType,
Union[CSEVariable, tuple[CSEVariable, ...]],
]
class CSE(Generic[CSEVariableType, AugmentedKeyT]):
"""Common subexpression elimination"""
def __init__(
self,
prefix="",
suffix="",
name_prefix="tmp",
iter_buffers=None,
store_cache=None,
reduction_cache=None,
varname_map=None,
prefix: str = "",
suffix: str = "",
name_prefix: str = "tmp",
iter_buffers: Optional[itertools.count[int]] = None,
store_cache: Optional[MutableMapping[str, CSEVariableType]] = None,
reduction_cache: Optional[
MutableMapping[ReductionCacheKey, CSEVariableType]
] = None,
varname_map: Optional[dict[str, CSEVariableType]] = None,
):
self.prefix = prefix
self.suffix = suffix
self._cache = {}
self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {}
self.name_prefix = name_prefix
self.store_cache = store_cache or {}
self.reduction_cache = reduction_cache or {}
self.iter_buffer_ids = iter_buffers or itertools.count()
self.invalidated_stores = OrderedSet[str]()
self.varname_map = varname_map or {}
self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {}
self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = (
reduction_cache or {}
)
self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count()
self.invalidated_stores: OrderedSet[str] = OrderedSet()
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
def invalidate(self, keep_vars: Union[OrderedSet[str], OrderedSet[Never]]):
for name, tmp in list(self.store_cache.items()):
def invalidate(self, keep_vars: OrderedSet[CSEVariable]):
for name, tmp in [*self.store_cache.items()]:
if tmp not in keep_vars:
del self.store_cache[name]
self.invalidated_stores.add(name)
if keep_vars:
self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
else:
self._cache = {}
def clone(self):
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
def clone(self) -> typing.Self:
return type(self)(
prefix=self.prefix,
suffix=self.suffix,
@ -1542,22 +1575,23 @@ class CSE:
iter_buffers=self.iter_buffer_ids,
store_cache=self.store_cache,
varname_map=self.varname_map,
reduction_cache=self.reduction_cache,
)
def augment_key(self, cache_key: object) -> object:
def augment_key(self, cache_key: str) -> AugmentedKeyT:
"Override this method to augment cache key with backend specifics"
return cache_key
return cast(AugmentedKeyT, cache_key)
def put(self, cache_key: object, val: CSEVariable) -> None:
def put(self, cache_key: str, val: CSEVariableType) -> None:
self._cache[self.augment_key(cache_key)] = val
def contains(self, cache_key) -> bool:
def contains(self, cache_key: str) -> bool:
return self.augment_key(cache_key) in self._cache
def try_get(self, cache_key: object) -> Optional[CSEVariable]:
def try_get(self, cache_key: str) -> Optional[CSEVariableType]:
return self._cache.get(self.augment_key(cache_key), None)
def get(self, cache_key: object) -> CSEVariable:
def get(self, cache_key: str) -> CSEVariableType:
return self._cache[self.augment_key(cache_key)]
def generate(
@ -1566,10 +1600,10 @@ class CSE:
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase],
*,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
write=True,
assignment=True,
write: bool = True,
assignment: bool = True,
dtype: Optional[torch.dtype] = None,
) -> CSEVariable:
) -> CSEVariableType:
if isinstance(expr, OpsValue):
expr = expr.value
@ -1580,7 +1614,7 @@ class CSE:
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
expr.bounds = expr.bounds.tighten(bounds)
expr.use_count += 1
return expr
return cast(CSEVariableType, expr)
elif isinstance(expr, IndentedBuffer):
cache_key = expr.getvalue()
elif isinstance(expr, DeferredLineBase):
@ -1623,7 +1657,7 @@ class CSE:
self,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
dtype: Optional[torch.dtype] = None,
) -> CSEVariable:
) -> CSEVariableType:
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
var = V.kernel.create_cse_var(var_name, bounds, dtype)
self.varname_map[var_name] = var
@ -1634,7 +1668,7 @@ class CSE:
name: str,
bounds: ValueRanges[Any] = ValueRanges.unknown(),
dtype: Optional[torch.dtype] = None,
) -> CSEVariable:
) -> CSEVariableType:
torch._check_value(
name not in self.varname_map, lambda: f"duplicate name: {name}"
)
@ -1648,45 +1682,22 @@ class CodeGen:
super().__init__()
self.exit_stack = contextlib.ExitStack()
def __enter__(self):
def __enter__(self) -> typing.Self:
self.exit_stack.__enter__()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
class ScopedDict:
def __init__(self, original_dict):
self.original_dict = original_dict
self.new_items = {}
def __getitem__(self, key):
if key in self.new_items:
return self.new_items[key]
return self.original_dict[key]
def __setitem__(self, key, value):
self.new_items[key] = value
def __contains__(self, key):
return key in self.new_items or key in self.original_dict
def get(self, key, default=None):
if key in self.new_items:
return self.new_items[key]
return self.original_dict.get(key, default)
class Kernel(CodeGen):
newvar_prefix = ""
suffix = ""
class Kernel(CodeGen, Generic[CSEVariableType]):
newvar_prefix: str = ""
suffix: str = ""
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
# TODO: these look dead, but with all the getattr it's hard to tell...
load_format: None = None
store_format: None = None
def __init__(self, args=None, increase_kernel_count=True):
def __init__(
self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
) -> None:
super().__init__()
if increase_kernel_count:
metrics.generated_kernel_count += 1
@ -1698,13 +1709,13 @@ class Kernel(CodeGen):
self.num_load = 0
self.num_reduction = 0
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
self.must_keep_buffers = OrderedSet[str]()
self.store_buffer_names = OrderedSet[str]()
self._load_mask = None
self._load_other = None
self._load_mask: Optional[str] = None
self._load_other: Union[None, int, float] = None
# OrderedSet in set_current_node
self.current_node = None
self.current_node: Optional[SchedulerNode] = None
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
self.removed_buffers = OrderedSet[str]()
@ -1713,10 +1724,10 @@ class Kernel(CodeGen):
# key: the buffer to write
# value: the buffer to read and whose memory can be reused for
# the buffer specified by key
self.inplace_update_buffers = {}
self.inplace_update_buffers: dict[str, str] = {}
# Set minimum number of elements processed per thread.
self.min_elem_per_thread = 1
self.kernel_name = None
self.kernel_name: Optional[str] = None
@contextlib.contextmanager
def set_current_node(self, node):
@ -1730,7 +1741,7 @@ class Kernel(CodeGen):
@contextlib.contextmanager
def swap_buffers(self, lb, cb=None, sb=None):
def scope_cse(cse):
def scope_cse(cse: CSE[CSEVariableType, Any]):
new_cse = cse.clone()
new_cse._cache = ScopedDict(cse._cache)
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
@ -2057,6 +2068,7 @@ class Kernel(CodeGen):
@staticmethod
def _update_store_cache(name: str, value: CSEVariable):
value = cast(CSEVariableType, value)
self.cse.store_cache[name] = value
if self.current_node and name in V.graph.name_to_buffer:
buf = self.current_node.get_output(name)
@ -2283,6 +2295,14 @@ class Kernel(CodeGen):
def create_cse_var(self, *args, **kwargs):
return CSEVariable(*args, **kwargs)
def arg_name(self, node: IRNode) -> Optional[str]:
"""
Returns arg name of a given input or output node.
"""
if node is None:
return None
return self.args.arg_name(node.get_name())
@dataclasses.dataclass
class OptimizationContext:

View File

@ -2646,7 +2646,7 @@ class CppVecKernel(CppKernel):
return super().load(name, index)
elif stride == 1:
# load contiguously
line = self._get_vec_load_line(var, index, dtype, self._load_mask)
line = self._get_vec_load_line(var, index, dtype, self._load_mask) # type: ignore[arg-type]
csevar = self.cse.generate(self.loads, line) # type: ignore[assignment]
else:
csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment]

View File

@ -163,16 +163,6 @@ class CUDATemplateKernel(CUDAKernel):
super().__init__()
self.kernel_name = kernel_name
def arg_name(self, node: IRNode) -> Optional[str]:
"""
Returns arg name of a given input or output node.
"""
if node is None:
return None
return {**self.args.input_buffers, **self.args.output_buffers}.get(
node.get_name(), None
)
def check_not_null(self, node: IRNode) -> str:
"""
Generates code to check that a node is not null.
@ -273,6 +263,7 @@ class CUDATemplateKernel(CUDAKernel):
"""
wrapper = V.graph.wrapper_code
arg_types: list[Any]
if V.graph.cpp_wrapper:
# Make sure we initialize these kernels since they're exported as
# C-style symbol names.

View File

@ -8,7 +8,7 @@ import logging
import re
from collections import defaultdict
from math import inf
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
import sympy
@ -39,6 +39,7 @@ from .common import (
CSEVariable,
DeferredLine,
IndentedBuffer,
KernelArgType,
OpOverrides,
PythonPrinter,
SizeArg,
@ -1383,7 +1384,7 @@ class HalideKernel(SIMDKernel):
assert "in_ptr" in arg.name
return 0
result = []
result: list[tuple[Optional[str], KernelArgType]] = []
_, a, b, _ = self.args.python_argdefs()
for call_str, arg in sorted(zip(a, b), key=arg_order):
result.append((call_str, arg))
@ -1527,7 +1528,7 @@ class HalideKernel(SIMDKernel):
code.splice(self.indexing_code)
def update_index(m):
var = self.cse.varname_map[m.group(1)]
var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)])
assert var.used_dims is not None, var
return str(var)

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from typing import Callable, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
@ -52,16 +52,6 @@ class ROCmTemplateKernel(ROCmKernel):
# Mapping from arg name to IRNode.
self.named_nodes: dict[str, IRNode] = {}
def arg_name(self, node: IRNode) -> Optional[str]:
"""
Returns arg name of a given input or output node.
"""
if node is None:
return None
return {**self.args.input_buffers, **self.args.output_buffers}.get(
node.get_name(), None
)
def get_signature(self):
return self.signature
@ -133,6 +123,7 @@ class ROCmTemplateKernel(ROCmKernel):
"""
wrapper = V.graph.wrapper_code
arg_types: list[Any]
if V.graph.cpp_wrapper:
# Make sure we initialize these kernels since they're exported as
# C-style symbol names.

View File

@ -14,12 +14,14 @@ from collections import Counter
from typing import (
Any,
Callable,
Generic,
Iterator,
no_type_check,
Optional,
TYPE_CHECKING,
Union,
)
from typing_extensions import TypeVar
import sympy
@ -339,7 +341,10 @@ def constant_repr(value: Union[int, float]) -> str:
return repr(value)
class SIMDKernel(Kernel):
CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
"""
Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests.
"""
@ -456,7 +461,7 @@ class SIMDKernel(Kernel):
numels[prefix],
prefix,
index,
self,
self, # type: ignore[arg-type]
pid_cache=pid_cache,
is_loop=is_reduction and not self.persistent_reduction,
tensor_dim=tensor_dim,

View File

@ -102,6 +102,7 @@ if TYPE_CHECKING:
from typing import TypeVar
from ..ir import IRNode
from .simd_kernel_features import SIMDKernelFeatures
_T = TypeVar("_T")
@ -1522,20 +1523,20 @@ class FixedTritonConfig:
return item in self.config
class TritonCSE(CSE):
class TritonCSE(CSE[TritonCSEVariable, Union[str, tuple[str, str]]]):
"""
Subclasses CSE to apply the current load mask to the cache key to avoid CSEing
variables across separate masked blocks.
"""
def augment_key(self, cache_key: object) -> object:
def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]:
if mask := V.kernel._load_mask:
return (cache_key, mask.name)
else:
return cache_key
class TritonKernel(SIMDKernel):
class TritonKernel(SIMDKernel[TritonCSEVariable]):
overrides = TritonKernelOverrides # type: ignore[assignment]
helper_functions: HelperFunctions
kexpr: Callable[[sympy.Expr], str] = texpr
@ -2802,7 +2803,7 @@ class TritonKernel(SIMDKernel):
# in the global namespace
helper = IndentedBuffer()
helper.writeline("@triton.jit")
cse = CSE(prefix="", suffix="")
cse = CSE()
args = [
tuple(cse.namedvar(f"arg{i}_{n}", dtype=dtypes[n]) for n in range(num_args))
@ -2954,7 +2955,8 @@ class TritonKernel(SIMDKernel):
result_vars = partial_scan_vars
for result_var in result_vars:
result_var.mask_vars = masks # type: ignore[attr-defined]
assert isinstance(result_var, TritonCSEVariable)
result_var.mask_vars = OrderedSet(masks)
return tuple(result_vars)
@ -4054,9 +4056,12 @@ class TritonScheduling(SIMDScheduling):
store_cache()
return ms, mod.__file__
def create_kernel_choices(
self, kernel_features, kernel_args, kernel_kwargs
) -> list[SIMDKernel]:
def create_kernel_choices( # type: ignore[override]
self,
kernel_features: SIMDKernelFeatures,
kernel_args: list[Any],
kernel_kwargs: dict[str, Any],
) -> list[TritonKernel]:
is_scan = kernel_features.contains_op("scan")
is_split_scan = is_scan and any(
node.is_split_scan() for node in kernel_features.scheduler_nodes()
@ -4090,11 +4095,11 @@ class TritonScheduling(SIMDScheduling):
def add_multi_kernel_choices(
self,
kernel: SIMDKernel,
kernel: TritonKernel,
kernel_args: list[Any],
kernel_kwargs: dict[str, Any],
) -> list[SIMDKernel]:
kernels: list[SIMDKernel] = [kernel]
) -> list[TritonKernel]:
kernels: list[TritonKernel] = [kernel]
if not config.triton.multi_kernel:
return kernels

View File

@ -1,11 +1,16 @@
# mypy: allow-untyped-defs
import functools
from typing import Union
import sympy
from torch._inductor import config
from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
from torch._inductor.codegen.triton import (
triton_compute_type,
TritonCSEVariable,
TritonKernel,
)
from torch._inductor.runtime.triton_heuristics import split_scan_grid
from torch.utils._ordered_set import OrderedSet
from torch.utils._sympy.functions import CeilDiv
@ -68,7 +73,7 @@ class TritonSplitScanKernel(TritonKernel):
numel,
prefix,
grid_dim,
self,
self, # type: ignore[arg-type]
pid_cache=pid_cache,
is_loop=False,
tensor_dim=tensor_dim,
@ -115,6 +120,7 @@ class TritonSplitScanKernel(TritonKernel):
)
max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock)
nbytes = scratch_nbytes_per_block * max_blocks
scratch_base: Union[str, TritonCSEVariable]
scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
if offset != 0:
scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")

View File

@ -648,7 +648,7 @@ class TritonTemplateKernel(TritonKernel):
self.body.writeline(str(scatter))
body_val = self.body.getvalue()
self.cse.invalidate(OrderedSet[str]())
self.cse.invalidate(OrderedSet())
return body_val
def load_input(

View File

@ -28,6 +28,9 @@ from typing import (
Any,
Callable,
Generic,
Iterator,
Mapping,
MutableMapping,
NamedTuple,
Optional,
Protocol,
@ -2437,6 +2440,58 @@ def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
return dtype
KeyType = TypeVar("KeyType")
ValType = TypeVar("ValType")
class ScopedDict(MutableMapping[KeyType, ValType]):
"""
A dictionary-like object that allows for scoped updates. It maintains
an original dictionary and a set of new items that can override
the original items within the scope. The original dictionary is
unmodified.
"""
def __init__(self, original_dict: Mapping[KeyType, ValType]):
self.original_dict = original_dict
self.new_items: dict[KeyType, ValType] = {}
def __getitem__(self, key: KeyType) -> ValType:
if key in self.new_items:
return self.new_items[key]
return self.original_dict[key]
def __setitem__(self, key: KeyType, value: ValType):
self.new_items[key] = value
def __contains__(self, key: object) -> bool:
return key in self.new_items or key in self.original_dict
def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override]
if key in self.new_items:
return self.new_items[key]
return self.original_dict.get(key, default)
def __len__(self) -> int:
n = len(self.original_dict)
for k in self.new_items:
if k not in self.original_dict:
n += 1
return n
def __iter__(self) -> Iterator[KeyType]:
yield from self.original_dict
for k in self.new_items:
if k not in self.original_dict:
yield k
def __bool__(self) -> bool:
return bool(self.original_dict or self.new_items)
def __delitem__(self, key: KeyType) -> None:
raise NotImplementedError
@dataclass_transform(frozen_default=True)
def ir_dataclass(cls=None, /, *, frozen: bool = True):
def wrap(cls: _T) -> _T: