mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[inductor] Add typing to common.CSE (#145993)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145993 Approved by: https://github.com/yanboliang ghstack dependencies: #145913, #145914, #145915, #145916
This commit is contained in:
committed by
PyTorch MergeBot
parent
68cf36d5ab
commit
8c657ae4be
@ -10,6 +10,7 @@ import logging
|
||||
import math
|
||||
import operator
|
||||
import re
|
||||
import typing
|
||||
from enum import auto, Enum
|
||||
from itertools import chain
|
||||
from typing import (
|
||||
@ -17,12 +18,15 @@ from typing import (
|
||||
Callable,
|
||||
cast,
|
||||
ClassVar,
|
||||
Generic,
|
||||
Iterator,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import sympy
|
||||
|
||||
@ -44,6 +48,7 @@ from ..utils import (
|
||||
generate_assert,
|
||||
IndentedBuffer,
|
||||
ir_dataclass,
|
||||
ScopedDict,
|
||||
sympy_dot,
|
||||
sympy_subs,
|
||||
unique,
|
||||
@ -52,11 +57,9 @@ from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Never, TypeVar
|
||||
|
||||
from ..ir import FixedLayout
|
||||
from ..ir import FixedLayout, IRNode
|
||||
from ..loop_body import LoopBody
|
||||
from ..scheduler import BaseScheduling, Scheduler
|
||||
from ..scheduler import BaseScheduling, Scheduler, SchedulerNode
|
||||
from .wrapper import PythonWrapperCodegen
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -1336,6 +1339,18 @@ class KernelArgs:
|
||||
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
||||
)
|
||||
|
||||
def arg_name(self, name: str) -> Optional[str]:
|
||||
"""
|
||||
Returns inner name of a given outer name.
|
||||
"""
|
||||
inplaced = self.inplace_buffers.get(name, None)
|
||||
if inplaced is not None and not isinstance(inplaced, RemovedArg):
|
||||
return inplaced.inner_name
|
||||
output_name = self.output_buffers.get(name, None)
|
||||
if output_name is not None and not isinstance(output_name, RemovedArg):
|
||||
return output_name
|
||||
return self.input_buffers.get(name, None)
|
||||
|
||||
def wrap_ptr_arg(self, buf: str, dtype: torch.dtype) -> str:
|
||||
return buf
|
||||
|
||||
@ -1477,17 +1492,18 @@ class CSEVariable:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
name: str,
|
||||
bounds: ValueRanges[Any],
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(bounds, ValueRanges)
|
||||
self.name = name
|
||||
self.bounds = bounds
|
||||
self.use_count = 1 # track how many times this expression is used
|
||||
self.dtype = dtype
|
||||
|
||||
def __str__(self):
|
||||
def __str__(self) -> str:
|
||||
return self.name
|
||||
|
||||
def __hash__(self) -> int:
|
||||
@ -1496,45 +1512,62 @@ class CSEVariable:
|
||||
def __eq__(self, other) -> bool:
|
||||
return type(other) == type(self) and other.name == self.name
|
||||
|
||||
def update_on_args(self, name, args, kwargs):
|
||||
def update_on_args(self, name: str, args: Any, kwargs: Any) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}({self.name!r})"
|
||||
|
||||
|
||||
class CSE:
|
||||
AugmentedKeyT = TypeVar("AugmentedKeyT", default=str)
|
||||
CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
ReductionCacheKey = tuple[
|
||||
torch.dtype,
|
||||
ReductionType,
|
||||
Union[CSEVariable, tuple[CSEVariable, ...]],
|
||||
]
|
||||
|
||||
|
||||
class CSE(Generic[CSEVariableType, AugmentedKeyT]):
|
||||
"""Common subexpression elimination"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prefix="",
|
||||
suffix="",
|
||||
name_prefix="tmp",
|
||||
iter_buffers=None,
|
||||
store_cache=None,
|
||||
reduction_cache=None,
|
||||
varname_map=None,
|
||||
prefix: str = "",
|
||||
suffix: str = "",
|
||||
name_prefix: str = "tmp",
|
||||
iter_buffers: Optional[itertools.count[int]] = None,
|
||||
store_cache: Optional[MutableMapping[str, CSEVariableType]] = None,
|
||||
reduction_cache: Optional[
|
||||
MutableMapping[ReductionCacheKey, CSEVariableType]
|
||||
] = None,
|
||||
varname_map: Optional[dict[str, CSEVariableType]] = None,
|
||||
):
|
||||
self.prefix = prefix
|
||||
self.suffix = suffix
|
||||
self._cache = {}
|
||||
self._cache: MutableMapping[AugmentedKeyT, CSEVariableType] = {}
|
||||
self.name_prefix = name_prefix
|
||||
self.store_cache = store_cache or {}
|
||||
self.reduction_cache = reduction_cache or {}
|
||||
self.iter_buffer_ids = iter_buffers or itertools.count()
|
||||
self.invalidated_stores = OrderedSet[str]()
|
||||
self.varname_map = varname_map or {}
|
||||
self.store_cache: MutableMapping[str, CSEVariableType] = store_cache or {}
|
||||
self.reduction_cache: MutableMapping[ReductionCacheKey, CSEVariableType] = (
|
||||
reduction_cache or {}
|
||||
)
|
||||
self.iter_buffer_ids: itertools.count[int] = iter_buffers or itertools.count()
|
||||
self.invalidated_stores: OrderedSet[str] = OrderedSet()
|
||||
self.varname_map: dict[str, CSEVariableType] = varname_map or {}
|
||||
|
||||
def invalidate(self, keep_vars: Union[OrderedSet[str], OrderedSet[Never]]):
|
||||
for name, tmp in list(self.store_cache.items()):
|
||||
def invalidate(self, keep_vars: OrderedSet[CSEVariable]):
|
||||
for name, tmp in [*self.store_cache.items()]:
|
||||
if tmp not in keep_vars:
|
||||
del self.store_cache[name]
|
||||
self.invalidated_stores.add(name)
|
||||
if keep_vars:
|
||||
self._cache = {k: v for k, v in self._cache.items() if v in keep_vars}
|
||||
else:
|
||||
self._cache = {}
|
||||
|
||||
def clone(self):
|
||||
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
||||
def clone(self) -> typing.Self:
|
||||
return type(self)(
|
||||
prefix=self.prefix,
|
||||
suffix=self.suffix,
|
||||
@ -1542,22 +1575,23 @@ class CSE:
|
||||
iter_buffers=self.iter_buffer_ids,
|
||||
store_cache=self.store_cache,
|
||||
varname_map=self.varname_map,
|
||||
reduction_cache=self.reduction_cache,
|
||||
)
|
||||
|
||||
def augment_key(self, cache_key: object) -> object:
|
||||
def augment_key(self, cache_key: str) -> AugmentedKeyT:
|
||||
"Override this method to augment cache key with backend specifics"
|
||||
return cache_key
|
||||
return cast(AugmentedKeyT, cache_key)
|
||||
|
||||
def put(self, cache_key: object, val: CSEVariable) -> None:
|
||||
def put(self, cache_key: str, val: CSEVariableType) -> None:
|
||||
self._cache[self.augment_key(cache_key)] = val
|
||||
|
||||
def contains(self, cache_key) -> bool:
|
||||
def contains(self, cache_key: str) -> bool:
|
||||
return self.augment_key(cache_key) in self._cache
|
||||
|
||||
def try_get(self, cache_key: object) -> Optional[CSEVariable]:
|
||||
def try_get(self, cache_key: str) -> Optional[CSEVariableType]:
|
||||
return self._cache.get(self.augment_key(cache_key), None)
|
||||
|
||||
def get(self, cache_key: object) -> CSEVariable:
|
||||
def get(self, cache_key: str) -> CSEVariableType:
|
||||
return self._cache[self.augment_key(cache_key)]
|
||||
|
||||
def generate(
|
||||
@ -1566,10 +1600,10 @@ class CSE:
|
||||
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer, DeferredLineBase],
|
||||
*,
|
||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||
write=True,
|
||||
assignment=True,
|
||||
write: bool = True,
|
||||
assignment: bool = True,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> CSEVariable:
|
||||
) -> CSEVariableType:
|
||||
if isinstance(expr, OpsValue):
|
||||
expr = expr.value
|
||||
|
||||
@ -1580,7 +1614,7 @@ class CSE:
|
||||
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
||||
expr.bounds = expr.bounds.tighten(bounds)
|
||||
expr.use_count += 1
|
||||
return expr
|
||||
return cast(CSEVariableType, expr)
|
||||
elif isinstance(expr, IndentedBuffer):
|
||||
cache_key = expr.getvalue()
|
||||
elif isinstance(expr, DeferredLineBase):
|
||||
@ -1623,7 +1657,7 @@ class CSE:
|
||||
self,
|
||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> CSEVariable:
|
||||
) -> CSEVariableType:
|
||||
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
||||
var = V.kernel.create_cse_var(var_name, bounds, dtype)
|
||||
self.varname_map[var_name] = var
|
||||
@ -1634,7 +1668,7 @@ class CSE:
|
||||
name: str,
|
||||
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
) -> CSEVariable:
|
||||
) -> CSEVariableType:
|
||||
torch._check_value(
|
||||
name not in self.varname_map, lambda: f"duplicate name: {name}"
|
||||
)
|
||||
@ -1648,45 +1682,22 @@ class CodeGen:
|
||||
super().__init__()
|
||||
self.exit_stack = contextlib.ExitStack()
|
||||
|
||||
def __enter__(self):
|
||||
def __enter__(self) -> typing.Self:
|
||||
self.exit_stack.__enter__()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
|
||||
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
||||
|
||||
|
||||
class ScopedDict:
|
||||
def __init__(self, original_dict):
|
||||
self.original_dict = original_dict
|
||||
self.new_items = {}
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.new_items:
|
||||
return self.new_items[key]
|
||||
return self.original_dict[key]
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.new_items[key] = value
|
||||
|
||||
def __contains__(self, key):
|
||||
return key in self.new_items or key in self.original_dict
|
||||
|
||||
def get(self, key, default=None):
|
||||
if key in self.new_items:
|
||||
return self.new_items[key]
|
||||
return self.original_dict.get(key, default)
|
||||
|
||||
|
||||
class Kernel(CodeGen):
|
||||
newvar_prefix = ""
|
||||
suffix = ""
|
||||
class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
newvar_prefix: str = ""
|
||||
suffix: str = ""
|
||||
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
||||
# TODO: these look dead, but with all the getattr it's hard to tell...
|
||||
load_format: None = None
|
||||
store_format: None = None
|
||||
|
||||
def __init__(self, args=None, increase_kernel_count=True):
|
||||
def __init__(
|
||||
self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if increase_kernel_count:
|
||||
metrics.generated_kernel_count += 1
|
||||
@ -1698,13 +1709,13 @@ class Kernel(CodeGen):
|
||||
self.num_load = 0
|
||||
self.num_reduction = 0
|
||||
|
||||
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
||||
self.cse: CSE[CSEVariableType, Any] = CSE(self.newvar_prefix, self.suffix)
|
||||
self.must_keep_buffers = OrderedSet[str]()
|
||||
self.store_buffer_names = OrderedSet[str]()
|
||||
self._load_mask = None
|
||||
self._load_other = None
|
||||
self._load_mask: Optional[str] = None
|
||||
self._load_other: Union[None, int, float] = None
|
||||
# OrderedSet in set_current_node
|
||||
self.current_node = None
|
||||
self.current_node: Optional[SchedulerNode] = None
|
||||
self.node_to_bounds: Optional[dict[torch.fx.Node, ValueRanges[Any]]] = None
|
||||
|
||||
self.removed_buffers = OrderedSet[str]()
|
||||
@ -1713,10 +1724,10 @@ class Kernel(CodeGen):
|
||||
# key: the buffer to write
|
||||
# value: the buffer to read and whose memory can be reused for
|
||||
# the buffer specified by key
|
||||
self.inplace_update_buffers = {}
|
||||
self.inplace_update_buffers: dict[str, str] = {}
|
||||
# Set minimum number of elements processed per thread.
|
||||
self.min_elem_per_thread = 1
|
||||
self.kernel_name = None
|
||||
self.kernel_name: Optional[str] = None
|
||||
|
||||
@contextlib.contextmanager
|
||||
def set_current_node(self, node):
|
||||
@ -1730,7 +1741,7 @@ class Kernel(CodeGen):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def swap_buffers(self, lb, cb=None, sb=None):
|
||||
def scope_cse(cse):
|
||||
def scope_cse(cse: CSE[CSEVariableType, Any]):
|
||||
new_cse = cse.clone()
|
||||
new_cse._cache = ScopedDict(cse._cache)
|
||||
new_cse.reduction_cache = ScopedDict(cse.reduction_cache)
|
||||
@ -2057,6 +2068,7 @@ class Kernel(CodeGen):
|
||||
|
||||
@staticmethod
|
||||
def _update_store_cache(name: str, value: CSEVariable):
|
||||
value = cast(CSEVariableType, value)
|
||||
self.cse.store_cache[name] = value
|
||||
if self.current_node and name in V.graph.name_to_buffer:
|
||||
buf = self.current_node.get_output(name)
|
||||
@ -2283,6 +2295,14 @@ class Kernel(CodeGen):
|
||||
def create_cse_var(self, *args, **kwargs):
|
||||
return CSEVariable(*args, **kwargs)
|
||||
|
||||
def arg_name(self, node: IRNode) -> Optional[str]:
|
||||
"""
|
||||
Returns arg name of a given input or output node.
|
||||
"""
|
||||
if node is None:
|
||||
return None
|
||||
return self.args.arg_name(node.get_name())
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class OptimizationContext:
|
||||
|
@ -2646,7 +2646,7 @@ class CppVecKernel(CppKernel):
|
||||
return super().load(name, index)
|
||||
elif stride == 1:
|
||||
# load contiguously
|
||||
line = self._get_vec_load_line(var, index, dtype, self._load_mask)
|
||||
line = self._get_vec_load_line(var, index, dtype, self._load_mask) # type: ignore[arg-type]
|
||||
csevar = self.cse.generate(self.loads, line) # type: ignore[assignment]
|
||||
else:
|
||||
csevar = self._load_or_store_non_contiguous(var, index, dtype) # type: ignore[assignment]
|
||||
|
@ -163,16 +163,6 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
super().__init__()
|
||||
self.kernel_name = kernel_name
|
||||
|
||||
def arg_name(self, node: IRNode) -> Optional[str]:
|
||||
"""
|
||||
Returns arg name of a given input or output node.
|
||||
"""
|
||||
if node is None:
|
||||
return None
|
||||
return {**self.args.input_buffers, **self.args.output_buffers}.get(
|
||||
node.get_name(), None
|
||||
)
|
||||
|
||||
def check_not_null(self, node: IRNode) -> str:
|
||||
"""
|
||||
Generates code to check that a node is not null.
|
||||
@ -273,6 +263,7 @@ class CUDATemplateKernel(CUDAKernel):
|
||||
"""
|
||||
wrapper = V.graph.wrapper_code
|
||||
|
||||
arg_types: list[Any]
|
||||
if V.graph.cpp_wrapper:
|
||||
# Make sure we initialize these kernels since they're exported as
|
||||
# C-style symbol names.
|
||||
|
@ -8,7 +8,7 @@ import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from math import inf
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import sympy
|
||||
|
||||
@ -39,6 +39,7 @@ from .common import (
|
||||
CSEVariable,
|
||||
DeferredLine,
|
||||
IndentedBuffer,
|
||||
KernelArgType,
|
||||
OpOverrides,
|
||||
PythonPrinter,
|
||||
SizeArg,
|
||||
@ -1383,7 +1384,7 @@ class HalideKernel(SIMDKernel):
|
||||
assert "in_ptr" in arg.name
|
||||
return 0
|
||||
|
||||
result = []
|
||||
result: list[tuple[Optional[str], KernelArgType]] = []
|
||||
_, a, b, _ = self.args.python_argdefs()
|
||||
for call_str, arg in sorted(zip(a, b), key=arg_order):
|
||||
result.append((call_str, arg))
|
||||
@ -1527,7 +1528,7 @@ class HalideKernel(SIMDKernel):
|
||||
code.splice(self.indexing_code)
|
||||
|
||||
def update_index(m):
|
||||
var = self.cse.varname_map[m.group(1)]
|
||||
var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)])
|
||||
assert var.used_dims is not None, var
|
||||
return str(var)
|
||||
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
|
||||
|
||||
@ -52,16 +52,6 @@ class ROCmTemplateKernel(ROCmKernel):
|
||||
# Mapping from arg name to IRNode.
|
||||
self.named_nodes: dict[str, IRNode] = {}
|
||||
|
||||
def arg_name(self, node: IRNode) -> Optional[str]:
|
||||
"""
|
||||
Returns arg name of a given input or output node.
|
||||
"""
|
||||
if node is None:
|
||||
return None
|
||||
return {**self.args.input_buffers, **self.args.output_buffers}.get(
|
||||
node.get_name(), None
|
||||
)
|
||||
|
||||
def get_signature(self):
|
||||
return self.signature
|
||||
|
||||
@ -133,6 +123,7 @@ class ROCmTemplateKernel(ROCmKernel):
|
||||
"""
|
||||
wrapper = V.graph.wrapper_code
|
||||
|
||||
arg_types: list[Any]
|
||||
if V.graph.cpp_wrapper:
|
||||
# Make sure we initialize these kernels since they're exported as
|
||||
# C-style symbol names.
|
||||
|
@ -14,12 +14,14 @@ from collections import Counter
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
no_type_check,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
import sympy
|
||||
|
||||
@ -339,7 +341,10 @@ def constant_repr(value: Union[int, float]) -> str:
|
||||
return repr(value)
|
||||
|
||||
|
||||
class SIMDKernel(Kernel):
|
||||
CSEVariableType = TypeVar("CSEVariableType", bound=CSEVariable, default=CSEVariable)
|
||||
|
||||
|
||||
class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||
"""
|
||||
Common base class for Triton/Halide codegen which both use flattened indexing rather than loop nests.
|
||||
"""
|
||||
@ -456,7 +461,7 @@ class SIMDKernel(Kernel):
|
||||
numels[prefix],
|
||||
prefix,
|
||||
index,
|
||||
self,
|
||||
self, # type: ignore[arg-type]
|
||||
pid_cache=pid_cache,
|
||||
is_loop=is_reduction and not self.persistent_reduction,
|
||||
tensor_dim=tensor_dim,
|
||||
|
@ -102,6 +102,7 @@ if TYPE_CHECKING:
|
||||
from typing import TypeVar
|
||||
|
||||
from ..ir import IRNode
|
||||
from .simd_kernel_features import SIMDKernelFeatures
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
@ -1522,20 +1523,20 @@ class FixedTritonConfig:
|
||||
return item in self.config
|
||||
|
||||
|
||||
class TritonCSE(CSE):
|
||||
class TritonCSE(CSE[TritonCSEVariable, Union[str, tuple[str, str]]]):
|
||||
"""
|
||||
Subclasses CSE to apply the current load mask to the cache key to avoid CSEing
|
||||
variables across separate masked blocks.
|
||||
"""
|
||||
|
||||
def augment_key(self, cache_key: object) -> object:
|
||||
def augment_key(self, cache_key: str) -> Union[str, tuple[str, str]]:
|
||||
if mask := V.kernel._load_mask:
|
||||
return (cache_key, mask.name)
|
||||
else:
|
||||
return cache_key
|
||||
|
||||
|
||||
class TritonKernel(SIMDKernel):
|
||||
class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||
overrides = TritonKernelOverrides # type: ignore[assignment]
|
||||
helper_functions: HelperFunctions
|
||||
kexpr: Callable[[sympy.Expr], str] = texpr
|
||||
@ -2802,7 +2803,7 @@ class TritonKernel(SIMDKernel):
|
||||
# in the global namespace
|
||||
helper = IndentedBuffer()
|
||||
helper.writeline("@triton.jit")
|
||||
cse = CSE(prefix="", suffix="")
|
||||
cse = CSE()
|
||||
|
||||
args = [
|
||||
tuple(cse.namedvar(f"arg{i}_{n}", dtype=dtypes[n]) for n in range(num_args))
|
||||
@ -2954,7 +2955,8 @@ class TritonKernel(SIMDKernel):
|
||||
result_vars = partial_scan_vars
|
||||
|
||||
for result_var in result_vars:
|
||||
result_var.mask_vars = masks # type: ignore[attr-defined]
|
||||
assert isinstance(result_var, TritonCSEVariable)
|
||||
result_var.mask_vars = OrderedSet(masks)
|
||||
|
||||
return tuple(result_vars)
|
||||
|
||||
@ -4054,9 +4056,12 @@ class TritonScheduling(SIMDScheduling):
|
||||
store_cache()
|
||||
return ms, mod.__file__
|
||||
|
||||
def create_kernel_choices(
|
||||
self, kernel_features, kernel_args, kernel_kwargs
|
||||
) -> list[SIMDKernel]:
|
||||
def create_kernel_choices( # type: ignore[override]
|
||||
self,
|
||||
kernel_features: SIMDKernelFeatures,
|
||||
kernel_args: list[Any],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> list[TritonKernel]:
|
||||
is_scan = kernel_features.contains_op("scan")
|
||||
is_split_scan = is_scan and any(
|
||||
node.is_split_scan() for node in kernel_features.scheduler_nodes()
|
||||
@ -4090,11 +4095,11 @@ class TritonScheduling(SIMDScheduling):
|
||||
|
||||
def add_multi_kernel_choices(
|
||||
self,
|
||||
kernel: SIMDKernel,
|
||||
kernel: TritonKernel,
|
||||
kernel_args: list[Any],
|
||||
kernel_kwargs: dict[str, Any],
|
||||
) -> list[SIMDKernel]:
|
||||
kernels: list[SIMDKernel] = [kernel]
|
||||
) -> list[TritonKernel]:
|
||||
kernels: list[TritonKernel] = [kernel]
|
||||
if not config.triton.multi_kernel:
|
||||
return kernels
|
||||
|
||||
|
@ -1,11 +1,16 @@
|
||||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from typing import Union
|
||||
|
||||
import sympy
|
||||
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codegen.simd import IterationRangesRoot, prefix_is_reduction
|
||||
from torch._inductor.codegen.triton import triton_compute_type, TritonKernel
|
||||
from torch._inductor.codegen.triton import (
|
||||
triton_compute_type,
|
||||
TritonCSEVariable,
|
||||
TritonKernel,
|
||||
)
|
||||
from torch._inductor.runtime.triton_heuristics import split_scan_grid
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._sympy.functions import CeilDiv
|
||||
@ -68,7 +73,7 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
numel,
|
||||
prefix,
|
||||
grid_dim,
|
||||
self,
|
||||
self, # type: ignore[arg-type]
|
||||
pid_cache=pid_cache,
|
||||
is_loop=False,
|
||||
tensor_dim=tensor_dim,
|
||||
@ -115,6 +120,7 @@ class TritonSplitScanKernel(TritonKernel):
|
||||
)
|
||||
max_blocks = pointwise_numel * CeilDiv(reduction_numel, min_rblock)
|
||||
nbytes = scratch_nbytes_per_block * max_blocks
|
||||
scratch_base: Union[str, TritonCSEVariable]
|
||||
scratch_base, offset = self.args.workspace(nbytes=nbytes, zero_fill=True)
|
||||
if offset != 0:
|
||||
scratch_base = cse_load(f"{scratch_base} + {self.index_to_str(offset)}")
|
||||
|
@ -648,7 +648,7 @@ class TritonTemplateKernel(TritonKernel):
|
||||
self.body.writeline(str(scatter))
|
||||
|
||||
body_val = self.body.getvalue()
|
||||
self.cse.invalidate(OrderedSet[str]())
|
||||
self.cse.invalidate(OrderedSet())
|
||||
return body_val
|
||||
|
||||
def load_input(
|
||||
|
@ -28,6 +28,9 @@ from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Iterator,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
@ -2437,6 +2440,58 @@ def upcast_compute_type(dtype: torch.dtype) -> torch.dtype:
|
||||
return dtype
|
||||
|
||||
|
||||
KeyType = TypeVar("KeyType")
|
||||
ValType = TypeVar("ValType")
|
||||
|
||||
|
||||
class ScopedDict(MutableMapping[KeyType, ValType]):
|
||||
"""
|
||||
A dictionary-like object that allows for scoped updates. It maintains
|
||||
an original dictionary and a set of new items that can override
|
||||
the original items within the scope. The original dictionary is
|
||||
unmodified.
|
||||
"""
|
||||
|
||||
def __init__(self, original_dict: Mapping[KeyType, ValType]):
|
||||
self.original_dict = original_dict
|
||||
self.new_items: dict[KeyType, ValType] = {}
|
||||
|
||||
def __getitem__(self, key: KeyType) -> ValType:
|
||||
if key in self.new_items:
|
||||
return self.new_items[key]
|
||||
return self.original_dict[key]
|
||||
|
||||
def __setitem__(self, key: KeyType, value: ValType):
|
||||
self.new_items[key] = value
|
||||
|
||||
def __contains__(self, key: object) -> bool:
|
||||
return key in self.new_items or key in self.original_dict
|
||||
|
||||
def get(self, key: KeyType, default: Optional[ValType] = None) -> Optional[ValType]: # type: ignore[override]
|
||||
if key in self.new_items:
|
||||
return self.new_items[key]
|
||||
return self.original_dict.get(key, default)
|
||||
|
||||
def __len__(self) -> int:
|
||||
n = len(self.original_dict)
|
||||
for k in self.new_items:
|
||||
if k not in self.original_dict:
|
||||
n += 1
|
||||
return n
|
||||
|
||||
def __iter__(self) -> Iterator[KeyType]:
|
||||
yield from self.original_dict
|
||||
for k in self.new_items:
|
||||
if k not in self.original_dict:
|
||||
yield k
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.original_dict or self.new_items)
|
||||
|
||||
def __delitem__(self, key: KeyType) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@dataclass_transform(frozen_default=True)
|
||||
def ir_dataclass(cls=None, /, *, frozen: bool = True):
|
||||
def wrap(cls: _T) -> _T:
|
||||
|
Reference in New Issue
Block a user