Files
pytorch/torch/_inductor/runtime/triton_helpers.py
Shunting Zhang 7b8a64557d [inductor] fix 3d tiled online softmax (#162341)
The online_softmax_reduce runtime helper previously assumes the input tl.Tensor's are 2d tensors. But with tiled reduction, they can be 3d (y, x, r).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162341
Approved by: https://github.com/jansel, https://github.com/eellison
ghstack dependencies: #162311
2025-09-09 02:59:52 +00:00

756 lines
23 KiB
Python

# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
import math as pymath
import warnings
from typing import Any, Callable, TypeVar
from .triton_compat import ( # noqa: F401
_log2,
builtins_use_semantic_kwarg,
libdevice,
math,
tl,
triton,
)
_T = TypeVar("_T")
_LOG_2_E: tl.constexpr = tl.constexpr(pymath.log2(pymath.e))
def set_driver_to_cpu():
driver = triton.runtime.driver
if backend := triton.backends.backends.get("cpu", None):
if isinstance(driver.active, backend.driver):
# Don't re-initialize backend if it is already active
return
driver.set_active(backend.driver())
return
# This can be a hard error once triton-cpu is merged into fbcode
warnings.warn(
"Could not find an active CPU backend. Generated kernels will not be executable!"
)
def set_driver_to_gpu():
driver = triton.runtime.driver
for name, backend in triton.backends.backends.items():
if backend.driver.is_active() and name != "cpu":
# After https://github.com/triton-lang/triton/commit/b844d519bc5e86edf00fe6b3c6c2d1badcd509a4,
# `driver.active` can be of `LazyProxy` type and the sign of this - `_obj` attribute.
if (
isinstance(driver.active, backend.driver)
or hasattr(driver.active, "_obj")
and isinstance(driver.active._obj, backend.driver)
):
# Don't re-initialize backend if it is already active
return
driver.set_active(backend.driver())
return
raise RuntimeError("Could not find an active GPU backend")
def get_backend_options():
from triton.runtime import driver
target = driver.active.get_current_target()
backend = triton.compiler.compiler.make_backend(target)
options = backend.parse_options(dict())
return options.__dict__
@triton.jit
def promote_to_tensor(x):
# Addition promotes to tensor for us
return x + tl.zeros((1,), tl.int1)
@triton.jit
def div_floor_integer(a, b):
# NOTE: a // b is C division, but we want floor division
# Based on c10::div_floor_integer
quot = a // b
remainder = a % b
fixed = tl.where(remainder != 0, quot - 1, quot)
return tl.where((a < 0) != (b < 0), fixed, quot)
@triton.jit
def remainder_integer(a, b):
# NOTE: a % b matches C division, not floor division
remainder = a % b
return tl.where((remainder != 0) & ((a < 0) != (b < 0)), remainder + b, remainder)
@triton.jit
def is_floating(x):
return promote_to_tensor(x).dtype.is_floating()
@triton.jit
def _prod_accumulate(a, b):
return a * b
@triton.jit
def prod(input, axis):
return tl.reduce(input, axis, _prod_accumulate)
@triton.jit
def minimum(a, b):
mask = a < b
if is_floating(a):
mask |= a != a
return tl.where(mask, a, b)
@triton.jit
def maximum(a, b):
mask = a > b
if is_floating(a):
mask |= a != a
return tl.where(mask, a, b)
@triton.jit
def min2(a, dim):
return tl.reduce(a, dim, minimum)
@triton.jit
def max2(a, dim):
return tl.reduce(a, dim, maximum)
@triton.jit
def minimum_with_index(a_value, a_index, b_value, b_index):
mask = a_value < b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = a_value != a_value
b_isnan = b_value != b_value
mask |= a_isnan & (not b_isnan)
# Consider NaNs as equal
equal |= a_isnan & b_isnan
# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
@triton.jit
def maximum_with_index(a_value, a_index, b_value, b_index):
mask = a_value > b_value
equal = a_value == b_value
if is_floating(a_value):
a_isnan = a_value != a_value
b_isnan = b_value != b_value
mask |= a_isnan & (not b_isnan)
# Consider NaNs as equal
equal |= a_isnan & b_isnan
# Prefer lowest index if values are equal
mask |= equal & (a_index < b_index)
return tl.where(mask, a_value, b_value), tl.where(mask, a_index, b_index)
@triton.jit
def min_with_index(value, index, dim):
return tl.reduce((value, index), dim, minimum_with_index)
@triton.jit
def max_with_index(value, index, dim):
return tl.reduce((value, index), dim, maximum_with_index)
@triton.jit
def exp(x, use_fast_math: tl.constexpr):
if use_fast_math:
return math.exp(x)
else:
return libdevice.exp(x)
@triton.jit
def online_softmax_reduce(lhs_max, lhs_sum, dim, use_fast_math: tl.constexpr):
out_max = max2(lhs_max, dim)
out_max_keepdim = tl.expand_dims(out_max, dim)
delta = tl.where(out_max_keepdim == float("-inf"), 0, lhs_max - out_max_keepdim)
out_sum = tl.sum(lhs_sum * exp(delta, use_fast_math), dim)
return out_max, out_sum
@triton.jit
def online_softmax_combine(lhs_max, lhs_sum, rhs_max, use_fast_math: tl.constexpr):
"""
When we do combine, we assume lhs is the accumulator and rhs is the next
block of data.
Then rhs_sum is always 1. With that assumption, we can save some registers
and computation.
"""
out_max = maximum(lhs_max, rhs_max)
lhs_scale = tl.where(
out_max == float("-inf"), 1.0, exp(lhs_max - out_max, use_fast_math)
)
rhs_scale = tl.where(
out_max == float("-inf"), 1.0, exp(rhs_max - out_max, use_fast_math)
)
# Should be
# out_sum = lhs_sum * lhs_scale + rhs_sum * rhs_scale
# but since rhs_sum is all 1, we can simplify it.
out_sum = lhs_sum * lhs_scale + rhs_scale
return out_max, out_sum
@triton.jit
def welford_reduce(value, mean, m2, weight, first_iteration):
if first_iteration:
new_weight = tl.full(weight.shape, 1, weight.dtype)
new_mean = value
new_m2 = tl.zeros_like(m2)
else:
delta = value - mean
new_weight = weight + 1
new_mean = mean + delta / new_weight
new_m2 = m2 + delta * (value - new_mean)
return new_mean, new_m2, new_weight
@triton.jit
def welford_combine(mean_1, m2_1, weight_1, mean_2, m2_2, weight_2):
delta = mean_2 - mean_1
new_weight = weight_1 + weight_2
w2_over_w = tl.where(new_weight == 0.0, 0.0, weight_2 / new_weight)
return (
mean_1 + delta * w2_over_w,
m2_1 + m2_2 + delta * delta * weight_1 * w2_over_w,
new_weight,
)
@triton.jit
def welford(mean, m2, weight, dim):
return tl.reduce((mean, m2, weight), dim, welford_combine)
@triton.jit
def device_assert_then(cond, msg, r):
tl.device_assert(cond, msg)
return r
@triton.jit
def randint64(seed, offset, low, high):
r0, r1, _r2, _r3 = tl.randint4x(seed, offset)
r0 = r0.to(tl.uint64)
r1 = r1.to(tl.uint64)
result = r0 | (r1 << 32)
size = high - low
result = result % size.to(tl.uint64)
result = result.to(tl.int64) + low
return result
@triton.jit
def _any_combine(a, b):
return a | b
@triton.jit
def any(a, dim):
return tl.reduce(a, dim, _any_combine)
@triton.jit
def bucketize_binary_search(
values: tl.tensor,
boundaries_ptr: tl.tensor,
BOUNDARIES_SIZE: int,
BOUNDARIES_UNDERLYING_NUMEL: int,
BOUNDARIES_STRIDE: int,
boundary_indices: tl.tensor,
indexing_dtype: tl.dtype,
right: "bool", # triton can't handle the unquoted bool annotation
sorter_ptr: tl.tensor,
SORTER_STRIDE: int,
sorter_indices: tl.tensor,
):
"""
See [Note: Inductor bucketize op]
Inputs:
-------
values: the values to bucketize.
boundaries_ptr: a pointer to the beginning of the boundaries tensor, in 1-D.
BOUNDARIES_SIZE: the length of the last dimension of the boundaries tensor (i.e. one
individual set of boundaries).
BOUNDARIES_UNDERLYING_NUMEL: the length of the boundaries tensor, in 1-D, ignoring
any striding.
BOUNDARIES_STRIDE: the stride of the last dimension of the boundaries tensor
boundary_indices: a tensor of the same size as "values"; each element is an index
into a 1-D, un-strided boundaries tensor, pointing to the first element in the set
of boundaries used for that value.
indexing_dtype: the dtype used for indexing into the boundaries tensor, and the
return dtype.
right: if true, use boundary intervals closed on the left; otherwise use intervals
closed on the right.
sorter_ptr: an optional pointer to a sorter tensor of the same shape as boundaries,
but potentially different striding. If present, this allows us to treat boundaries
as sorted even if the elements of boundaries are unsorted.
SORTER_STRIDE: must be present if sorter_ptr is non-None; the stride of the last
dimension of the sorter tensor.
sorter_indices: must be present if sorter_ptr is non-None; see "boundary_indices".
BLOCK_SHAPE: the shape of the data block being processed.
"""
low = tl.zeros(values.shape, dtype=indexing_dtype)
high = tl.full(values.shape, BOUNDARIES_SIZE, dtype=indexing_dtype)
full_range = BOUNDARIES_SIZE + 1
while full_range > 1:
mid = (high + low) // 2
mask = (
(mid * BOUNDARIES_STRIDE + boundary_indices) < BOUNDARIES_UNDERLYING_NUMEL
).logical_and(mid < BOUNDARIES_SIZE)
mid_indices = (
mid
if sorter_ptr is None or SORTER_STRIDE is None
else tl.load(
sorter_ptr + sorter_indices + SORTER_STRIDE * mid,
mask=mask,
other=0,
)
)
bucket_upper_bound = tl.load(
boundaries_ptr + boundary_indices + BOUNDARIES_STRIDE * mid_indices,
mask=mask,
other=0,
)
if right:
is_above = values >= bucket_upper_bound
else:
is_above = values > bucket_upper_bound
low = tl.where(is_above & mask, mid + 1, low)
high = tl.where(is_above, high, mid)
full_range = (full_range + 1) // 2
return low
@triton.jit
def pack_value_flag(
value,
flag,
DTYPE_VALUE_AS_UINT: tl.constexpr,
DTYPE_PACK: tl.constexpr,
):
# Workaround for triton bug, tensor.to doesn't unwrap constexpr values
DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT)
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
uv = value.to(DTYPE_VALUE_AS_UINT, bitcast=True).to(DTYPE_PACK)
return flag.to(DTYPE_PACK) | (uv << bitwidth)
@triton.jit
def unpack_value(
pack,
DTYPE_VALUE,
DTYPE_VALUE_AS_UINT,
):
# Workaround for triton bug, tensor.to doesn't unwrap constexpr values
DTYPE_VALUE = tl.core._unwrap_if_constexpr(DTYPE_VALUE)
DTYPE_VALUE_AS_UINT = tl.core._unwrap_if_constexpr(DTYPE_VALUE_AS_UINT)
bitwidth = DTYPE_VALUE_AS_UINT.primitive_bitwidth
value_uint = (pack >> bitwidth).to(DTYPE_VALUE_AS_UINT)
return value_uint.to(DTYPE_VALUE, bitcast=True)
@triton.jit
def unpack_flag(pack, DTYPE_FLAG):
return pack.to(DTYPE_FLAG)
@triton.jit
def exclusive_scan_decoupled_lookback(
scratch_base,
block_value,
index,
combine_fn,
DTYPE_VALUE_AS_UINT: tl.constexpr,
DTYPE_PACK: tl.constexpr,
):
"""Compute exclusive scan of a scalar value between blocks
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
scratch_base: Pointer to scratch space in global memory
block_value: Scalar value for this block
index: Scalar index of this block relative to the current scan
combine_fn: Function ``(value, value) -> value`` which is scanned over
DTYPE_VALUE_AS_UINT: A tl.uint{n} type equal in size to ``block_value``
DTYPE_PACK: Unsigned type twice the width of block_value
NOTE: This function is limited to values which are 32-bits or less because
we need to pack (value, flag) into a single unsigned int.
"""
# Publish block sum so subsequent blocks don't get stuck waiting for us
DTYPE_VALUE = block_value.dtype
pack = pack_value_flag(
block_value,
tl.full(block_value.shape, 1, DTYPE_VALUE_AS_UINT),
DTYPE_VALUE_AS_UINT,
DTYPE_PACK,
)
if index > 0:
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
# Calculate exclusive prefix scan
exclusive_prefix = tl.zeros([], DTYPE_VALUE)
prefix_valid = False
test_target = index - 1
while test_target >= 0:
# tl.atomic_load
flag = tl.full([], 0, DTYPE_VALUE_AS_UINT)
while flag == 0:
pack = tl.atomic_add(scratch_base + test_target, 0, sem="relaxed")
flag = unpack_flag(pack, DTYPE_VALUE_AS_UINT)
value = unpack_value(pack, DTYPE_VALUE, DTYPE_VALUE_AS_UINT)
if prefix_valid:
exclusive_prefix = combine_fn(value, exclusive_prefix)
else:
exclusive_prefix = value
prefix_valid = True
if flag == 2:
test_target = -1
else:
test_target = test_target - 1
# Make inclusive block sum visible to other blocks
if prefix_valid:
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
else:
inclusive_prefix = block_value
pack = pack_value_flag(
inclusive_prefix,
tl.full([], 2, DTYPE_VALUE_AS_UINT),
DTYPE_VALUE_AS_UINT,
DTYPE_PACK,
)
tl.atomic_xchg(scratch_base + index, pack, sem="relaxed")
return exclusive_prefix
@triton.jit
def exclusive_scan_decoupled_lookback_64(scratch_base, block_value, index, combine_fn):
"""Compute exclusive scan of a scalar value between blocks
Ref: https://research.nvidia.com/publication/2016-03_single-pass-parallel-prefix-scan-decoupled-look-back
scratch_base: Pointer to scratch space in global memory
block_value: Scalar value for this block, must be 64-bits wide
index: Scalar index of this block relative to the current scan
combine_fn: Function ``(value, value) -> value`` which is scanned over
init: Scalar value equal to the identity of combine_fn
"""
# Publish block sum so subsequent blocks don't get stuck waiting for us
if index > 0:
block_value_u64 = block_value.to(tl.uint64, bitcast=True)
tl.store(scratch_base + 3 * index + 1, block_value_u64)
tl.debug_barrier()
flag_one = tl.full([], 1, tl.uint64)
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_one, sem="release")
# Calculate exclusive prefix scan
exclusive_prefix = tl.zeros([], block_value.dtype)
prefix_valid = False
test_target = index - 1
while test_target >= 0:
flag = tl.full([], 0, tl.uint64)
while flag == 0:
flag = tl.atomic_add(scratch_base + 3 * test_target + 0, 0, sem="acquire")
value_u64 = tl.load(scratch_base + 3 * test_target + flag.to(tl.int32))
value = value_u64.to(block_value.dtype, bitcast=True)
if prefix_valid:
exclusive_prefix = combine_fn(value, exclusive_prefix)
else:
exclusive_prefix = value
prefix_valid = True
if flag == 2:
test_target = -1
else:
test_target = test_target - 1
# Make inclusive block sum visible to other blocks
if prefix_valid:
inclusive_prefix = combine_fn(exclusive_prefix, block_value)
else:
inclusive_prefix = block_value
inclusive_prefix_u64 = inclusive_prefix.to(tl.uint64, bitcast=True)
tl.store(scratch_base + 3 * index + 2, inclusive_prefix_u64)
tl.debug_barrier()
flag_two = tl.full([], 2, tl.uint64)
tl.atomic_xchg(scratch_base + 3 * index + 0, flag_two, sem="release")
return exclusive_prefix
@triton.jit
def frexp(x):
# TODO(isuruf): use inline_asm_elementwise here
y = libdevice.ilogb(x) + 1
exponent = tl.where(x == 0, 0, y)
mantissa = tl.where(x == 0, 0, libdevice.ldexp(x, -y))
return mantissa, exponent
@triton.jit
def _compare_and_swap_with_index(
x,
idxs,
rnumel,
flip,
i: tl.constexpr,
n_dims: tl.constexpr,
stable: tl.constexpr,
descending: tl.constexpr,
):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * 2**i, 2, 2 ** (n_dims - i - 1)]
idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
y = tl.reshape(x, shape)
iy = y.to(idtype, bitcast=True)
# slice left/right with 'stride' 2**(n_dims - i - 1)
right_mask = tl.arange(0, 2)[None, :, None].to(idtype)
left_mask = (1 - right_mask).to(idtype)
ileft = tl.broadcast_to(tl.sum(iy * left_mask, 1).to(idtype)[:, None, :], shape)
iright = tl.broadcast_to(tl.sum(iy * right_mask, 1).to(idtype)[:, None, :], shape)
ileft = tl.reshape(ileft, x.shape)
iright = tl.reshape(iright, x.shape)
left = ileft.to(x.dtype, bitcast=True)
right = iright.to(x.dtype, bitcast=True)
# idx
y_idx = tl.reshape(idxs, shape)
left_idx = tl.broadcast_to(
tl.sum(y_idx * left_mask.to(y_idx.dtype), 1)[:, None, :], shape
)
right_idx = tl.broadcast_to(
tl.sum(y_idx * right_mask.to(y_idx.dtype), 1)[:, None, :], shape
)
left_idx = tl.reshape(left_idx, x.shape)
right_idx = tl.reshape(right_idx, x.shape)
# valid
if rnumel is None:
left_valid_mask = tl.full(x.shape, True, tl.int1)
right_valid_mask = tl.full(x.shape, True, tl.int1)
else:
left_valid_mask = left_idx < rnumel
right_valid_mask = right_idx < rnumel
# actual compare-and-swap
ix = x.to(idtype, bitcast=True)
# sort treats nan as having the higher value. comparisons with nan always return False.
# to align with sort semantics, we need to update descending to check if right_isnan,
# and ascending to check if left_isnan.
left_isnan = left != left
right_isnan = right != right
if descending:
cond = left < right
if is_floating(left):
if not stable:
cond = cond | right_isnan
else:
cond = cond | (right_isnan & (~left_isnan))
else:
cond = left > right
if is_floating(left):
if not stable:
cond = cond | left_isnan
else:
cond = cond | (left_isnan & (~right_isnan))
if stable:
# When stable sorting, tie break by index
eq = left == right
if is_floating(left):
eq = eq | (left_isnan & right_isnan)
cond = cond | (eq & (left_idx > right_idx))
cond = (right_valid_mask > left_valid_mask) | (
(right_valid_mask == left_valid_mask) & cond
)
cond = (cond ^ flip).to(tl.int1)
ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
new_idxs = idxs ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(idxs))
return ret.to(x.dtype, bitcast=True), new_idxs
@triton.jit
def _bitonic_merge_with_index(
x,
idxs,
rnumel,
stage: tl.constexpr,
alternating: tl.constexpr,
n_dims: tl.constexpr,
stable: tl.constexpr,
descending: tl.constexpr,
):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
# descending order.
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
# a stride of 2) at this stage
if alternating:
shape: tl.constexpr = [n_outer * 2 ** (n_dims - 1 - stage), 2, 2**stage]
flip = tl.reshape(
tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape
)
else:
flip = False
# perform `stage` rounds of `compare-and-swap`
for i in tl.static_range(stage):
x, idxs = _compare_and_swap_with_index(
x, idxs, rnumel, flip, i + (n_dims - stage), n_dims, stable, descending
)
return x, idxs
@triton.jit
def sort_with_index(
x, # value
idxs, # index
rnumel, # number of elements
dim: tl.constexpr = None,
stable: tl.constexpr = tl.constexpr(False),
descending: tl.constexpr = tl.constexpr(False),
):
x, idxs = tl.broadcast(x, idxs)
# handle default dimension or check that it is the most minor dim
_dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
tl.static_assert(
_dim == len(x.shape) - 1, "only minor dimension is currently supported"
)
# iteratively run bitonic merge-sort steps
n_dims: tl.constexpr = _log2(x.shape[_dim])
for i in tl.static_range(1, n_dims + 1):
x, idxs = _bitonic_merge_with_index(
x,
idxs,
rnumel,
i,
alternating=i < n_dims,
n_dims=n_dims,
stable=stable,
descending=descending,
)
return x, idxs
@triton.jit
def select_one(x, mask, dim, keep_dims=False):
idtype = tl.core.get_int_dtype(x.dtype.primitive_bitwidth, signed=False)
ix = x.to(idtype, bitcast=True)
iy = tl.sum(ix * mask, dim, keep_dims=keep_dims)
return iy.to(x.dtype, bitcast=True)
@triton.jit
def x_grid_barrier(sem):
"""
Wait for all other thread blocks in grid sharing same y/z program_id
to reach this barrier before returning.
Args:
sem: an uint32 semaphores, zero or 0x80000000 initialized. Must be unique to each y/z program ID.
"""
# ensure stores before this are visible
tl.debug_barrier()
one_i32 = 1
one_u32 = one_i32.to(tl.uint32) # type: ignore[attr-defined]
expected = tl.num_programs(0).to(tl.uint32)
if tl.program_id(0) == 0:
nb = 0x80000000 - (expected - one_u32)
else:
nb = one_u32
old_arrive = tl.atomic_add(sem, nb, sem="release")
bar_flipped = False
while not bar_flipped:
# want a `ld.acquire.gpu.u32 $0,[$1];` but Triton doesn't have it
current_arrive = tl.atomic_add(sem, 0, sem="acquire")
# current_arrive = tl.load(sem, volatile=True)
bar_flipped = ((old_arrive ^ current_arrive) & 0x80000000) != 0
# TODO(jansel): is this needed?
tl.debug_barrier()
def triton_builtin(f: Callable[..., _T]) -> Callable[..., _T]:
"""
Decorator to mark a function as a Triton built-in function. These functions
are evaluated at compile time.
Args:
f (function): The function to be marked as a Triton built-in.
Returns:
function: The same function, marked as a Triton built-in.
"""
if builtins_use_semantic_kwarg:
# support Triton before and after https://github.com/triton-lang/triton/pull/7054
# and after https://github.com/triton-lang/triton/pull/7239
def wrapper(*args, _semantic, **kwargs):
kwargs["_builder"] = _semantic
return f(*args, **kwargs)
else:
wrapper = f # type: ignore[assignment]
wrapper.__triton_builtin__ = True # type: ignore[attr-defined]
return wrapper
@triton_builtin
def constexpr_next_power_of_2(
n: tl.constexpr, *, _builder: object = None
) -> tl.constexpr:
"""
A version triton.next_power_of_two that can be used within a kernel on constants.
"""
assert isinstance(n, tl.constexpr)
return tl.constexpr(triton.next_power_of_2(n.value))
@triton_builtin
def if_mask(mask: Any, val, *, _builder: object = None) -> tl.constexpr:
"""
Work around triton compile error: `ValueError: `other` cannot be provided without `mask``
A compile-time to check to return either `val` or `None` depending on the value of mask.
"""
if isinstance(mask, tl.constexpr) and mask.value is None:
return tl.constexpr(None)
return val