Revert "Factor out the strings to templates for better editor integration (#160357)"

This reverts commit cbffde774557752cf20447d42d99ec6102673c31.

Reverted https://github.com/pytorch/pytorch/pull/160357 on behalf of https://github.com/clee2000 due to broke a bunch of internal builds due to not being able to find the file  No such file or directory: torch/_inductor/kernel/flex/templates/flex_decode.py.jinja D80145761, might need a buck targets change? ([comment](https://github.com/pytorch/pytorch/pull/160357#issuecomment-3184435581))
This commit is contained in:
PyTorch MergeBot
2025-08-13 15:40:50 +00:00
parent 31c9ac4319
commit c656334120
9 changed files with 1477 additions and 1451 deletions

View File

@ -1670,7 +1670,6 @@ def main() -> None:
"_inductor/codegen/aoti_runtime/*.h",
"_inductor/codegen/aoti_runtime/*.cpp",
"_inductor/script.ld",
"_inductor/kernel/flex/templates/*.jinja",
"_export/serde/*.yaml",
"_export/serde/*.thrift",
"share/cmake/ATen/*.cmake",

View File

@ -3,7 +3,6 @@
import math
from collections.abc import Sequence
from pathlib import Path
from typing import Any, Optional, Union
import sympy
@ -324,13 +323,267 @@ def next_power_of_two(n):
return 2 ** math.ceil(math.log2(n))
_TEMPLATE_DIR = Path(__file__).parent / "templates"
# ---- Common Template Strings ----
compute_forward_block_mn = r"""
@triton.jit
def forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
# -- load k --
# NB reversed order to since K is transposed
{%- if USE_TMA %}
k = tl.load_tensor_descriptor(
desc_k,
[kv_start + kv_offset, 0],
)
{%- else %}
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
{%- endif %}
if USE_TMA:
k = tl.trans(k)
# -- compute qk ---
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
# which is larger than the actual number of elements. To avoid access memory out of bound,
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=1,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
{%- if USE_TMA %}
v = tl.load_tensor_descriptor(
desc_v,
[kv_start + kv_offset, 0],
)
{%- else %}
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
# -- update m_i
m_i = m_ij
return acc, l_i, m_i
"""
compute_forward_inner = r"""
@triton.jit
def forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
kv_offset = 0
# loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end):
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
else:
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
# it's on par or slightly faster than only applying to the last block in fwd.
# However, we choose different strategy for bwd, where we only apply mod & mask
# to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
def load_template(name: str) -> str:
"""Load a template file and return its content."""
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
return f.read()
offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
)
offs_n = offs_n + offset
kv_offset += offset
if not USE_TMA:
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
# Template strings have been moved to templates/common.py.jinja
return acc, l_i, m_i
"""
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
compute_next_offset_func = r"""
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
"""
get_bounded_indices_func = r"""
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
"""
load_checked_block = r"""
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
"""
load_checked_2d = r"""
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_DIM: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)
"""

View File

