mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Factor out the strings to templates for better editor integration (#160357)
# Summary More code motion, tldr is that install 'Better Jinja' in vscode and now you can get highlighting Before <img width="776" height="926" alt="Screenshot 2025-08-11 at 2 41 08 PM" src="https://github.com/user-attachments/assets/10868b31-f8ac-4cf5-99fe-19b8789ce06b" /> After: <img width="1184" height="1299" alt="Screenshot 2025-08-11 at 2 40 27 PM" src="https://github.com/user-attachments/assets/45203765-589e-4d76-8196-d895a2f2fbf6" /> Pull Request resolved: https://github.com/pytorch/pytorch/pull/160357 Approved by: https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
78a2fe1d42
commit
cbffde7745
1
setup.py
1
setup.py
@ -1669,6 +1669,7 @@ def main() -> None:
|
||||
"_inductor/codegen/aoti_runtime/*.h",
|
||||
"_inductor/codegen/aoti_runtime/*.cpp",
|
||||
"_inductor/script.ld",
|
||||
"_inductor/kernel/flex/templates/*.jinja",
|
||||
"_export/serde/*.yaml",
|
||||
"_export/serde/*.thrift",
|
||||
"share/cmake/ATen/*.cmake",
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
import math
|
||||
from collections.abc import Sequence
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import sympy
|
||||
@ -323,267 +324,13 @@ def next_power_of_two(n):
|
||||
return 2 ** math.ceil(math.log2(n))
|
||||
|
||||
|
||||
# ---- Common Template Strings ----
|
||||
compute_forward_block_mn = r"""
|
||||
@triton.jit
|
||||
def forward_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
kv_offset,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||
|
||||
):
|
||||
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
# -- load k --
|
||||
# NB reversed order to since K is transposed
|
||||
{%- if USE_TMA %}
|
||||
k = tl.load_tensor_descriptor(
|
||||
desc_k,
|
||||
[kv_start + kv_offset, 0],
|
||||
)
|
||||
{%- else %}
|
||||
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
|
||||
{%- endif %}
|
||||
|
||||
if USE_TMA:
|
||||
k = tl.trans(k)
|
||||
# -- compute qk ---
|
||||
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
||||
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
||||
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
||||
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
||||
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name="mask_mod_output",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
||||
# apply mask for partially unmasked blocks
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# -- compute scaling constant ---
|
||||
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
||||
if not ROWS_GUARANTEED_SAFE:
|
||||
masked_out_rows = (m_ij == float("-inf"))
|
||||
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
||||
else:
|
||||
m_ij_masked = m_ij
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_ij_masked)
|
||||
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
||||
|
||||
# NB: l_i update is pulled up here since it's a bit faster
|
||||
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
||||
# m_ij
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
# # -- scale and update acc --
|
||||
acc = acc * alpha[:, None]
|
||||
{%- if USE_TMA %}
|
||||
v = tl.load_tensor_descriptor(
|
||||
desc_v,
|
||||
[kv_start + kv_offset, 0],
|
||||
)
|
||||
{%- else %}
|
||||
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||
{%- endif %}
|
||||
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
||||
|
||||
# -- update m_i
|
||||
m_i = m_ij
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
"""
|
||||
|
||||
compute_forward_inner = r"""
|
||||
@triton.jit
|
||||
def forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets used as inputs to score_mod & mask_mod
|
||||
# of size [BLOCK_M, BLOCK_N] or scalar.
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
# blocksparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
# start kv and end kv block
|
||||
block_n_start, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
kv_offset = 0
|
||||
|
||||
# loop over k, v and update accumulator until block_n_end
|
||||
for start_n in range(block_n_start, block_n_end):
|
||||
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
||||
if IS_DIVISIBLE:
|
||||
acc, l_i, m_i = forward_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
kv_offset,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
else:
|
||||
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
||||
# it's on par or slightly faster than only applying to the last block in fwd.
|
||||
# However, we choose different strategy for bwd, where we only apply mod & mask
|
||||
# to the last block because it's faster a lot.
|
||||
acc, l_i, m_i = forward_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
kv_offset,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||
)
|
||||
_TEMPLATE_DIR = Path(__file__).parent / "templates"
|
||||
|
||||
|
||||
|
||||
offset = get_offset_for_next_block(
|
||||
start_n, kv_indices, kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
offs_n = offs_n + offset
|
||||
kv_offset += offset
|
||||
if not USE_TMA:
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
||||
def load_template(name: str) -> str:
|
||||
"""Load a template file and return its content."""
|
||||
with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f:
|
||||
return f.read()
|
||||
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
"""
|
||||
|
||||
# Inner Triton functions shared by flex_attention & split-k decoding kernels.
|
||||
compute_next_offset_func = r"""
|
||||
@triton.jit
|
||||
def get_offset_for_next_block(
|
||||
loop_iter, col_indices, total_blocks,
|
||||
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
||||
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
||||
):
|
||||
if BLOCKS_ARE_CONTIGUOUS:
|
||||
return BLOCK
|
||||
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
||||
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
||||
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
||||
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
||||
return offset
|
||||
"""
|
||||
|
||||
get_bounded_indices_func = r"""
|
||||
@triton.jit
|
||||
def get_bounded_indices(indices, max_len=None):
|
||||
return indices % max_len if max_len is not None else indices
|
||||
"""
|
||||
|
||||
|
||||
load_checked_block = r"""
|
||||
@triton.jit
|
||||
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
||||
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
return tl.load(block_ptr)
|
||||
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
||||
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
||||
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
||||
else:
|
||||
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
||||
"""
|
||||
|
||||
load_checked_2d = r"""
|
||||
@triton.jit
|
||||
def load_checked_2d(
|
||||
ptr,
|
||||
offs_m,
|
||||
offs_n,
|
||||
stride_m,
|
||||
stride_n,
|
||||
IS_DIVISIBLE_M: tl.constexpr,
|
||||
IS_DIVISIBLE_N: tl.constexpr,
|
||||
M_LEN: tl.constexpr,
|
||||
N_DIM: tl.constexpr,
|
||||
):
|
||||
# Calculate final pointer if strides are provided
|
||||
if stride_m is not None and stride_n is not None:
|
||||
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
||||
|
||||
# Handle all masking cases
|
||||
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
||||
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
|
||||
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
||||
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
|
||||
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
||||
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
||||
else: # Both divisible
|
||||
return tl.load(ptr)
|
||||
"""
|
||||
# Template strings have been moved to templates/common.py.jinja
|
||||
|
@ -22,17 +22,12 @@ from ...select_algorithm import (
|
||||
)
|
||||
from .common import (
|
||||
build_subgraph_buffer,
|
||||
compute_forward_block_mn,
|
||||
compute_forward_inner,
|
||||
compute_next_offset_func,
|
||||
create_indices_fake,
|
||||
create_num_blocks_fake_generator,
|
||||
create_placeholder,
|
||||
get_bounded_indices_func,
|
||||
get_fwd_subgraph_outputs,
|
||||
infer_dense_strides,
|
||||
load_checked_2d,
|
||||
load_checked_block,
|
||||
load_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
SubgraphResults,
|
||||
@ -67,267 +62,12 @@ def get_float32_precision():
|
||||
return "'tf32'"
|
||||
|
||||
|
||||
compute_flex_attention = r"""
|
||||
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
#
|
||||
# Q: Query, K: Key, V: Value
|
||||
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
||||
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
||||
# V_HEAD_DIM: The dimension of the value embeddings
|
||||
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
#
|
||||
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
||||
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
#
|
||||
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
||||
#
|
||||
# (Modifiable) Performance tuning options
|
||||
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
||||
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
||||
|
||||
# The below are kernel options that can be applied for certain score_mods,
|
||||
# or involve a numerics vs. perf tradeoff
|
||||
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
||||
# about 20% more numerical error, but slightly faster.
|
||||
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
||||
# is not masked out? If so, we can skip an extra safety check
|
||||
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
||||
# contiguous? If so, we don't need to do an indirect jump for every block
|
||||
|
||||
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
||||
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
||||
|
||||
# Define strides of inputs
|
||||
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
||||
|
||||
ZQ = {{size("Q", 0)}}
|
||||
HQ = {{size("Q", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
ZKV = {{size("K", 0)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
q_start = tl.program_id(0)
|
||||
off_zq = tl.program_id(1)
|
||||
off_hq = tl.program_id(2)
|
||||
|
||||
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
||||
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
||||
off_zkv = off_zq % ZKV
|
||||
off_hkv = off_hq // GQA_SHARED_HEADS
|
||||
off_g = off_hq % GQA_SHARED_HEADS
|
||||
|
||||
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
||||
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
||||
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
||||
|
||||
Q = Q + q_offset
|
||||
K = K + k_offset
|
||||
V = V + v_offset
|
||||
|
||||
# Setting up the TMA descriptors for Q, K, V
|
||||
desc_q = None
|
||||
desc_k = None
|
||||
desc_v = None
|
||||
{%- if USE_TMA %}
|
||||
desc_q = tl.make_tensor_descriptor(
|
||||
base=Q,
|
||||
shape=[Q_LEN, QK_HEAD_DIM],
|
||||
strides=[stride_qm, 1],
|
||||
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
|
||||
desc_k = tl.make_tensor_descriptor(
|
||||
base=K,
|
||||
shape=[KV_LEN, QK_HEAD_DIM],
|
||||
strides=[stride_kn, 1],
|
||||
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
|
||||
desc_v = tl.make_tensor_descriptor(
|
||||
base=V,
|
||||
shape=[KV_LEN, V_HEAD_DIM],
|
||||
strides=[stride_vn, 1],
|
||||
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
{%- endif %}
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_zq % SPARSE_Z
|
||||
sparse_idx_hq = off_hq % SPARSE_HQ
|
||||
|
||||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
|
||||
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
||||
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
||||
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
||||
K_block_ptr = None
|
||||
V_block_ptr = None
|
||||
Q_block_ptr = None
|
||||
|
||||
if not USE_TMA:
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q ,
|
||||
shape=(Q_LEN, QK_HEAD_DIM),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(q_start * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
{%- if USE_TMA %}
|
||||
q = tl.load_tensor_descriptor(
|
||||
desc_q,
|
||||
[(q_start * BLOCK_M).to(tl.int32), 0],
|
||||
)
|
||||
{%- else %}
|
||||
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||
{%- endif %}
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We don't know anything "special" about these blocks, so we need to apply
|
||||
# both score_mod and mask_mod to it
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
|
||||
|
||||
if not USE_TMA:
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(QK_HEAD_DIM, KV_LEN),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, kv_start),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(kv_start, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
acc, l_i, m_i,
|
||||
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
||||
kv_start,
|
||||
kv_indices, kv_num_blocks,
|
||||
0, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We know these blocks are guaranteed to be "full", so we don't need to
|
||||
# apply mask_mod to them - only score_mod
|
||||
if HAS_FULL_BLOCKS:
|
||||
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
if not USE_TMA:
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(QK_HEAD_DIM, KV_LEN),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, kv_start),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(kv_start, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
acc, l_i, m_i,
|
||||
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
||||
kv_start,
|
||||
kv_indices, kv_num_blocks,
|
||||
0, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
|
||||
# [Note] Handle fully masked out rows:
|
||||
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
||||
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
||||
l_i = tl.where(l_i == 0.0, 1, l_i)
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
idx_zq = tl.program_id(1)
|
||||
idx_hq = tl.program_id(2)
|
||||
idx_m = offs_m[:, None]
|
||||
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
|
||||
|
||||
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
||||
|
||||
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
|
||||
if OUTPUT_LOGSUMEXP:
|
||||
off_hz = off_zq * HQ + off_hq
|
||||
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
||||
lse = m_i + tl.math.log2(l_i)
|
||||
if IS_DIVISIBLE:
|
||||
tl.store(l_ptrs, lse)
|
||||
else:
|
||||
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
||||
"""
|
||||
|
||||
|
||||
flex_attention_template = TritonTemplate(
|
||||
name="flex_attention",
|
||||
grid=flex_attention_grid,
|
||||
source=compute_flex_attention
|
||||
+ compute_forward_inner
|
||||
+ compute_next_offset_func
|
||||
+ compute_forward_block_mn
|
||||
+ load_checked_block
|
||||
+ get_bounded_indices_func,
|
||||
source=load_template("flex_attention")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
)
|
||||
|
||||
|
||||
@ -684,693 +424,7 @@ def flex_attention_backward_grid(
|
||||
flex_attention_backward_template = TritonTemplate(
|
||||
name="flex_attention_backward",
|
||||
grid=flex_attention_backward_grid,
|
||||
source=r"""
|
||||
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
#
|
||||
# Q: Query, K: Key, V: Value
|
||||
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
||||
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
||||
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
||||
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
||||
# inductor codegen
|
||||
# M: Number of queries, N: Number of keys/values
|
||||
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
||||
# V_HEAD_DIM: The dimension of the value embeddings
|
||||
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
# (Modifiable) Performance tuning options
|
||||
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
||||
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
||||
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
||||
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
||||
#
|
||||
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
||||
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
||||
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
||||
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
||||
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
||||
|
||||
# The below are kernel options that can be applied for certain score_mods,
|
||||
# or involve a numerics vs. perf tradeoff
|
||||
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
||||
# about 20% more numerical error, but slightly faster.
|
||||
|
||||
# Define strides of inputs
|
||||
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
|
||||
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
|
||||
|
||||
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
|
||||
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
|
||||
|
||||
ZQ = {{size("Q", 0)}}
|
||||
HQ = {{size("Q", 1)}}
|
||||
HKV = {{size("K", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
ZKV = {{size("K", 0)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
pid = tl.program_id(0)
|
||||
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
||||
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
||||
|
||||
off_zq = tl.program_id(1) # q batch idx
|
||||
off_hkv = tl.program_id(2) # kv head idx
|
||||
off_zkv = off_zq % ZKV # kv batch idx
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_zq % SPARSE_Z
|
||||
|
||||
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
||||
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
||||
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
||||
|
||||
# offset K, V, DV pointers for batch/kv-head
|
||||
K += k_adj
|
||||
V += v_adj
|
||||
DV += dv_adj
|
||||
|
||||
RCP_LN2 = 1.44269504
|
||||
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
if pid >= NUM_KV_BLOCKS:
|
||||
off_pid = pid - NUM_KV_BLOCKS
|
||||
# THIS BLOCK DOES DQ
|
||||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
||||
start_m2_block = off_pid % NUM_Q_BLOCKS
|
||||
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
||||
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
||||
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
||||
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
||||
|
||||
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
||||
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
||||
|
||||
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
||||
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
||||
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
||||
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
||||
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
||||
|
||||
Q2 = Q + q_adj2
|
||||
DO2 = DO + do_adj2
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
DQ2 = DQ + dq_adj2
|
||||
LSE2 = LSE + off_chz2
|
||||
DELTA2 = DELTA + off_chz2
|
||||
|
||||
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
||||
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
start_m2 = start_m2_block * BLOCK_M2
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
|
||||
# load Q and do: they stay in SRAM throughout the inner loop.
|
||||
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
||||
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
if IS_DIVISIBLE:
|
||||
Di = tl.load(DELTA2 + offs_m2)
|
||||
lse = tl.load(LSE2 + offs_m2)
|
||||
else:
|
||||
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
||||
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
||||
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
||||
lse = lse[:, None]
|
||||
|
||||
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
{{gen_argdefs()}},
|
||||
K, V,
|
||||
dq, q, do, Di, lse,
|
||||
off_zq, off_hq2, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
if HAS_FULL_BLOCKS:
|
||||
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
{{gen_argdefs()}},
|
||||
K, V,
|
||||
dq, q, do, Di, lse,
|
||||
off_zq, off_hq2, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
||||
dq *= SM_SCALE
|
||||
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
tl.store(dq_ptrs, dq)
|
||||
else:
|
||||
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
||||
else:
|
||||
# THIS BLOCK DOES DK & DV
|
||||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
||||
|
||||
pid_mask = pid // SPARSE_KV_MULTIPLE
|
||||
|
||||
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
|
||||
stride_q_idx_h = {{stride("Q_IDX", 1)}}
|
||||
stride_q_idx_n = {{stride("Q_IDX", 2)}}
|
||||
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
start_n1 = pid * BLOCK_N1
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
||||
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
||||
|
||||
if PRESCALE_QK:
|
||||
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
for off_g in range(0, GQA_SHARED_HEADS):
|
||||
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
||||
|
||||
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
||||
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
||||
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
||||
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
||||
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
||||
|
||||
Q1 = Q + q_adj1
|
||||
DO1 = DO + do_adj1
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
LSE1 = LSE + off_chz1
|
||||
DELTA1 = DELTA + off_chz1
|
||||
|
||||
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
||||
|
||||
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
||||
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
||||
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
||||
q_indices = Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
{{gen_argdefs()}},
|
||||
Q1, DO1, DELTA1, LSE1,
|
||||
dk, dv, k, v,
|
||||
off_zq, off_hq1, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
|
||||
if HAS_FULL_BLOCKS:
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
||||
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
{{gen_argdefs()}},
|
||||
Q1, DO1, DELTA1, LSE1,
|
||||
dk, dv, k, v,
|
||||
off_zq, off_hq1, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
# Write back dV and dK.
|
||||
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
||||
|
||||
index_n = offs_n1[:, None]
|
||||
index_k = offs_k[None, :]
|
||||
index_v = offs_v[None, :]
|
||||
|
||||
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
tl.store(dv_ptrs, dv)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
||||
|
||||
dk *= SM_SCALE
|
||||
|
||||
if SAFE_HEAD_DIM:
|
||||
mask = index_n < KV_LEN
|
||||
else:
|
||||
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
||||
|
||||
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
{{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
||||
|
||||
@triton.jit
|
||||
def bwd_dq_inner(
|
||||
{{gen_argdefs()}},
|
||||
K, V, # pointers
|
||||
dq, q, do, Di, lse,
|
||||
off_z, off_hq, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
||||
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
|
||||
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
||||
if not IS_DIVISIBLE:
|
||||
if hi >= 1:
|
||||
for start_n in range(0, hi - 1):
|
||||
dq = bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_n, kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
kT_ptrs += offset * stride_kn
|
||||
vT_ptrs += offset * stride_vn
|
||||
|
||||
offs_n2 += offset
|
||||
|
||||
dq = bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||
)
|
||||
else:
|
||||
for start_n in range(0, hi):
|
||||
dq = bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_n, kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
kT_ptrs += offset * stride_kn
|
||||
vT_ptrs += offset * stride_vn
|
||||
|
||||
offs_n2 += offset
|
||||
|
||||
return dq
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
# NB reversed order to since K is transposed
|
||||
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
||||
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
pre_mod_scores = qk
|
||||
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
||||
# that the M reads out of bounds prior to the last loop
|
||||
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
|
||||
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
||||
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
{{ modification(
|
||||
subgraph_number=2,
|
||||
output_name="mask_mod_output",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
|
||||
# apply mask for partial masked block
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
p = tl.math.exp2(post_mod_scores - lse)
|
||||
# Compute dP and dS.
|
||||
# NB reversed order to since V is transposed
|
||||
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
||||
|
||||
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
||||
ds = p * (dp - Di[:, None])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name = "grad_scores",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="ds"
|
||||
) | indent_except_first(1) }}
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
||||
if WRITE_DQ:
|
||||
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="ds"
|
||||
) | indent_except_first(2) }}
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
ds = grad_scores
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
|
||||
# (grads) apply mask for partially unmasked block
|
||||
ds = tl.where(mask_mod_output, ds, 0.0)
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
ds = ds.to(MATMUL_PRECISION)
|
||||
# Compute dQ.
|
||||
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
||||
|
||||
return dq
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_dkdv_inner(
|
||||
{{gen_argdefs()}},
|
||||
Q, DO, DELTA, LSE, # pointers
|
||||
dk, dv, k, v,
|
||||
off_z, off_hq, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
||||
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
||||
|
||||
if not IS_DIVISIBLE:
|
||||
if hi >= 1:
|
||||
for start_m in range(0, hi - 1):
|
||||
dk, dv = bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_m, q_indices, sparse_q_num_blocks,
|
||||
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
qT_ptrs += offset * stride_qm
|
||||
do_ptrs += offset * stride_dom
|
||||
|
||||
offs_m1 += offset
|
||||
|
||||
dk, dv = bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||
)
|
||||
else:
|
||||
for start_m in range(0, hi):
|
||||
dk, dv = bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_m, q_indices, sparse_q_num_blocks,
|
||||
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
qT_ptrs += offset * stride_qm
|
||||
do_ptrs += offset * stride_dom
|
||||
|
||||
offs_m1 += offset
|
||||
|
||||
return dk, dv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
|
||||
# NB reversed order since Q is transposed
|
||||
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
||||
# Load LSE before computing qk to reduce pipeline stall.
|
||||
if IS_DIVISIBLE:
|
||||
lse = tl.load(LSE + offs_m1)
|
||||
else:
|
||||
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
||||
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
||||
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
||||
if not PRESCALE_QK:
|
||||
qkT *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
||||
# that the n reads out of bounds prior to the last loop
|
||||
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
|
||||
|
||||
pre_mod_scores = qkT
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qkT",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qkT"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
||||
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
{{ modification(
|
||||
subgraph_number=2,
|
||||
output_name="mask_mod_output",
|
||||
score="qkT",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
|
||||
# (grads) apply mask for fully masked block
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
||||
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
||||
# Compute dV.
|
||||
ppT = pT
|
||||
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
||||
if IS_DIVISIBLE:
|
||||
Di = tl.load(DELTA + offs_m1)
|
||||
else:
|
||||
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name = "grad_scores",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="dsT"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
||||
if not WRITE_DQ:
|
||||
idx_b = off_z
|
||||
idx_h = off_hq
|
||||
idx_m = m
|
||||
idx_n = n
|
||||
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="idx_b",
|
||||
h="idx_h",
|
||||
m="idx_m",
|
||||
n="idx_n",
|
||||
grad_score_mod="dsT"
|
||||
) | indent_except_first(2) }}
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
|
||||
|
||||
dsT = grad_scores
|
||||
if not IS_FULL_BLOCKS:
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
|
||||
# (grads) apply mask for partially unmasked block
|
||||
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
||||
|
||||
return dk, dv
|
||||
"""
|
||||
+ compute_next_offset_func
|
||||
+ get_bounded_indices_func
|
||||
+ load_checked_2d,
|
||||
source=load_template("flex_backwards") + load_template("utilities"),
|
||||
)
|
||||
|
||||
|
||||
|
@ -18,15 +18,10 @@ from ...select_algorithm import (
|
||||
TritonTemplate,
|
||||
)
|
||||
from .common import (
|
||||
compute_forward_block_mn,
|
||||
compute_forward_inner,
|
||||
compute_next_offset_func,
|
||||
create_indices_fake,
|
||||
create_num_blocks_fake_generator,
|
||||
get_bounded_indices_func,
|
||||
get_fwd_subgraph_outputs,
|
||||
load_checked_2d,
|
||||
load_checked_block,
|
||||
load_template,
|
||||
maybe_realize,
|
||||
set_head_dim_values,
|
||||
)
|
||||
@ -90,266 +85,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me
|
||||
flex_decoding_template = TritonTemplate(
|
||||
name="flex_decoding",
|
||||
grid=flex_decoding_grid,
|
||||
source=r"""
|
||||
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
# Q: Query, K: Key, V: Value
|
||||
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
||||
# M: Number of queries, N: Number of keys/values
|
||||
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
||||
# V_HEAD_DIM: The dimension of the value embeddings
|
||||
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
||||
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
||||
# (Modifiable) Config options:
|
||||
# SPLIT_KV: number of blocks K & V are split into
|
||||
# TILE_KV: length of each local KV split
|
||||
# BLOCK_M: block size that Q is padded along seqlen dim.
|
||||
# BLOCK_N: block size of K & V along N dimension.
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
#
|
||||
# change of base out of the loop
|
||||
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
||||
# is not masked out? If so, we can skip an extra safety check
|
||||
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
||||
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
||||
|
||||
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
||||
#
|
||||
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
||||
#
|
||||
#
|
||||
# Output: ACC output accumulated across local KV split.
|
||||
|
||||
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
||||
|
||||
# Define Q Strides
|
||||
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
||||
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
|
||||
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
|
||||
|
||||
|
||||
Z = {{size("Q", 0)}}
|
||||
ZKV = {{size("K", 0)}}
|
||||
HKV = {{size("Q", 1)}}
|
||||
G: tl.constexpr = GQA_SHARED_HEADS
|
||||
HQ = HKV * G
|
||||
Q_LEN = {{size("Q", 3)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
# Make sure each split is a multiple of BLOCK_N
|
||||
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
||||
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
||||
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
||||
|
||||
off_z = tl.program_id(0) // HKV
|
||||
off_zkv = off_z % ZKV
|
||||
off_hkv = tl.program_id(0) % HKV
|
||||
off_t = tl.program_id(1)
|
||||
|
||||
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
||||
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
||||
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_z % SPARSE_Z
|
||||
sparse_idx_h = off_hkv % SPARSE_HQ
|
||||
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
# initialize offsets
|
||||
tl.device_assert(BLOCK_M % G == 0)
|
||||
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
||||
off_g = tl.arange(0, G) # [G]
|
||||
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
||||
offs_hq = offs_g + off_hkv * G
|
||||
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
||||
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
||||
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
||||
stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}}
|
||||
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
||||
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}}
|
||||
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
||||
|
||||
# Calculate KV blocks that belong this CTA.
|
||||
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
||||
|
||||
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
||||
|
||||
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
||||
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
||||
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
||||
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
||||
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
||||
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
||||
else:
|
||||
q = tl.load(Q + q_offset + q_range)
|
||||
|
||||
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Apply both score_mod and mask_mod
|
||||
|
||||
# find first kv block we are loading and the number of blocks we are loading
|
||||
# Offset the kv_indices tensor by the correct batch and head
|
||||
kv_indices = KV_IDX + sparse_idx_hz_offset
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
# first kv block we're loading
|
||||
|
||||
# last valid block according to sparse mask
|
||||
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, off_n),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||
None,
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We know these blocks are guaranteed to be "full", so we don't need to
|
||||
# apply mask_mod to them - only score_mod
|
||||
if HAS_FULL_BLOCKS:
|
||||
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
||||
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
||||
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
|
||||
# last valid block according to sparse mask
|
||||
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, off_n),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||
None,
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
m_offset = off_t * stride_mt + off_z * stride_mz
|
||||
l_offset = off_t * stride_lt + off_z * stride_lz
|
||||
|
||||
M_block_ptr = tl.make_block_ptr(
|
||||
base=M + m_offset,
|
||||
shape=(G, Q_LEN), # (G, M)
|
||||
strides=(stride_mh, stride_mm),
|
||||
offsets=(off_hkv*G, 0),
|
||||
block_shape=(G, BLOCK_M_PER_HQ),
|
||||
order=(1, 0)
|
||||
)
|
||||
L_block_ptr = tl.make_block_ptr(
|
||||
base=L + l_offset,
|
||||
shape=(G, Q_LEN), # (G, M)
|
||||
strides=(stride_lh, stride_lm),
|
||||
offsets=(off_hkv*G, 0),
|
||||
block_shape=(G, BLOCK_M_PER_HQ),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
||||
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
||||
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
||||
if SAFE_M_BOUNDARY:
|
||||
tl.store(M_block_ptr, m_i)
|
||||
tl.store(L_block_ptr, l_i)
|
||||
else:
|
||||
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
||||
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
||||
|
||||
# -- store output
|
||||
idx_z = off_z
|
||||
idx_t = off_t
|
||||
idx_hq = off_hkv*G + off_g[:, None, None]
|
||||
idx_m = off_m[None, :, None]
|
||||
idx_d = offs_vd[None, None, :]
|
||||
|
||||
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
||||
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
||||
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
"""
|
||||
+ compute_forward_inner
|
||||
+ get_bounded_indices_func
|
||||
+ load_checked_block
|
||||
+ load_checked_2d
|
||||
+ compute_next_offset_func
|
||||
+ compute_forward_block_mn,
|
||||
source=load_template("flex_decode")
|
||||
+ load_template("utilities")
|
||||
+ load_template("common"),
|
||||
)
|
||||
|
||||
|
||||
|
193
torch/_inductor/kernel/flex/templates/common.py.jinja
Normal file
193
torch/_inductor/kernel/flex/templates/common.py.jinja
Normal file
@ -0,0 +1,193 @@
|
||||
|
||||
|
||||
# Common Imports
|
||||
@triton.jit
|
||||
def forward_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
kv_offset,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||
|
||||
):
|
||||
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
# -- load k --
|
||||
# NB reversed order to since K is transposed
|
||||
{%- if USE_TMA %}
|
||||
k = tl.load_tensor_descriptor(
|
||||
desc_k,
|
||||
[kv_start + kv_offset, 0],
|
||||
)
|
||||
{%- else %}
|
||||
k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE)
|
||||
{%- endif %}
|
||||
|
||||
if USE_TMA:
|
||||
k = tl.trans(k)
|
||||
# -- compute qk ---
|
||||
qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2.
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
# If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements,
|
||||
# which is larger than the actual number of elements. To avoid access memory out of bound,
|
||||
# we need to mask out the elements that are out of Q_LEN & KV_LEN.
|
||||
m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
||||
post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf"))
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name="mask_mod_output",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_h",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False)
|
||||
# apply mask for partially unmasked blocks
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# -- compute scaling constant ---
|
||||
m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1))
|
||||
if not ROWS_GUARANTEED_SAFE:
|
||||
masked_out_rows = (m_ij == float("-inf"))
|
||||
m_ij_masked = tl.where(masked_out_rows, 0, m_ij)
|
||||
else:
|
||||
m_ij_masked = m_ij
|
||||
|
||||
alpha = tl.math.exp2(m_i - m_ij_masked)
|
||||
p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None])
|
||||
|
||||
# NB: l_i update is pulled up here since it's a bit faster
|
||||
# NB: For headdim=256, it's faster to move it back down to after m_i =
|
||||
# m_ij
|
||||
l_i = l_i * alpha + tl.sum(p, 1)
|
||||
# # -- scale and update acc --
|
||||
acc = acc * alpha[:, None]
|
||||
{%- if USE_TMA %}
|
||||
v = tl.load_tensor_descriptor(
|
||||
desc_v,
|
||||
[kv_start + kv_offset, 0],
|
||||
)
|
||||
{%- else %}
|
||||
v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||
{%- endif %}
|
||||
acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION)
|
||||
|
||||
# -- update m_i
|
||||
m_i = m_ij
|
||||
|
||||
return acc, l_i, m_i
|
||||
|
||||
@triton.jit
|
||||
def forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets used as inputs to score_mod & mask_mod
|
||||
# of size [BLOCK_M, BLOCK_N] or scalar.
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
# blocksparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
# start kv and end kv block
|
||||
block_n_start, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
# Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
kv_offset = 0
|
||||
|
||||
# loop over k, v and update accumulator until block_n_end
|
||||
for start_n in range(block_n_start, block_n_end):
|
||||
# Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention.
|
||||
if IS_DIVISIBLE:
|
||||
acc, l_i, m_i = forward_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
kv_offset,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
else:
|
||||
# Benchmark shows even we applied mod & mask to each block for non divisible seqlen,
|
||||
# it's on par or slightly faster than only applying to the last block in fwd.
|
||||
# However, we choose different strategy for bwd, where we only apply mod & mask
|
||||
# to the last block because it's faster a lot.
|
||||
acc, l_i, m_i = forward_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
# accumulated values
|
||||
acc, l_i, m_i,
|
||||
# Offsets
|
||||
off_z, off_h, offs_m, offs_n,
|
||||
# Offsets needed for TMA loads
|
||||
kv_start,
|
||||
kv_offset,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
offset = get_offset_for_next_block(
|
||||
start_n, kv_indices, kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
offs_n = offs_n + offset
|
||||
kv_offset += offset
|
||||
if not USE_TMA:
|
||||
K_block_ptr = tl.advance(K_block_ptr, (0, offset))
|
||||
V_block_ptr = tl.advance(V_block_ptr, (offset, 0))
|
||||
|
||||
|
||||
return acc, l_i, m_i
|
248
torch/_inductor/kernel/flex/templates/flex_attention.py.jinja
Normal file
248
torch/_inductor/kernel/flex/templates/flex_attention.py.jinja
Normal file
@ -0,0 +1,248 @@
|
||||
{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
#
|
||||
# Q: Query, K: Key, V: Value
|
||||
# M: Number of queries, N: Number of keys/values, D: Model dimension
|
||||
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
||||
# V_HEAD_DIM: The dimension of the value embeddings
|
||||
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
#
|
||||
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
||||
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
#
|
||||
# OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad
|
||||
#
|
||||
# (Modifiable) Performance tuning options
|
||||
# BLOCK_M: The thread block size across the seqlen dim of Q.
|
||||
# BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block.
|
||||
|
||||
# The below are kernel options that can be applied for certain score_mods,
|
||||
# or involve a numerics vs. perf tradeoff
|
||||
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
||||
# about 20% more numerical error, but slightly faster.
|
||||
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
||||
# is not masked out? If so, we can skip an extra safety check
|
||||
# BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are
|
||||
# contiguous? If so, we don't need to do an indirect jump for every block
|
||||
|
||||
tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0)
|
||||
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
||||
|
||||
# Define strides of inputs
|
||||
stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
||||
|
||||
ZQ = {{size("Q", 0)}}
|
||||
HQ = {{size("Q", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
ZKV = {{size("K", 0)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
q_start = tl.program_id(0)
|
||||
off_zq = tl.program_id(1)
|
||||
off_hq = tl.program_id(2)
|
||||
|
||||
# We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq.
|
||||
# b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0.
|
||||
off_zkv = off_zq % ZKV
|
||||
off_hkv = off_hq // GQA_SHARED_HEADS
|
||||
off_g = off_hq % GQA_SHARED_HEADS
|
||||
|
||||
q_offset = off_zq * stride_qz + off_hq * stride_qh
|
||||
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
||||
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
||||
|
||||
Q = Q + q_offset
|
||||
K = K + k_offset
|
||||
V = V + v_offset
|
||||
|
||||
# Setting up the TMA descriptors for Q, K, V
|
||||
desc_q = None
|
||||
desc_k = None
|
||||
desc_v = None
|
||||
{%- if USE_TMA %}
|
||||
desc_q = tl.make_tensor_descriptor(
|
||||
base=Q,
|
||||
shape=[Q_LEN, QK_HEAD_DIM],
|
||||
strides=[stride_qm, 1],
|
||||
block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
|
||||
desc_k = tl.make_tensor_descriptor(
|
||||
base=K,
|
||||
shape=[KV_LEN, QK_HEAD_DIM],
|
||||
strides=[stride_kn, 1],
|
||||
block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
|
||||
desc_v = tl.make_tensor_descriptor(
|
||||
base=V,
|
||||
shape=[KV_LEN, V_HEAD_DIM],
|
||||
strides=[stride_vn, 1],
|
||||
block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED],
|
||||
)
|
||||
{%- endif %}
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_zq % SPARSE_Z
|
||||
sparse_idx_hq = off_hq % SPARSE_HQ
|
||||
|
||||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M)
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
|
||||
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
||||
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
||||
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M)
|
||||
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE
|
||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950
|
||||
K_block_ptr = None
|
||||
V_block_ptr = None
|
||||
Q_block_ptr = None
|
||||
|
||||
if not USE_TMA:
|
||||
Q_block_ptr = tl.make_block_ptr(
|
||||
base=Q ,
|
||||
shape=(Q_LEN, QK_HEAD_DIM),
|
||||
strides=(stride_qm, stride_qk),
|
||||
offsets=(q_start * BLOCK_M, 0),
|
||||
block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
{%- if USE_TMA %}
|
||||
q = tl.load_tensor_descriptor(
|
||||
desc_q,
|
||||
[(q_start * BLOCK_M).to(tl.int32), 0],
|
||||
)
|
||||
{%- else %}
|
||||
q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM)
|
||||
{%- endif %}
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We don't know anything "special" about these blocks, so we need to apply
|
||||
# both score_mod and mask_mod to it
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
|
||||
|
||||
if not USE_TMA:
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(QK_HEAD_DIM, KV_LEN),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, kv_start),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(kv_start, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
acc, l_i, m_i,
|
||||
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
||||
kv_start,
|
||||
kv_indices, kv_num_blocks,
|
||||
0, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We know these blocks are guaranteed to be "full", so we don't need to
|
||||
# apply mask_mod to them - only score_mod
|
||||
if HAS_FULL_BLOCKS:
|
||||
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
if not USE_TMA:
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K,
|
||||
shape=(QK_HEAD_DIM, KV_LEN),
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, kv_start),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(kv_start, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = kv_start + tl.arange(0, BLOCK_N)
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr,
|
||||
desc_k, desc_v, Q_LEN, KV_LEN,
|
||||
acc, l_i, m_i,
|
||||
off_zq, off_hq, offs_m[:, None], offs_n[None, :],
|
||||
kv_start,
|
||||
kv_indices, kv_num_blocks,
|
||||
0, block_n_end,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
|
||||
# [Note] Handle fully masked out rows:
|
||||
# Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf.
|
||||
# We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step
|
||||
l_i = tl.where(l_i == 0.0, 1, l_i)
|
||||
|
||||
acc = acc / l_i[:, None]
|
||||
idx_zq = tl.program_id(1)
|
||||
idx_hq = tl.program_id(2)
|
||||
idx_m = offs_m[:, None]
|
||||
idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :]
|
||||
|
||||
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
||||
|
||||
{{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
||||
|
||||
if OUTPUT_LOGSUMEXP:
|
||||
off_hz = off_zq * HQ + off_hq
|
||||
l_ptrs = LSE + off_hz * Q_LEN + offs_m
|
||||
lse = m_i + tl.math.log2(l_i)
|
||||
if IS_DIVISIBLE:
|
||||
tl.store(l_ptrs, lse)
|
||||
else:
|
||||
tl.store(l_ptrs, lse, mask=offs_m < Q_LEN)
|
682
torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja
Normal file
682
torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja
Normal file
@ -0,0 +1,682 @@
|
||||
{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
#
|
||||
# Q: Query, K: Key, V: Value
|
||||
# LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype)
|
||||
# DELTA: Precomputed sum(OUT*DO, axis=-1)
|
||||
# DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value
|
||||
# DK: Derivative of Key, is the written to via the store_output call due to some limitations with
|
||||
# inductor codegen
|
||||
# M: Number of queries, N: Number of keys/values
|
||||
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
||||
# V_HEAD_DIM: The dimension of the value embeddings
|
||||
# z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
# (Modifiable) Performance tuning options
|
||||
# BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block.
|
||||
# BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V.
|
||||
# BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q.
|
||||
# BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block.
|
||||
#
|
||||
# The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
||||
# Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query.
|
||||
# Q_IDX: The indices of Q blocks (that may or may not require masking) for each query.
|
||||
# FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
# FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query.
|
||||
# FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query.
|
||||
# FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query.
|
||||
|
||||
# The below are kernel options that can be applied for certain score_mods,
|
||||
# or involve a numerics vs. perf tradeoff
|
||||
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has
|
||||
# about 20% more numerical error, but slightly faster.
|
||||
|
||||
# Define strides of inputs
|
||||
stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}}
|
||||
stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}}
|
||||
|
||||
stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}}
|
||||
stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}}
|
||||
|
||||
ZQ = {{size("Q", 0)}}
|
||||
HQ = {{size("Q", 1)}}
|
||||
HKV = {{size("K", 1)}}
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
ZKV = {{size("K", 0)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
pid = tl.program_id(0)
|
||||
NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1)
|
||||
NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2)
|
||||
|
||||
off_zq = tl.program_id(1) # q batch idx
|
||||
off_hkv = tl.program_id(2) # kv head idx
|
||||
off_zkv = off_zq % ZKV # kv batch idx
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_zq % SPARSE_Z
|
||||
|
||||
k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64)
|
||||
v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64)
|
||||
# first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
# then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64)
|
||||
|
||||
# offset K, V, DV pointers for batch/kv-head
|
||||
K += k_adj
|
||||
V += v_adj
|
||||
DV += dv_adj
|
||||
|
||||
RCP_LN2 = 1.44269504
|
||||
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
if pid >= NUM_KV_BLOCKS:
|
||||
off_pid = pid - NUM_KV_BLOCKS
|
||||
# THIS BLOCK DOES DQ
|
||||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS
|
||||
start_m2_block = off_pid % NUM_Q_BLOCKS
|
||||
off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE
|
||||
stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}}
|
||||
stride_kv_idx_h = {{stride("KV_IDX", 1)}}
|
||||
stride_kv_idx_m = {{stride("KV_IDX", 2)}}
|
||||
|
||||
sparse_idx_hq2 = off_hq2 % SPARSE_HQ
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2
|
||||
|
||||
sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask
|
||||
sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950
|
||||
|
||||
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
||||
q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64)
|
||||
do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64)
|
||||
dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64)
|
||||
off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64)
|
||||
|
||||
Q2 = Q + q_adj2
|
||||
DO2 = DO + do_adj2
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
DQ2 = DQ + dq_adj2
|
||||
LSE2 = LSE + off_chz2
|
||||
DELTA2 = DELTA + off_chz2
|
||||
|
||||
# dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32)
|
||||
dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
start_m2 = start_m2_block * BLOCK_M2
|
||||
offs_m2 = start_m2 + tl.arange(0, BLOCK_M2)
|
||||
|
||||
# load Q and do: they stay in SRAM throughout the inner loop.
|
||||
q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM)
|
||||
do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
||||
|
||||
if PRESCALE_QK:
|
||||
q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
if IS_DIVISIBLE:
|
||||
Di = tl.load(DELTA2 + offs_m2)
|
||||
lse = tl.load(LSE2 + offs_m2)
|
||||
else:
|
||||
Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN)
|
||||
lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN)
|
||||
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
||||
lse = lse[:, None]
|
||||
|
||||
# ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# KV_IDX and KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
{{gen_argdefs()}},
|
||||
K, V,
|
||||
dq, q, do, Di, lse,
|
||||
off_zq, off_hq2, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
if HAS_FULL_BLOCKS:
|
||||
# ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous.
|
||||
kv_indices = FULL_KV_IDX + sparse_kv_idx_offset
|
||||
kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading
|
||||
sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset)
|
||||
|
||||
offs_n2 = kv_start + tl.arange(0, BLOCK_N2)
|
||||
dq = bwd_dq_inner(
|
||||
{{gen_argdefs()}},
|
||||
K, V,
|
||||
dq, q, do, Di, lse,
|
||||
off_zq, off_hq2, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
# Write back dQ.
|
||||
dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd
|
||||
dq *= SM_SCALE
|
||||
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
tl.store(dq_ptrs, dq)
|
||||
else:
|
||||
tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM))
|
||||
else:
|
||||
# THIS BLOCK DOES DK & DV
|
||||
SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
||||
SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1)
|
||||
|
||||
pid_mask = pid // SPARSE_KV_MULTIPLE
|
||||
|
||||
stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}}
|
||||
stride_q_idx_h = {{stride("Q_IDX", 1)}}
|
||||
stride_q_idx_n = {{stride("Q_IDX", 2)}}
|
||||
|
||||
|
||||
dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
start_n1 = pid * BLOCK_N1
|
||||
offs_n1 = start_n1 + tl.arange(0, BLOCK_N1)
|
||||
|
||||
# load K and V: they stay in SRAM throughout the inner loop.
|
||||
k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM)
|
||||
v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM)
|
||||
|
||||
if PRESCALE_QK:
|
||||
k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION)
|
||||
|
||||
for off_g in range(0, GQA_SHARED_HEADS):
|
||||
off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g
|
||||
|
||||
# Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads.
|
||||
q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64)
|
||||
do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64)
|
||||
dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64)
|
||||
off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64)
|
||||
|
||||
Q1 = Q + q_adj1
|
||||
DO1 = DO + do_adj1
|
||||
# TODO: This does not work if DQ is not the same layout as Q (for example,
|
||||
# if Q is broadcasted)
|
||||
LSE1 = LSE + off_chz1
|
||||
DELTA1 = DELTA + off_chz1
|
||||
|
||||
sparse_idx_hq1 = off_hq1 % SPARSE_HQ
|
||||
sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1
|
||||
|
||||
sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask
|
||||
sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950
|
||||
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Q_IDX and Q_NUM_BLKS are always contiguous.
|
||||
q_indices = Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
{{gen_argdefs()}},
|
||||
Q1, DO1, DELTA1, LSE1,
|
||||
dk, dv, k, v,
|
||||
off_zq, off_hq1, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
|
||||
if HAS_FULL_BLOCKS:
|
||||
# ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous.
|
||||
q_indices = FULL_Q_IDX + sparse_q_idx_offset
|
||||
q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading
|
||||
sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset)
|
||||
|
||||
offs_m1 = q_start + tl.arange(0, BLOCK_M1)
|
||||
dk, dv = bwd_dkdv_inner(
|
||||
{{gen_argdefs()}},
|
||||
Q1, DO1, DELTA1, LSE1,
|
||||
dk, dv, k, v,
|
||||
off_zq, off_hq1, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
# Write back dV and dK.
|
||||
dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd
|
||||
|
||||
index_n = offs_n1[:, None]
|
||||
index_k = offs_k[None, :]
|
||||
index_v = offs_v[None, :]
|
||||
|
||||
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
tl.store(dv_ptrs, dv)
|
||||
else:
|
||||
tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM))
|
||||
|
||||
dk *= SM_SCALE
|
||||
|
||||
if SAFE_HEAD_DIM:
|
||||
mask = index_n < KV_LEN
|
||||
else:
|
||||
mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM)
|
||||
|
||||
# first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
# then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM]
|
||||
{{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}}
|
||||
|
||||
@triton.jit
|
||||
def bwd_dq_inner(
|
||||
{{gen_argdefs()}},
|
||||
K, V, # pointers
|
||||
dq, q, do, Di, lse,
|
||||
off_z, off_hq, offs_m2, offs_n2,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd
|
||||
vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd
|
||||
# BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0)
|
||||
|
||||
hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1))
|
||||
if not IS_DIVISIBLE:
|
||||
if hi >= 1:
|
||||
for start_n in range(0, hi - 1):
|
||||
dq = bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_n, kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
kT_ptrs += offset * stride_kn
|
||||
vT_ptrs += offset * stride_vn
|
||||
|
||||
offs_n2 += offset
|
||||
|
||||
dq = bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||
)
|
||||
else:
|
||||
for start_n in range(0, hi):
|
||||
dq = bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_n, kv_indices, sparse_kv_num_blocks,
|
||||
SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
kT_ptrs += offset * stride_kn
|
||||
vT_ptrs += offset * stride_vn
|
||||
|
||||
offs_n2 += offset
|
||||
|
||||
return dq
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_dq_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v,
|
||||
stride_kn, stride_kd, stride_vn, stride_vd,
|
||||
kv_indices, sparse_kv_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1)}}
|
||||
|
||||
# NB reversed order to since K is transposed
|
||||
kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN)
|
||||
qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION)
|
||||
if not PRESCALE_QK:
|
||||
qk *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
pre_mod_scores = qk
|
||||
n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
# The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim
|
||||
# that the M reads out of bounds prior to the last loop
|
||||
m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
|
||||
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qk"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
||||
post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf"))
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
{{ modification(
|
||||
subgraph_number=2,
|
||||
output_name="mask_mod_output",
|
||||
score="qk",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
|
||||
# apply mask for partial masked block
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
p = tl.math.exp2(post_mod_scores - lse)
|
||||
# Compute dP and dS.
|
||||
# NB reversed order to since V is transposed
|
||||
vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN)
|
||||
|
||||
dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION)
|
||||
ds = p * (dp - Di[:, None])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name = "grad_scores",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="ds"
|
||||
) | indent_except_first(1) }}
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0)
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
||||
if WRITE_DQ:
|
||||
scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN)
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="ds"
|
||||
) | indent_except_first(2) }}
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
ds = grad_scores
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False)
|
||||
# (grads) apply mask for partially unmasked block
|
||||
ds = tl.where(mask_mod_output, ds, 0.0)
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
ds = ds.to(MATMUL_PRECISION)
|
||||
# Compute dQ.
|
||||
dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION)
|
||||
|
||||
return dq
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_dkdv_inner(
|
||||
{{gen_argdefs()}},
|
||||
Q, DO, DELTA, LSE, # pointers
|
||||
dk, dv, k, v,
|
||||
off_z, off_hq, offs_n1, offs_m1,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1)
|
||||
RCP_LN2: tl.constexpr = 1.44269504
|
||||
Q_LEN = {{size("Q", 2)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd
|
||||
do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod
|
||||
# BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work.
|
||||
tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0)
|
||||
hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1))
|
||||
|
||||
if not IS_DIVISIBLE:
|
||||
if hi >= 1:
|
||||
for start_m in range(0, hi - 1):
|
||||
dk, dv = bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_m, q_indices, sparse_q_num_blocks,
|
||||
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
qT_ptrs += offset * stride_qm
|
||||
do_ptrs += offset * stride_dom
|
||||
|
||||
offs_m1 += offset
|
||||
|
||||
dk, dv = bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True,
|
||||
)
|
||||
else:
|
||||
for start_m in range(0, hi):
|
||||
dk, dv = bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS,
|
||||
)
|
||||
# Increment pointers.
|
||||
offset = get_offset_for_next_block(
|
||||
start_m, q_indices, sparse_q_num_blocks,
|
||||
SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS
|
||||
)
|
||||
|
||||
qT_ptrs += offset * stride_qm
|
||||
do_ptrs += offset * stride_dom
|
||||
|
||||
offs_m1 += offset
|
||||
|
||||
return dk, dv
|
||||
|
||||
|
||||
@triton.jit
|
||||
def bwd_dkdv_block_mn(
|
||||
{{gen_argdefs()}},
|
||||
dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN,
|
||||
off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v,
|
||||
stride_qm, stride_qd, stride_dom, stride_dod,
|
||||
q_indices, sparse_q_num_blocks,
|
||||
MATMUL_PRECISION, RCP_LN2,
|
||||
IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False,
|
||||
):
|
||||
{{gen_defines() | indent_except_first(1) }}
|
||||
|
||||
# NB reversed order since Q is transposed
|
||||
qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN)
|
||||
# Load LSE before computing qk to reduce pipeline stall.
|
||||
if IS_DIVISIBLE:
|
||||
lse = tl.load(LSE + offs_m1)
|
||||
else:
|
||||
lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN)
|
||||
lse = tl.where(lse == -float("inf"), 0.0, lse)
|
||||
qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION)
|
||||
if not PRESCALE_QK:
|
||||
qkT *= SM_SCALE
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~
|
||||
m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None)
|
||||
# The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim
|
||||
# that the n reads out of bounds prior to the last loop
|
||||
n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None)
|
||||
|
||||
pre_mod_scores = qkT
|
||||
{{ modification(
|
||||
subgraph_number=0,
|
||||
output_name="post_mod_scores",
|
||||
score="qkT",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
out="qkT"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
# Mask out the elements that are out of the KV_LEN for non divisible seqlen.
|
||||
post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf"))
|
||||
|
||||
if not IS_FULL_BLOCKS:
|
||||
{{ modification(
|
||||
subgraph_number=2,
|
||||
output_name="mask_mod_output",
|
||||
score="qkT",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
) | indent_except_first(2) }}
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
|
||||
# (grads) apply mask for fully masked block
|
||||
post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf"))
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
if not PRESCALE_QK:
|
||||
post_mod_scores *= RCP_LN2
|
||||
pT = tl.math.exp2(post_mod_scores - lse[None, :])
|
||||
do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM)
|
||||
# Compute dV.
|
||||
ppT = pT
|
||||
dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION)
|
||||
if IS_DIVISIBLE:
|
||||
Di = tl.load(DELTA + offs_m1)
|
||||
else:
|
||||
Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN)
|
||||
# Compute dP and dS.
|
||||
dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION)
|
||||
dsT = pT * (dpT - Di[None, :])
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~
|
||||
{{ modification(
|
||||
subgraph_number=1,
|
||||
output_name = "grad_scores",
|
||||
score="pre_mod_scores",
|
||||
b="off_z",
|
||||
h="off_hq",
|
||||
m="m",
|
||||
n="n",
|
||||
grad_score_mod="dsT"
|
||||
) | indent_except_first(1) }}
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~
|
||||
if not WRITE_DQ:
|
||||
idx_b = off_z
|
||||
idx_h = off_hq
|
||||
idx_m = m
|
||||
idx_n = n
|
||||
scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN)
|
||||
{{ modification(
|
||||
subgraph_number=3,
|
||||
output_name=None,
|
||||
mask="scatter_mask",
|
||||
score="pre_mod_scores",
|
||||
b="idx_b",
|
||||
h="idx_h",
|
||||
m="idx_m",
|
||||
n="idx_n",
|
||||
grad_score_mod="dsT"
|
||||
) | indent_except_first(2) }}
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0)
|
||||
|
||||
dsT = grad_scores
|
||||
if not IS_FULL_BLOCKS:
|
||||
if CHECK_BLOCK_BOUNDARY:
|
||||
mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False)
|
||||
# (grads) apply mask for partially unmasked block
|
||||
dsT = tl.where(mask_mod_output, dsT, 0.0)
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION)
|
||||
|
||||
return dk, dv
|
252
torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
Normal file
252
torch/_inductor/kernel/flex/templates/flex_decode.py.jinja
Normal file
@ -0,0 +1,252 @@
|
||||
{{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}}
|
||||
# Sub notation for this kernel:
|
||||
# Q: Query, K: Key, V: Value
|
||||
# reduction buffers: M rowmax across local KV split, L local sumexp across local KV split
|
||||
# M: Number of queries, N: Number of keys/values
|
||||
# QK_HEAD_DIM: The dimension of the query and key embeddings
|
||||
# V_HEAD_DIM: The dimension of the value embeddings
|
||||
# BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block
|
||||
# z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits
|
||||
# (Modifiable) Config options:
|
||||
# SPLIT_KV: number of blocks K & V are split into
|
||||
# TILE_KV: length of each local KV split
|
||||
# BLOCK_M: block size that Q is padded along seqlen dim.
|
||||
# BLOCK_N: block size of K & V along N dimension.
|
||||
# GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups.
|
||||
#
|
||||
# change of base out of the loop
|
||||
# ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row
|
||||
# is not masked out? If so, we can skip an extra safety check
|
||||
# SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query.
|
||||
# SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value.
|
||||
|
||||
# PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base.
|
||||
#
|
||||
# SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim.
|
||||
# KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query.
|
||||
# KV_IDX: The indices of KV blocks (that may or may not require masking) for each query.
|
||||
#
|
||||
#
|
||||
# Output: ACC output accumulated across local KV split.
|
||||
|
||||
tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0)
|
||||
|
||||
# Define Q Strides
|
||||
stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}}
|
||||
stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}}
|
||||
stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}}
|
||||
stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}}
|
||||
stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}}
|
||||
|
||||
|
||||
Z = {{size("Q", 0)}}
|
||||
ZKV = {{size("K", 0)}}
|
||||
HKV = {{size("Q", 1)}}
|
||||
G: tl.constexpr = GQA_SHARED_HEADS
|
||||
HQ = HKV * G
|
||||
Q_LEN = {{size("Q", 3)}}
|
||||
KV_LEN = {{size("K", 2)}}
|
||||
|
||||
MATMUL_PRECISION = Q.dtype.element_ty
|
||||
|
||||
# Make sure each split is a multiple of BLOCK_N
|
||||
TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV)
|
||||
TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N
|
||||
TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N)
|
||||
|
||||
off_z = tl.program_id(0) // HKV
|
||||
off_zkv = off_z % ZKV
|
||||
off_hkv = tl.program_id(0) % HKV
|
||||
off_t = tl.program_id(1)
|
||||
|
||||
q_offset = off_z * stride_qz + off_hkv * stride_qh
|
||||
k_offset = off_zkv * stride_kz + off_hkv * stride_kh
|
||||
v_offset = off_zkv * stride_vz + off_hkv * stride_vh
|
||||
|
||||
SPARSE_Z = {{size("KV_NUM_BLKS", 0)}}
|
||||
SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}}
|
||||
|
||||
sparse_idx_z = off_z % SPARSE_Z
|
||||
sparse_idx_h = off_hkv % SPARSE_HQ
|
||||
|
||||
SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N)
|
||||
SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE)
|
||||
|
||||
# initialize pointer to m and l
|
||||
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
|
||||
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
|
||||
acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32)
|
||||
|
||||
# initialize offsets
|
||||
tl.device_assert(BLOCK_M % G == 0)
|
||||
BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G
|
||||
off_g = tl.arange(0, G) # [G]
|
||||
offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
||||
offs_hq = offs_g + off_hkv * G
|
||||
off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ]
|
||||
offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M]
|
||||
offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED)
|
||||
offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED)
|
||||
|
||||
# Get HZ offsets for KV_NUM_BLKS and KV_IDX
|
||||
stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}}
|
||||
sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h
|
||||
stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}}
|
||||
sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h
|
||||
|
||||
# Calculate KV blocks that belong this CTA.
|
||||
block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N
|
||||
|
||||
q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :]
|
||||
|
||||
if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
||||
q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN))
|
||||
elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM:
|
||||
q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM)
|
||||
elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM:
|
||||
q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN)
|
||||
else:
|
||||
q = tl.load(Q + q_offset + q_range)
|
||||
|
||||
q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED])
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# Apply both score_mod and mask_mod
|
||||
|
||||
# find first kv block we are loading and the number of blocks we are loading
|
||||
# Offset the kv_indices tensor by the correct batch and head
|
||||
kv_indices = KV_IDX + sparse_idx_hz_offset
|
||||
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
# first kv block we're loading
|
||||
|
||||
# last valid block according to sparse mask
|
||||
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, off_n),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||
None,
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=False,
|
||||
)
|
||||
|
||||
|
||||
# ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# We know these blocks are guaranteed to be "full", so we don't need to
|
||||
# apply mask_mod to them - only score_mod
|
||||
if HAS_FULL_BLOCKS:
|
||||
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
|
||||
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
|
||||
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
|
||||
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
|
||||
block_n_end = block_n_start + TILE_KV_MULTIPLE
|
||||
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
|
||||
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
|
||||
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
|
||||
|
||||
# last valid block according to sparse mask
|
||||
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
|
||||
|
||||
K_block_ptr = tl.make_block_ptr(
|
||||
base=K + k_offset,
|
||||
shape=(QK_HEAD_DIM, KV_LEN), # (d, N)
|
||||
strides=(stride_kk, stride_kn),
|
||||
offsets=(0, off_n),
|
||||
block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N),
|
||||
order=(0, 1)
|
||||
)
|
||||
V_block_ptr = tl.make_block_ptr(
|
||||
base=V + v_offset,
|
||||
shape=(KV_LEN, V_HEAD_DIM),
|
||||
strides=(stride_vn, stride_vk),
|
||||
offsets=(off_n, 0),
|
||||
block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED),
|
||||
order=(1, 0)
|
||||
)
|
||||
offs_n = tl.arange(0, BLOCK_N) + off_n
|
||||
|
||||
acc, l_i, m_i = forward_inner(
|
||||
{{gen_argdefs()}},
|
||||
q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN,
|
||||
# accumulatd values
|
||||
acc, l_i, m_i,
|
||||
#offsets
|
||||
off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :],
|
||||
None,
|
||||
#block sparse data
|
||||
kv_indices, kv_num_blocks,
|
||||
block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid,
|
||||
MATMUL_PRECISION,
|
||||
IS_FULL_BLOCKS=True,
|
||||
)
|
||||
|
||||
m_offset = off_t * stride_mt + off_z * stride_mz
|
||||
l_offset = off_t * stride_lt + off_z * stride_lz
|
||||
|
||||
M_block_ptr = tl.make_block_ptr(
|
||||
base=M + m_offset,
|
||||
shape=(G, Q_LEN), # (G, M)
|
||||
strides=(stride_mh, stride_mm),
|
||||
offsets=(off_hkv*G, 0),
|
||||
block_shape=(G, BLOCK_M_PER_HQ),
|
||||
order=(1, 0)
|
||||
)
|
||||
L_block_ptr = tl.make_block_ptr(
|
||||
base=L + l_offset,
|
||||
shape=(G, Q_LEN), # (G, M)
|
||||
strides=(stride_lh, stride_lm),
|
||||
offsets=(off_hkv*G, 0),
|
||||
block_shape=(G, BLOCK_M_PER_HQ),
|
||||
order=(1, 0)
|
||||
)
|
||||
|
||||
# Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16)
|
||||
m_i = m_i.reshape(G, BLOCK_M_PER_HQ)
|
||||
l_i = l_i.reshape(G, BLOCK_M_PER_HQ)
|
||||
if SAFE_M_BOUNDARY:
|
||||
tl.store(M_block_ptr, m_i)
|
||||
tl.store(L_block_ptr, l_i)
|
||||
else:
|
||||
tl.store(M_block_ptr, m_i, boundary_check=(1,))
|
||||
tl.store(L_block_ptr, l_i, boundary_check=(1,))
|
||||
|
||||
# -- store output
|
||||
idx_z = off_z
|
||||
idx_t = off_t
|
||||
idx_hq = off_hkv*G + off_g[:, None, None]
|
||||
idx_m = off_m[None, :, None]
|
||||
idx_d = offs_vd[None, None, :]
|
||||
|
||||
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
|
||||
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
|
||||
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
|
59
torch/_inductor/kernel/flex/templates/utilities.py.jinja
Normal file
59
torch/_inductor/kernel/flex/templates/utilities.py.jinja
Normal file
@ -0,0 +1,59 @@
|
||||
|
||||
|
||||
# Utility triton funcs
|
||||
@triton.jit
|
||||
def get_offset_for_next_block(
|
||||
loop_iter, col_indices, total_blocks,
|
||||
SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK,
|
||||
BLOCKS_ARE_CONTIGUOUS: tl.constexpr
|
||||
):
|
||||
if BLOCKS_ARE_CONTIGUOUS:
|
||||
return BLOCK
|
||||
cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE
|
||||
cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last")
|
||||
next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks)
|
||||
needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0
|
||||
jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK
|
||||
offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK
|
||||
return offset
|
||||
|
||||
@triton.jit
|
||||
def get_bounded_indices(indices, max_len=None):
|
||||
return indices % max_len if max_len is not None else indices
|
||||
|
||||
@triton.jit
|
||||
def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr):
|
||||
if IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
return tl.load(block_ptr)
|
||||
elif IS_DIVISIBLE and not SAFE_HEAD_DIM:
|
||||
return tl.load(block_ptr, boundary_check=(1,), padding_option="zero")
|
||||
elif not IS_DIVISIBLE and SAFE_HEAD_DIM:
|
||||
return tl.load(block_ptr, boundary_check=(0,), padding_option="zero")
|
||||
else:
|
||||
return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero")
|
||||
|
||||
@triton.jit
|
||||
def load_checked_2d(
|
||||
ptr,
|
||||
offs_m,
|
||||
offs_n,
|
||||
stride_m,
|
||||
stride_n,
|
||||
IS_DIVISIBLE_M: tl.constexpr,
|
||||
IS_DIVISIBLE_N: tl.constexpr,
|
||||
M_LEN: tl.constexpr,
|
||||
N_DIM: tl.constexpr,
|
||||
):
|
||||
# Calculate final pointer if strides are provided
|
||||
if stride_m is not None and stride_n is not None:
|
||||
ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n
|
||||
|
||||
# Handle all masking cases
|
||||
if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
||||
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0)
|
||||
elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N:
|
||||
return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0)
|
||||
elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N:
|
||||
return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0)
|
||||
else: # Both divisible
|
||||
return tl.load(ptr)
|
Reference in New Issue
Block a user