mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
The online_softmax_reduce runtime helper previously assumes the input tl.Tensor's are 2d tensors. But with tiled reduction, they can be 3d (y, x, r). Pull Request resolved: https://github.com/pytorch/pytorch/pull/162341 Approved by: https://github.com/jansel, https://github.com/eellison ghstack dependencies: #162311
756 lines
23 KiB
Python
756 lines
23 KiB
Python
# mypy: allow-untyped-decorators
|
|
# mypy: allow-untyped-defs
|
|
import math as pymath
|
|
import warnings
|
|
from typing import Any, Callable, TypeVar
|
|
|
|
from .triton_compat import ( # noqa: F401
|
|
_log2,
|
|
builtins_use_semantic_kwarg,
|
|
libdevice,
|
|
math,
|
|
tl,
|
|
triton,
|
|
)
|
|
|
|
|
|
_T = TypeVar("_T")
|
|
_LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e))
|
|
|
|
|
|
def set_driver_to_cpu():
|
|
driver = triton.runtime.driver
|
|
if backend := triton.backends.backends.get("cpu", None):
|
|
if isinstance(driver.active, backend.driver):
|
|
# Don't re-initialize backend if it is already active
|
|
return
|
|
driver.set_active(backend.driver())
|
|
return
|
|
# This can be a hard error once triton-cpu is merged into fbcode
|
|
warnings.warn(
|
|
"Could not find an active CPU backend. Generated kernels will not be executable!"
|
|
)
|
|
|
|
|
|
def set_driver_to_gpu():
|
|
driver = triton.runtime.driver
|
|
for name, backend in triton.backends.backends.items():
|
|
if backend.driver.is_active() and name != "cpu":
|
|
# After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4,
|
|
# `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute.
|
|
if (
|
|
isinstance(driver.active, backend.driver)
|
|
or hasattr(driver.active, "_obj")
|
|
and isinstance(driver.active._obj, backend.driver)
|
|
):
|
|
# Don't re-initialize backend if it is already active
|
|
return
|
|
driver.set_active(backend.driver())
|
|
return
|
|
raise RuntimeError("Could not find an active GPU backend")
|
|
|
|
|
|
def get_backend_options():
|
|
from triton.runtime import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
backend = triton.compiler.compiler.make_backend(target)
|
|
options = backend.parse_options(dict())
|
|
return options.__dict__
|
|
|
|
|
|
@triton.jit
|
|
def promote_to_tensor(x):
|
|
# Addition promotes to tensor for us
|
|
return x + tl.zeros((1,), tl.int1)
|
|
|
|
|
|
@triton.jit
|
|
def div_floor_integer(a, b):
|
|
# NOTE: a // b is C division, but we want floor division
|
|
# Based on c10::div_floor_integer
|
|
quot = a // b
|
|
remainder = a % b
|
|
fixed = tl.where(remainder != 0, quot - 1, quot)
|
|
return tl.where((a < 0) != (b < 0), fixed, quot)
|
|
|
|
|
|
@triton.jit
|
|
def remainder_integer(a, b):
|
|
# NOTE: a % b matches C division, not floor division
|
|
remainder = a % b
|
|
return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder)
|
|
|
|
|
|
@triton.jit
|
|
def is_floating(x):
|
|
return promote_to_tensor(x).dtype.is_floating()
|
|
|
|
|
|
@triton.jit
|
|
def _prod_accumulate(a, b):
|
|
return a * b
|
|
|
|
|
|
@triton.jit
|
|
def prod(input, axis):
|
|
return tl.reduce(input, axis, _prod_accumulate)
|
|
|
|
|
|
@triton.jit
|
|
def minimum(a, b):
|
|
mask = a < b
|
|
if is_floating(a):
|
|
mask |= a != a
|
|
return tl.where(mask, a, b)
|
|
|
|
|
|
@triton.jit
|
|
def maximum(a, b):
|
|
mask = a > b
|
|
if is_floating(a):
|
|
mask |= a != a
|
|
return tl.where(mask, a, b)
|
|
|
|
|
|
@triton.jit
|
|
def min2(a, dim):
|
|
return tl.reduce(a, dim, minimum)
|
|
|
|
|
|
@triton.jit
|
|
def max2(a, dim):
|
|
return tl.reduce(a, dim, maximum)
|
|
|
|
|
|
@triton.jit
|
|
def minimum_with_index(a_value, a_index, b_value, b_index):
|
|
mask = a_value < b_value
|
|
equal = a_value == b_value
|
|
if is_floating(a_value):
|
|
a_isnan = a_value != a_value
|
|
b_isnan = b_value != b_value
|
|
mask |= a_isnan & (not b_isnan)
|
|
# Consider NaNs as equal
|
|
equal |= a_isnan & b_isnan
|
|
|
|
# Prefer lowest index if values are equal
|
|
mask |= equal & (a_index < b_index)
|
|
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
|
|
|
|
|
|
@triton.jit
|
|
def maximum_with_index(a_value, a_index, b_value, b_index):
|
|
mask = a_value > b_value
|
|
equal = a_value == b_value
|
|
if is_floating(a_value):
|
|
a_isnan = a_value != a_value
|
|
b_isnan = b_value != b_value
|
|
mask |= a_isnan & (not b_isnan)
|
|
# Consider NaNs as equal
|
|
equal |= a_isnan & b_isnan
|
|
|
|
# Prefer lowest index if values are equal
|
|
mask |= equal & (a_index < b_index)
|
|
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
|
|
|
|
|
|
@triton.jit
|
|
def min_with_index(value, index, dim):
|
|
return tl.reduce((value, index), dim, minimum_with_index)
|
|
|
|
|
|
@triton.jit
|
|
def max_with_index(value, index, dim):
|
|
return tl.reduce((value, index), dim, maximum_with_index)
|
|
|
|
|
|
@triton.jit
|
|
def exp(x, use_fast_math: tl.constexpr):
|
|
if use_fast_math:
|
|
return math.exp(x)
|
|
else:
|
|
return libdevice.exp(x)
|
|
|
|
|
|
@triton.jit
|
|
def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr):
|
|
out_max = max2(lhs_max, dim)
|
|
out_max_keepdim = tl.expand_dims(out_max, dim)
|
|
delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim)
|
|
out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim)
|
|
return out_max, out_sum
|
|
|
|
|
|
@triton.jit
|
|
def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr):
|
|
"""
|
|
When we do combine, we assume lhs is the accumulator and rhs is the next
|
|
block of data.
|
|
Then rhs_sum is always 1. With that assumption, we can save some registers
|
|
and computation.
|
|
"""
|
|
out_max = maximum(lhs_max, rhs_max)
|
|
|
|
lhs_scale = tl.where(
|
|
out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math)
|
|
)
|
|
rhs_scale = tl.where(
|
|
out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math)
|
|
)
|
|
|
|
# Should be
|
|
# out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale
|
|
# but since rhs_sum is all 1, we can simplify it.
|
|
out_sum = lhs_sum * lhs_scale + rhs_scale
|
|
return out_max, out_sum
|
|
|
|
|
|
@triton.jit
|
|
def welford_reduce(value, mean, m2, weight, first_iteration):
|
|
if first_iteration:
|
|
new_weight = tl.full(weight.shape, 1, weight.dtype)
|
|
new_mean = value
|
|
new_m2 = tl.zeros_like(m2)
|
|
else:
|
|
delta = value - mean
|
|
new_weight = weight + 1
|
|
new_mean = mean + delta / new_weight
|
|
new_m2 = m2 + delta * (value - new_mean)
|
|
return new_mean, new_m2, new_weight
|
|
|
|
|
|
@triton.jit
|
|
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
|
|
delta = mean_2 - mean_1
|
|
new_weight = weight_1 + weight_2
|
|
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
|
|
return (
|
|
mean_1 + delta * w2_over_w,
|
|
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
|
|
new_weight,
|
|
)
|
|
|
|
|
|
@triton.jit
|
|
def welford(mean, m2, weight, dim):
|
|
return tl.reduce((mean, m2, weight), dim, welford_combine)
|
|
|
|
|
|
@triton.jit
|
|
def device_assert_then(cond, msg, r):
|
|
tl.device_assert(cond, msg)
|
|
return r
|
|
|
|
|
|
@triton.jit
|
|
def randint64(seed, offset, low, high):
|
|
r0, r1, _r2, _r3 = tl.randint4x(seed, offset)
|
|
r0 = r0.to(tl.uint64)
|
|
r1 = r1.to(tl.uint64)
|
|
result = r0 | (r1 << 32)
|
|
size = high - low
|
|
result = result % size.to(tl.uint64)
|
|
result = result.to(tl.int64) + low
|
|
return result
|
|
|
|
|
|
@triton.jit
|
|
def _any_combine(a, b):
|
|
return a | b
|
|
|
|
|
|
@triton.jit
|
|
def any(a, dim):
|
|
return tl.reduce(a, dim, _any_combine)
|
|
|
|
|
|
@triton.jit
|
|
def bucketize_binary_search(
|
|
values: tl.tensor,
|
|
boundaries_ptr: tl.tensor,
|
|
BOUNDARIES_SIZE: int,
|
|
BOUNDARIES_UNDERLYING_NUMEL: int,
|
|
BOUNDARIES_STRIDE: int,
|
|
boundary_indices: tl.tensor,
|
|
indexing_dtype: tl.dtype,
|
|
right: "bool", # triton can't handle the unquoted bool annotation
|
|
sorter_ptr: tl.tensor,
|
|
SORTER_STRIDE: int,
|
|
sorter_indices: tl.tensor,
|
|
):
|
|
"""
|
|
See [Note: Inductor bucketize op]
|
|
|
|
Inputs:
|
|
-------
|
|
values: the values to bucketize.
|
|
boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D.
|
|
BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one
|
|
individual set of boundaries).
|
|
BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring
|
|
any striding.
|
|
BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor
|
|
boundary_indices: a tensor of the same size as "values"; each element is an index
|
|
into a 1-D, un-strided boundaries tensor, pointing to the first element in the set
|
|
of boundaries used for that value.
|
|
indexing_dtype: the dtype used for indexing into the boundaries tensor, and the
|
|
return dtype.
|
|
right: if true, use boundary intervals closed on the left; otherwise use intervals
|
|
closed on the right.
|
|
sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries,
|
|
but potentially different striding. If present, this allows us to treat boundaries
|
|
as sorted even if the elements of boundaries are unsorted.
|
|
SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last
|
|
dimension of the sorter tensor.
|
|
sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices".
|
|
BLOCK_SHAPE: the shape of the data block being processed.
|
|
"""
|
|
|
|
low = tl.zeros(values.shape, dtype=indexing_dtype)
|
|
high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype)
|
|
|
|
full_range = BOUNDARIES_SIZE + 1
|
|
while full_range > 1:
|
|
mid = (high + low) // 2
|
|
mask = (
|
|
(mid * BOUNDARIES_STRIDE + boundary_indices) < BOUNDARIES_UNDERLYING_NUMEL
|
|
).logical_and(mid < BOUNDARIES_SIZE)
|
|
mid_indices = (
|
|
mid
|
|
if sorter_ptr is None or SORTER_STRIDE is None
|
|
else tl.load(
|
|
sorter_ptr + sorter_indices + SORTER_STRIDE * mid,
|
|
mask=mask,
|
|
other=0,
|
|
)
|
|
)
|
|
|
|
bucket_upper_bound = tl.load(
|
|
boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices,
|
|
mask=mask,
|
|
other=0,
|
|
)
|
|
if right:
|
|
is_above = values >= bucket_upper_bound
|
|
else:
|
|
is_above = values > bucket_upper_bound
|
|
|
|
low = tl.where(is_above & mask, mid + 1, low)
|
|
high = tl.where(is_above, high, mid)
|
|
|
|
full_range = (full_range + 1) // 2
|
|
|
|
return low
|
|
|
|
|
|
@triton.jit
|
|
def pack_value_flag(
|
|
value,
|
|
flag,
|
|
DTYPE_VALUE_AS_UINT: tl.constexpr,
|
|
DTYPE_PACK: tl.constexpr,
|
|
):
|
|
# Workaround for triton bug, tensor.to doesn't unwrap constexpr values
|
|
DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT)
|
|
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
|
|
uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
|
|
return flag.to(DTYPE_PACK) | (uv << bitwidth)
|
|
|
|
|
|
@triton.jit
|
|
def unpack_value(
|
|
pack,
|
|
DTYPE_VALUE,
|
|
DTYPE_VALUE_AS_UINT,
|
|
):
|
|
# Workaround for triton bug, tensor.to doesn't unwrap constexpr values
|
|
DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE)
|
|
DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT)
|
|
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
|
|
value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
|
|
return value_uint.to(DTYPE_VALUE, bitcast=True)
|
|
|
|
|
|
@triton.jit
|
|
def unpack_flag(pack, DTYPE_FLAG):
|
|
return pack.to(DTYPE_FLAG)
|
|
|
|
|
|
@triton.jit
|
|
def exclusive_scan_decoupled_lookback(
|
|
scratch_base,
|
|
block_value,
|
|
index,
|
|
combine_fn,
|
|
DTYPE_VALUE_AS_UINT: tl.constexpr,
|
|
DTYPE_PACK: tl.constexpr,
|
|
):
|
|
"""Compute exclusive scan of a scalar value between blocks
|
|
|
|
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
|
|
|
|
scratch_base: Pointer to scratch space in global memory
|
|
block_value: Scalar value for this block
|
|
index: Scalar index of this block relative to the current scan
|
|
combine_fn: Function ``(value, value) -> value`` which is scanned over
|
|
DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
|
|
DTYPE_PACK: Unsigned type twice the width of block_value
|
|
|
|
NOTE: This function is limited to values which are 32-bits or less because
|
|
we need to pack (value, flag) into a single unsigned int.
|
|
"""
|
|
# Publish block sum so subsequent blocks don't get stuck waiting for us
|
|
DTYPE_VALUE = block_value.dtype
|
|
pack = pack_value_flag(
|
|
block_value,
|
|
tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
|
|
DTYPE_VALUE_AS_UINT,
|
|
DTYPE_PACK,
|
|
)
|
|
if index > 0:
|
|
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
|
|
|
|
# Calculate exclusive prefix scan
|
|
exclusive_prefix = tl.zeros([], DTYPE_VALUE)
|
|
prefix_valid = False
|
|
test_target = index - 1
|
|
while test_target >= 0:
|
|
# tl.atomic_load
|
|
flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
|
|
while flag == 0:
|
|
pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
|
|
flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
|
|
|
|
value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
|
|
if prefix_valid:
|
|
exclusive_prefix = combine_fn(value, exclusive_prefix)
|
|
else:
|
|
exclusive_prefix = value
|
|
prefix_valid = True
|
|
|
|
if flag == 2:
|
|
test_target = -1
|
|
else:
|
|
test_target = test_target - 1
|
|
|
|
# Make inclusive block sum visible to other blocks
|
|
if prefix_valid:
|
|
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
|
|
else:
|
|
inclusive_prefix = block_value
|
|
pack = pack_value_flag(
|
|
inclusive_prefix,
|
|
tl.full([], 2, DTYPE_VALUE_AS_UINT),
|
|
DTYPE_VALUE_AS_UINT,
|
|
DTYPE_PACK,
|
|
)
|
|
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
|
|
return exclusive_prefix
|
|
|
|
|
|
@triton.jit
|
|
def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):
|
|
"""Compute exclusive scan of a scalar value between blocks
|
|
|
|
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
|
|
|
|
scratch_base: Pointer to scratch space in global memory
|
|
block_value: Scalar value for this block, must be 64-bits wide
|
|
index: Scalar index of this block relative to the current scan
|
|
combine_fn: Function ``(value, value) -> value`` which is scanned over
|
|
init: Scalar value equal to the identity of combine_fn
|
|
"""
|
|
# Publish block sum so subsequent blocks don't get stuck waiting for us
|
|
if index > 0:
|
|
block_value_u64 = block_value.to(tl.uint64, bitcast=True)
|
|
tl.store(scratch_base + 3 * index + 1, block_value_u64)
|
|
tl.debug_barrier()
|
|
flag_one = tl.full([], 1, tl.uint64)
|
|
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
|
|
|
|
# Calculate exclusive prefix scan
|
|
exclusive_prefix = tl.zeros([], block_value.dtype)
|
|
prefix_valid = False
|
|
test_target = index - 1
|
|
while test_target >= 0:
|
|
flag = tl.full([], 0, tl.uint64)
|
|
while flag == 0:
|
|
flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
|
|
|
|
value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
|
|
value = value_u64.to(block_value.dtype, bitcast=True)
|
|
if prefix_valid:
|
|
exclusive_prefix = combine_fn(value, exclusive_prefix)
|
|
else:
|
|
exclusive_prefix = value
|
|
prefix_valid = True
|
|
|
|
if flag == 2:
|
|
test_target = -1
|
|
else:
|
|
test_target = test_target - 1
|
|
|
|
# Make inclusive block sum visible to other blocks
|
|
if prefix_valid:
|
|
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
|
|
else:
|
|
inclusive_prefix = block_value
|
|
inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
|
|
tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
|
|
tl.debug_barrier()
|
|
flag_two = tl.full([], 2, tl.uint64)
|
|
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
|
|
|
|
return exclusive_prefix
|
|
|
|
|
|
@triton.jit
|
|
def frexp(x):
|
|
# TODO(isuruf): use inline_asm_elementwise here
|
|
y = libdevice.ilogb(x) + 1
|
|
exponent = tl.where(x == 0, 0, y)
|
|
mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))
|
|
return mantissa, exponent
|
|
|
|
|
|
@triton.jit
|
|
def _compare_and_swap_with_index(
|
|
x,
|
|
idxs,
|
|
rnumel,
|
|
flip,
|
|
i: tl.constexpr,
|
|
n_dims: tl.constexpr,
|
|
stable: tl.constexpr,
|
|
descending: tl.constexpr,
|
|
):
|
|
n_outer: tl.constexpr = x.numel >> n_dims
|
|
shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
|
|
|
|
idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
|
|
|
|
y = tl.reshape(x, shape)
|
|
iy = y.to(idtype, bitcast=True)
|
|
# slice left/right with 'stride' 2**(n_dims - i - 1)
|
|
right_mask = tl.arange(0, 2)[None, :, None].to(idtype)
|
|
left_mask = (1 - right_mask).to(idtype)
|
|
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape)
|
|
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape)
|
|
ileft = tl.reshape(ileft, x.shape)
|
|
iright = tl.reshape(iright, x.shape)
|
|
left = ileft.to(x.dtype, bitcast=True)
|
|
right = iright.to(x.dtype, bitcast=True)
|
|
|
|
# idx
|
|
y_idx = tl.reshape(idxs, shape)
|
|
left_idx = tl.broadcast_to(
|
|
tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape
|
|
)
|
|
right_idx = tl.broadcast_to(
|
|
tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape
|
|
)
|
|
left_idx = tl.reshape(left_idx, x.shape)
|
|
right_idx = tl.reshape(right_idx, x.shape)
|
|
|
|
# valid
|
|
if rnumel is None:
|
|
left_valid_mask = tl.full(x.shape, True, tl.int1)
|
|
right_valid_mask = tl.full(x.shape, True, tl.int1)
|
|
else:
|
|
left_valid_mask = left_idx < rnumel
|
|
right_valid_mask = right_idx < rnumel
|
|
|
|
# actual compare-and-swap
|
|
ix = x.to(idtype, bitcast=True)
|
|
|
|
# sort treats nan as having the higher value. comparisons with nan always return False.
|
|
# to align with sort semantics, we need to update descending to check if right_isnan,
|
|
# and ascending to check if left_isnan.
|
|
left_isnan = left != left
|
|
right_isnan = right != right
|
|
|
|
if descending:
|
|
cond = left < right
|
|
if is_floating(left):
|
|
if not stable:
|
|
cond = cond | right_isnan
|
|
else:
|
|
cond = cond | (right_isnan & (~left_isnan))
|
|
|
|
else:
|
|
cond = left > right
|
|
if is_floating(left):
|
|
if not stable:
|
|
cond = cond | left_isnan
|
|
else:
|
|
cond = cond | (left_isnan & (~right_isnan))
|
|
|
|
if stable:
|
|
# When stable sorting, tie break by index
|
|
eq = left == right
|
|
if is_floating(left):
|
|
eq = eq | (left_isnan & right_isnan)
|
|
cond = cond | (eq & (left_idx > right_idx))
|
|
|
|
cond = (right_valid_mask > left_valid_mask) | (
|
|
(right_valid_mask == left_valid_mask) & cond
|
|
)
|
|
cond = (cond ^ flip).to(tl.int1)
|
|
ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
|
|
new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs))
|
|
|
|
return ret.to(x.dtype, bitcast=True), new_idxs
|
|
|
|
|
|
@triton.jit
|
|
def _bitonic_merge_with_index(
|
|
x,
|
|
idxs,
|
|
rnumel,
|
|
stage: tl.constexpr,
|
|
alternating: tl.constexpr,
|
|
n_dims: tl.constexpr,
|
|
stable: tl.constexpr,
|
|
descending: tl.constexpr,
|
|
):
|
|
n_outer: tl.constexpr = x.numel >> n_dims
|
|
tl.static_assert(stage <= n_dims)
|
|
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
|
|
# descending order.
|
|
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
|
|
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
|
|
# a stride of 2) at this stage
|
|
if alternating:
|
|
shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
|
|
flip = tl.reshape(
|
|
tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape
|
|
)
|
|
else:
|
|
flip = False
|
|
# perform `stage` rounds of `compare-and-swap`
|
|
for i in tl.static_range(stage):
|
|
x, idxs = _compare_and_swap_with_index(
|
|
x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending
|
|
)
|
|
return x, idxs
|
|
|
|
|
|
@triton.jit
|
|
def sort_with_index(
|
|
x, # value
|
|
idxs, # index
|
|
rnumel, # number of elements
|
|
dim: tl.constexpr = None,
|
|
stable: tl.constexpr = tl.constexpr(False),
|
|
descending: tl.constexpr = tl.constexpr(False),
|
|
):
|
|
x, idxs = tl.broadcast(x, idxs)
|
|
# handle default dimension or check that it is the most minor dim
|
|
_dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
|
|
tl.static_assert(
|
|
_dim == len(x.shape) - 1, "only minor dimension is currently supported"
|
|
)
|
|
# iteratively run bitonic merge-sort steps
|
|
n_dims: tl.constexpr = _log2(x.shape[_dim])
|
|
|
|
for i in tl.static_range(1, n_dims + 1):
|
|
x, idxs = _bitonic_merge_with_index(
|
|
x,
|
|
idxs,
|
|
rnumel,
|
|
i,
|
|
alternating=i < n_dims,
|
|
n_dims=n_dims,
|
|
stable=stable,
|
|
descending=descending,
|
|
)
|
|
return x, idxs
|
|
|
|
|
|
@triton.jit
|
|
def select_one(x, mask, dim, keep_dims=False):
|
|
idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False)
|
|
ix = x.to(idtype, bitcast=True)
|
|
iy = tl.sum(ix * mask, dim, keep_dims=keep_dims)
|
|
return iy.to(x.dtype, bitcast=True)
|
|
|
|
|
|
@triton.jit
|
|
def x_grid_barrier(sem):
|
|
"""
|
|
Wait for all other thread blocks in grid sharing same y/z program_id
|
|
to reach this barrier before returning.
|
|
|
|
Args:
|
|
sem: an uint32 semaphores, zero or 0x80000000 initialized. Must be unique to each y/z program ID.
|
|
"""
|
|
# ensure stores before this are visible
|
|
tl.debug_barrier()
|
|
|
|
one_i32 = 1
|
|
one_u32 = one_i32.to(tl.uint32) # type: ignore[attr-defined]
|
|
expected = tl.num_programs(0).to(tl.uint32)
|
|
if tl.program_id(0) == 0:
|
|
nb = 0x80000000 - (expected - one_u32)
|
|
else:
|
|
nb = one_u32
|
|
|
|
old_arrive = tl.atomic_add(sem, nb, sem="release")
|
|
|
|
bar_flipped = False
|
|
while not bar_flipped:
|
|
# want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it
|
|
current_arrive = tl.atomic_add(sem, 0, sem="acquire")
|
|
# current_arrive = tl.load(sem, volatile=True)
|
|
bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0
|
|
|
|
# TODO(jansel): is this needed?
|
|
tl.debug_barrier()
|
|
|
|
|
|
def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]:
|
|
"""
|
|
Decorator to mark a function as a Triton built-in function. These functions
|
|
are evaluated at compile time.
|
|
|
|
Args:
|
|
f (function): The function to be marked as a Triton built-in.
|
|
|
|
Returns:
|
|
function: The same function, marked as a Triton built-in.
|
|
"""
|
|
if builtins_use_semantic_kwarg:
|
|
# support Triton before and after https://github.com/triton-lang/triton/pull/7054
|
|
# and after https://github.com/triton-lang/triton/pull/7239
|
|
def wrapper(*args, _semantic, **kwargs):
|
|
kwargs["_builder"] = _semantic
|
|
return f(*args, **kwargs)
|
|
else:
|
|
wrapper = f # type: ignore[assignment]
|
|
|
|
wrapper.__triton_builtin__ = True # type: ignore[attr-defined]
|
|
return wrapper
|
|
|
|
|
|
@triton_builtin
|
|
def constexpr_next_power_of_2(
|
|
n: tl.constexpr, *, _builder: object = None
|
|
) -> tl.constexpr:
|
|
"""
|
|
A version triton.next_power_of_two that can be used within a kernel on constants.
|
|
"""
|
|
assert isinstance(n, tl.constexpr)
|
|
return tl.constexpr(triton.next_power_of_2(n.value))
|
|
|
|
|
|
@triton_builtin
|
|
def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr:
|
|
"""
|
|
Work around triton compile error: `ValueError: `other` cannot be provided without `mask``
|
|
A compile-time to check to return either `val` or `None` depending on the value of mask.
|
|
"""
|
|
if isinstance(mask, tl.constexpr) and mask.value is None:
|
|
return tl.constexpr(None)
|
|
return val
|