@ -22,12 +22,17 @@ from ...select_algorithm import (
)
from .common import (
build_subgraph_buffer,
compute_forward_block_mn,
compute_forward_inner,
compute_next_offset_func,
create_indices_fake,
create_num_blocks_fake_generator,
create_placeholder,
get_bounded_indices_func,
get_fwd_subgraph_outputs,
infer_dense_strides,
load_template,
load_checked_2d,
load_checked_block,
maybe_realize,
set_head_dim_values,
SubgraphResults,
@ -62,12 +67,267 @@ def get_float32_precision():
return "'tf32'"
compute_flex_attention = r"""
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
#
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
#
# (Modifiable) Performance tuning options
# BLOCK_M: The thread block size across the seqlen dim of Q.
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
# contiguous? If so, we don't need to do an indirect jump for every block
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
ZQ = {{size("Q", 0)}}
HQ = {{size("Q", 1)}}
Q_LEN = {{size("Q", 2)}}
ZKV = {{size("K", 0)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0)
off_zq = tl.program_id(1)
off_hq = tl.program_id(2)
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
off_zkv = off_zq % ZKV
off_hkv = off_hq // GQA_SHARED_HEADS
off_g = off_hq % GQA_SHARED_HEADS
q_offset = off_zq * stride_qz + off_hq * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
Q = Q + q_offset
K = K + k_offset
V = V + v_offset
# Setting up the TMA descriptors for Q, K, V
desc_q = None
desc_k = None
desc_v = None
{%- if USE_TMA %}
desc_q = tl.make_tensor_descriptor(
base=Q,
shape=[Q_LEN, QK_HEAD_DIM],
strides=[stride_qm, 1],
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
)
desc_k = tl.make_tensor_descriptor(
base=K,
shape=[KV_LEN, QK_HEAD_DIM],
strides=[stride_kn, 1],
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
)
desc_v = tl.make_tensor_descriptor(
base=V,
shape=[KV_LEN, V_HEAD_DIM],
strides=[stride_vn, 1],
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
)
{%- endif %}
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_zq % SPARSE_Z
sparse_idx_hq = off_hq % SPARSE_HQ
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
# KV_IDX and KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
K_block_ptr = None
V_block_ptr = None
Q_block_ptr = None
if not USE_TMA:
Q_block_ptr = tl.make_block_ptr(
base=Q ,
shape=(Q_LEN, QK_HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(q_start * BLOCK_M, 0),
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
order=(1, 0)
)
{%- if USE_TMA %}
q = tl.load_tensor_descriptor(
desc_q,
[(q_start * BLOCK_M).to(tl.int32), 0],
)
{%- else %}
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply
# both score_mod and mask_mod to it
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
if not USE_TMA:
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
if not USE_TMA:
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# [Note] Handle fully masked out rows:
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None]
idx_zq = tl.program_id(1)
idx_hq = tl.program_id(2)
idx_m = offs_m[:, None]
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
if OUTPUT_LOGSUMEXP:
off_hz = off_zq * HQ + off_hq
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE:
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
"""
flex_attention_template = TritonTemplate(
name="flex_attention",
grid=flex_attention_grid,
source=load_template("flex_attention")
+ load_template("utilities")
+ load_template("common"),
source=compute_flex_attention
+ compute_forward_inner
+ compute_next_offset_func
+ compute_forward_block_mn
+ load_checked_block
+ get_bounded_indices_func,
)
@ -424,7 +684,693 @@ def flex_attention_backward_grid(
flex_attention_backward_template = TritonTemplate(
name="flex_attention_backward",
grid=flex_attention_backward_grid,
source=load_template("flex_backwards") + load_template("utilities"),
source=r"""
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
ZQ = {{size("Q", 0)}}
HQ = {{size("Q", 1)}}
HKV = {{size("K", 1)}}
Q_LEN = {{size("Q", 2)}}
ZKV = {{size("K", 0)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_zq = tl.program_id(1) # q batch idx
off_hkv = tl.program_id(2) # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_zq % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
{{gen_argdefs()}},
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
{{gen_argdefs()}},
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
stride_q_idx_h = {{stride("Q_IDX", 1)}}
stride_q_idx_n = {{stride("Q_IDX", 2)}}
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
{{gen_argdefs()}},
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
{{gen_argdefs()}},
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
index_v = offs_v[None, :]
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
dk *= SM_SCALE
if SAFE_HEAD_DIM:
mask = index_n < KV_LEN
else:
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
{{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
@triton.jit
def bwd_dq_inner(
{{gen_argdefs()}},
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
{{gen_defines() | indent_except_first(1) }}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_n in range(0, hi - 1):
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_n in range(0, hi):
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
return dq
@triton.jit
def bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
{{gen_defines() | indent_except_first(1)}}
# NB reversed order to since K is transposed
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
# that the M reads out of bounds prior to the last loop
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_hq",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=2,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_hq",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
# NB reversed order to since V is transposed
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
return dq
@triton.jit
def bwd_dkdv_inner(
{{gen_argdefs()}},
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
{{gen_defines() | indent_except_first(1) }}
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_m in range(0, hi - 1):
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_m in range(0, hi):
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
return dk, dv
@triton.jit
def bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
{{gen_defines() | indent_except_first(1) }}
# NB reversed order since Q is transposed
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
lse = tl.load(LSE + offs_m1)
else:
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
# that the n reads out of bounds prior to the last loop
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
pre_mod_scores = qkT
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qkT",
b="off_z",
h="off_hq",
m="m",
n="n",
out="qkT"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=2,
output_name="mask_mod_output",
score="qkT",
b="off_z",
h="off_hq",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="dsT"
) | indent_except_first(1) }}
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="idx_b",
h="idx_h",
m="idx_m",
n="idx_n",
grad_score_mod="dsT"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
dsT = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
return dk, dv
"""
+ compute_next_offset_func
+ get_bounded_indices_func
+ load_checked_2d,
)

View File

@ -18,10 +18,15 @@ from ...select_algorithm import (
TritonTemplate,
)
from .common import (
compute_forward_block_mn,
compute_forward_inner,
compute_next_offset_func,
create_indices_fake,
create_num_blocks_fake_generator,
get_bounded_indices_func,
get_fwd_subgraph_outputs,
load_template,
load_checked_2d,
load_checked_block,
maybe_realize,
set_head_dim_values,
)
@ -85,9 +90,266 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
flex_decoding_template = TritonTemplate(
name="flex_decoding",
grid=flex_decoding_grid,
source=load_template("flex_decode")
+ load_template("utilities")
+ load_template("common"),
source=r"""
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
# Q: Query, K: Key, V: Value
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
# (Modifiable) Config options:
# SPLIT_KV: number of blocks K & V are split into
# TILE_KV: length of each local KV split
# BLOCK_M: block size that Q is padded along seqlen dim.
# BLOCK_N: block size of K & V along N dimension.
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# change of base out of the loop
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
#
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
#
#
# Output: ACC output accumulated across local KV split.
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define Q Strides
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
Z = {{size("Q", 0)}}
ZKV = {{size("K", 0)}}
HKV = {{size("Q", 1)}}
G: tl.constexpr = GQA_SHARED_HEADS
HQ = HKV * G
Q_LEN = {{size("Q", 3)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
# Make sure each split is a multiple of BLOCK_N
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
off_z = tl.program_id(0) // HKV
off_zkv = off_z % ZKV
off_hkv = tl.program_id(0) % HKV
off_t = tl.program_id(1)
q_offset = off_z * stride_qz + off_hkv * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_z % SPARSE_Z
sparse_idx_h = off_hkv % SPARSE_HQ
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
# initialize offsets
tl.device_assert(BLOCK_M % G == 0)
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
off_g = tl.arange(0, G) # [G]
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_hq = offs_g + off_hkv * G
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}}
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}}
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
# Calculate KV blocks that belong this CTA.
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
else:
q = tl.load(Q + q_offset + q_range)
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Apply both score_mod and mask_mod
# find first kv block we are loading and the number of blocks we are loading
# Offset the kv_indices tensor by the correct batch and head
kv_indices = KV_IDX + sparse_idx_hz_offset
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
# last valid block according to sparse mask
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
None,
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
block_n_end = block_n_start + TILE_KV_MULTIPLE
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# last valid block according to sparse mask
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
None,
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
m_offset = off_t * stride_mt + off_z * stride_mz
l_offset = off_t * stride_lt + off_z * stride_lz
M_block_ptr = tl.make_block_ptr(
base=M + m_offset,
shape=(G, Q_LEN), # (G, M)
strides=(stride_mh, stride_mm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
L_block_ptr = tl.make_block_ptr(
base=L + l_offset,
shape=(G, Q_LEN), # (G, M)
strides=(stride_lh, stride_lm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
if SAFE_M_BOUNDARY:
tl.store(M_block_ptr, m_i)
tl.store(L_block_ptr, l_i)
else:
tl.store(M_block_ptr, m_i, boundary_check=(1,))
tl.store(L_block_ptr, l_i, boundary_check=(1,))
# -- store output
idx_z = off_z
idx_t = off_t
idx_hq = off_hkv*G + off_g[:, None, None]
idx_m = off_m[None, :, None]
idx_d = offs_vd[None, None, :]
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
"""
+ compute_forward_inner
+ get_bounded_indices_func
+ load_checked_block
+ load_checked_2d
+ compute_next_offset_func
+ compute_forward_block_mn,
)

View File

@ -1,193 +0,0 @@
# Common Imports
@triton.jit
def forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
# -- load k --
# NB reversed order to since K is transposed
{%- if USE_TMA %}
k = tl.load_tensor_descriptor(
desc_k,
[kv_start + kv_offset, 0],
)
{%- else %}
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
{%- endif %}
if USE_TMA:
k = tl.trans(k)
# -- compute qk ---
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
# which is larger than the actual number of elements. To avoid access memory out of bound,
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=1,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_h",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
# apply mask for partially unmasked blocks
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# -- compute scaling constant ---
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
if not ROWS_GUARANTEED_SAFE:
masked_out_rows = (m_ij == float("-inf"))
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
else:
m_ij_masked = m_ij
alpha = tl.math.exp2(m_i - m_ij_masked)
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
# NB: l_i update is pulled up here since it's a bit faster
# NB: For headdim=256, it's faster to move it back down to after m_i =
# m_ij
l_i = l_i * alpha + tl.sum(p, 1)
# # -- scale and update acc --
acc = acc * alpha[:, None]
{%- if USE_TMA %}
v = tl.load_tensor_descriptor(
desc_v,
[kv_start + kv_offset, 0],
)
{%- else %}
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
# -- update m_i
m_i = m_ij
return acc, l_i, m_i
@triton.jit
def forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets used as inputs to score_mod & mask_mod
# of size [BLOCK_M, BLOCK_N] or scalar.
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
# blocksparse data
kv_indices, kv_num_blocks,
# start kv and end kv block
block_n_start, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
{{gen_defines() | indent_except_first(1)}}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
RCP_LN2: tl.constexpr = 1.44269504
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
kv_offset = 0
# loop over k, v and update accumulator until block_n_end
for start_n in range(block_n_start, block_n_end):
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
if IS_DIVISIBLE:
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
else:
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
# it's on par or slightly faster than only applying to the last block in fwd.
# However, we choose different strategy for bwd, where we only apply mod & mask
# to the last block because it's faster a lot.
acc, l_i, m_i = forward_block_mn(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
# accumulated values
acc, l_i, m_i,
# Offsets
off_z, off_h, offs_m, offs_n,
# Offsets needed for TMA loads
kv_start,
kv_offset,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
offset = get_offset_for_next_block(
start_n, kv_indices, kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
)
offs_n = offs_n + offset
kv_offset += offset
if not USE_TMA:
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
return acc, l_i, m_i

View File

@ -1,248 +0,0 @@
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# M: Number of queries, N: Number of keys/values, D: Model dimension
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
#
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
#
# (Modifiable) Performance tuning options
# BLOCK_M: The thread block size across the seqlen dim of Q.
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
# contiguous? If so, we don't need to do an indirect jump for every block
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
ZQ = {{size("Q", 0)}}
HQ = {{size("Q", 1)}}
Q_LEN = {{size("Q", 2)}}
ZKV = {{size("K", 0)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
q_start = tl.program_id(0)
off_zq = tl.program_id(1)
off_hq = tl.program_id(2)
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
off_zkv = off_zq % ZKV
off_hkv = off_hq // GQA_SHARED_HEADS
off_g = off_hq % GQA_SHARED_HEADS
q_offset = off_zq * stride_qz + off_hq * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
Q = Q + q_offset
K = K + k_offset
V = V + v_offset
# Setting up the TMA descriptors for Q, K, V
desc_q = None
desc_k = None
desc_v = None
{%- if USE_TMA %}
desc_q = tl.make_tensor_descriptor(
base=Q,
shape=[Q_LEN, QK_HEAD_DIM],
strides=[stride_qm, 1],
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
)
desc_k = tl.make_tensor_descriptor(
base=K,
shape=[KV_LEN, QK_HEAD_DIM],
strides=[stride_kn, 1],
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
)
desc_v = tl.make_tensor_descriptor(
base=V,
shape=[KV_LEN, V_HEAD_DIM],
strides=[stride_vn, 1],
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
)
{%- endif %}
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_zq % SPARSE_Z
sparse_idx_hq = off_hq % SPARSE_HQ
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
# KV_IDX and KV_NUM_BLKS are always contiguous.
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
K_block_ptr = None
V_block_ptr = None
Q_block_ptr = None
if not USE_TMA:
Q_block_ptr = tl.make_block_ptr(
base=Q ,
shape=(Q_LEN, QK_HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(q_start * BLOCK_M, 0),
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
order=(1, 0)
)
{%- if USE_TMA %}
q = tl.load_tensor_descriptor(
desc_q,
[(q_start * BLOCK_M).to(tl.int32), 0],
)
{%- else %}
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
{%- endif %}
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We don't know anything "special" about these blocks, so we need to apply
# both score_mod and mask_mod to it
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
if not USE_TMA:
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
if not USE_TMA:
K_block_ptr = tl.make_block_ptr(
base=K,
shape=(QK_HEAD_DIM, KV_LEN),
strides=(stride_kk, stride_kn),
offsets=(0, kv_start),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(kv_start, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = kv_start + tl.arange(0, BLOCK_N)
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr,
desc_k, desc_v, Q_LEN, KV_LEN,
acc, l_i, m_i,
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
kv_start,
kv_indices, kv_num_blocks,
0, block_n_end,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# [Note] Handle fully masked out rows:
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
l_i = tl.where(l_i == 0.0, 1, l_i)
acc = acc / l_i[:, None]
idx_zq = tl.program_id(1)
idx_hq = tl.program_id(2)
idx_m = offs_m[:, None]
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
if OUTPUT_LOGSUMEXP:
off_hz = off_zq * HQ + off_hq
l_ptrs = LSE + off_hz * Q_LEN + offs_m
lse = m_i + tl.math.log2(l_i)
if IS_DIVISIBLE:
tl.store(l_ptrs, lse)
else:
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)

View File

@ -1,682 +0,0 @@
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
# Sub notation for this kernel:
#
# Q: Query, K: Key, V: Value
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
# DELTA: Precomputed sum(OUT*DO, axis=-1)
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
# inductor codegen
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
# (Modifiable) Performance tuning options
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
#
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
# The below are kernel options that can be applied for certain score_mods,
# or involve a numerics vs. perf tradeoff
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
# about 20% more numerical error, but slightly faster.
# Define strides of inputs
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
ZQ = {{size("Q", 0)}}
HQ = {{size("Q", 1)}}
HKV = {{size("K", 1)}}
Q_LEN = {{size("Q", 2)}}
ZKV = {{size("K", 0)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
pid = tl.program_id(0)
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
off_zq = tl.program_id(1) # q batch idx
off_hkv = tl.program_id(2) # kv head idx
off_zkv = off_zq % ZKV # kv batch idx
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_zq % SPARSE_Z
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
# offset K, V, DV pointers for batch/kv-head
K += k_adj
V += v_adj
DV += dv_adj
RCP_LN2 = 1.44269504
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
if pid >= NUM_KV_BLOCKS:
off_pid = pid - NUM_KV_BLOCKS
# THIS BLOCK DOES DQ
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
start_m2_block = off_pid % NUM_Q_BLOCKS
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
Q2 = Q + q_adj2
DO2 = DO + do_adj2
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
DQ2 = DQ + dq_adj2
LSE2 = LSE + off_chz2
DELTA2 = DELTA + off_chz2
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_m2 = start_m2_block * BLOCK_M2
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
# load Q and do: they stay in SRAM throughout the inner loop.
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
if PRESCALE_QK:
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA2 + offs_m2)
lse = tl.load(LSE2 + offs_m2)
else:
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
lse = lse[:, None]
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# KV_IDX and KV_NUM_BLKS are always contiguous.
kv_indices = KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
{{gen_argdefs()}},
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
dq = bwd_dq_inner(
{{gen_argdefs()}},
K, V,
dq, q, do, Di, lse,
off_zq, off_hq2, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dQ.
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
dq *= SM_SCALE
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dq_ptrs, dq)
else:
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
else:
# THIS BLOCK DOES DK & DV
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
pid_mask = pid // SPARSE_KV_MULTIPLE
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
stride_q_idx_h = {{stride("Q_IDX", 1)}}
stride_q_idx_n = {{stride("Q_IDX", 2)}}
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
start_n1 = pid * BLOCK_N1
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
# load K and V: they stay in SRAM throughout the inner loop.
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
if PRESCALE_QK:
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
for off_g in range(0, GQA_SHARED_HEADS):
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
Q1 = Q + q_adj1
DO1 = DO + do_adj1
# TODO: This does not work if DQ is not the same layout as Q (for example,
# if Q is broadcasted)
LSE1 = LSE + off_chz1
DELTA1 = DELTA + off_chz1
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Q_IDX and Q_NUM_BLKS are always contiguous.
q_indices = Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
{{gen_argdefs()}},
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
if HAS_FULL_BLOCKS:
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
q_indices = FULL_Q_IDX + sparse_q_idx_offset
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
dk, dv = bwd_dkdv_inner(
{{gen_argdefs()}},
Q1, DO1, DELTA1, LSE1,
dk, dv, k, v,
off_zq, off_hq1, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
# Write back dV and dK.
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
index_n = offs_n1[:, None]
index_k = offs_k[None, :]
index_v = offs_v[None, :]
if IS_DIVISIBLE and SAFE_HEAD_DIM:
tl.store(dv_ptrs, dv)
else:
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
dk *= SM_SCALE
if SAFE_HEAD_DIM:
mask = index_n < KV_LEN
else:
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
{{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
@triton.jit
def bwd_dq_inner(
{{gen_argdefs()}},
K, V, # pointers
dq, q, do, Di, lse,
off_z, off_hq, offs_m2, offs_n2,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
{{gen_defines() | indent_except_first(1) }}
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_n in range(0, hi - 1):
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_n in range(0, hi):
dq = bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_n, kv_indices, sparse_kv_num_blocks,
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
)
kT_ptrs += offset * stride_kn
vT_ptrs += offset * stride_vn
offs_n2 += offset
return dq
@triton.jit
def bwd_dq_block_mn(
{{gen_argdefs()}},
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
stride_kn, stride_kd, stride_vn, stride_vd,
kv_indices, sparse_kv_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
{{gen_defines() | indent_except_first(1)}}
# NB reversed order to since K is transposed
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qk *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
pre_mod_scores = qk
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
# that the M reads out of bounds prior to the last loop
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qk",
b="off_z",
h="off_hq",
m="m",
n="n",
out="qk"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=2,
output_name="mask_mod_output",
score="qk",
b="off_z",
h="off_hq",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# apply mask for partial masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
p = tl.math.exp2(post_mod_scores - lse)
# Compute dP and dS.
# NB reversed order to since V is transposed
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
ds = p * (dp - Di[:, None])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if WRITE_DQ:
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="ds"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
ds = tl.where(mask_mod_output, ds, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
ds = ds.to(MATMUL_PRECISION)
# Compute dQ.
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
return dq
@triton.jit
def bwd_dkdv_inner(
{{gen_argdefs()}},
Q, DO, DELTA, LSE, # pointers
dk, dv, k, v,
off_z, off_hq, offs_n1, offs_m1,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION,
IS_FULL_BLOCKS,
):
{{gen_defines() | indent_except_first(1) }}
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
RCP_LN2: tl.constexpr = 1.44269504
Q_LEN = {{size("Q", 2)}}
KV_LEN = {{size("K", 2)}}
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
if not IS_DIVISIBLE:
if hi >= 1:
for start_m in range(0, hi - 1):
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
)
else:
for start_m in range(0, hi):
dk, dv = bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS,
)
# Increment pointers.
offset = get_offset_for_next_block(
start_m, q_indices, sparse_q_num_blocks,
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
)
qT_ptrs += offset * stride_qm
do_ptrs += offset * stride_dom
offs_m1 += offset
return dk, dv
@triton.jit
def bwd_dkdv_block_mn(
{{gen_argdefs()}},
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
stride_qm, stride_qd, stride_dom, stride_dod,
q_indices, sparse_q_num_blocks,
MATMUL_PRECISION, RCP_LN2,
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
):
{{gen_defines() | indent_except_first(1) }}
# NB reversed order since Q is transposed
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
# Load LSE before computing qk to reduce pipeline stall.
if IS_DIVISIBLE:
lse = tl.load(LSE + offs_m1)
else:
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
lse = tl.where(lse == -float("inf"), 0.0, lse)
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
if not PRESCALE_QK:
qkT *= SM_SCALE
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
# that the n reads out of bounds prior to the last loop
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
pre_mod_scores = qkT
{{ modification(
subgraph_number=0,
output_name="post_mod_scores",
score="qkT",
b="off_z",
h="off_hq",
m="m",
n="n",
out="qkT"
) | indent_except_first(1) }}
if CHECK_BLOCK_BOUNDARY:
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
if not IS_FULL_BLOCKS:
{{ modification(
subgraph_number=2,
output_name="mask_mod_output",
score="qkT",
b="off_z",
h="off_hq",
m="m",
n="n",
) | indent_except_first(2) }}
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for fully masked block
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if not PRESCALE_QK:
post_mod_scores *= RCP_LN2
pT = tl.math.exp2(post_mod_scores - lse[None, :])
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
# Compute dV.
ppT = pT
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
if IS_DIVISIBLE:
Di = tl.load(DELTA + offs_m1)
else:
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
dsT = pT * (dpT - Di[None, :])
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
{{ modification(
subgraph_number=1,
output_name = "grad_scores",
score="pre_mod_scores",
b="off_z",
h="off_hq",
m="m",
n="n",
grad_score_mod="dsT"
) | indent_except_first(1) }}
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
if not WRITE_DQ:
idx_b = off_z
idx_h = off_hq
idx_m = m
idx_n = n
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
{{ modification(
subgraph_number=3,
output_name=None,
mask="scatter_mask",
score="pre_mod_scores",
b="idx_b",
h="idx_h",
m="idx_m",
n="idx_n",
grad_score_mod="dsT"
) | indent_except_first(2) }}
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
if CHECK_BLOCK_BOUNDARY:
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
dsT = grad_scores
if not IS_FULL_BLOCKS:
if CHECK_BLOCK_BOUNDARY:
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
# (grads) apply mask for partially unmasked block
dsT = tl.where(mask_mod_output, dsT, 0.0)
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
return dk, dv

View File

@ -1,252 +0,0 @@
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
# Sub notation for this kernel:
# Q: Query, K: Key, V: Value
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
# M: Number of queries, N: Number of keys/values
# QK_HEAD_DIM: The dimension of the query and key embeddings
# V_HEAD_DIM: The dimension of the value embeddings
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
# (Modifiable) Config options:
# SPLIT_KV: number of blocks K & V are split into
# TILE_KV: length of each local KV split
# BLOCK_M: block size that Q is padded along seqlen dim.
# BLOCK_N: block size of K & V along N dimension.
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
#
# change of base out of the loop
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
# is not masked out? If so, we can skip an extra safety check
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
#
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
#
#
# Output: ACC output accumulated across local KV split.
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
# Define Q Strides
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
Z = {{size("Q", 0)}}
ZKV = {{size("K", 0)}}
HKV = {{size("Q", 1)}}
G: tl.constexpr = GQA_SHARED_HEADS
HQ = HKV * G
Q_LEN = {{size("Q", 3)}}
KV_LEN = {{size("K", 2)}}
MATMUL_PRECISION = Q.dtype.element_ty
# Make sure each split is a multiple of BLOCK_N
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
off_z = tl.program_id(0) // HKV
off_zkv = off_z % ZKV
off_hkv = tl.program_id(0) % HKV
off_t = tl.program_id(1)
q_offset = off_z * stride_qz + off_hkv * stride_qh
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
sparse_idx_z = off_z % SPARSE_Z
sparse_idx_h = off_hkv % SPARSE_HQ
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
# initialize pointer to m and l
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
# initialize offsets
tl.device_assert(BLOCK_M % G == 0)
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
off_g = tl.arange(0, G) # [G]
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_hq = offs_g + off_hkv * G
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}}
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}}
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
# Calculate KV blocks that belong this CTA.
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
else:
q = tl.load(Q + q_offset + q_range)
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Apply both score_mod and mask_mod
# find first kv block we are loading and the number of blocks we are loading
# Offset the kv_indices tensor by the correct batch and head
kv_indices = KV_IDX + sparse_idx_hz_offset
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
# last valid block according to sparse mask
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
None,
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
IS_FULL_BLOCKS=False,
)
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# We know these blocks are guaranteed to be "full", so we don't need to
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
block_n_end = block_n_start + TILE_KV_MULTIPLE
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# last valid block according to sparse mask
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
strides=(stride_kk, stride_kn),
offsets=(0, off_n),
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
order=(0, 1)
)
V_block_ptr = tl.make_block_ptr(
base=V + v_offset,
shape=(KV_LEN, V_HEAD_DIM),
strides=(stride_vn, stride_vk),
offsets=(off_n, 0),
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
order=(1, 0)
)
offs_n = tl.arange(0, BLOCK_N) + off_n
acc, l_i, m_i = forward_inner(
{{gen_argdefs()}},
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
# accumulatd values
acc, l_i, m_i,
#offsets
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
None,
#block sparse data
kv_indices, kv_num_blocks,
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
MATMUL_PRECISION,
IS_FULL_BLOCKS=True,
)
m_offset = off_t * stride_mt + off_z * stride_mz
l_offset = off_t * stride_lt + off_z * stride_lz
M_block_ptr = tl.make_block_ptr(
base=M + m_offset,
shape=(G, Q_LEN), # (G, M)
strides=(stride_mh, stride_mm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
L_block_ptr = tl.make_block_ptr(
base=L + l_offset,
shape=(G, Q_LEN), # (G, M)
strides=(stride_lh, stride_lm),
offsets=(off_hkv*G, 0),
block_shape=(G, BLOCK_M_PER_HQ),
order=(1, 0)
)
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
if SAFE_M_BOUNDARY:
tl.store(M_block_ptr, m_i)
tl.store(L_block_ptr, l_i)
else:
tl.store(M_block_ptr, m_i, boundary_check=(1,))
tl.store(L_block_ptr, l_i, boundary_check=(1,))
# -- store output
idx_z = off_z
idx_t = off_t
idx_hq = off_hkv*G + off_g[:, None, None]
idx_m = off_m[None, :, None]
idx_d = offs_vd[None, None, :]
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}

View File

@ -1,59 +0,0 @@
# Utility triton funcs
@triton.jit
def get_offset_for_next_block(
loop_iter, col_indices, total_blocks,
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
):
if BLOCKS_ARE_CONTIGUOUS:
return BLOCK
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
return offset
@triton.jit
def get_bounded_indices(indices, max_len=None):
return indices % max_len if max_len is not None else indices
@triton.jit
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
if IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr)
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
else:
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
@triton.jit
def load_checked_2d(
ptr,
offs_m,
offs_n,
stride_m,
stride_n,
IS_DIVISIBLE_M: tl.constexpr,
IS_DIVISIBLE_N: tl.constexpr,
M_LEN: tl.constexpr,
N_DIM: tl.constexpr,
):
# Calculate final pointer if strides are provided
if stride_m is not None and stride_n is not None:
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
# Handle all masking cases
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
else: # Both divisible
return tl.load(ptr)