mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: delete lines in the pyrefly.toml file from the project-excludes field step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199 after: INFO 0 errors (6,884 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/165659 Approved by: https://github.com/oulgen
1723 lines
62 KiB
Python
1723 lines
62 KiB
Python
# mypy: allow-untyped-defs
|
|
from __future__ import annotations
|
|
|
|
import dataclasses
|
|
import functools
|
|
import itertools
|
|
import logging
|
|
import re
|
|
from collections import defaultdict
|
|
from math import inf
|
|
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, Union
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch._logging
|
|
|
|
from ..._prims_common import is_integer_dtype
|
|
from ...utils._ordered_set import OrderedSet
|
|
from ...utils._sympy.functions import FloorDiv, ModularIndexing
|
|
from ...utils._sympy.symbol import symbol_is_type, SymT
|
|
from ...utils._sympy.value_ranges import ValueRanges
|
|
from .. import config, ir
|
|
from ..codecache import HalideCodeCache
|
|
from ..ir import get_reduction_combine_fn
|
|
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
|
from ..ops_handler import AddParenHandler
|
|
from ..runtime.hints import HalideInputSpec, HalideMeta
|
|
from ..utils import (
|
|
get_bounds_index_expr,
|
|
get_kernel_metadata,
|
|
parallel_num_threads,
|
|
sympy_index_symbol,
|
|
sympy_subs,
|
|
)
|
|
from ..virtualized import _ops as ops, V
|
|
from .common import (
|
|
BackendFeature,
|
|
CSEVariable,
|
|
DeferredLine,
|
|
IndentedBuffer,
|
|
KernelArgType,
|
|
OpOverrides,
|
|
PythonPrinter,
|
|
SizeArg,
|
|
TensorArg,
|
|
)
|
|
from .cpp import DTYPE_TO_CPP
|
|
from .cpp_utils import cexpr
|
|
from .simd import constant_repr, SIMDKernel, SIMDScheduling
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Sequence
|
|
|
|
from ..ops_handler import ReductionType, StoreMode
|
|
from ..shape_propagation import BlockShapeType
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def halide_constant(val):
|
|
if isinstance(val, int) and not (-2147483648 <= val <= 2147483647):
|
|
info = torch.iinfo(torch.int64)
|
|
if val == info.min:
|
|
return "hl.Int(64).min()"
|
|
if val == info.max:
|
|
return "hl.Int(64).max()"
|
|
return f"hl.i64({val!r})"
|
|
if isinstance(val, float):
|
|
return f"hl.f64({constant_repr(val)})"
|
|
return repr(val)
|
|
|
|
|
|
class Unsupported(RuntimeError):
|
|
def __init__(self, thing) -> None:
|
|
super().__init__(f"halide backend does not support: {thing}")
|
|
|
|
|
|
class HalidePrinter(PythonPrinter):
|
|
@staticmethod
|
|
def cast_index(expr):
|
|
return f"hl.cast({V.kernel.index_dtype}, {expr})"
|
|
|
|
@staticmethod
|
|
def cast_float(expr):
|
|
return f"hl.cast(hl.Float(32), {expr})"
|
|
|
|
def _print_Float(self, expr):
|
|
return f"hl.f32({expr})"
|
|
|
|
def _print_ToFloat(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.f32({self._print(expr.args[0])})"
|
|
|
|
def _print_floor(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.floor({self._print(expr.args[0])})")
|
|
|
|
_print_FloorToInt = _print_floor
|
|
|
|
def _print_Trunc(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.trunc({self._print(expr.args[0])})")
|
|
|
|
_print_TruncToInt = _print_Trunc
|
|
|
|
def _print_ceiling(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.ceil({self._print(expr.args[0])})")
|
|
|
|
def _helper_sqrt(self, expr):
|
|
return f"hl.sqrt({self.cast_float(self._print(expr))})"
|
|
|
|
def _print_Where(self, expr):
|
|
c = self.doprint(expr.args[0])
|
|
p = self.doprint(expr.args[1])
|
|
q = self.doprint(expr.args[2])
|
|
return f"hl.select({c}, {p}, {q})"
|
|
|
|
def _print_Min(self, expr):
|
|
if len(expr.args) == 1:
|
|
return self._print(expr.args[0])
|
|
|
|
mid = len(expr.args) // 2
|
|
a = self._print(sympy.Min(*expr.args[:mid]))
|
|
b = self._print(sympy.Min(*expr.args[mid:]))
|
|
return f"hl.min({a}, {b})"
|
|
|
|
def _print_Max(self, expr):
|
|
if len(expr.args) == 1:
|
|
return self._print(expr.args[0])
|
|
|
|
mid = len(expr.args) // 2
|
|
a = self._print(sympy.Max(*expr.args[:mid]))
|
|
b = self._print(sympy.Max(*expr.args[mid:]))
|
|
|
|
return f"hl.max({a}, {b})"
|
|
|
|
def _print_Abs(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.abs({self._print(expr.args[0])})")
|
|
|
|
def _print_OpaqueUnaryFn_cos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.cos({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_cosh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.cosh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_acos(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.acos({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.sin({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_sinh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.sinh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_asin(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.asin({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.tan({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_tanh(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.tanh({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_atan(self, expr):
|
|
assert len(expr.args) == 1
|
|
return f"hl.atan({self._print(expr.args[0])})"
|
|
|
|
def _print_OpaqueUnaryFn_log2(self, expr):
|
|
raise NotImplementedError("log2")
|
|
|
|
def _print_FloorDiv(self, expr):
|
|
if expr.is_integer:
|
|
return super()._print_FloorDiv(expr)
|
|
|
|
x, div = expr.args
|
|
x = self.cast_float(self.doprint(x))
|
|
div = self.cast_float(self.doprint(div))
|
|
return self.cast_index(f"hl.floor({x} / {div})")
|
|
|
|
def _print_Round(self, expr):
|
|
assert len(expr.args) == 1
|
|
return self.cast_index(f"hl.round({self._print(expr.args[0])})")
|
|
|
|
_print_RoundToInt = _print_Round
|
|
|
|
def _print_IntTrueDiv(self, expr):
|
|
a, b = expr.args
|
|
# force a cast to float
|
|
return f"({a}) / ({b}+hl.f32(0))"
|
|
|
|
def _print_RoundDecimal(self, expr):
|
|
val, n = expr.args
|
|
val = self._print(val)
|
|
n = int(n)
|
|
return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
|
|
|
|
|
|
texpr = HalidePrinter().doprint
|
|
pexpr = PythonPrinter().doprint
|
|
|
|
|
|
_halide_type = {
|
|
torch.bool: "hl.Bool()",
|
|
torch.bfloat16: "hl.BFloat(16)",
|
|
torch.float16: "hl.Float(16)",
|
|
torch.float32: "hl.Float(32)",
|
|
torch.float64: "hl.Float(64)",
|
|
torch.int8: "hl.Int(8)",
|
|
torch.int16: "hl.Int(16)",
|
|
torch.int32: "hl.Int(32)",
|
|
torch.int64: "hl.Int(64)",
|
|
torch.uint8: "hl.UInt(8)",
|
|
torch.uint16: "hl.UInt(16)",
|
|
torch.uint32: "hl.UInt(32)",
|
|
torch.uint64: "hl.UInt(64)",
|
|
}
|
|
|
|
|
|
def halide_type(dtype):
|
|
return _halide_type[dtype]
|
|
|
|
|
|
def halide_acc_type(dtype):
|
|
if is_integer_dtype(dtype) and dtype.is_signed and dtype != torch.int64:
|
|
dtype = torch.int32
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
dtype = torch.float32
|
|
return halide_type(dtype)
|
|
|
|
|
|
class HalideOverrides(OpOverrides):
|
|
@staticmethod
|
|
def to_dtype(
|
|
x,
|
|
dtype: torch.dtype,
|
|
src_dtype: Optional[torch.dtype] = None,
|
|
use_compute_types=True,
|
|
):
|
|
if dtype == torch.bool:
|
|
return f"({x} != 0)"
|
|
return f"hl.cast({halide_type(dtype)}, {x})"
|
|
|
|
@staticmethod
|
|
def to_dtype_bitcast(x, dtype: torch.dtype, src_dtype: torch.dtype):
|
|
if src_dtype in (torch.float16, torch.bfloat16):
|
|
x = f"hl.cast({halide_type(src_dtype)}, {x})" # body compute is upcast to fp32
|
|
line = f"hl.reinterpret({halide_type(dtype)}, {x})"
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
line = f"hl.cast(hl.Float(32), {line})"
|
|
return line
|
|
|
|
@classmethod
|
|
def constant(cls, value, dtype):
|
|
return cls.to_dtype(halide_constant(value), dtype)
|
|
|
|
@staticmethod
|
|
def abs(x):
|
|
return f"hl.abs({x})"
|
|
|
|
@staticmethod
|
|
def exp(x):
|
|
if not hasattr(x, "name"):
|
|
return f"hl.exp({x})"
|
|
return f"hl.fast_exp(hl.cast(hl.Float(32), {x})) if {x.name}.type().bits() <= 32 else hl.exp({x})"
|
|
|
|
@staticmethod
|
|
def sqrt(x):
|
|
return f"hl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def minimum(a, b):
|
|
# return f"hl.min({a}, {b})" <== handles nan wrong
|
|
if not hasattr(a, "name"):
|
|
return f"hl.min({a}, {b})"
|
|
b = f"hl.cast({a.name}.type(), {b})"
|
|
return f"hl.select(({a}<{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.min({a}, {b})"
|
|
|
|
@staticmethod
|
|
def maximum(a, b):
|
|
# return f"hl.max({a}, {b})" <== handles nan wrong
|
|
if not hasattr(a, "name"):
|
|
return f"hl.max({a}, {b})"
|
|
b = f"hl.cast({a.name}.type(), {b})"
|
|
return f"hl.select(({a}>{b})|hl.is_nan({a}), {a}, {b}) if {a.name}.type().is_float() else hl.max({a}, {b})"
|
|
|
|
@staticmethod
|
|
def where(a, b, c):
|
|
if hasattr(b, "name"):
|
|
c = f"hl.cast({b.name}.type(), {c})"
|
|
return f"hl.select({a}, {b}, {c})"
|
|
|
|
@staticmethod
|
|
def cos(x):
|
|
return f"hl.cos({x})"
|
|
|
|
@staticmethod
|
|
def sin(x):
|
|
return f"hl.sin({x})"
|
|
|
|
@staticmethod
|
|
def lgamma(x):
|
|
raise Unsupported("lgamma")
|
|
|
|
@staticmethod
|
|
def erf(x):
|
|
return f"hl.erf({x})"
|
|
|
|
@staticmethod
|
|
def cosh(x):
|
|
return f"hl.cosh({x})"
|
|
|
|
@staticmethod
|
|
def sinh(x):
|
|
return f"hl.sinh({x})"
|
|
|
|
@staticmethod
|
|
def acos(x):
|
|
return f"hl.acos({x})"
|
|
|
|
@staticmethod
|
|
def acosh(x):
|
|
return f"hl.acosh({x})"
|
|
|
|
@staticmethod
|
|
def asin(x):
|
|
return f"hl.asin({x})"
|
|
|
|
@staticmethod
|
|
def asinh(x):
|
|
return f"hl.asinh({x})"
|
|
|
|
@staticmethod
|
|
def atan2(x, y):
|
|
return f"hl.atan2({x}, {y})"
|
|
|
|
@staticmethod
|
|
def atan(x):
|
|
return f"hl.atan({x})"
|
|
|
|
@staticmethod
|
|
def atanh(x):
|
|
return f"hl.atanh({x})"
|
|
|
|
@staticmethod
|
|
def copysign(x, y):
|
|
raise Unsupported("copysign")
|
|
|
|
@staticmethod
|
|
def erfinv(x):
|
|
raise Unsupported("erfinv")
|
|
|
|
@staticmethod
|
|
def hypot(x, y):
|
|
return f"hl.hypot({x}, {y})"
|
|
|
|
@staticmethod
|
|
def nextafter(x, y):
|
|
raise Unsupported("nextafter")
|
|
|
|
@staticmethod
|
|
def logical_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def logical_not(a):
|
|
return f"{a} == 0"
|
|
|
|
@staticmethod
|
|
def logical_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def logical_xor(a, b):
|
|
return f"({a} ^ {b})"
|
|
|
|
@staticmethod
|
|
def bitwise_and(a, b):
|
|
return f"{a} & {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_not(a):
|
|
return f"~{a}"
|
|
|
|
@staticmethod
|
|
def bitwise_or(a, b):
|
|
return f"{a} | {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_xor(a, b):
|
|
return f"{a} ^ {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_left_shift(a, b):
|
|
return f"{a} << {b}"
|
|
|
|
@staticmethod
|
|
def bitwise_right_shift(a, b):
|
|
return f"{a} >> {b}"
|
|
|
|
@staticmethod
|
|
def rand(seed, offset):
|
|
return f"halide_helpers.rand({seed}, {offset})"
|
|
|
|
@staticmethod
|
|
def randn(seed, offset):
|
|
return f"halide_helpers.randn({seed}, {offset})"
|
|
|
|
@staticmethod
|
|
def randint64(seed, offset, low, high):
|
|
return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})"
|
|
|
|
@staticmethod
|
|
def load_seed(name, offset):
|
|
return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}"
|
|
|
|
@staticmethod
|
|
def rsqrt(x):
|
|
# return f"hl.fast_inverse_sqrt({x})" <== accuracy issues
|
|
return f"1./hl.sqrt({x})"
|
|
|
|
@staticmethod
|
|
def tan(x):
|
|
return f"hl.tan({x})"
|
|
|
|
@staticmethod
|
|
def tanh(x):
|
|
return f"hl.tanh({x})"
|
|
|
|
@staticmethod
|
|
def signbit(x):
|
|
return f"(hl.reinterpret(hl.UInt(32), hl.cast(hl.Float(32), {x})) >> 31) != 0"
|
|
|
|
@staticmethod
|
|
def fmod(a, b):
|
|
# TODO(jansel): find a better way to do this, builtin % has wrong sign
|
|
return f"{a} - hl.trunc({a}/{b})*{b}"
|
|
|
|
@staticmethod
|
|
def pow(a, b):
|
|
return f"hl.pow({a}, {b})" # hl.fast_pow fails accuracy
|
|
|
|
@staticmethod
|
|
def log(x):
|
|
return f"hl.log({x})" # hl.fast_log fails accuracy
|
|
|
|
@staticmethod
|
|
def log2(x):
|
|
raise NotImplementedError("log2")
|
|
|
|
@staticmethod
|
|
def isinf(x):
|
|
# workaround https://github.com/halide/Halide/issues/8309
|
|
return f"hl.is_inf(hl.cast(hl.Float(32), {x}))"
|
|
|
|
@staticmethod
|
|
def isnan(x):
|
|
# workaround https://github.com/halide/Halide/issues/8309
|
|
return f"hl.is_nan(hl.cast(hl.Float(32), {x}))"
|
|
|
|
@staticmethod
|
|
def round(x):
|
|
return f"hl.round({x})"
|
|
|
|
@staticmethod
|
|
def floor(x):
|
|
return f"hl.floor({x})"
|
|
|
|
@staticmethod
|
|
def int_truediv(a, b):
|
|
return f"({a}) / ({b} + hl.f32(0))"
|
|
|
|
@staticmethod
|
|
def floordiv(a, b):
|
|
# TODO(jansel): find a better ways to do this, the select-based trick from triton.py didn't work
|
|
return (
|
|
f"hl.floor(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
|
|
)
|
|
|
|
@classmethod
|
|
def sign(cls, x):
|
|
left = ops.to_dtype(ops.lt("0", x), torch.int8)
|
|
right = ops.to_dtype(ops.lt(x, "0"), torch.int8)
|
|
sub = ops.sub(left, right)
|
|
return f"hl.cast({x.name}.type(), {sub})"
|
|
|
|
@staticmethod
|
|
def trunc(x):
|
|
return f"hl.trunc({x})"
|
|
|
|
@staticmethod
|
|
def truncdiv(a, b):
|
|
# this causes crashes with floating point exception, see test_div_zero_dim_cpu
|
|
# return f"hl.div_round_to_zero({a}, {b})"
|
|
return (
|
|
f"hl.trunc(hl.cast(hl.Float(max(32, {a.name}.type().bits())), {a}) / {b})"
|
|
)
|
|
|
|
@staticmethod
|
|
def ceil(x):
|
|
return f"hl.ceil({x})"
|
|
|
|
@staticmethod
|
|
def relu(x):
|
|
return f"hl.max({x}, 0)"
|
|
|
|
@classmethod
|
|
def index_expr(cls, expr, dtype):
|
|
index = V.kernel.prepare_indexing(expr)
|
|
var = V.kernel.genfunc(
|
|
V.kernel.index_to_str(index),
|
|
V.kernel.used_dims_from_index(index),
|
|
bounds=get_bounds_index_expr(expr),
|
|
)
|
|
if dtype not in (torch.int32, torch.int64):
|
|
return ops.to_dtype(var, dtype)
|
|
return var
|
|
|
|
@classmethod
|
|
def indirect_indexing(cls, index_var, size, check=True, wrap_neg=True):
|
|
# TODO(jansel): Halide only supports 32-bit indexing, we should error on overflow
|
|
index_var = ops.to_dtype(index_var, torch.int32)
|
|
index_var = ops.halide_clamp(index_var, size, check)
|
|
index_var.indirect_indexing_size = size
|
|
return sympy_index_symbol(str(index_var))
|
|
|
|
@classmethod
|
|
def halide_clamp(cls, value, size, check):
|
|
end = V.kernel.kexpr(V.kernel.rename_indexing(size) - 1)
|
|
if not isinstance(size, (int, sympy.Integer)):
|
|
end = f"hl.cast({value.name}.type(), {end})"
|
|
# Skip unsafe_promise_clamped to workaround: https://github.com/halide/Halide/issues/8261#issuecomment-2148835692
|
|
# return f"hl.unsafe_promise_clamped({value}, 0, {end})"
|
|
return f"hl.clamp({value}, 0, {end})"
|
|
|
|
@staticmethod
|
|
def masked(mask, body, other):
|
|
with V.kernel.mask_loads(mask, other) as new_mask:
|
|
result = body()
|
|
|
|
if result.bounds.is_bool:
|
|
other = bool(other)
|
|
|
|
# Take dtype from result to prevent accidental promotion
|
|
other = V.kernel.genfunc(
|
|
f"hl.cast({result.name}.type(), {halide_constant(other)})",
|
|
[],
|
|
bounds=ValueRanges.wrap(other),
|
|
shape=result.shape,
|
|
)
|
|
# TODO(jansel): look into removing the where in the same places triton does
|
|
return ops.where(new_mask, result, other)
|
|
|
|
@staticmethod
|
|
def frexp(x):
|
|
raise NotImplementedError("frexp")
|
|
|
|
@staticmethod
|
|
def device_assert_async(cond, msg):
|
|
raise NotImplementedError("device_assert_async")
|
|
|
|
|
|
HalideOverrides._initialize_pointwise_overrides("halide")
|
|
|
|
|
|
class HalideCSEVariable(CSEVariable):
|
|
undefined_re = re.compile(r"\b(tmp\d+)\[\?\]")
|
|
|
|
def __init__(
|
|
self,
|
|
name,
|
|
bounds: ValueRanges[Any],
|
|
dtype: Optional[torch.dtype] = None,
|
|
shape: BlockShapeType = None,
|
|
) -> None:
|
|
super().__init__(name, bounds, dtype, shape=shape)
|
|
self.used_dims: Optional[list[sympy.Symbol]] = None
|
|
|
|
def update_on_args(self, name, args, kwargs):
|
|
used = OrderedSet(self.used_dims or ())
|
|
for arg in itertools.chain(args, kwargs.values()):
|
|
if isinstance(arg, HalideCSEVariable):
|
|
assert arg.used_dims is not None, (name, arg, args)
|
|
used.update(arg.used_dims)
|
|
self.used_dims = V.kernel.sort_used_dims(used)
|
|
|
|
def index_str(self, dims):
|
|
if len(dims) == 0:
|
|
return f"{self.name}[()]"
|
|
# Reversed since Halide is column major
|
|
return f"{self.name}[{', '.join(map(str, dims))}]"
|
|
|
|
def __str__(self) -> str:
|
|
if self.used_dims is None:
|
|
# This will get recomputed and replaced in codegen_kernel()
|
|
return f"{self.name}[?]"
|
|
return self.index_str(self.used_dims)
|
|
|
|
def subs_str(self, replacements):
|
|
assert self.used_dims is not None and all(
|
|
isinstance(x, sympy.Expr) for x in self.used_dims
|
|
)
|
|
return self.index_str([replacements.get(n, n) for n in self.used_dims])
|
|
|
|
|
|
@dataclasses.dataclass
|
|
class DimensionInfo:
|
|
expr: Optional[sympy.Expr]
|
|
size: sympy.Expr
|
|
stride: sympy.Expr
|
|
|
|
def __init__(self, expr, size, stride) -> None:
|
|
super().__init__()
|
|
if V.graph.sizevars.statically_known_lt(stride, 0):
|
|
stride = -stride
|
|
expr = -expr
|
|
self.expr = expr
|
|
self.size = size
|
|
self.stride = stride
|
|
|
|
def index_str(self, replacements=None, zero_vars=False):
|
|
assert self.expr is not None
|
|
expr = self.expr
|
|
if zero_vars and expr == 0:
|
|
return "hl.Var()"
|
|
if replacements:
|
|
replacements = {**replacements}
|
|
# pyrefly: ignore # missing-attribute
|
|
for sym in expr.free_symbols:
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
assert isinstance(sym, sympy.Symbol)
|
|
var = V.kernel.lookup_cse_var(sym.name)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
replacements[sym] = sympy_index_symbol(var.subs_str(replacements))
|
|
expr = sympy_subs(expr, replacements)
|
|
return V.kernel.index_to_str(expr)
|
|
|
|
|
|
def eq(left, right):
|
|
if V.graph.sizevars.statically_known_equals(left, right):
|
|
return True
|
|
try:
|
|
a = V.graph.sizevars.size_hint_or_throw(left)
|
|
b = V.graph.sizevars.size_hint_or_throw(right)
|
|
except TypeError: # unbacked symints
|
|
return False
|
|
if a == b:
|
|
V.graph.sizevars.check_equals(left, right)
|
|
return a == b
|
|
|
|
|
|
def lt(left, right):
|
|
if V.graph.sizevars.statically_known_lt(left, right):
|
|
return True
|
|
try:
|
|
a = V.graph.sizevars.size_hint_or_throw(left)
|
|
b = V.graph.sizevars.size_hint_or_throw(right)
|
|
except TypeError: # unbacked symints
|
|
gcd = sympy.gcd(left, right)
|
|
if gcd == left:
|
|
return left != right
|
|
return False
|
|
if a < b:
|
|
V.graph.sizevars.check_lt(left, right)
|
|
return a < b
|
|
|
|
|
|
class HalideKernel(SIMDKernel):
|
|
overrides = HalideOverrides # type: ignore[assignment]
|
|
kexpr: Callable[[sympy.Expr], str] = texpr
|
|
|
|
def __init__(
|
|
self,
|
|
tiling: dict[str, sympy.Expr],
|
|
**kwargs,
|
|
) -> None:
|
|
super().__init__(tiling, **kwargs)
|
|
# For halide, we just write directly to the body
|
|
self.compute = self.body
|
|
self.loads = self.body
|
|
self.stores = self.body
|
|
self.indexing_code_dom = IndentedBuffer()
|
|
self.needs_dom_indexing = self.inside_reduction
|
|
self.has_reduction = self.inside_reduction
|
|
self.buffer_dimensions: dict[str, list[DimensionInfo]] = {}
|
|
self.buffer_offsets: dict[str, sympy.Expr] = {}
|
|
# {h0: size1, h1: size2, ...}
|
|
self.halide_vars: dict[sympy.Symbol, sympy.Expr] = {}
|
|
# {x0: h0, x1: h1+10*h2, ...}
|
|
self.index_replacements: dict[sympy.Expr, sympy.Expr] = {}
|
|
# {h1: hr1, ...}
|
|
self.reduction_renames: dict[sympy.Symbol, sympy.Symbol] = {}
|
|
# {"i": {h0: hi0}, "o": ...}
|
|
self.dom_renames: dict[str, dict[sympy.Symbol, sympy.Symbol]] = {}
|
|
# {"in_ptr0": ["in_ptr0_view0"], ...}
|
|
self.buffer_aliases: dict[str, list[str]] = defaultdict(list)
|
|
self.has_indirect_indexing = False
|
|
|
|
def dtype_to_str(self, dtype: torch.dtype) -> str:
|
|
return halide_type(dtype)
|
|
|
|
# pyrefly: ignore # bad-override
|
|
def create_cse_var(self, name, bounds=None, dtype=None, shape=None):
|
|
self.body.writeline(f"{name} = hl.Func({name!r})")
|
|
# pyrefly: ignore # bad-argument-type
|
|
return HalideCSEVariable(name, bounds, dtype, shape)
|
|
|
|
def finalize_indexing(self, indices: Sequence[sympy.Expr]):
|
|
"""
|
|
Hook called right before codegen with every index that will be
|
|
used in the fused kernel.
|
|
|
|
This populates self.halide_vars/index_replacements/reduction_renames which is an alternate indexing
|
|
scheme that avoids using divide and modulus. Instead of xindex/yindex/rindex
|
|
we base indexing on a larger number of vars whose product combines to those.
|
|
|
|
This function populates self.halide_vars, self.index_replacements, and self.reduction_renames
|
|
"""
|
|
assert not (
|
|
self.index_replacements or self.halide_vars or self.reduction_renames
|
|
)
|
|
size_hint = functools.partial(V.graph.sizevars.size_hint, fallback=inf) # type: ignore[arg-type]
|
|
# pyrefly: ignore # bad-assignment
|
|
indices = dict.fromkeys(map(super().prepare_indexing, indices))
|
|
all_used_symbols = OrderedSet[Any]()
|
|
sym_to_node = {
|
|
n.symbol(): n
|
|
for n in itertools.chain.from_iterable(
|
|
[tree.nodes.values() for tree in self.range_trees]
|
|
)
|
|
}
|
|
|
|
def simplify(expr):
|
|
return sympy.simplify(
|
|
V.graph.sizevars.remove_precomputed_replacements(expr)
|
|
)
|
|
|
|
def visit_modular_indexing(base, divisor, modulus):
|
|
if base in sym_to_node:
|
|
node = sym_to_node[base]
|
|
all_used_symbols.add(
|
|
node.root.lookup(
|
|
node.divisor * divisor,
|
|
V.graph.sizevars.evaluate_min(
|
|
modulus, FloorDiv(node.length, divisor)
|
|
),
|
|
).symbol()
|
|
)
|
|
|
|
def visit_floor_div(base, divisor):
|
|
if base in sym_to_node:
|
|
node = sym_to_node[base]
|
|
all_used_symbols.add(
|
|
node.root.lookup(
|
|
node.divisor * divisor,
|
|
FloorDiv(node.length, divisor),
|
|
).symbol()
|
|
)
|
|
|
|
# first figure out all_used_symbols to do dead symbol elimination
|
|
for index in indices:
|
|
if index.has(ModularIndexing):
|
|
index.replace(
|
|
ModularIndexing(
|
|
sympy.Wild("base"),
|
|
sympy.Wild("divisor"),
|
|
sympy.Wild("modulus"),
|
|
),
|
|
visit_modular_indexing,
|
|
)
|
|
if index.has(FloorDiv):
|
|
index.replace(
|
|
FloorDiv(
|
|
sympy.Wild("base"),
|
|
sympy.Wild("divisor"),
|
|
),
|
|
visit_floor_div,
|
|
)
|
|
all_used_symbols.update(super().prepare_indexing(index).free_symbols)
|
|
|
|
self.has_indirect_indexing = any(
|
|
symbol_is_type(sym, SymT.INDIRECT) for sym in all_used_symbols
|
|
)
|
|
|
|
had_fallback = False
|
|
for tree in reversed(self.range_trees):
|
|
nodes = [n for n in tree.nodes.values() if n.symbol() in all_used_symbols]
|
|
nodes.sort(key=lambda n: size_hint(n.divisor))
|
|
if not nodes:
|
|
nodes.append(tree.lookup(1, tree.numel))
|
|
handled_count = 0
|
|
divisor = sympy.S.One
|
|
added_sym_size = []
|
|
# decide on a minimal set of symbols and put them in self.halide_vars
|
|
while handled_count < len(nodes) and not eq(tree.numel, divisor):
|
|
sizes_to_add = [
|
|
simplify(n.length) for n in nodes if eq(n.divisor, divisor)
|
|
]
|
|
handled_count += len(sizes_to_add)
|
|
assert sizes_to_add, nodes
|
|
end = divisor * functools.reduce(
|
|
V.graph.sizevars.evaluate_max, sizes_to_add
|
|
)
|
|
sizes_to_add.extend(
|
|
[
|
|
simplify(n.divisor / divisor)
|
|
for n in nodes
|
|
if lt(divisor, n.divisor) and lt(n.divisor, end)
|
|
]
|
|
)
|
|
while sizes_to_add:
|
|
next_size = functools.reduce(sympy.gcd, sizes_to_add)
|
|
if eq(next_size, 1):
|
|
# sizes share no common factors, e.g [2, 21, 42, 441, 889056]
|
|
# TODO(jansel): we should just prevent fusion in cases that hit this
|
|
next_size = simplify(tree.numel / divisor)
|
|
assert not eq(next_size, 1)
|
|
sizes_to_add = []
|
|
handled_count = len(nodes)
|
|
had_fallback = True
|
|
sym = sympy_index_symbol(f"h{len(self.halide_vars)}")
|
|
# pyrefly: ignore # missing-argument
|
|
if tree.is_reduction:
|
|
self.reduction_renames[sym] = sympy_index_symbol(
|
|
f"hr{len(self.halide_vars)}"
|
|
)
|
|
self.halide_vars[sym] = next_size
|
|
added_sym_size.append((sym, next_size))
|
|
divisor *= next_size
|
|
new_sizes = [n.length for n in nodes if eq(n.divisor, divisor)]
|
|
handled_count += len(new_sizes)
|
|
prior_len = len(sizes_to_add)
|
|
sizes_to_add = [
|
|
sympy.simplify(s / next_size)
|
|
for s in sizes_to_add
|
|
if not eq(s, next_size)
|
|
]
|
|
assert len(sizes_to_add) < prior_len or prior_len == 0
|
|
sizes_to_add.extend(new_sizes)
|
|
|
|
# create a mapping to the new set of symbols in self.index_replacements
|
|
for node in nodes:
|
|
try:
|
|
idx = 0
|
|
divisor = 1
|
|
while not eq(node.divisor, divisor):
|
|
sym, size = added_sym_size[idx]
|
|
idx += 1
|
|
divisor *= size
|
|
length = 1
|
|
expr = sympy.S.Zero
|
|
while not eq(node.length, length):
|
|
sym, size = added_sym_size[idx]
|
|
idx += 1
|
|
expr += length * sym
|
|
length *= size
|
|
self.index_replacements[node.symbol()] = expr
|
|
except IndexError:
|
|
assert had_fallback
|
|
full_index = sympy.S.Zero
|
|
stride = sympy.S.One
|
|
for sym, size in added_sym_size:
|
|
full_index += stride * sym
|
|
stride *= size
|
|
self.index_replacements[node.symbol()] = (
|
|
V.graph.sizevars.simplify_with_ranges(
|
|
ModularIndexing(full_index, node.divisor, node.length),
|
|
self.halide_vars, # type: ignore[arg-type]
|
|
)
|
|
)
|
|
|
|
# codegen the variable definitions
|
|
for sym in self.halide_vars:
|
|
self.indexing_code.writeline(f"{sym} = hl.Var({sym.name!r})")
|
|
if self.reduction_renames:
|
|
self.codegen_rdom(
|
|
"rdom",
|
|
{rv: self.halide_vars[v] for v, rv in self.reduction_renames.items()},
|
|
)
|
|
|
|
def setup_dom_indexing(self):
|
|
"""RDom based indexing uses explicit iteration ranges for Func updates"""
|
|
prefix = "i" if self.inside_reduction else "o"
|
|
if prefix in self.dom_renames:
|
|
return self.dom_renames[prefix]
|
|
|
|
renames = {}
|
|
for var in self.halide_vars.keys():
|
|
if not self.inside_reduction and var in self.reduction_renames:
|
|
continue
|
|
m = re.match(r"^h(\d+)$", var.name)
|
|
assert m
|
|
renames[var] = sympy_index_symbol(f"h{prefix}{m.group(1)}")
|
|
|
|
self.codegen_rdom(
|
|
f"{prefix}dom", {rv: self.halide_vars[v] for v, rv in renames.items()}
|
|
)
|
|
|
|
self.dom_renames[prefix] = renames
|
|
return renames
|
|
|
|
def codegen_rdom(self, name, vars):
|
|
rsizes = [
|
|
f"hl.Range(0, {self.kexpr(self.rename_indexing(size))})"
|
|
for size in vars.values()
|
|
]
|
|
self.indexing_code.writeline(f"{name} = hl.RDom([{', '.join(rsizes)}])")
|
|
for i, rsym in enumerate(vars.keys()):
|
|
self.indexing_code.writeline(f"{rsym} = {name}[{i}]")
|
|
|
|
def prepare_indexing(
|
|
self,
|
|
index: sympy.Expr,
|
|
):
|
|
index = super().prepare_indexing(index)
|
|
index = sympy_subs(index, self.index_replacements)
|
|
return V.graph.sizevars.simplify_with_ranges(index, self.halide_vars) # type: ignore[arg-type]
|
|
|
|
def sym_size(self, sym):
|
|
"""The size of an index symbol"""
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
return self.lookup_cse_var(sym.name).indirect_indexing_size
|
|
return self.halide_vars[sym]
|
|
|
|
def indexing_to_dimensions(self, var: str, index: sympy.Expr, is_store: bool):
|
|
"""Convert address-based indexing into dimensions using self.halide_vars"""
|
|
symbols = []
|
|
for sym in sorted(index.free_symbols, key=lambda x: x.name): # type: ignore[attr-defined]
|
|
if symbol_is_type(sym, (SymT.HALIDE, SymT.TMP)):
|
|
symbols.append(sym)
|
|
else:
|
|
assert symbol_is_type(
|
|
sym,
|
|
(
|
|
SymT.UNBACKED_INT,
|
|
SymT.SIZE,
|
|
SymT.PRECOMPUTED_SIZE,
|
|
),
|
|
), sym
|
|
|
|
# group the expression by variables used
|
|
offset = sympy.S.Zero
|
|
split_expr = dict.fromkeys(symbols, sympy.S.Zero)
|
|
split_failed: list[tuple[list[sympy.Symbol], sympy.Expr]] = []
|
|
index = sympy.expand(self.rename_indexing(index))
|
|
for part in index.args if isinstance(index, sympy.Add) else [index]:
|
|
part_vars = [v for v in part.free_symbols if v in split_expr]
|
|
if len(part_vars) == 0:
|
|
offset += part
|
|
elif len(part_vars) == 1:
|
|
split_expr[part_vars[0]] += part
|
|
else:
|
|
new_split_failed = []
|
|
for i in range(len(split_failed)):
|
|
assert split_failed[i] is not None
|
|
other_vars, other_part = split_failed[i]
|
|
if OrderedSet(other_vars) & OrderedSet(part_vars):
|
|
part_vars.extend([v for v in other_vars if v not in part_vars])
|
|
part += other_part
|
|
else:
|
|
new_split_failed.append((other_vars, other_part))
|
|
split_failed = [*new_split_failed, (part_vars, part)]
|
|
|
|
def expr_to_dimension(expr, syms):
|
|
expr = sympy.factor(expr)
|
|
if len(syms) == 1:
|
|
stride_wild = sympy.Wild("wild", exclude=symbols)
|
|
m = expr.match(stride_wild * syms[0])
|
|
if m:
|
|
return DimensionInfo(
|
|
syms[0], self.sym_size(syms[0]), m[stride_wild]
|
|
)
|
|
assert not is_store, expr
|
|
length = sympy.simplify(
|
|
sympy_subs(expr, {sym: self.sym_size(sym) - 1 for sym in syms}) + 1
|
|
)
|
|
stride = sympy.S.One
|
|
if isinstance(expr, sympy.Mul):
|
|
for term in expr.args:
|
|
if isinstance(term, sympy.Integer):
|
|
stride *= term
|
|
expr = sympy.simplify(expr / term)
|
|
length = sympy.simplify(sympy.ceiling(length / term))
|
|
return DimensionInfo(expr, length, stride)
|
|
|
|
# try to turn each group into a strided access
|
|
dims = []
|
|
for syms, expr in split_failed:
|
|
for v in syms:
|
|
expr += split_expr.pop(v)
|
|
dims.append(expr_to_dimension(expr, syms))
|
|
for sym, expr in split_expr.items():
|
|
dims.append(expr_to_dimension(expr, [sym]))
|
|
dims.sort(key=lambda d: V.graph.sizevars.size_hint(d.stride, fallback=inf)) # type: ignore[arg-type]
|
|
|
|
if not dims: # scalar load/store
|
|
if self.has_indirect_indexing:
|
|
# workaround https://github.com/halide/Halide/issues/8338
|
|
dims.append(DimensionInfo(sympy.S.Zero, 1, 1))
|
|
elif not V.graph.sizevars.statically_known_equals(dims[0].stride, 1):
|
|
# Halide assumes dimension 0 is stride == 1, so add a dummy dimension
|
|
dims.insert(
|
|
0, DimensionInfo(sympy.S.Zero, 1 if is_store else dims[0].stride, 1)
|
|
)
|
|
|
|
if dims and not is_store:
|
|
if var in self.buffer_offsets and V.graph.sizevars.statically_known_geq(
|
|
offset, self.buffer_offsets[var]
|
|
):
|
|
# reuse the existing offset to avoid needing an input alias
|
|
self.apply_offset_to_dimension(dims, offset - self.buffer_offsets[var])
|
|
offset = self.buffer_offsets[var]
|
|
elif V.graph.sizevars.statically_known_gt(
|
|
offset, 0
|
|
): # TODO(jansel): negative offsets
|
|
# roll the offset into the dimensions for cleaner indexing
|
|
self.apply_offset_to_dimension(dims, offset)
|
|
offset = 0
|
|
|
|
orig_var = var
|
|
for i in itertools.count():
|
|
if self.install_dims(var, dims, offset, is_store):
|
|
return var, dims
|
|
assert not is_store
|
|
var = f"{orig_var}_view{i}"
|
|
if var not in self.buffer_aliases[orig_var]:
|
|
self.buffer_aliases[orig_var].append(var)
|
|
|
|
def install_dims(self, var, dims, offset, is_store):
|
|
"""Try to set self.buffer_dimensions[var], return True on success"""
|
|
if var not in self.buffer_dimensions:
|
|
self.buffer_dimensions[var] = dims
|
|
self.buffer_offsets[var] = offset
|
|
return True
|
|
if self.buffer_offsets[var] != offset or len(
|
|
self.buffer_dimensions[var]
|
|
) != len(dims):
|
|
return False
|
|
if is_store:
|
|
return self.buffer_dimensions[var] == dims
|
|
for old, new in zip(self.buffer_dimensions[var], dims):
|
|
if old.stride != new.stride:
|
|
return False
|
|
if old.size != new.size or old.expr != new.expr:
|
|
old.size = V.graph.sizevars.evaluate_max(old.size, new.size)
|
|
old.expr = None
|
|
return True
|
|
|
|
def apply_offset_to_dimension(self, dims, offset):
|
|
if offset == 0:
|
|
return
|
|
for i in reversed(range(len(dims))):
|
|
if dims[i].stride == 1 or V.graph.sizevars.statically_known_geq(
|
|
offset, dims[i].stride
|
|
):
|
|
part = FloorDiv(offset, dims[i].stride)
|
|
offset -= part * dims[i].stride
|
|
dims[i].expr += part
|
|
assert offset == 0
|
|
|
|
def used_dims_from_index(self, index: sympy.Expr):
|
|
"""Detect which range trees are used to populate HalideCSEVariable.used_dims"""
|
|
used_dims = OrderedSet[sympy.Symbol]()
|
|
for sym in index.free_symbols:
|
|
assert isinstance(sym, sympy.Symbol)
|
|
if symbol_is_type(sym, SymT.TMP):
|
|
# indirect indexing
|
|
cse_var = self.lookup_cse_var(sym.name)
|
|
assert (
|
|
isinstance(cse_var, HalideCSEVariable)
|
|
and cse_var.used_dims is not None
|
|
)
|
|
used_dims.update(cse_var.used_dims)
|
|
elif symbol_is_type(sym, SymT.HALIDE):
|
|
used_dims.add(sym)
|
|
elif symbol_is_type(
|
|
sym, (SymT.UNBACKED_INT, SymT.SIZE, SymT.PRECOMPUTED_SIZE, SymT.INDEX)
|
|
):
|
|
pass
|
|
else:
|
|
raise NotImplementedError(f"unhandled symbol {sym}")
|
|
return self.sort_used_dims(used_dims)
|
|
|
|
def sort_used_dims(self, used_dims):
|
|
assert all(isinstance(x, sympy.Expr) for x in used_dims)
|
|
ordered = [
|
|
sym
|
|
for sym in itertools.chain(
|
|
self.halide_vars, self.reduction_renames.values()
|
|
)
|
|
if sym in used_dims
|
|
]
|
|
assert len(ordered) == len(used_dims)
|
|
return ordered
|
|
|
|
def make_index_str(self, dims, replacements=None, zero_vars=False):
|
|
index_str = ", ".join(d.index_str(replacements, zero_vars) for d in dims)
|
|
if len(dims) == 0:
|
|
index_str = "()"
|
|
elif len(dims) == 1:
|
|
# workaround for https://github.com/halide/Halide/issues/8299
|
|
index_str = f"{index_str},"
|
|
return index_str
|
|
|
|
def load(self, name: str, index: sympy.Expr):
|
|
"""Codegen a load from an InputBuffer"""
|
|
var = self.args.input(name)
|
|
index = self.prepare_indexing(index)
|
|
var, dims = self.indexing_to_dimensions(var, index, False)
|
|
line = f"{var}[{self.make_index_str(dims)}]"
|
|
dtype = V.graph.get_dtype(name)
|
|
if dtype in (torch.float16, torch.bfloat16):
|
|
dtype = torch.float32
|
|
line = f"hl.cast(hl.Float(32), {line})"
|
|
|
|
if self._load_mask:
|
|
assert (
|
|
isinstance(self._load_mask, HalideCSEVariable)
|
|
and self._load_mask.used_dims is not None
|
|
)
|
|
used_dims = OrderedSet(
|
|
(*self.used_dims_from_index(index), *self._load_mask.used_dims)
|
|
)
|
|
result = self.newfunc(self.sort_used_dims(used_dims))
|
|
if result.used_dims:
|
|
self.body.writeline(f"{result.name}_mask = hl.RDom([hl.Range(0, 1)])")
|
|
self.body.writeline(f"{result.name}_mask.where({self._load_mask})")
|
|
other = self.kexpr(self._load_other or 0) # type: ignore[arg-type]
|
|
self.body.writeline(
|
|
f"{result} = hl.cast({halide_type(dtype)}, {other})"
|
|
)
|
|
self.body.writeline(
|
|
f"{result} = {line} + hl.cast({halide_type(dtype)}, {result.name}_mask)"
|
|
)
|
|
else:
|
|
# scalar case
|
|
self.body.writeline(
|
|
f"{result} = hl.select({self._load_mask}, {line}, hl.cast({halide_type(dtype)}, 0))"
|
|
)
|
|
return result
|
|
else:
|
|
return self.genfunc(line, self.used_dims_from_index(index))
|
|
|
|
def lookup_cse_var(self, name: str):
|
|
return self.cse.varname_map[re.sub(r"\[.*", "", name)]
|
|
|
|
def store(
|
|
self, name: str, index: sympy.Expr, value: CSEVariable, mode: StoreMode = None
|
|
) -> None:
|
|
"""Codegen a store to an OutputBuffer"""
|
|
assert isinstance(value, HalideCSEVariable)
|
|
var = self.args.output(name)
|
|
index = self.prepare_indexing(index)
|
|
var, dims = self.indexing_to_dimensions(var, index, True)
|
|
if self.is_indirect_indexing(index) or mode is not None:
|
|
replacements = self.setup_dom_indexing()
|
|
index_str = self.make_index_str(dims, replacements)
|
|
value_str = value.subs_str(replacements)
|
|
undef_dims = (", ".join(["hl.Var()"] * len(dims))) or "()"
|
|
self.body.writeline(
|
|
DeferredLine(name, f"{var}[{undef_dims}] = hl.undef({var}.type())")
|
|
)
|
|
else:
|
|
index_str = self.make_index_str(dims, zero_vars=True)
|
|
value_str = str(value)
|
|
|
|
dtype = V.graph.get_dtype(name)
|
|
if mode is None:
|
|
line = f"{var}[{index_str}] = hl.cast({halide_type(dtype)}, {value_str})"
|
|
elif mode == "atomic_add":
|
|
line = f"{var}[{index_str}] += hl.cast({halide_type(dtype)}, {value_str})"
|
|
else:
|
|
raise NotImplementedError(f"store mode={mode}")
|
|
self.body.writeline(DeferredLine(name, line))
|
|
|
|
def reduction(
|
|
self,
|
|
dtype: torch.dtype,
|
|
src_dtype: torch.dtype,
|
|
reduction_type: ReductionType,
|
|
value: Union[CSEVariable, tuple[CSEVariable, ...]],
|
|
) -> Union[CSEVariable, tuple[CSEVariable, ...]]:
|
|
"""Codegen a reduction operation"""
|
|
assert self.inside_reduction
|
|
assert not self._load_mask
|
|
cache_key = (src_dtype, reduction_type, value)
|
|
if cache_key in self.cse.reduction_cache:
|
|
return self.cse.reduction_cache[cache_key]
|
|
|
|
if isinstance(value, tuple):
|
|
assert reduction_type == "welford_combine"
|
|
self.cse.reduction_cache[cache_key] = result_tuple = (
|
|
self.welford_combine_impl(*value)
|
|
)
|
|
return result_tuple
|
|
|
|
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
|
reduction_vars = OrderedSet(self.reduction_renames)
|
|
result_var = self.newfunc(
|
|
[v for v in value.used_dims if v not in reduction_vars],
|
|
)
|
|
if reduction_vars - OrderedSet(value.used_dims):
|
|
value = self.genfunc(
|
|
f"{value}",
|
|
self.sort_used_dims(OrderedSet((*value.used_dims, *reduction_vars))),
|
|
shape=value.shape,
|
|
)
|
|
value_str = value.subs_str(self.reduction_renames)
|
|
default = ir.Reduction.default_accumulator(reduction_type, src_dtype)
|
|
acc_type = halide_acc_type(dtype)
|
|
|
|
if reduction_type in ("argmax", "argmin"):
|
|
index = f"{result_var.name}_{reduction_type}"
|
|
self.body.writeline(f"{index} = hl.{reduction_type}(rdom, {value_str})")
|
|
# turn the N-D argmax index into a 1-D one
|
|
parts = []
|
|
stride = 1
|
|
for i, sym in enumerate(self.reduction_renames):
|
|
# pyrefly: ignore # bad-argument-type
|
|
parts.append(f"{index}[{i}]")
|
|
if stride != 1:
|
|
# pyrefly: ignore # unsupported-operation
|
|
parts[-1] += f"*{stride}"
|
|
stride *= self.halide_vars[sym]
|
|
self.body.writeline(f"{result_var} = {' + '.join(parts)}")
|
|
elif reduction_type == "welford_reduce":
|
|
# TODO(jansel): implement welford_reduce without fallback
|
|
result_var = self.welford_reduce_fallback(dtype, value)
|
|
else:
|
|
combine_fn = get_reduction_combine_fn(reduction_type, acc_type)
|
|
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
|
|
combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type]
|
|
default_str = f"hl.cast({acc_type}, {halide_constant(default)})"
|
|
self.body.writeline(f"{result_var} = {default_str}")
|
|
self.body.writeline(f"{result_var} = {combine_str}")
|
|
|
|
self.cse.reduction_cache[cache_key] = result_var
|
|
return result_var
|
|
|
|
def welford_combine_impl(self, mean, m2, weight):
|
|
assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None
|
|
assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None
|
|
assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None
|
|
used_dims = OrderedSet(
|
|
(*mean.used_dims, *m2.used_dims, *weight.used_dims) or self.halide_vars
|
|
)
|
|
used_dims -= OrderedSet(self.reduction_renames)
|
|
result_var = self.newfunc(self.sort_used_dims(used_dims))
|
|
default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)]
|
|
pfx = result_var.name
|
|
self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])")
|
|
self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]")
|
|
self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]")
|
|
self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]")
|
|
self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}")
|
|
self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}")
|
|
self.body.writeline(
|
|
f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}"
|
|
)
|
|
self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1")
|
|
self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2")
|
|
self.body.writeline(
|
|
f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)"
|
|
)
|
|
update = [
|
|
f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w",
|
|
f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w",
|
|
f"{pfx}_new_weight",
|
|
]
|
|
self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])")
|
|
|
|
unpacked = []
|
|
for i in range(3):
|
|
unpacked.append(self.newfunc(result_var.used_dims))
|
|
self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
|
|
return tuple(unpacked)
|
|
|
|
def scan(
|
|
self,
|
|
dtypes: tuple[torch.dtype, ...],
|
|
combine_fn: Callable[
|
|
[tuple[CSEVariable, ...], tuple[CSEVariable, ...]], tuple[CSEVariable, ...]
|
|
],
|
|
values_orig: tuple[CSEVariable, ...],
|
|
) -> tuple[CSEVariable, ...]:
|
|
assert self.inside_reduction
|
|
assert len(dtypes) == len(values_orig)
|
|
values: list[HalideCSEVariable] = []
|
|
all_used_dims = OrderedSet[sympy.Symbol]()
|
|
|
|
for value in values_orig:
|
|
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
|
if OrderedSet(value.used_dims) & OrderedSet(self.reduction_renames):
|
|
values.append(value)
|
|
else:
|
|
values.append(
|
|
self.genfunc(
|
|
f"{value}",
|
|
[*value.used_dims, [*self.reduction_renames][:1]],
|
|
shape=value.shape,
|
|
)
|
|
)
|
|
all_used_dims.update(value.used_dims)
|
|
result_var = self.newfunc(self.sort_used_dims(all_used_dims))
|
|
assert result_var.used_dims and OrderedSet(result_var.used_dims) & OrderedSet(
|
|
self.reduction_renames
|
|
)
|
|
initial = [
|
|
f"hl.cast({halide_acc_type(dtype)}, {value})"
|
|
for dtype, value in zip(dtypes, values)
|
|
]
|
|
|
|
length = self.kexpr(self.rename_indexing(self.range_trees[-1].numel))
|
|
scan_dom = f"{result_var.name}_rdom"
|
|
scan = f"{scan_dom}.x"
|
|
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
|
|
|
|
assert len(self.reduction_renames) == 1, (
|
|
"multi-dimensional scan not implemented"
|
|
)
|
|
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
|
|
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
|
|
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
|
|
|
|
if len(values) == 1:
|
|
|
|
def maybe_tuple(x):
|
|
return x[0]
|
|
|
|
read_left = [result_var.subs_str(scan_renames_pri)]
|
|
read_right = [result_var.subs_str(scan_renames_cur)]
|
|
else:
|
|
|
|
def maybe_tuple(x):
|
|
return f"hl.Tuple([{', '.join(x)}])"
|
|
|
|
read_left = [
|
|
result_var.subs_str(scan_renames_pri) + f"[{i}]"
|
|
for i in range(len(values))
|
|
]
|
|
read_right = [
|
|
result_var.subs_str(scan_renames_cur) + f"[{i}]"
|
|
for i in range(len(values))
|
|
]
|
|
|
|
self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")
|
|
|
|
# Disable CSE for update fn
|
|
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
|
|
combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type]
|
|
self.body.writeline(
|
|
f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}"
|
|
)
|
|
|
|
if len(values) == 1:
|
|
return (result_var,)
|
|
|
|
unpack_vars = [self.newfunc(self.sort_used_dims(all_used_dims)) for _ in values]
|
|
for i, v in enumerate(unpack_vars):
|
|
self.body.writeline(f"{v} = {result_var}[{i}]")
|
|
return tuple(unpack_vars)
|
|
|
|
def genfunc(
|
|
self,
|
|
line,
|
|
used_dims,
|
|
*,
|
|
bounds=ValueRanges.unknown(),
|
|
shape: BlockShapeType = None,
|
|
) -> HalideCSEVariable:
|
|
var = self.cse.generate(self.body, line, bounds=bounds, shape=shape)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
var.used_dims = used_dims
|
|
return var
|
|
|
|
def newfunc(self, used_dims, *, shape: BlockShapeType = None) -> HalideCSEVariable:
|
|
var = self.cse.newvar(shape=shape)
|
|
assert isinstance(var, HalideCSEVariable)
|
|
var.used_dims = used_dims
|
|
return var
|
|
|
|
def halide_buffer_numel(self, name: str):
|
|
"""
|
|
We map all tensors to 1D buffers in Halide since Halide has trouble representing some strides that PyTorch
|
|
supports. If there are gaps in the underlying layout the numel we pass to Halide includes the gaps while
|
|
PyTorch's numel excludes them.
|
|
"""
|
|
return V.graph.get_buffer(name).get_layout().storage_size()
|
|
|
|
def halide_argdefs(self):
|
|
"""
|
|
Halide requires scalar inputs before outputs, so need to reorder args.
|
|
"""
|
|
|
|
def arg_order(arg_tuple):
|
|
_call_str, arg = arg_tuple
|
|
if isinstance(arg, SizeArg):
|
|
return 1 # this would normally be at the end, move it to middle
|
|
elif "out_ptr" in arg.name:
|
|
return 2
|
|
else:
|
|
assert "in_ptr" in arg.name
|
|
return 0
|
|
|
|
result: list[tuple[Optional[str], KernelArgType]] = []
|
|
_, a, b, _ = self.args.python_argdefs()
|
|
for call_str, arg in sorted(zip(a, b), key=arg_order):
|
|
result.append((call_str, arg))
|
|
if isinstance(arg, TensorArg):
|
|
assert arg.offset == 0 and arg.alias_of is None
|
|
result.extend(
|
|
(
|
|
None,
|
|
TensorArg(
|
|
alias,
|
|
arg.buffer,
|
|
arg.dtype,
|
|
arg.offset,
|
|
alias_of=arg.name,
|
|
),
|
|
)
|
|
for alias in self.buffer_aliases.get(arg.name, ())
|
|
)
|
|
return result
|
|
|
|
def halide_kernel_meta(self) -> HalideMeta:
|
|
"""Compute metadata required by codecache.py"""
|
|
argtypes = []
|
|
for _, arg in self.halide_argdefs():
|
|
if isinstance(arg, SizeArg):
|
|
shape = None
|
|
stride = None
|
|
offset = None
|
|
dtype = "long"
|
|
else:
|
|
shape = [
|
|
cexpr(self.rename_indexing(x.size))
|
|
for x in self.buffer_dimensions[arg.name]
|
|
]
|
|
stride = [
|
|
cexpr(self.rename_indexing(x.stride))
|
|
for x in self.buffer_dimensions[arg.name]
|
|
]
|
|
assert len(shape) == len(stride)
|
|
offset = cexpr(self.buffer_offsets[arg.name])
|
|
dtype = f"{DTYPE_TO_CPP[arg.dtype]}*"
|
|
argtypes.append(
|
|
HalideInputSpec(
|
|
dtype,
|
|
arg.name,
|
|
shape=shape,
|
|
stride=stride,
|
|
offset=offset,
|
|
alias_of=arg.alias_of,
|
|
)
|
|
)
|
|
|
|
current_device = V.graph.get_current_device_or_throw()
|
|
if current_device.type == "cpu":
|
|
target = [config.halide.cpu_target]
|
|
scheduler = config.halide.scheduler_cpu
|
|
scheduler_flags = {
|
|
"parallelism": parallel_num_threads(),
|
|
}
|
|
cuda_device = None
|
|
else:
|
|
assert current_device.type == "cuda", "only cpu/cuda supported"
|
|
assert current_device.index <= 0, "only default device supported"
|
|
target = [config.halide.gpu_target]
|
|
scheduler = config.halide.scheduler_cuda
|
|
capability = torch.cuda.get_device_properties(current_device)
|
|
if "cuda_capability" not in target[0]:
|
|
for major, minor in [(8, 6), (8, 0), (7, 5), (7, 0), (6, 1)]:
|
|
if capability.major >= major and capability.minor >= minor:
|
|
target.append(f"cuda_capability_{major}{minor}")
|
|
break
|
|
target.append("user_context")
|
|
scheduler_flags = {
|
|
"parallelism": capability.multi_processor_count,
|
|
# TODO(jansel): explore other flags, see:
|
|
# grep parser.parse ~/Halide/src/autoschedulers/anderson2021/AutoSchedule.cpp
|
|
}
|
|
cuda_device = max(0, current_device.index)
|
|
|
|
# strict_float is requires for correctness
|
|
target.append("strict_float")
|
|
|
|
# without this we will initialize cuda once per kernel and hit errors
|
|
target.append("no_runtime")
|
|
|
|
if not config.halide.asserts:
|
|
target.append("no_asserts")
|
|
|
|
if config.halide.debug:
|
|
target.append("debug")
|
|
|
|
if "64" in self.index_dtype:
|
|
# TODO(jansel): it is unclear if this does anything, since input sizes are still int32
|
|
target.append("large_buffers")
|
|
|
|
return HalideMeta(
|
|
argtypes,
|
|
target="-".join(target),
|
|
scheduler=scheduler,
|
|
scheduler_flags=scheduler_flags, # type: ignore[arg-type]
|
|
cuda_device=cuda_device,
|
|
)
|
|
|
|
def codegen_kernel(self, name=None):
|
|
"""Called at the end to generate a final kernel string"""
|
|
if self.args.inplace_buffers:
|
|
raise Unsupported("inplace_buffers")
|
|
meta = self.halide_kernel_meta() # ensure needed args are added early
|
|
code = IndentedBuffer()
|
|
code.splice(
|
|
"""
|
|
import halide as hl
|
|
from torch._inductor.runtime import halide_helpers
|
|
from math import inf, nan
|
|
|
|
@hl.generator(name="kernel")
|
|
class Kernel:
|
|
""",
|
|
strip=True,
|
|
)
|
|
code.do_indent()
|
|
for _, arg in self.halide_argdefs():
|
|
if isinstance(arg, SizeArg):
|
|
code.writeline(f"{arg.name} = hl.InputScalar({self.index_dtype})")
|
|
else:
|
|
assert arg.buffer, arg
|
|
argcls = "hl.OutputBuffer" if "out" in arg.name else "hl.InputBuffer"
|
|
argtype = halide_type(arg.dtype)
|
|
ndim = len(self.buffer_dimensions[arg.name])
|
|
code.writeline(f"{arg.name} = {argcls}({argtype}, {ndim})")
|
|
code.splice(
|
|
"""
|
|
def generate(g):
|
|
"""
|
|
)
|
|
code.do_indent()
|
|
for _, arg in self.halide_argdefs():
|
|
code.writeline(f"{arg.name} = g.{arg.name}")
|
|
for old, new in self.args.aliases():
|
|
code.writeline(f"{old} = {new}")
|
|
code.splice(self.indexing_code)
|
|
|
|
def update_index(m):
|
|
var = cast(HalideCSEVariable, self.cse.varname_map[m.group(1)])
|
|
assert var.used_dims is not None, var
|
|
return str(var)
|
|
|
|
for line in self.body._lines:
|
|
if isinstance(line, str):
|
|
# fill in missing indices
|
|
line = HalideCSEVariable.undefined_re.sub(update_index, line)
|
|
code.writeline(line)
|
|
code.writeline("")
|
|
code.writeline("assert g.using_autoscheduler()")
|
|
|
|
for _, arg in self.halide_argdefs():
|
|
# fallback=1 below because halide requires buffers to be at least as large as the estimates
|
|
# This causes crashes if our estimate is greater than the vector length
|
|
# https://github.com/halide/Halide/issues/3103
|
|
if isinstance(arg, SizeArg):
|
|
hint = V.graph.sizevars.size_hint(arg.expr, fallback=1)
|
|
code.writeline(f"{arg.name}.set_estimate({hint})")
|
|
else:
|
|
dims = self.buffer_dimensions[arg.name]
|
|
range_hints = []
|
|
for i, dim in enumerate(dims):
|
|
hint = self._autoscheduler_workarounds(
|
|
V.graph.sizevars.size_hint(dim.size, fallback=1), dims
|
|
)
|
|
# pyrefly: ignore # bad-argument-type
|
|
range_hints.append(f"hl.Range(0, {hint})")
|
|
if "out" not in arg.name:
|
|
code.writeline(f"{arg.name}.dim({i}).set_min(0)")
|
|
try:
|
|
code.writeline(
|
|
f"{arg.name}.dim({i}).set_stride({int(dim.stride)})"
|
|
)
|
|
except TypeError:
|
|
pass # not integer
|
|
try:
|
|
code.writeline(
|
|
f"{arg.name}.dim({i}).set_extent({int(dim.size)})"
|
|
)
|
|
except TypeError:
|
|
pass # not integer
|
|
code.writeline(f"{arg.name}.set_estimates([{', '.join(range_hints)}])")
|
|
|
|
code.do_unindent(2)
|
|
code.splice(
|
|
"""
|
|
if __name__ == "__main__":
|
|
hl.main()
|
|
""".rstrip(),
|
|
)
|
|
if meta.scheduler:
|
|
code.splice(
|
|
f"""
|
|
else:
|
|
hl.load_plugin({HalideCodeCache.find_libautoschedule(meta.scheduler)!r})
|
|
target = hl.Target({meta.target!r})
|
|
autoscheduler = hl.AutoschedulerParams({meta.scheduler!r}, {meta.scheduler_flags!r})
|
|
with hl.GeneratorContext(target, autoscheduler):
|
|
gen = Kernel()
|
|
pipeline = gen._build_pipeline()
|
|
# gen.compile_to_callable() does not run the autoscheduler
|
|
pipeline.apply_autoscheduler(target, autoscheduler)
|
|
kernel = pipeline.compile_to_callable([
|
|
gen._get_input_parameter(a.name)._to_argument()
|
|
for a in gen._get_arginfos()
|
|
if a.dir == hl.ArgInfoDirection.Input
|
|
], target)
|
|
""",
|
|
strip=True,
|
|
)
|
|
else:
|
|
code.splice(
|
|
f"""
|
|
else:
|
|
with hl.GeneratorContext(hl.Target({meta.target!r})):
|
|
kernel = Kernel().compile_to_callable()
|
|
""",
|
|
strip=True,
|
|
)
|
|
return code.getvalue()
|
|
|
|
@staticmethod
|
|
def _autoscheduler_workarounds(n, dims):
|
|
if (
|
|
len(dims) == 1
|
|
and config.halide.scheduler_cuda == "Anderson2021"
|
|
and V.graph.get_current_device_or_throw().type == "cuda"
|
|
):
|
|
# workaround https://github.com/halide/Halide/issues/8246
|
|
n = max(2, n)
|
|
return n
|
|
|
|
def call_kernel(self, name: str, node=None):
|
|
"""Codegen a call to this kernel"""
|
|
wrapper = V.graph.wrapper_code
|
|
call_args = [f"{n}" for n, arg in self.halide_argdefs() if arg.alias_of is None]
|
|
current_device = V.graph.get_current_device_or_throw()
|
|
if current_device.type == "cuda":
|
|
stream_name = wrapper.write_get_raw_stream(
|
|
current_device.index, V.graph.name
|
|
)
|
|
call_args.append(stream_name)
|
|
wrapper.generate_kernel_call(
|
|
name,
|
|
call_args,
|
|
device=current_device,
|
|
triton=False,
|
|
)
|
|
|
|
def generate_assert(self, check):
|
|
return False # TODO(jansel): support asserts
|
|
|
|
def check_bounds(
|
|
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
|
):
|
|
pass # TODO(jansel): support asserts
|
|
|
|
|
|
class HalideScheduling(SIMDScheduling):
|
|
kernel_type = HalideKernel # type: ignore[arg-type,assignment]
|
|
|
|
@classmethod
|
|
def get_backend_features(cls, device: torch.device) -> OrderedSet[BackendFeature]:
|
|
result = OrderedSet(
|
|
[
|
|
BackendFeature.TUPLE_REDUCTION,
|
|
BackendFeature.PREFER_STORE_LOOP_ORDER,
|
|
BackendFeature.REDUCE_TO_SINGLE_ELEMENT,
|
|
]
|
|
)
|
|
if config.halide.scan_kernels:
|
|
result.add(BackendFeature.SCAN)
|
|
return result
|
|
|
|
def define_kernel(self, src_code, node_schedule, kernel):
|
|
"""Codegen kernel definition to go in output wrapper code"""
|
|
wrapper = V.graph.wrapper_code
|
|
if src_code in wrapper.src_to_kernel:
|
|
kernel_name = wrapper.src_to_kernel[src_code]
|
|
else:
|
|
kernel_name = f"halide_kernel_{wrapper.next_kernel_suffix()}"
|
|
wrapper.src_to_kernel[src_code] = kernel_name
|
|
wrapper.add_import_once(
|
|
"from torch._inductor.runtime.hints import HalideMeta, HalideInputSpec"
|
|
)
|
|
|
|
compile_wrapper = IndentedBuffer()
|
|
compile_wrapper.writeline(
|
|
f"async_compile.halide({kernel.halide_kernel_meta()!r}, '''"
|
|
)
|
|
compile_wrapper.splice(src_code, strip=True)
|
|
compile_wrapper.writeline("''')")
|
|
|
|
origins, detailed_origins = get_kernel_metadata(node_schedule, wrapper)
|
|
metadata_comment = f"{origins}\n{detailed_origins}"
|
|
wrapper.define_kernel(
|
|
kernel_name, compile_wrapper.getvalue(), metadata_comment
|
|
)
|
|
if is_metric_table_enabled("kernel_metadata"):
|
|
log_kernel_metadata(kernel_name, "", src_code)
|
|
|
|
return kernel_name
|