mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119217 Approved by: https://github.com/peterbell10 ghstack dependencies: #119284, #120027
1749 lines
58 KiB
Python
1749 lines
58 KiB
Python
import contextlib
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import operator
|
|
import re
|
|
from itertools import chain
|
|
from typing import (
|
|
Any,
|
|
Callable,
|
|
ClassVar,
|
|
Dict,
|
|
List,
|
|
NamedTuple,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
TYPE_CHECKING,
|
|
Union,
|
|
)
|
|
|
|
import sympy
|
|
from sympy.printing.printer import Printer
|
|
|
|
import torch
|
|
import torch.fx
|
|
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
|
|
from .. import config, metrics
|
|
from ..utils import (
|
|
DeferredLineBase,
|
|
do_bench,
|
|
free_symbol_startswith,
|
|
IndentedBuffer,
|
|
sympy_dot,
|
|
sympy_index_symbol,
|
|
sympy_subs,
|
|
unique,
|
|
)
|
|
from ..virtualized import ops, OpsHandler, OpsValue, ReductionType, StoreMode, V
|
|
|
|
if TYPE_CHECKING:
|
|
from ..ir import TensorBox
|
|
|
|
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
|
|
|
|
|
def data_type_logger(msg):
|
|
if schedule_log.isEnabledFor(logging.DEBUG):
|
|
schedule_log.debug("Data type propagation: %s", msg)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class WorkspaceArg:
|
|
"""A temporary buffer used for a single kernel, then discarded.
|
|
|
|
Not registered as a traditional buffer since there are no users,
|
|
so it would be dead code eliminated.
|
|
"""
|
|
|
|
nbytes: sympy.Expr
|
|
zero_fill: bool
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class TensorArg:
|
|
name: str
|
|
buffer: str
|
|
dtype: torch.dtype
|
|
offset: sympy.Expr = sympy.Integer(0)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class SizeArg:
|
|
name: str
|
|
expr: sympy.Expr
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DeviceCodegen:
|
|
scheduling: type
|
|
wrapper_codegen: type
|
|
|
|
|
|
KernelArgType = Union[WorkspaceArg, TensorArg, SizeArg]
|
|
|
|
device_codegens: Dict[str, DeviceCodegen] = {}
|
|
|
|
|
|
class DeviceOpOverrides:
|
|
def import_get_raw_stream_as(self, name):
|
|
raise NotImplementedError()
|
|
|
|
def set_device(self, device_idx):
|
|
raise NotImplementedError()
|
|
|
|
def synchronize(self):
|
|
raise NotImplementedError()
|
|
|
|
def device_guard(self, device_idx):
|
|
raise NotImplementedError()
|
|
|
|
|
|
device_op_overrides_dict: Dict[str, DeviceOpOverrides] = {}
|
|
|
|
|
|
# The code generated by Inductor consists of two main parts: kernel code and wrapper code.
|
|
# For any new backend looking to integrate with Inductor, customization of these two main
|
|
# parts are necessary to generate its specific code.
|
|
#
|
|
# Kernel code generation is determined by different Scheduling. Consequently, a new
|
|
# backend needs to provide a custom Scheduling for its unique kernel code generation. Currently,
|
|
# CppScheduling and TritonScheduling serve the C++/OpenMP and Triton backends, respectively.
|
|
#
|
|
# For the Wrapper, Inductor provides a WrapperCodeGen class to generate the Python wrapper code
|
|
# that bridges kernels. This allows out-of-tree backends to inherit from WrapperCodeGen,
|
|
# and override specific member functions to create backend-specific Python wrapper code.
|
|
#
|
|
# Other classes, such as CppKernel and TritonKernel, used for code generation, typically form part
|
|
# of the logic for either Scheduling or WrapperCodeGen. So the Scheduling and WrapperCodeGen interfaces
|
|
# provide flexibility to the backend. A backend can choose to implement these classes from scratch,
|
|
# or reuse them by extending and overriding as necessary. And Inductor provides the registration API,
|
|
# register_backend_for_device, to equip a new backend at runtime.
|
|
#
|
|
# Intel has developed a new backend on top of Triton to support Intel GPUs, leveraging these interfaces.
|
|
# This backend can be used as a reference:
|
|
# https://github.com/intel/intel-extension-for-pytorch/blob/5dcc9d57e5422cf295e1a1ee97896d6b6a554a85/intel_extension_for_pytorch/_inductor/__init__.py#L9
|
|
def register_backend_for_device(
|
|
device: str, device_scheduling: type, device_wrapper_codegen: type
|
|
):
|
|
device_codegens[device] = DeviceCodegen(device_scheduling, device_wrapper_codegen)
|
|
|
|
|
|
def get_scheduling_for_device(device: str):
|
|
return device_codegens[device].scheduling if device in device_codegens else None
|
|
|
|
|
|
def get_wrapper_codegen_for_device(device: str):
|
|
return (
|
|
device_codegens[device].wrapper_codegen if device in device_codegens else None
|
|
)
|
|
|
|
|
|
def index_prevent_reordering(index: List[sympy.Expr], index_vars, sizes):
|
|
from ..ir import FlexibleLayout
|
|
|
|
# added contiguous index prevents reordering
|
|
return [*index, sympy_dot(index_vars, FlexibleLayout.contiguous_strides(sizes))]
|
|
|
|
|
|
def register_device_op_overrides(device: str, device_op_overrides: DeviceOpOverrides):
|
|
device_op_overrides_dict[device] = device_op_overrides
|
|
|
|
|
|
def get_device_op_overrides(device: str):
|
|
assert isinstance(device, str)
|
|
|
|
if not device_op_overrides_dict.keys():
|
|
from .cuda import device_op_overrides # noqa: F401
|
|
|
|
if device in device_op_overrides_dict.keys():
|
|
return device_op_overrides_dict[device]
|
|
|
|
return DeviceOpOverrides()
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def boolean_ops():
|
|
return (
|
|
"is_inf",
|
|
"is_nan",
|
|
"bitwise_xor",
|
|
"logical_not",
|
|
"signbit",
|
|
"le",
|
|
"lt",
|
|
"ge",
|
|
"gt",
|
|
"eq",
|
|
"ne",
|
|
)
|
|
|
|
|
|
DTYPE_TO_COMPUTATION_DTYPE = {
|
|
torch.bfloat16: torch.float,
|
|
torch.float16: torch.float,
|
|
**{
|
|
dtype: dtype
|
|
for dtype in [
|
|
torch.bool,
|
|
torch.float32,
|
|
torch.float64,
|
|
torch.int8,
|
|
torch.int16,
|
|
torch.int32,
|
|
torch.int64,
|
|
torch.uint8,
|
|
torch.uint16,
|
|
torch.uint32,
|
|
torch.uint64,
|
|
]
|
|
},
|
|
}
|
|
|
|
|
|
class DataTypePropagation:
|
|
def __init__(self, body) -> None:
|
|
self.body = body
|
|
self.graphs: Dict[Union[Callable[..., Any], str], Any] = {
|
|
"root": body.root_block.graph
|
|
}
|
|
for k, v in body.subblocks.items():
|
|
self.graphs[k] = v.graph
|
|
|
|
def deduce_node_dtype_by_inputs(self, node: torch.fx.Node):
|
|
inputs = node.all_input_nodes
|
|
input_nodes = [
|
|
n for n in inputs if isinstance(n, torch.fx.Node) and n.op != "placeholder"
|
|
]
|
|
if len(input_nodes) == 0:
|
|
return None
|
|
|
|
all_input_nodes_propogated = all(
|
|
OptimizationContext.key in n.meta
|
|
and n.meta[OptimizationContext.key].dtype is not None
|
|
for n in input_nodes
|
|
)
|
|
if not all_input_nodes_propogated:
|
|
return None
|
|
|
|
return functools.reduce(
|
|
torch.promote_types,
|
|
[n.meta[OptimizationContext.key].dtype for n in input_nodes],
|
|
)
|
|
|
|
def deduce_node_dtype_by_subgraph(self, node: torch.fx.Node):
|
|
sub_graph = self.graphs[node.target]
|
|
dtype = self.propagate_graph(sub_graph)
|
|
assert dtype
|
|
return dtype
|
|
|
|
def deduce_node_dtype(self, node: torch.fx.Node):
|
|
if node.target in boolean_ops():
|
|
return torch.bool
|
|
|
|
if node.op == "placeholder":
|
|
return None
|
|
|
|
if node.target == "output":
|
|
# we can infer output node if it only have 1 arg
|
|
if len(node.args) != 1:
|
|
return None
|
|
|
|
if node.target in (
|
|
"to_dtype",
|
|
"index_expr",
|
|
):
|
|
return node.args[-1]
|
|
|
|
if node.target in (
|
|
"rand",
|
|
"randn",
|
|
):
|
|
return torch.float
|
|
|
|
if node.target in (
|
|
"get_index",
|
|
"index_expr",
|
|
):
|
|
return torch.int64
|
|
|
|
if node.target in (
|
|
"load",
|
|
"store",
|
|
"store_reduction",
|
|
):
|
|
buf_name = node.args[1]
|
|
return V.graph.get_dtype(buf_name) # type: ignore[arg-type]
|
|
|
|
if node.target == operator.getitem:
|
|
return self.deduce_node_dtype(node.args[0]) # type: ignore[arg-type]
|
|
|
|
assert isinstance(node.target, str)
|
|
|
|
if node.target == "reduction":
|
|
return node.args[1]
|
|
|
|
if node.target == "constant":
|
|
return DTYPE_TO_COMPUTATION_DTYPE[node.args[-1]] # type: ignore[index]
|
|
|
|
if node.target.startswith("masked_subblock"):
|
|
return self.deduce_node_dtype_by_subgraph(node)
|
|
|
|
return self.deduce_node_dtype_by_inputs(node)
|
|
|
|
def propagate_graph(self, graph: torch.fx.Graph):
|
|
assert graph.nodes
|
|
graph_dtype = None
|
|
# For masked_subblock, we use output's dtype to represent
|
|
# the dtype of this subgraph. For other cases, graph_dtype
|
|
# might be None
|
|
for node in graph.nodes:
|
|
if OptimizationContext.key in node.meta:
|
|
opt_ctx = node.meta[OptimizationContext.key]
|
|
else:
|
|
opt_ctx = OptimizationContext()
|
|
|
|
opt_ctx.dtype = self.deduce_node_dtype(node)
|
|
node.meta[OptimizationContext.key] = opt_ctx
|
|
if node.target == "output":
|
|
graph_dtype = opt_ctx.dtype
|
|
return graph_dtype
|
|
|
|
def propagate(self):
|
|
self.propagate_graph(self.graphs["root"])
|
|
|
|
@classmethod
|
|
def propagate_loopbody(cls, body):
|
|
return cls(body).propagate()
|
|
|
|
@classmethod
|
|
def propagate_scheduler_node(cls, node):
|
|
from ..ir import LoopBody
|
|
from ..scheduler import SchedulerNode
|
|
|
|
assert isinstance(node, SchedulerNode)
|
|
assert isinstance(node._body, LoopBody)
|
|
DataTypePropagation.propagate_loopbody(node._body)
|
|
|
|
|
|
class ExprPrinter(Printer):
|
|
@staticmethod
|
|
def paren(string):
|
|
def all_in_parens(string):
|
|
if string[0] != "(" or len(string) < 2:
|
|
return False
|
|
count = 1
|
|
for i, char in enumerate(string[1:]):
|
|
if char == "(":
|
|
count += 1
|
|
elif char == ")":
|
|
count -= 1
|
|
if count == 0 and i != len(string) - 2:
|
|
return False
|
|
assert count == 0
|
|
return True
|
|
|
|
if (
|
|
isinstance(string, CSEVariable)
|
|
or re.match(r"^[a-z0-9_.]+$", string, re.I)
|
|
or re.match(r"^\([^)]*\)$", string, re.I)
|
|
or string == ""
|
|
):
|
|
return string
|
|
# don't put extra parens for strings that are already wrapped in parens
|
|
if all_in_parens(string):
|
|
return string
|
|
return f"({string})"
|
|
|
|
def _print_Infinity(self, expr):
|
|
return "math.inf"
|
|
|
|
def _print_NegativeInfinity(self, expr):
|
|
return "-math.inf"
|
|
|
|
def _print_Relational(self, expr):
|
|
return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Mul(self, expr):
|
|
return "*".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Add(self, expr):
|
|
return " + ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_Mod(self, expr):
|
|
return " % ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}")
|
|
|
|
def _print_CleanDiv(self, expr):
|
|
return self._print_FloorDiv(expr)
|
|
|
|
def _print_GreaterThan(self, expr):
|
|
# GreaterThan: >=
|
|
# StrictlyGreaterThan: >
|
|
# Go figure...
|
|
return " >= ".join(map(self.paren, map(self._print, expr.args)))
|
|
|
|
def _print_align(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"align({self._print(expr.args[0])})"
|
|
|
|
|
|
class PythonPrinter(ExprPrinter):
|
|
def _print_ModularIndexing(self, expr):
|
|
x, div, mod = expr.args
|
|
x = self.paren(self.doprint(x))
|
|
div = self.paren(self.doprint(div))
|
|
mod = self.paren(self.doprint(mod))
|
|
if div != "1":
|
|
x = f"({x} // {div})"
|
|
return f"{x} % {mod}"
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
x, div = expr.args
|
|
x = self.paren(self.doprint(x))
|
|
div = self.paren(self.doprint(div))
|
|
return f"({x} // {div})"
|
|
|
|
def _helper_sqrt(self, expr):
|
|
return f"math.sqrt({self._print(expr)})"
|
|
|
|
def _print_Pow(self, expr):
|
|
# Pow() confuses triton
|
|
base, exp = expr.args
|
|
# NB: Remember this is sizevar computation! You don't typically
|
|
# expect to have to do floating point computation including exponents
|
|
# in sizevar compute. Instead of adding support for floating
|
|
# point pow, you should make upstream retranslate the Sympy expression
|
|
# into Tensor expressions earlier and do that instead.
|
|
if exp == 0.5:
|
|
return self._helper_sqrt(base)
|
|
elif exp == -0.5:
|
|
return "1/" + self._helper_sqrt(base)
|
|
base = self._print(base)
|
|
assert exp == int(exp), exp
|
|
exp = int(exp)
|
|
if exp > 0:
|
|
return "*".join([self.paren(base)] * exp)
|
|
elif exp < 0:
|
|
return "1/" + self.paren("*".join([self.paren(base)] * abs(exp)))
|
|
else: # exp == 0
|
|
return "1"
|
|
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.floor({self._print(expr.args[0])})"
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.ceil({self._print(expr.args[0])})"
|
|
|
|
def _print_Abs(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"abs({self._print(expr.args[0])})"
|
|
|
|
def _print_Max(self, expr):
|
|
assert len(expr.args) >= 2
|
|
return f"max({', '.join(map(self._print, expr.args))})"
|
|
|
|
def _print_Min(self, expr):
|
|
assert len(expr.args) >= 2
|
|
return f"min({', '.join(map(self._print, expr.args))})"
|
|
|
|
def _print_cos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.cos({self._print(expr.args[0])})"
|
|
|
|
def _print_cosh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.cosh({self._print(expr.args[0])})"
|
|
|
|
def _print_acos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.acos({self._print(expr.args[0])})"
|
|
|
|
def _print_sin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.sin({self._print(expr.args[0])})"
|
|
|
|
def _print_sinh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.sinh({self._print(expr.args[0])})"
|
|
|
|
def _print_asin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.asin({self._print(expr.args[0])})"
|
|
|
|
def _print_tan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.tan({self._print(expr.args[0])})"
|
|
|
|
def _print_tanh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.tanh({self._print(expr.args[0])})"
|
|
|
|
def _print_atan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"math.atan({self._print(expr.args[0])})"
|
|
|
|
def _print_Round(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"round({self._print(expr.args[0])})"
|
|
|
|
def _print_RoundDecimal(self, expr):
|
|
assert len(expr.args) == 2
|
|
number, ndigits = expr.args
|
|
assert isinstance(ndigits, sympy.Integer)
|
|
return f"round({self._print(number)}, {ndigits})"
|
|
|
|
|
|
class OpOverrides:
|
|
def __init__(self, parent):
|
|
super().__init__()
|
|
self._parent = parent
|
|
|
|
def __getattr__(self, item):
|
|
return getattr(self._parent, item)
|
|
|
|
@staticmethod
|
|
def identity(value):
|
|
# used to trigger cse
|
|
return value
|
|
|
|
@staticmethod
|
|
def constant(value, dtype):
|
|
return repr(value)
|
|
|
|
@staticmethod
|
|
def reciprocal(x):
|
|
return ops.truediv("1", x)
|
|
|
|
@staticmethod
|
|
def square(x):
|
|
return ops.mul(x, x)
|
|
|
|
@staticmethod
|
|
def bitwise_not(x):
|
|
return f"~{ExprPrinter.paren(x)}"
|
|
|
|
@staticmethod
|
|
def logical_not(a):
|
|
return f"{ExprPrinter.paren(a)} == 0"
|
|
|
|
@staticmethod
|
|
def bitwise_and(x, y):
|
|
return f"{ExprPrinter.paren(x)} & {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_or(x, y):
|
|
return f"{ExprPrinter.paren(x)} | {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_xor(x, y):
|
|
return f"{ExprPrinter.paren(x)} ^ {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_left_shift(x, y):
|
|
return f"{ExprPrinter.paren(x)} << {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def bitwise_right_shift(x, y):
|
|
return f"{ExprPrinter.paren(x)} >> {ExprPrinter.paren(y)}"
|
|
|
|
@staticmethod
|
|
def remainder(a, b):
|
|
r = ops.mod(a, b)
|
|
return ops.where(f"(({r} != 0) & (({r} < 0) != ({b} < 0)))", ops.add(r, b), r)
|
|
|
|
@staticmethod
|
|
def load_seed(name, offset):
|
|
return ops.load(name, sympy.Integer(offset))
|
|
|
|
@classmethod
|
|
def _initialize_pointwise_overrides(cls, target):
|
|
assert target in {"triton", "cpp", "cppvec"}, target
|
|
|
|
def pointwise_factory_1(impl):
|
|
def func(x):
|
|
return impl.format(x=x)
|
|
|
|
return func
|
|
|
|
def pointwise_factory_2(impl):
|
|
def func(x, y):
|
|
return impl.format(x=x, y=y)
|
|
|
|
return func
|
|
|
|
for funcname, data in pointwise_overrides_data.items():
|
|
impl = getattr(data, target)
|
|
if isinstance(impl, str):
|
|
nof_args = 2 if "{y}" in impl else 1
|
|
# extend the following dictionary with factory
|
|
# functions for a specific number of arguments as
|
|
# needed:
|
|
factory = {1: pointwise_factory_1, 2: pointwise_factory_2}[nof_args]
|
|
setattr(cls, funcname, staticmethod(factory(impl)))
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OverridesData:
|
|
name: str
|
|
cpp: str
|
|
triton: Optional[str] = None # None when not impl in libdevice/triton
|
|
cppvec: Optional[str] = None # None when not impl in aten/.../vec
|
|
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
|
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
|
)
|
|
|
|
|
|
pointwise_overrides_data: Dict[str, OverridesData] = dict(
|
|
airy_ai=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="airy_ai_forward({x})",
|
|
name="special_airy_ai",
|
|
),
|
|
bessel_j0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="bessel_j0_forward({x})",
|
|
triton="tl.math.j0({x})",
|
|
name="special_bessel_j0",
|
|
),
|
|
bessel_j1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="bessel_j1_forward({x})",
|
|
triton="tl.math.j1({x})",
|
|
name="special_bessel_j1",
|
|
),
|
|
bessel_y0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="bessel_y0_forward({x})",
|
|
triton="tl.math.y0({x})",
|
|
name="special_bessel_y0",
|
|
),
|
|
bessel_y1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="bessel_y1_forward({x})",
|
|
triton="tl.math.y1({x})",
|
|
name="special_bessel_y1",
|
|
),
|
|
digamma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_digamma({x})",
|
|
cppvec="{x}.digamma()",
|
|
name="digamma",
|
|
),
|
|
# no cpp nor triton implementation for entr, it is defined as decomposition
|
|
# erf, erfc
|
|
erfcx=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_erfcx({x})",
|
|
triton="tl.math.erfcx({x})",
|
|
name="special_erfcx",
|
|
),
|
|
# erfinv, exp2, expit, gammaln
|
|
igamma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_igamma({x}, {y})",
|
|
name="igamma",
|
|
),
|
|
igammac=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_igammac({x}, {y})",
|
|
name="igammac",
|
|
),
|
|
gammainc=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_igamma({x}, {y})",
|
|
name="special_gammainc",
|
|
),
|
|
gammaincc=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_igammac({x}, {y})",
|
|
name="special_gammaincc",
|
|
),
|
|
i0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_i0({x})",
|
|
triton="tl.math.cyl_bessel_i0({x})",
|
|
cppvec="{x}.i0()",
|
|
name="i0",
|
|
),
|
|
i0e=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_i0e({x})",
|
|
cppvec="{x}.i0e()",
|
|
name="special_i0e",
|
|
),
|
|
i1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_i1({x})",
|
|
triton="tl.math.cyl_bessel_i1({x})",
|
|
name="special_i1",
|
|
),
|
|
i1e=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_i1e({x})",
|
|
name="special_i1e",
|
|
),
|
|
log_ndtr=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_log_ndtr({x})",
|
|
name="special_log_ndtr",
|
|
),
|
|
# logit
|
|
modified_bessel_i0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="modified_bessel_i0_forward({x})",
|
|
triton="tl.math.cyl_bessel_i0({x})",
|
|
name="special_modified_bessel_i0",
|
|
),
|
|
modified_bessel_i1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="modified_bessel_i1_forward({x})",
|
|
triton="tl.math.cyl_bessel_i1({x})",
|
|
name="special_modified_bessel_i1",
|
|
),
|
|
modified_bessel_k0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="modified_bessel_k0_forward({x})",
|
|
name="special_modified_bessel_k0",
|
|
),
|
|
modified_bessel_k1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="modified_bessel_k1_forward({x})",
|
|
name="special_modified_bessel_k1",
|
|
),
|
|
# multigamma
|
|
ndtr=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_ndtr({x})",
|
|
name="special_ndtr",
|
|
),
|
|
ndtri=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_ndtri({x})",
|
|
name="special_ndtri",
|
|
),
|
|
polygamma=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="calc_polygamma({y}, {x})",
|
|
name="polygamma",
|
|
),
|
|
# psi - alias to digamma
|
|
# round
|
|
scaled_modified_bessel_k0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="scaled_modified_bessel_k0_forward({x})",
|
|
name="special_scaled_modified_bessel_k0",
|
|
),
|
|
scaled_modified_bessel_k1=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="scaled_modified_bessel_k1_forward({x})",
|
|
name="special_scaled_modified_bessel_k1",
|
|
),
|
|
# sinc
|
|
spherical_bessel_j0=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="spherical_bessel_j0_forward({x})",
|
|
name="special_spherical_bessel_j0",
|
|
),
|
|
zeta=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="zeta({x}, {y})",
|
|
name="special_zeta",
|
|
),
|
|
chebyshev_polynomial_t=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="chebyshev_polynomial_t_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_t",
|
|
),
|
|
chebyshev_polynomial_u=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="chebyshev_polynomial_u_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_u",
|
|
),
|
|
chebyshev_polynomial_v=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="chebyshev_polynomial_v_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_v",
|
|
),
|
|
chebyshev_polynomial_w=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="chebyshev_polynomial_w_forward({x}, {y})",
|
|
name="special_chebyshev_polynomial_w",
|
|
),
|
|
legendre_polynomial_p=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="legendre_polynomial_p_forward({x}, {y})",
|
|
name="special_legendre_polynomial_p",
|
|
),
|
|
shifted_chebyshev_polynomial_t=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="shifted_chebyshev_polynomial_t_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_t",
|
|
),
|
|
shifted_chebyshev_polynomial_u=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="shifted_chebyshev_polynomial_u_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_u",
|
|
),
|
|
shifted_chebyshev_polynomial_v=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="shifted_chebyshev_polynomial_v_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_v",
|
|
),
|
|
shifted_chebyshev_polynomial_w=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="shifted_chebyshev_polynomial_w_forward({x}, {y})",
|
|
name="special_shifted_chebyshev_polynomial_w",
|
|
),
|
|
hermite_polynomial_h=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="hermite_polynomial_h_forward({x}, {y})",
|
|
name="special_hermite_polynomial_h",
|
|
),
|
|
hermite_polynomial_he=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="hermite_polynomial_he_forward({x}, {y})",
|
|
name="special_hermite_polynomial_he",
|
|
),
|
|
laguerre_polynomial_l=OverridesData(
|
|
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
|
|
cpp="laguerre_polynomial_l_forward({x}, {y})",
|
|
name="special_laguerre_polynomial_l",
|
|
),
|
|
)
|
|
|
|
|
|
# Use mypy to check protocol implemented correctly
|
|
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[str]:
|
|
return h
|
|
|
|
|
|
class DeferredLine(DeferredLineBase):
|
|
"""A line that can be 'unwritten' by adding name to V.graph.removed_buffers"""
|
|
|
|
def __init__(self, name, line):
|
|
super().__init__(line)
|
|
self.name = name
|
|
assert not isinstance(line, DeferredLineBase)
|
|
|
|
def __call__(self):
|
|
if all(
|
|
self.name not in x
|
|
for x in (
|
|
V.graph.removed_buffers,
|
|
V.kernel.removed_buffers,
|
|
V.graph.inplaced_to_remove,
|
|
V.kernel.inplaced_to_remove,
|
|
)
|
|
):
|
|
return self.line
|
|
return None
|
|
|
|
def _new_line(self, line):
|
|
return DeferredLine(self.name, line)
|
|
|
|
|
|
class BracesBuffer(IndentedBuffer):
|
|
def indent(self, offset=1):
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
for _ in range(offset):
|
|
self.writeline("{")
|
|
self._indent += 1
|
|
for _ in range(-offset):
|
|
self._indent -= 1
|
|
self.writeline("}")
|
|
yield
|
|
for _ in range(-offset):
|
|
self.writeline("{")
|
|
self._indent += 1
|
|
for _ in range(offset):
|
|
self._indent -= 1
|
|
self.writeline("}")
|
|
|
|
return ctx()
|
|
|
|
|
|
class InplacedBuffer(NamedTuple):
|
|
inner_name: str
|
|
other_names: List[str]
|
|
|
|
|
|
class KernelArgs:
|
|
@staticmethod
|
|
def _lookup(prefix, odict, name):
|
|
assert isinstance(name, (str, sympy.Symbol))
|
|
if name not in odict:
|
|
odict[name] = f"{prefix}{len(odict)}"
|
|
return odict[name]
|
|
|
|
def __init__(self, sizevars=None):
|
|
self.input_buffers = dict()
|
|
self.output_buffers = dict()
|
|
self.inplace_buffers = dict()
|
|
self.sizevars = sizevars or dict()
|
|
self.workspace_arg = None
|
|
|
|
def __repr__(self):
|
|
return "KernelArgs({})".format(
|
|
", ".join(
|
|
map(
|
|
repr,
|
|
[
|
|
self.input_buffers,
|
|
self.output_buffers,
|
|
self.inplace_buffers,
|
|
self.sizevars,
|
|
],
|
|
)
|
|
)
|
|
)
|
|
|
|
def _buffer_is_marked_removed(self, name):
|
|
return isinstance(name, str) and name.startswith("REMOVED")
|
|
|
|
def input(self, name):
|
|
if V.graph.scheduler:
|
|
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
|
assert name not in V.graph.removed_buffers, name
|
|
if name in self.output_buffers:
|
|
return self.output_buffers[name]
|
|
if name in self.inplace_buffers:
|
|
return self.inplace_buffers[name].inner_name
|
|
if name.startswith("seed"):
|
|
return self._lookup("seed", self.input_buffers, name)
|
|
return self._lookup("in_ptr", self.input_buffers, name)
|
|
|
|
def output(self, name):
|
|
if V.graph.scheduler:
|
|
name = V.graph.scheduler.mutation_real_name.get(name, name)
|
|
assert name not in V.graph.removed_buffers, name
|
|
if name in self.inplace_buffers:
|
|
return self.inplace_buffers[name].inner_name
|
|
return self._lookup("out_ptr", self.output_buffers, name)
|
|
|
|
def make_inplace(self, input_name, output_name):
|
|
assert output_name not in self.inplace_buffers
|
|
if input_name in self.inplace_buffers:
|
|
buf = self.inplace_buffers[input_name]
|
|
buf.other_names.append(output_name)
|
|
self.inplace_buffers[output_name] = buf
|
|
else:
|
|
buf = InplacedBuffer(
|
|
f"in_out_ptr{len(unique(self.inplace_buffers.values()))}",
|
|
[input_name, output_name],
|
|
)
|
|
self.inplace_buffers[input_name] = buf
|
|
self.inplace_buffers[output_name] = buf
|
|
|
|
def workspace(self, nbytes: sympy.Expr, zero_fill: bool):
|
|
if self.workspace_arg is None:
|
|
self.workspace_arg = WorkspaceArg(nbytes, zero_fill)
|
|
return "ws_ptr", 0
|
|
|
|
offset = self.workspace_arg.nbytes
|
|
zero_fill = zero_fill or self.workspace_arg.zero_fill
|
|
self.workspace_arg = WorkspaceArg(offset + nbytes, zero_fill)
|
|
return "ws_ptr", offset
|
|
|
|
def seed_offset(self, name, value):
|
|
if value in self.sizevars:
|
|
return self.sizevars[value]
|
|
if name in self.sizevars.values():
|
|
name = (
|
|
f"{name}{sum(1 for v in self.sizevars.values() if v.startswith(name))}"
|
|
)
|
|
self.sizevars[value] = name
|
|
return name
|
|
|
|
def size(self, name):
|
|
if str(name) == "seed":
|
|
self.sizevars["seed"] = "seed"
|
|
return "seed"
|
|
return self._lookup("ks", self.sizevars, name)
|
|
|
|
def call_names(self):
|
|
return chain(
|
|
self.input_buffers.keys(), self.output_buffers.keys(), self.sizevars.keys()
|
|
)
|
|
|
|
def wrap_ptr_arg(self, buf, dtype):
|
|
return buf
|
|
|
|
def wrap_size_arg(self, size):
|
|
return str(size)
|
|
|
|
def cpp_argdefs(self):
|
|
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
|
|
|
|
call_args = []
|
|
arg_defs = []
|
|
arg_types = []
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
outer = inplaced.other_names[-1]
|
|
inner = inplaced.inner_name
|
|
dtype = V.graph.get_dtype(outer)
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"{cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"{cpp_dtype}*")
|
|
for outer, inner in self.input_buffers.items():
|
|
if outer in self.inplace_buffers:
|
|
continue
|
|
dtype = V.graph.get_dtype(outer)
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"const {cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"const {cpp_dtype}*")
|
|
for outer, inner in self.output_buffers.items():
|
|
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
|
continue
|
|
dtype = V.graph.get_dtype(outer)
|
|
cpp_dtype = DTYPE_TO_CPP[dtype]
|
|
arg_defs.append(f"{cpp_dtype}* {inner}")
|
|
call_args.append(self.wrap_ptr_arg(outer, dtype))
|
|
arg_types.append(f"{cpp_dtype}*")
|
|
for outer, inner in self.sizevars.items():
|
|
arg_defs.append(f"const {INDEX_TYPE} {inner}")
|
|
call_args.append(self.wrap_size_arg(outer))
|
|
arg_types.append(f"const {INDEX_TYPE}")
|
|
if V.graph.wrapper_code:
|
|
V.graph.wrapper_code.ensure_size_computed(outer)
|
|
assert self.workspace_arg is None, "Workspace not supported on CPU "
|
|
return arg_defs, call_args, arg_types
|
|
|
|
def python_argdefs(self):
|
|
arg_defs = []
|
|
call_args = []
|
|
precompile_args: List[Union[TensorArg, SizeArg, WorkspaceArg]] = []
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
arg_defs.append(inplaced.inner_name)
|
|
call_args.append(inplaced.other_names[-1])
|
|
precompile_args.append(
|
|
TensorArg(
|
|
name=inplaced.inner_name,
|
|
buffer=inplaced.other_names[-1],
|
|
dtype=V.graph.get_dtype(inplaced.other_names[-1]),
|
|
)
|
|
)
|
|
for outer, inner in chain(
|
|
self.input_buffers.items(), self.output_buffers.items()
|
|
):
|
|
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
|
continue
|
|
arg_defs.append(inner)
|
|
call_args.append(outer)
|
|
precompile_args.append(
|
|
TensorArg(
|
|
name=inner,
|
|
buffer=outer,
|
|
dtype=V.graph.get_dtype(outer),
|
|
)
|
|
)
|
|
for outer, inner in self.sizevars.items():
|
|
arg_defs.append(inner)
|
|
call_args.append(outer)
|
|
precompile_args.append(SizeArg(inner, outer))
|
|
if V.graph.wrapper_code:
|
|
V.graph.wrapper_code.ensure_size_computed(outer)
|
|
if self.workspace_arg is not None:
|
|
arg_defs.append("ws_ptr")
|
|
call_args.append("workspace")
|
|
precompile_args.append(self.workspace_arg)
|
|
|
|
return arg_defs, call_args, precompile_args
|
|
|
|
def aliases(self):
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
for other in inplaced.other_names:
|
|
if (
|
|
other in V.graph.inplaced_to_remove
|
|
or other in V.kernel.inplaced_to_remove
|
|
):
|
|
continue
|
|
if other in self.input_buffers:
|
|
yield self.input_buffers[other], inplaced.inner_name
|
|
if other in self.output_buffers:
|
|
yield self.output_buffers[other], inplaced.inner_name
|
|
|
|
def is_removed(self, name):
|
|
def _is_removed(name, buffers):
|
|
return name not in buffers or self._buffer_is_marked_removed(buffers[name])
|
|
|
|
return _is_removed(name, self.output_buffers) and _is_removed(
|
|
name, self.inplace_buffers
|
|
)
|
|
|
|
# Includes inplace buffers, excludes removed buffers. Essentially,
|
|
# after you do a call into this kernel, which buffers actually contain
|
|
# updated data? Modeled off of python_argdefs.
|
|
def live_output_buffers(self):
|
|
live_outs = set()
|
|
for inplaced in unique(self.inplace_buffers.values()):
|
|
if self._buffer_is_marked_removed(inplaced):
|
|
continue
|
|
live_outs.add(inplaced.other_names[-1])
|
|
for outer, inner in self.output_buffers.items():
|
|
if outer in self.inplace_buffers or self._buffer_is_marked_removed(inner):
|
|
continue
|
|
live_outs.add(outer)
|
|
return live_outs
|
|
|
|
|
|
class CSEVariable:
|
|
"""A CSEVariable is just a name for an expression but it is useful to be able to annotate them on a backend dependent basis.
|
|
To do so, the backends can simply overload `Kernel.create_cse_var`
|
|
The "CSEVariable.update_on_args" method gives you a hook for annotations
|
|
See example of TritonCSEVariable in triton.py
|
|
"""
|
|
|
|
def __init__(self, name, bounds: ValueRanges[Any]):
|
|
assert isinstance(bounds, ValueRanges)
|
|
self.name = name
|
|
self.bounds = bounds
|
|
|
|
def __str__(self):
|
|
return self.name
|
|
|
|
def __hash__(self) -> int:
|
|
return hash(self.name)
|
|
|
|
def __eq__(self, other) -> bool:
|
|
return type(other) == type(self) and other.name == self.name
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
pass
|
|
|
|
|
|
class CppWrapperKernelArgs(KernelArgs):
|
|
def wrap_ptr_arg(self, buf, dtype):
|
|
from .cpp import DTYPE_TO_CPP
|
|
|
|
if config.abi_compatible:
|
|
# In the abi_compatible model, we just return the buf here.
|
|
# We will form correct call args later in wrapper.generate_kernel_all.
|
|
return buf
|
|
else:
|
|
return f"({DTYPE_TO_CPP[dtype]}*)({buf}.data_ptr())"
|
|
|
|
def wrap_size_arg(self, size):
|
|
return f"{size}"
|
|
|
|
|
|
class CSE:
|
|
"""Common subexpression elimination"""
|
|
|
|
def __init__(
|
|
self,
|
|
prefix="",
|
|
suffix="",
|
|
name_prefix="tmp",
|
|
iter_buffers=None,
|
|
store_cache=None,
|
|
reduction_cache=None,
|
|
varname_map=None,
|
|
):
|
|
self.prefix = prefix
|
|
self.suffix = suffix
|
|
self.cache = {}
|
|
self.name_prefix = name_prefix
|
|
self.store_cache = store_cache or {}
|
|
self.reduction_cache = reduction_cache or {}
|
|
self.iter_buffer_ids = iter_buffers or itertools.count()
|
|
self.invalidated_stores = set()
|
|
self.varname_map = varname_map or {}
|
|
|
|
def invalidate(self, keep_vars: Set[str]):
|
|
for name, tmp in list(self.store_cache.items()):
|
|
if tmp not in keep_vars:
|
|
del self.store_cache[name]
|
|
self.invalidated_stores.add(name)
|
|
self.cache = {k: v for k, v in self.cache.items() if v in keep_vars}
|
|
|
|
def clone(self):
|
|
# Note(fdrocha): reduction_cache is not being cloned, not sure if this is intentional
|
|
return CSE(
|
|
prefix=self.prefix,
|
|
suffix=self.suffix,
|
|
name_prefix=self.name_prefix,
|
|
iter_buffers=self.iter_buffer_ids,
|
|
store_cache=self.store_cache,
|
|
varname_map=self.varname_map,
|
|
)
|
|
|
|
def generate(
|
|
self,
|
|
buffer: IndentedBuffer,
|
|
expr: Union[str, CSEVariable, OpsValue, IndentedBuffer],
|
|
*,
|
|
bounds: ValueRanges[Any] = ValueRanges.unknown(),
|
|
write=True,
|
|
assignment=True,
|
|
) -> CSEVariable:
|
|
if isinstance(expr, OpsValue):
|
|
expr = expr.value
|
|
|
|
assert isinstance(expr, (str, CSEVariable, IndentedBuffer)), type(expr)
|
|
assert write or assignment
|
|
if isinstance(expr, CSEVariable):
|
|
# If the expressions were always created with all the information, we could
|
|
# assert expr.bounds == bounds, but sometimes the expression is created
|
|
# with the loose ValueRanges.unknown(), so we need to tighten the bounds
|
|
expr.bounds = expr.bounds.tighten(bounds)
|
|
return expr
|
|
cache_key = expr.getvalue() if isinstance(expr, IndentedBuffer) else expr
|
|
var = self.cache.get(cache_key, None)
|
|
if not var:
|
|
var = self.newvar(bounds) if assignment else None
|
|
self.cache[cache_key] = var
|
|
if write:
|
|
if V.kernel.current_node:
|
|
V.kernel.current_node.codegen_originating_info(
|
|
buffer, only_once=True
|
|
)
|
|
if isinstance(expr, IndentedBuffer):
|
|
if assignment:
|
|
buffer.writeline(f"{self.prefix}{var} =")
|
|
buffer.splice(expr)
|
|
buffer.writeline(self.suffix)
|
|
else:
|
|
if assignment:
|
|
line = f"{self.prefix}{var} = {expr}{self.suffix}"
|
|
else:
|
|
line = f"{expr}{self.suffix}"
|
|
buffer.writeline(line)
|
|
else:
|
|
var.bounds = var.bounds.tighten(bounds)
|
|
|
|
return var
|
|
|
|
def newvar(self, bounds: ValueRanges[Any] = ValueRanges.unknown()) -> CSEVariable:
|
|
var_name = f"{self.name_prefix}{next(self.iter_buffer_ids)}"
|
|
var = V.kernel.create_cse_var(var_name, bounds)
|
|
self.varname_map[var_name] = var
|
|
return var
|
|
|
|
|
|
class IndirectAssertLine(DeferredLineBase):
|
|
def __init__(self, line, assert_fn, var, mask, size_map):
|
|
self.var = var
|
|
self.mask = mask
|
|
self.line = line
|
|
self.assert_fn = assert_fn
|
|
self.size_map = size_map
|
|
|
|
def __call__(self):
|
|
size, size_str = self.size_map[(self.var, self.mask)]
|
|
|
|
# We assert if we've not been able to prove the bound
|
|
assert_min = (self.var.bounds.lower >= 0) != sympy.true
|
|
assert_max = (self.var.bounds.upper < size) != sympy.true
|
|
|
|
# FooBar interview question
|
|
if not (assert_min or assert_max):
|
|
return None
|
|
elif assert_min and assert_max:
|
|
# The conditions need to be in parens because of Python's operator precedence.
|
|
# It'd be less error-prone to use and/or/not, which is suported by triton
|
|
cond = f"(0 <= {self.var}) & ({self.var} < {size_str})"
|
|
cond_print = f"0 <= {self.var} < {size_str}"
|
|
elif assert_min:
|
|
cond = f"0 <= {self.var}"
|
|
cond_print = cond
|
|
else:
|
|
assert assert_max
|
|
cond = f"{self.var} < {size_str}"
|
|
cond_print = cond
|
|
|
|
if self.mask:
|
|
cond = f"({cond}) | ~{self.mask}"
|
|
return self.line.format(
|
|
assert_fn=self.assert_fn, cond=cond, cond_print=cond_print
|
|
)
|
|
|
|
def _new_line(self, line):
|
|
return IndirectAssertLine(
|
|
line, self.assert_fn, self.var, self.mask, self.size_map
|
|
)
|
|
|
|
|
|
class CodeGen:
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.exit_stack = contextlib.ExitStack()
|
|
|
|
def __enter__(self):
|
|
self.exit_stack.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
self.exit_stack.__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
|
|
class Kernel(CodeGen):
|
|
newvar_prefix = ""
|
|
suffix = ""
|
|
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
|
# TODO: these look dead, but with all the getattr it's hard to tell...
|
|
load_format: None = None
|
|
store_format: None = None
|
|
|
|
def __init__(self, args=None, increase_kernel_count=True):
|
|
super().__init__()
|
|
if increase_kernel_count:
|
|
metrics.generated_kernel_count += 1
|
|
self.args = args or KernelArgs()
|
|
self.loads = IndentedBuffer()
|
|
self.compute = IndentedBuffer()
|
|
self.stores = IndentedBuffer()
|
|
self.cse: CSE = CSE(self.newvar_prefix, self.suffix)
|
|
self.must_keep_buffers = set()
|
|
self.store_buffer_names = set()
|
|
self._load_mask = None
|
|
# set in set_current_node
|
|
self.current_node = None
|
|
self.node_to_bounds: Optional[Dict[torch.fx.Node, ValueRanges[Any]]] = None
|
|
# Upper bounds for indirect_indexing and their str representation
|
|
# NB: None, None is never stored in map, but it is the assumed
|
|
# "not set" value for the dict
|
|
self.indirect_max_sizes: Dict[
|
|
Tuple[CSEVariable, str], Union[Tuple[sympy.Expr, str], Tuple[None, None]]
|
|
] = {}
|
|
|
|
self.removed_buffers = set()
|
|
self.inplaced_to_remove = set()
|
|
|
|
# key: the buffer to write
|
|
# value: the buffer to read and whose memory can be reused for
|
|
# the buffer specified by key
|
|
self.inplace_update_buffers = dict()
|
|
# Set minimum number of elements processed per thread.
|
|
self.min_elem_per_thread = 1
|
|
self.kernel_name = None
|
|
|
|
@contextlib.contextmanager
|
|
def set_current_node(self, node):
|
|
prior = self.current_node
|
|
self.current_node = node
|
|
self.node_to_bounds = node._body.bounds().get_bounds()
|
|
try:
|
|
yield
|
|
finally:
|
|
self.current_node = prior
|
|
|
|
@contextlib.contextmanager
|
|
def swap_buffers(self, lb, cb=None, sb=None):
|
|
if cb is None:
|
|
cb = lb
|
|
loads = self.loads
|
|
compute = self.compute
|
|
stores = self.stores
|
|
cse = self.cse
|
|
self.loads = lb
|
|
self.compute = cb
|
|
self.stores = sb
|
|
self.cse = cse.clone()
|
|
try:
|
|
yield
|
|
finally:
|
|
self.loads = loads
|
|
self.compute = compute
|
|
self.stores = stores
|
|
self.cse = cse
|
|
|
|
def load(self, name: str, index: sympy.Expr) -> CSEVariable:
|
|
raise NotImplementedError()
|
|
|
|
def indirect_load(self, name: str, index: sympy.Expr):
|
|
"""A load the depends on an index we have read"""
|
|
prior = self.loads
|
|
try:
|
|
# put the load in the compute section as it might have deps
|
|
self.loads = self.compute
|
|
return self.load(name, index)
|
|
finally:
|
|
self.loads = prior
|
|
|
|
def store_reduction(self, name: str, index: sympy.Expr, value: CSEVariable):
|
|
raise NotImplementedError()
|
|
|
|
def store(
|
|
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
raise NotImplementedError()
|
|
|
|
def reduction(
|
|
self,
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
|
raise NotImplementedError()
|
|
|
|
def scan(
|
|
self,
|
|
dtype: torch.dtype,
|
|
combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
|
|
value: CSEVariable,
|
|
init: int,
|
|
) -> CSEVariable:
|
|
raise NotImplementedError()
|
|
|
|
def bucketize(
|
|
self,
|
|
values: CSEVariable,
|
|
offsets_name: str,
|
|
offsets_size: sympy.Expr,
|
|
indexing_dtype: torch.dtype,
|
|
right: bool,
|
|
) -> CSEVariable:
|
|
"""
|
|
See [Note: Inductor bucketize op]
|
|
"""
|
|
raise NotImplementedError()
|
|
|
|
@property
|
|
def assert_function(self) -> str:
|
|
raise NotImplementedError()
|
|
|
|
def index_to_str(self, index: sympy.Expr) -> str:
|
|
raise NotImplementedError()
|
|
|
|
def __enter__(self):
|
|
# TODO: hoist this to top level
|
|
class CSEProxy:
|
|
self.name = "CSEProxy"
|
|
|
|
@staticmethod
|
|
def __getattr__(name: str) -> Callable[..., CSEVariable]: # type: ignore[misc]
|
|
def inner(*args, **kwargs):
|
|
# TritonTemplateKernel has no current_node
|
|
buf_bounds = ValueRanges.unknown()
|
|
if hasattr(V.interpreter, "current_node"):
|
|
fx_node = V.interpreter.current_node
|
|
assert isinstance(self.node_to_bounds, dict)
|
|
buf_bounds = self.node_to_bounds.get(
|
|
fx_node, ValueRanges.unknown()
|
|
)
|
|
|
|
value = getattr(parent_handler, name)(*args, **kwargs) # type: ignore[has-type]
|
|
|
|
def do_cse(v):
|
|
csevar = self.cse.generate(self.compute, v, bounds=buf_bounds)
|
|
csevar.update_on_args(name, args, kwargs)
|
|
return csevar
|
|
|
|
return pytree.tree_map(do_cse, value)
|
|
|
|
return inner
|
|
|
|
@staticmethod
|
|
def indirect_indexing(
|
|
var: CSEVariable, size: sympy.Expr, check: bool = True
|
|
):
|
|
# Skip CSE since this doesn't return an expression
|
|
|
|
if var.bounds.lower < 0: # type: ignore[operator]
|
|
new_bounds = ValueRanges.unknown()
|
|
if var.bounds != ValueRanges.unknown() and isinstance(
|
|
size, sympy.Number
|
|
):
|
|
# Take the negative part of the bound and add size to it
|
|
# Then take union of that and the positive part
|
|
# This is a tighter bound than that of a generic ops.where, as we have info on the cond
|
|
neg = var.bounds & ValueRanges(-sympy.oo, -1)
|
|
new_bounds = ValueRanges(neg.lower + size, neg.upper + size)
|
|
# We don't have a good way of representing the empty range
|
|
if var.bounds.upper >= 0: # type: ignore[operator]
|
|
pos = var.bounds & ValueRanges(0, sympy.oo)
|
|
new_bounds = new_bounds | pos
|
|
|
|
stm = ops.add(var, self.rename_indexing(size))
|
|
# Mixed negative and non-negative
|
|
if var.bounds.upper >= 0: # type: ignore[operator]
|
|
lt = ops.lt(var, "0")
|
|
stm = ops.where(lt, stm, var)
|
|
new_var = self.cse.generate(self.compute, stm, bounds=new_bounds)
|
|
|
|
new_var.update_on_args("index_wrap", (var,), {})
|
|
var = new_var
|
|
|
|
if self.generate_assert(check):
|
|
mask = self.load_mask(var)
|
|
|
|
# An assertion line may have been written already, if so just
|
|
# update the max size.
|
|
map_key = (var, mask)
|
|
existing_size, _ = self.indirect_max_sizes.get(
|
|
map_key, (None, None)
|
|
)
|
|
if existing_size is not None:
|
|
size = sympy.Min(size, existing_size)
|
|
else:
|
|
line = (
|
|
'{assert_fn}({cond}, "index out of bounds: {cond_print}")'
|
|
)
|
|
self.compute.writeline(
|
|
IndirectAssertLine(
|
|
line,
|
|
self.assert_function,
|
|
var,
|
|
mask,
|
|
self.indirect_max_sizes,
|
|
)
|
|
)
|
|
|
|
self.indirect_max_sizes[map_key] = (size, self.index_to_str(size))
|
|
return sympy_index_symbol(str(var))
|
|
|
|
@staticmethod
|
|
def load(name: str, index: sympy.Expr) -> CSEVariable:
|
|
if name in self.cse.invalidated_stores:
|
|
# A load from an invalidated store requires us to
|
|
# keep the actual buffer around
|
|
V.kernel.must_keep_buffers.add(name)
|
|
if free_symbol_startswith(index, "tmp"):
|
|
return self.indirect_load(name, index)
|
|
store_cache = self.cse.store_cache
|
|
if name in store_cache:
|
|
return store_cache[name]
|
|
return self.load(name, index)
|
|
|
|
@staticmethod
|
|
def store(
|
|
name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
self.store_buffer_names.add(name)
|
|
if mode is None:
|
|
self.cse.store_cache[name] = value
|
|
if self.current_node:
|
|
for other_name in self.current_node.get_mutations():
|
|
self.cse.store_cache[other_name] = value
|
|
if name not in V.graph.removed_buffers:
|
|
return self.store(name, index, value, mode=mode)
|
|
else:
|
|
return None # type: ignore[return-value]
|
|
|
|
@staticmethod
|
|
def store_reduction(name: str, index: sympy.Expr, value: CSEVariable):
|
|
self.store_buffer_names.add(name)
|
|
self.cse.store_cache[name] = value
|
|
if self.current_node:
|
|
for other_name in self.current_node.get_mutations():
|
|
self.cse.store_cache[other_name] = value
|
|
|
|
if name not in V.graph.removed_buffers:
|
|
return self.store_reduction(name, index, value)
|
|
|
|
@staticmethod
|
|
def reduction(
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, Tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, Tuple[CSEVariable, ...]]:
|
|
return self.reduction(dtype, src_dtype, reduction_type, value)
|
|
|
|
@staticmethod
|
|
def scan(
|
|
dtype: torch.dtype,
|
|
combine_fn: Callable[[CSEVariable, CSEVariable], CSEVariable],
|
|
value: CSEVariable,
|
|
init: int,
|
|
) -> CSEVariable:
|
|
return self.scan(dtype, combine_fn, value, init)
|
|
|
|
@staticmethod
|
|
def bucketize(
|
|
values: CSEVariable,
|
|
offsets_name: str,
|
|
offsets_size: sympy.Expr,
|
|
indexing_dtype: torch.dtype,
|
|
right: bool,
|
|
) -> CSEVariable:
|
|
"""
|
|
[Note: Inductor bucketize op]
|
|
|
|
Given values (tensor) and offsets_name (reference to the name of a 1D
|
|
tensor), calculate the bucket that each value belongs to.
|
|
|
|
e.g. for values [-1, 0, 1, 2, 3, 4, 5, 9], offsets [0, 4, 4, 8], right=True
|
|
return = [ 0, 1, 1, 1, 1, 3, 3, 4].
|
|
|
|
When right == False, bucket i refers to range (offsets[i], offsets[i+1]].
|
|
When right == True, bucket i refers to range [offsets[i], offsets[i+1]).
|
|
|
|
Offsets must be non-decreasing or the result is undefined.
|
|
"""
|
|
return self.bucketize(
|
|
values, offsets_name, offsets_size, indexing_dtype, right
|
|
)
|
|
|
|
# Use mypy to check protocol implemented correctly
|
|
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
|
return h
|
|
|
|
super().__enter__()
|
|
assert self.overrides
|
|
parent_handler = self.overrides(V.get_ops_handler())
|
|
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy()))
|
|
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
"""
|
|
Note that V.graph.scheduler can be None when codegening triton template
|
|
kernels.
|
|
"""
|
|
if V.graph.scheduler:
|
|
V.graph.scheduler.remove_kernel_local_buffers()
|
|
super().__exit__(exc_type, exc_val, exc_tb)
|
|
|
|
def generate_assert(self, check):
|
|
return (check or config.debug_index_asserts) and config.assert_indirect_indexing
|
|
|
|
def load_mask(self, var) -> str:
|
|
# only the triton kernel requires mask
|
|
return ""
|
|
|
|
def rename_indexing(self, index) -> sympy.Expr:
|
|
# adds the necessary kernel args for index expressions
|
|
# and renames variables in index expressions to kernel arg names
|
|
if isinstance(index, (list, tuple)):
|
|
return [self.rename_indexing(x) for x in index] # type: ignore[return-value]
|
|
index = V.graph.sizevars.simplify(index)
|
|
sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name)
|
|
replacements = {
|
|
x: self.args.size(x)
|
|
for x in sorted_symbols
|
|
if x.name.startswith(("s", "u", "ps"))
|
|
or (x.name.startswith("i") and not x.name.startswith("idx"))
|
|
}
|
|
return sympy_subs(index, replacements)
|
|
|
|
def create_cse_var(self, *args, **kwargs):
|
|
return CSEVariable(*args, **kwargs)
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class OptimizationContext:
|
|
key: ClassVar[str] = "opt_ctx"
|
|
|
|
# Load value as mask
|
|
is_load_as_mask: bool = False
|
|
|
|
dtype: Optional[torch.dtype] = None
|
|
ops_name: str = ""
|
|
|
|
# Load uint8/int8 value as float32
|
|
is_load_int8_as_float: bool = False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def jinja2_env():
|
|
try:
|
|
import jinja2
|
|
|
|
return jinja2.Environment(
|
|
undefined=jinja2.StrictUndefined,
|
|
)
|
|
except ImportError:
|
|
return None
|
|
|
|
|
|
class ChoiceCaller:
|
|
"""
|
|
Represents a possible choice used in autotune_process.py.
|
|
During autotuning, self.benchmark() is first called to get benchmark result,
|
|
and if this choice is selected, self.output_node() is called to get the output_node.
|
|
|
|
Children classes: TritonTemplateCaller, CUDATemplateCaller.
|
|
"""
|
|
|
|
def __init__(self, name, input_nodes, layout):
|
|
super().__init__()
|
|
self.name = name
|
|
self.layout = layout
|
|
self.input_nodes = input_nodes
|
|
|
|
def benchmark(self, *args, out) -> float:
|
|
algo = self.to_callable()
|
|
return do_bench(lambda: algo(*args, out=out))
|
|
|
|
def call_name(self) -> str:
|
|
raise NotImplementedError()
|
|
|
|
def to_callable(self):
|
|
raise NotImplementedError()
|
|
|
|
def hash_key(self) -> str:
|
|
raise NotImplementedError()
|
|
|
|
def output_node(self) -> "TensorBox":
|
|
raise NotImplementedError()
|
|
|
|
|
|
class KernelTemplate:
|
|
"""
|
|
Base class for defining kernel templates.
|
|
|
|
Children classes: TritonTemplate, CUDATemplate
|
|
"""
|
|
|
|
@staticmethod
|
|
def _template_from_string(source):
|
|
env = jinja2_env()
|
|
if env is not None:
|
|
return env.from_string(source)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _fake_get_dtype(fake_out):
|
|
_get_dtype_real = V.graph.get_dtype
|
|
|
|
def get_dtype(name):
|
|
if name == fake_out.get_name():
|
|
return fake_out.get_dtype()
|
|
return _get_dtype_real(name)
|
|
|
|
return get_dtype
|
|
|
|
def __init__(self, name: str):
|
|
self.name = name
|
|
|
|
def maybe_append_choice(self, choices, **kwargs):
|
|
"""
|
|
Maybe generates a new ChoiceCaller and appends it into existing choices.
|
|
|
|
choices: A list of ChoiceCallers.
|
|
kwargs: Additional kwargs to be passed to self.generate() to generate a new ChoiceCaller.
|
|
"""
|
|
|
|
try:
|
|
choices.append(self.generate(**kwargs))
|
|
except NotImplementedError:
|
|
pass
|
|
|
|
def generate(self, **kwargs) -> ChoiceCaller:
|
|
"""
|
|
Generates a ChoiceCaller instance from the given arguments.
|
|
"""
|
|
|
|
raise NotImplementedError()
|