From 15e49f61643e4c0eef420f0981609709ef55b848 Mon Sep 17 00:00:00 2001 From: drisspg Date: Thu, 14 Aug 2025 01:07:53 +0000 Subject: [PATCH] Factor out the strings to templates for better editor integration (#160357) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Summary More code motion, tldr is that install 'Better Jinja' in vscode and now you can get highlighting Before Screenshot 2025-08-11 at 2 41 08 PM After: Screenshot 2025-08-11 at 2 40 27 PM Pull Request resolved: https://github.com/pytorch/pytorch/pull/160357 Approved by: https://github.com/eellison --- setup.py | 1 + torch/_inductor/kernel/flex/common.py | 267 +---- torch/_inductor/kernel/flex/flex_attention.py | 956 +----------------- torch/_inductor/kernel/flex/flex_decoding.py | 270 +---- .../kernel/flex/templates/common.py.jinja | 193 ++++ .../flex/templates/flex_attention.py.jinja | 248 +++++ .../flex/templates/flex_backwards.py.jinja | 682 +++++++++++++ .../flex/templates/flex_decode.py.jinja | 252 +++++ .../kernel/flex/templates/utilities.py.jinja | 59 ++ 9 files changed, 1451 insertions(+), 1477 deletions(-) create mode 100644 torch/_inductor/kernel/flex/templates/common.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/flex_attention.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/flex_decode.py.jinja create mode 100644 torch/_inductor/kernel/flex/templates/utilities.py.jinja diff --git a/setup.py b/setup.py index 0f5c08ee8b8c..fc03de429801 100644 --- a/setup.py +++ b/setup.py @@ -1670,6 +1670,7 @@ def main() -> None: "_inductor/codegen/aoti_runtime/*.h", "_inductor/codegen/aoti_runtime/*.cpp", "_inductor/script.ld", + "_inductor/kernel/flex/templates/*.jinja", "_export/serde/*.yaml", "_export/serde/*.thrift", "share/cmake/ATen/*.cmake", diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index 8ee50753439e..6cc197a35b9c 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -3,6 +3,7 @@ import math from collections.abc import Sequence +from pathlib import Path from typing import Any, Optional, Union import sympy @@ -323,267 +324,13 @@ def next_power_of_two(n): return 2 ** math.ceil(math.log2(n)) -# ---- Common Template Strings ---- -compute_forward_block_mn = r""" -@triton.jit -def forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, - -): - # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through - {{gen_defines() | indent_except_first(1)}} - - # -- load k -- - # NB reversed order to since K is transposed - {%- if USE_TMA %} - k = tl.load_tensor_descriptor( - desc_k, - [kv_start + kv_offset, 0], - ) - {%- else %} - k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) - {%- endif %} - - if USE_TMA: - k = tl.trans(k) - # -- compute qk --- - qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. - if not PRESCALE_QK: - qk *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, - # which is larger than the actual number of elements. To avoid access memory out of bound, - # we need to mask out the elements that are out of Q_LEN & KV_LEN. - m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) - n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) - - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - out="qk" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=1, - output_name="mask_mod_output", - score="qk", - b="off_z", - h="off_h", - m="m", - n="n", - ) | indent_except_first(2) }} - - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) - # apply mask for partially unmasked blocks - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - # -- compute scaling constant --- - m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) - if not ROWS_GUARANTEED_SAFE: - masked_out_rows = (m_ij == float("-inf")) - m_ij_masked = tl.where(masked_out_rows, 0, m_ij) - else: - m_ij_masked = m_ij - - alpha = tl.math.exp2(m_i - m_ij_masked) - p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) - - # NB: l_i update is pulled up here since it's a bit faster - # NB: For headdim=256, it's faster to move it back down to after m_i = - # m_ij - l_i = l_i * alpha + tl.sum(p, 1) - # # -- scale and update acc -- - acc = acc * alpha[:, None] - {%- if USE_TMA %} - v = tl.load_tensor_descriptor( - desc_v, - [kv_start + kv_offset, 0], - ) - {%- else %} - v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) - {%- endif %} - acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) - - # -- update m_i - m_i = m_ij - - return acc, l_i, m_i - -""" - -compute_forward_inner = r""" -@triton.jit -def forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets used as inputs to score_mod & mask_mod - # of size [BLOCK_M, BLOCK_N] or scalar. - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - # blocksparse data - kv_indices, kv_num_blocks, - # start kv and end kv block - block_n_start, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through - {{gen_defines() | indent_except_first(1)}} - - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - RCP_LN2: tl.constexpr = 1.44269504 - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - kv_offset = 0 - - # loop over k, v and update accumulator until block_n_end - for start_n in range(block_n_start, block_n_end): - # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. - if IS_DIVISIBLE: - acc, l_i, m_i = forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - else: - # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, - # it's on par or slightly faster than only applying to the last block in fwd. - # However, we choose different strategy for bwd, where we only apply mod & mask - # to the last block because it's faster a lot. - acc, l_i, m_i = forward_block_mn( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, - # accumulated values - acc, l_i, m_i, - # Offsets - off_z, off_h, offs_m, offs_n, - # Offsets needed for TMA loads - kv_start, - kv_offset, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) +_TEMPLATE_DIR = Path(__file__).parent / "templates" - - offset = get_offset_for_next_block( - start_n, kv_indices, kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS - ) - - offs_n = offs_n + offset - kv_offset += offset - if not USE_TMA: - K_block_ptr = tl.advance(K_block_ptr, (0, offset)) - V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) +def load_template(name: str) -> str: + """Load a template file and return its content.""" + with open(_TEMPLATE_DIR / f"{name}.py.jinja") as f: + return f.read() - return acc, l_i, m_i - -""" - -# Inner Triton functions shared by flex_attention & split-k decoding kernels. -compute_next_offset_func = r""" -@triton.jit -def get_offset_for_next_block( - loop_iter, col_indices, total_blocks, - SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, - BLOCKS_ARE_CONTIGUOUS: tl.constexpr -): - if BLOCKS_ARE_CONTIGUOUS: - return BLOCK - cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE - cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") - next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) - needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 - jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK - offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK - return offset -""" - -get_bounded_indices_func = r""" -@triton.jit -def get_bounded_indices(indices, max_len=None): - return indices % max_len if max_len is not None else indices -""" - - -load_checked_block = r""" -@triton.jit -def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): - if IS_DIVISIBLE and SAFE_HEAD_DIM: - return tl.load(block_ptr) - elif IS_DIVISIBLE and not SAFE_HEAD_DIM: - return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") - elif not IS_DIVISIBLE and SAFE_HEAD_DIM: - return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") - else: - return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") -""" - -load_checked_2d = r""" -@triton.jit -def load_checked_2d( - ptr, - offs_m, - offs_n, - stride_m, - stride_n, - IS_DIVISIBLE_M: tl.constexpr, - IS_DIVISIBLE_N: tl.constexpr, - M_LEN: tl.constexpr, - N_DIM: tl.constexpr, -): - # Calculate final pointer if strides are provided - if stride_m is not None and stride_n is not None: - ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n - - # Handle all masking cases - if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) - elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) - elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: - return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) - else: # Both divisible - return tl.load(ptr) -""" +# Template strings have been moved to templates/common.py.jinja diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index 429f8d05c8cd..a3e441d033b3 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -22,17 +22,12 @@ from ...select_algorithm import ( ) from .common import ( build_subgraph_buffer, - compute_forward_block_mn, - compute_forward_inner, - compute_next_offset_func, create_indices_fake, create_num_blocks_fake_generator, create_placeholder, - get_bounded_indices_func, get_fwd_subgraph_outputs, infer_dense_strides, - load_checked_2d, - load_checked_block, + load_template, maybe_realize, set_head_dim_values, SubgraphResults, @@ -67,267 +62,12 @@ def get_float32_precision(): return "'tf32'" -compute_flex_attention = r""" -{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} - # Sub notation for this kernel: - # - # Q: Query, K: Key, V: Value - # M: Number of queries, N: Number of keys/values, D: Model dimension - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # - # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. - # - # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad - # - # (Modifiable) Performance tuning options - # BLOCK_M: The thread block size across the seqlen dim of Q. - # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. - - # The below are kernel options that can be applied for certain score_mods, - # or involve a numerics vs. perf tradeoff - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has - # about 20% more numerical error, but slightly faster. - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are - # contiguous? If so, we don't need to do an indirect jump for every block - - tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) - tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) - - # Define strides of inputs - stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - - ZQ = {{size("Q", 0)}} - HQ = {{size("Q", 1)}} - Q_LEN = {{size("Q", 2)}} - ZKV = {{size("K", 0)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - q_start = tl.program_id(0) - off_zq = tl.program_id(1) - off_hq = tl.program_id(2) - - # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. - # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. - off_zkv = off_zq % ZKV - off_hkv = off_hq // GQA_SHARED_HEADS - off_g = off_hq % GQA_SHARED_HEADS - - q_offset = off_zq * stride_qz + off_hq * stride_qh - k_offset = off_zkv * stride_kz + off_hkv * stride_kh - v_offset = off_zkv * stride_vz + off_hkv * stride_vh - - Q = Q + q_offset - K = K + k_offset - V = V + v_offset - - # Setting up the TMA descriptors for Q, K, V - desc_q = None - desc_k = None - desc_v = None - {%- if USE_TMA %} - desc_q = tl.make_tensor_descriptor( - base=Q, - shape=[Q_LEN, QK_HEAD_DIM], - strides=[stride_qm, 1], - block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], - ) - - desc_k = tl.make_tensor_descriptor( - base=K, - shape=[KV_LEN, QK_HEAD_DIM], - strides=[stride_kn, 1], - block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], - ) - - desc_v = tl.make_tensor_descriptor( - base=V, - shape=[KV_LEN, V_HEAD_DIM], - strides=[stride_vn, 1], - block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], - ) - {%- endif %} - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_zq % SPARSE_Z - sparse_idx_hq = off_hq % SPARSE_HQ - - SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - - stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} - stride_kv_idx_h = {{stride("KV_IDX", 1)}} - stride_kv_idx_m = {{stride("KV_IDX", 2)}} - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - - offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) - - # KV_IDX and KV_NUM_BLKS are always contiguous. - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq - sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE - sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 - K_block_ptr = None - V_block_ptr = None - Q_block_ptr = None - - if not USE_TMA: - Q_block_ptr = tl.make_block_ptr( - base=Q , - shape=(Q_LEN, QK_HEAD_DIM), - strides=(stride_qm, stride_qk), - offsets=(q_start * BLOCK_M, 0), - block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - - {%- if USE_TMA %} - q = tl.load_tensor_descriptor( - desc_q, - [(q_start * BLOCK_M).to(tl.int32), 0], - ) - {%- else %} - q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) - {%- endif %} - - # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We don't know anything "special" about these blocks, so we need to apply - # both score_mod and mask_mod to it - kv_indices = KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - - if not USE_TMA: - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(QK_HEAD_DIM, KV_LEN), - strides=(stride_kk, stride_kn), - offsets=(0, kv_start), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(kv_start, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - - offs_n = kv_start + tl.arange(0, BLOCK_N) - - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - acc, l_i, m_i, - off_zq, off_hq, offs_m[:, None], offs_n[None, :], - kv_start, - kv_indices, kv_num_blocks, - 0, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We know these blocks are guaranteed to be "full", so we don't need to - # apply mask_mod to them - only score_mod - if HAS_FULL_BLOCKS: - # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. - kv_indices = FULL_KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - if not USE_TMA: - K_block_ptr = tl.make_block_ptr( - base=K, - shape=(QK_HEAD_DIM, KV_LEN), - strides=(stride_kk, stride_kn), - offsets=(0, kv_start), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(kv_start, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = kv_start + tl.arange(0, BLOCK_N) - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, - desc_k, desc_v, Q_LEN, KV_LEN, - acc, l_i, m_i, - off_zq, off_hq, offs_m[:, None], offs_n[None, :], - kv_start, - kv_indices, kv_num_blocks, - 0, block_n_end, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - - # [Note] Handle fully masked out rows: - # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. - # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step - l_i = tl.where(l_i == 0.0, 1, l_i) - - acc = acc / l_i[:, None] - idx_zq = tl.program_id(1) - idx_hq = tl.program_id(2) - idx_m = offs_m[:, None] - idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] - - mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - - {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - - if OUTPUT_LOGSUMEXP: - off_hz = off_zq * HQ + off_hq - l_ptrs = LSE + off_hz * Q_LEN + offs_m - lse = m_i + tl.math.log2(l_i) - if IS_DIVISIBLE: - tl.store(l_ptrs, lse) - else: - tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) - """ - - flex_attention_template = TritonTemplate( name="flex_attention", grid=flex_attention_grid, - source=compute_flex_attention - + compute_forward_inner - + compute_next_offset_func - + compute_forward_block_mn - + load_checked_block - + get_bounded_indices_func, + source=load_template("flex_attention") + + load_template("utilities") + + load_template("common"), ) @@ -684,693 +424,7 @@ def flex_attention_backward_grid( flex_attention_backward_template = TritonTemplate( name="flex_attention_backward", grid=flex_attention_backward_grid, - source=r""" -{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} - # Sub notation for this kernel: - # - # Q: Query, K: Key, V: Value - # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) - # DELTA: Precomputed sum(OUT*DO, axis=-1) - # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value - # DK: Derivative of Key, is the written to via the store_output call due to some limitations with - # inductor codegen - # M: Number of queries, N: Number of keys/values - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # (Modifiable) Performance tuning options - # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. - # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. - # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. - # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. - # - # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. - # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. - # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. - # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. - # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. - - # The below are kernel options that can be applied for certain score_mods, - # or involve a numerics vs. perf tradeoff - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has - # about 20% more numerical error, but slightly faster. - - # Define strides of inputs - stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} - stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} - - stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} - stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} - - ZQ = {{size("Q", 0)}} - HQ = {{size("Q", 1)}} - HKV = {{size("K", 1)}} - Q_LEN = {{size("Q", 2)}} - ZKV = {{size("K", 0)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - pid = tl.program_id(0) - NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) - NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) - - off_zq = tl.program_id(1) # q batch idx - off_hkv = tl.program_id(2) # kv head idx - off_zkv = off_zq % ZKV # kv batch idx - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_zq % SPARSE_Z - - k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) - v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) - # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] - # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) - - # offset K, V, DV pointers for batch/kv-head - K += k_adj - V += v_adj - DV += dv_adj - - RCP_LN2 = 1.44269504 - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - if pid >= NUM_KV_BLOCKS: - off_pid = pid - NUM_KV_BLOCKS - # THIS BLOCK DOES DQ - SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) - SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) - off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS - start_m2_block = off_pid % NUM_Q_BLOCKS - off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE - stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} - stride_kv_idx_h = {{stride("KV_IDX", 1)}} - stride_kv_idx_m = {{stride("KV_IDX", 2)}} - - sparse_idx_hq2 = off_hq2 % SPARSE_HQ - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 - - sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask - sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 - - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. - q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) - do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) - dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) - off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) - - Q2 = Q + q_adj2 - DO2 = DO + do_adj2 - # TODO: This does not work if DQ is not the same layout as Q (for example, - # if Q is broadcasted) - DQ2 = DQ + dq_adj2 - LSE2 = LSE + off_chz2 - DELTA2 = DELTA + off_chz2 - - # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) - dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) - - start_m2 = start_m2_block * BLOCK_M2 - offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) - - # load Q and do: they stay in SRAM throughout the inner loop. - q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) - do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) - - if PRESCALE_QK: - q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - if IS_DIVISIBLE: - Di = tl.load(DELTA2 + offs_m2) - lse = tl.load(LSE2 + offs_m2) - else: - Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) - lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) - lse = tl.where(lse == -float("inf"), 0.0, lse) - lse = lse[:, None] - - # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # KV_IDX and KV_NUM_BLKS are always contiguous. - kv_indices = KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) - - offs_n2 = kv_start + tl.arange(0, BLOCK_N2) - dq = bwd_dq_inner( - {{gen_argdefs()}}, - K, V, - dq, q, do, Di, lse, - off_zq, off_hq2, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - if HAS_FULL_BLOCKS: - # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. - kv_indices = FULL_KV_IDX + sparse_kv_idx_offset - kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading - sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) - - offs_n2 = kv_start + tl.arange(0, BLOCK_N2) - dq = bwd_dq_inner( - {{gen_argdefs()}}, - K, V, - dq, q, do, Di, lse, - off_zq, off_hq2, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - # Write back dQ. - dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd - dq *= SM_SCALE - if IS_DIVISIBLE and SAFE_HEAD_DIM: - tl.store(dq_ptrs, dq) - else: - tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) - else: - # THIS BLOCK DOES DK & DV - SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) - SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) - - pid_mask = pid // SPARSE_KV_MULTIPLE - - stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} - stride_q_idx_h = {{stride("Q_IDX", 1)}} - stride_q_idx_n = {{stride("Q_IDX", 2)}} - - - dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) - - start_n1 = pid * BLOCK_N1 - offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) - - # load K and V: they stay in SRAM throughout the inner loop. - k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) - v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) - - if PRESCALE_QK: - k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) - - for off_g in range(0, GQA_SHARED_HEADS): - off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g - - # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. - q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) - do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) - dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) - off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) - - Q1 = Q + q_adj1 - DO1 = DO + do_adj1 - # TODO: This does not work if DQ is not the same layout as Q (for example, - # if Q is broadcasted) - LSE1 = LSE + off_chz1 - DELTA1 = DELTA + off_chz1 - - sparse_idx_hq1 = off_hq1 % SPARSE_HQ - sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 - - sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask - sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 - - # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Q_IDX and Q_NUM_BLKS are always contiguous. - q_indices = Q_IDX + sparse_q_idx_offset - q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading - sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) - - offs_m1 = q_start + tl.arange(0, BLOCK_M1) - dk, dv = bwd_dkdv_inner( - {{gen_argdefs()}}, - Q1, DO1, DELTA1, LSE1, - dk, dv, k, v, - off_zq, off_hq1, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - - if HAS_FULL_BLOCKS: - # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. - q_indices = FULL_Q_IDX + sparse_q_idx_offset - q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading - sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) - - offs_m1 = q_start + tl.arange(0, BLOCK_M1) - dk, dv = bwd_dkdv_inner( - {{gen_argdefs()}}, - Q1, DO1, DELTA1, LSE1, - dk, dv, k, v, - off_zq, off_hq1, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - # Write back dV and dK. - dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd - - index_n = offs_n1[:, None] - index_k = offs_k[None, :] - index_v = offs_v[None, :] - - if IS_DIVISIBLE and SAFE_HEAD_DIM: - tl.store(dv_ptrs, dv) - else: - tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) - - dk *= SM_SCALE - - if SAFE_HEAD_DIM: - mask = index_n < KV_LEN - else: - mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) - - # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] - # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] - {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} - -@triton.jit -def bwd_dq_inner( - {{gen_argdefs()}}, - K, V, # pointers - dq, q, do, Di, lse, - off_z, off_hq, offs_m2, offs_n2, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - {{gen_defines() | indent_except_first(1) }} - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) - RCP_LN2: tl.constexpr = 1.44269504 - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} - - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd - vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd - # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. - tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) - - hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) - if not IS_DIVISIBLE: - if hi >= 1: - for start_n in range(0, hi - 1): - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - - # Increment pointers. - offset = get_offset_for_next_block( - start_n, kv_indices, sparse_kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS - ) - - kT_ptrs += offset * stride_kn - vT_ptrs += offset * stride_vn - - offs_n2 += offset - - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - else: - for start_n in range(0, hi): - dq = bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - - # Increment pointers. - offset = get_offset_for_next_block( - start_n, kv_indices, sparse_kv_num_blocks, - SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS - ) - - kT_ptrs += offset * stride_kn - vT_ptrs += offset * stride_vn - - offs_n2 += offset - - return dq - - -@triton.jit -def bwd_dq_block_mn( - {{gen_argdefs()}}, - dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, - off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, - stride_kn, stride_kd, stride_vn, stride_vd, - kv_indices, sparse_kv_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, -): - {{gen_defines() | indent_except_first(1)}} - - # NB reversed order to since K is transposed - kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) - qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) - if not PRESCALE_QK: - qk *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - pre_mod_scores = qk - n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) - # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim - # that the M reads out of bounds prior to the last loop - m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) - - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qk", - b="off_z", - h="off_hq", - m="m", - n="n", - out="qk" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=2, - output_name="mask_mod_output", - score="qk", - b="off_z", - h="off_hq", - m="m", - n="n", - ) | indent_except_first(2) }} - - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) - # apply mask for partial masked block - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - p = tl.math.exp2(post_mod_scores - lse) - # Compute dP and dS. - # NB reversed order to since V is transposed - vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) - - dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) - ds = p * (dp - Di[:, None]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(1) }} - if CHECK_BLOCK_BOUNDARY: - grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) - - # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ - if WRITE_DQ: - scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) - {{ modification( - subgraph_number=3, - output_name=None, - mask="scatter_mask", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="ds" - ) | indent_except_first(2) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = grad_scores - - if not IS_FULL_BLOCKS: - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for partially unmasked block - ds = tl.where(mask_mod_output, ds, 0.0) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - ds = ds.to(MATMUL_PRECISION) - # Compute dQ. - dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) - - return dq - - -@triton.jit -def bwd_dkdv_inner( - {{gen_argdefs()}}, - Q, DO, DELTA, LSE, # pointers - dk, dv, k, v, - off_z, off_hq, offs_n1, offs_m1, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, - IS_FULL_BLOCKS, -): - {{gen_defines() | indent_except_first(1) }} - SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) - RCP_LN2: tl.constexpr = 1.44269504 - Q_LEN = {{size("Q", 2)}} - KV_LEN = {{size("K", 2)}} - - offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) - - qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd - do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod - # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. - tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) - hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) - - if not IS_DIVISIBLE: - if hi >= 1: - for start_m in range(0, hi - 1): - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - # Increment pointers. - offset = get_offset_for_next_block( - start_m, q_indices, sparse_q_num_blocks, - SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS - ) - - qT_ptrs += offset * stride_qm - do_ptrs += offset * stride_dom - - offs_m1 += offset - - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, - ) - else: - for start_m in range(0, hi): - dk, dv = bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, - ) - # Increment pointers. - offset = get_offset_for_next_block( - start_m, q_indices, sparse_q_num_blocks, - SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS - ) - - qT_ptrs += offset * stride_qm - do_ptrs += offset * stride_dom - - offs_m1 += offset - - return dk, dv - - -@triton.jit -def bwd_dkdv_block_mn( - {{gen_argdefs()}}, - dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, - off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, - stride_qm, stride_qd, stride_dom, stride_dod, - q_indices, sparse_q_num_blocks, - MATMUL_PRECISION, RCP_LN2, - IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, -): - {{gen_defines() | indent_except_first(1) }} - - # NB reversed order since Q is transposed - qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) - # Load LSE before computing qk to reduce pipeline stall. - if IS_DIVISIBLE: - lse = tl.load(LSE + offs_m1) - else: - lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) - lse = tl.where(lse == -float("inf"), 0.0, lse) - qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) - if not PRESCALE_QK: - qkT *= SM_SCALE - # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ - m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) - # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim - # that the n reads out of bounds prior to the last loop - n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) - - pre_mod_scores = qkT - {{ modification( - subgraph_number=0, - output_name="post_mod_scores", - score="qkT", - b="off_z", - h="off_hq", - m="m", - n="n", - out="qkT" - ) | indent_except_first(1) }} - - if CHECK_BLOCK_BOUNDARY: - # Mask out the elements that are out of the KV_LEN for non divisible seqlen. - post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) - - if not IS_FULL_BLOCKS: - {{ modification( - subgraph_number=2, - output_name="mask_mod_output", - score="qkT", - b="off_z", - h="off_hq", - m="m", - n="n", - ) | indent_except_first(2) }} - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for fully masked block - post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - if not PRESCALE_QK: - post_mod_scores *= RCP_LN2 - pT = tl.math.exp2(post_mod_scores - lse[None, :]) - do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) - # Compute dV. - ppT = pT - dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) - if IS_DIVISIBLE: - Di = tl.load(DELTA + offs_m1) - else: - Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) - # Compute dP and dS. - dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) - dsT = pT * (dpT - Di[None, :]) - # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ - {{ modification( - subgraph_number=1, - output_name = "grad_scores", - score="pre_mod_scores", - b="off_z", - h="off_hq", - m="m", - n="n", - grad_score_mod="dsT" - ) | indent_except_first(1) }} - - # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ - if not WRITE_DQ: - idx_b = off_z - idx_h = off_hq - idx_m = m - idx_n = n - scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) - {{ modification( - subgraph_number=3, - output_name=None, - mask="scatter_mask", - score="pre_mod_scores", - b="idx_b", - h="idx_h", - m="idx_m", - n="idx_n", - grad_score_mod="dsT" - ) | indent_except_first(2) }} - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - - if CHECK_BLOCK_BOUNDARY: - grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) - - dsT = grad_scores - if not IS_FULL_BLOCKS: - if CHECK_BLOCK_BOUNDARY: - mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) - # (grads) apply mask for partially unmasked block - dsT = tl.where(mask_mod_output, dsT, 0.0) - # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) - - return dk, dv - """ - + compute_next_offset_func - + get_bounded_indices_func - + load_checked_2d, + source=load_template("flex_backwards") + load_template("utilities"), ) diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 7f92fbc705a5..361729d44b99 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -18,15 +18,10 @@ from ...select_algorithm import ( TritonTemplate, ) from .common import ( - compute_forward_block_mn, - compute_forward_inner, - compute_next_offset_func, create_indices_fake, create_num_blocks_fake_generator, - get_bounded_indices_func, get_fwd_subgraph_outputs, - load_checked_2d, - load_checked_block, + load_template, maybe_realize, set_head_dim_values, ) @@ -90,266 +85,9 @@ def flex_decoding_grid(batch_size, kv_heads, gqa_group_size, n_keys, d_model, me flex_decoding_template = TritonTemplate( name="flex_decoding", grid=flex_decoding_grid, - source=r""" - {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} - # Sub notation for this kernel: - # Q: Query, K: Key, V: Value - # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split - # M: Number of queries, N: Number of keys/values - # QK_HEAD_DIM: The dimension of the query and key embeddings - # V_HEAD_DIM: The dimension of the value embeddings - # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block - # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits - # (Modifiable) Config options: - # SPLIT_KV: number of blocks K & V are split into - # TILE_KV: length of each local KV split - # BLOCK_M: block size that Q is padded along seqlen dim. - # BLOCK_N: block size of K & V along N dimension. - # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. - # - # change of base out of the loop - # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row - # is not masked out? If so, we can skip an extra safety check - # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. - # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. - - # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. - # - # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. - # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. - # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. - # - # - # Output: ACC output accumulated across local KV split. - - tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) - - # Define Q Strides - stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} - stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} - stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} - stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} - stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} - - - Z = {{size("Q", 0)}} - ZKV = {{size("K", 0)}} - HKV = {{size("Q", 1)}} - G: tl.constexpr = GQA_SHARED_HEADS - HQ = HKV * G - Q_LEN = {{size("Q", 3)}} - KV_LEN = {{size("K", 2)}} - - MATMUL_PRECISION = Q.dtype.element_ty - - # Make sure each split is a multiple of BLOCK_N - TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) - TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N - TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) - - off_z = tl.program_id(0) // HKV - off_zkv = off_z % ZKV - off_hkv = tl.program_id(0) % HKV - off_t = tl.program_id(1) - - q_offset = off_z * stride_qz + off_hkv * stride_qh - k_offset = off_zkv * stride_kz + off_hkv * stride_kh - v_offset = off_zkv * stride_vz + off_hkv * stride_vh - - SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} - SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} - - sparse_idx_z = off_z % SPARSE_Z - sparse_idx_h = off_hkv % SPARSE_HQ - - SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) - SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) - - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) - - # initialize offsets - tl.device_assert(BLOCK_M % G == 0) - BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G - off_g = tl.arange(0, G) # [G] - offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] - offs_hq = offs_g + off_hkv * G - off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] - offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] - offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) - offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) - - # Get HZ offsets for KV_NUM_BLKS and KV_IDX - stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} - sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h - stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} - sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h - - # Calculate KV blocks that belong this CTA. - block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block - block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N - - q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] - - if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) - elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) - elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: - q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) - else: - q = tl.load(Q + q_offset + q_range) - - q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) - - - # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # Apply both score_mod and mask_mod - - # find first kv block we are loading and the number of blocks we are loading - # Offset the kv_indices tensor by the correct batch and head - kv_indices = KV_IDX + sparse_idx_hz_offset - kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - # first kv block we're loading - - # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(off_n, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = tl.arange(0, BLOCK_N) + off_n - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, - # accumulatd values - acc, l_i, m_i, - #offsets - off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], - None, - #block sparse data - kv_indices, kv_num_blocks, - block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, - MATMUL_PRECISION, - IS_FULL_BLOCKS=False, - ) - - - # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # We know these blocks are guaranteed to be "full", so we don't need to - # apply mask_mod to them - only score_mod - if HAS_FULL_BLOCKS: - kv_indices = FULL_KV_IDX + sparse_idx_hz_offset - kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) - # Assign full block in a reverse order for off_t. Prioritize the last CTA. - block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE - block_n_end = block_n_start + TILE_KV_MULTIPLE - indices_idx = block_n_start // SPARSE_KV_MULTIPLE - off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE - off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N - - # last valid block according to sparse mask - block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) - - K_block_ptr = tl.make_block_ptr( - base=K + k_offset, - shape=(QK_HEAD_DIM, KV_LEN), # (d, N) - strides=(stride_kk, stride_kn), - offsets=(0, off_n), - block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + v_offset, - shape=(KV_LEN, V_HEAD_DIM), - strides=(stride_vn, stride_vk), - offsets=(off_n, 0), - block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), - order=(1, 0) - ) - offs_n = tl.arange(0, BLOCK_N) + off_n - - acc, l_i, m_i = forward_inner( - {{gen_argdefs()}}, - q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, - # accumulatd values - acc, l_i, m_i, - #offsets - off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], - None, - #block sparse data - kv_indices, kv_num_blocks, - block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, - MATMUL_PRECISION, - IS_FULL_BLOCKS=True, - ) - - m_offset = off_t * stride_mt + off_z * stride_mz - l_offset = off_t * stride_lt + off_z * stride_lz - - M_block_ptr = tl.make_block_ptr( - base=M + m_offset, - shape=(G, Q_LEN), # (G, M) - strides=(stride_mh, stride_mm), - offsets=(off_hkv*G, 0), - block_shape=(G, BLOCK_M_PER_HQ), - order=(1, 0) - ) - L_block_ptr = tl.make_block_ptr( - base=L + l_offset, - shape=(G, Q_LEN), # (G, M) - strides=(stride_lh, stride_lm), - offsets=(off_hkv*G, 0), - block_shape=(G, BLOCK_M_PER_HQ), - order=(1, 0) - ) - - # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) - m_i = m_i.reshape(G, BLOCK_M_PER_HQ) - l_i = l_i.reshape(G, BLOCK_M_PER_HQ) - if SAFE_M_BOUNDARY: - tl.store(M_block_ptr, m_i) - tl.store(L_block_ptr, l_i) - else: - tl.store(M_block_ptr, m_i, boundary_check=(1,)) - tl.store(L_block_ptr, l_i, boundary_check=(1,)) - - # -- store output - idx_z = off_z - idx_t = off_t - idx_hq = off_hkv*G + off_g[:, None, None] - idx_m = off_m[None, :, None] - idx_d = offs_vd[None, None, :] - - mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) - acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) - {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} - """ - + compute_forward_inner - + get_bounded_indices_func - + load_checked_block - + load_checked_2d - + compute_next_offset_func - + compute_forward_block_mn, + source=load_template("flex_decode") + + load_template("utilities") + + load_template("common"), ) diff --git a/torch/_inductor/kernel/flex/templates/common.py.jinja b/torch/_inductor/kernel/flex/templates/common.py.jinja new file mode 100644 index 000000000000..0e967570127d --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/common.py.jinja @@ -0,0 +1,193 @@ + + +# Common Imports +@triton.jit +def forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, + +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + # -- load k -- + # NB reversed order to since K is transposed + {%- if USE_TMA %} + k = tl.load_tensor_descriptor( + desc_k, + [kv_start + kv_offset, 0], + ) + {%- else %} + k = load_checked_block(K_block_ptr, SAFE_HEAD_DIM, IS_DIVISIBLE) + {%- endif %} + + if USE_TMA: + k = tl.trans(k) + # -- compute qk --- + qk = tl.dot(q, k, input_precision=FLOAT32_PRECISION) # TODO: use cuda matmul when q_len <= 2. + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + # If this is the last block of a non divisible seqlen, we still need to load [BLOCK_M, BLOCK_N] elements, + # which is larger than the actual number of elements. To avoid access memory out of bound, + # we need to mask out the elements that are out of Q_LEN & KV_LEN. + m = get_bounded_indices(offs_m, Q_LEN if CHECK_BLOCK_BOUNDARY else None) + n = get_bounded_indices(offs_n, KV_LEN if CHECK_BLOCK_BOUNDARY else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=1, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_h", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n < KV_LEN, mask_mod_output, False) + # apply mask for partially unmasked blocks + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + # -- compute scaling constant --- + m_ij = tl.maximum(m_i, tl.max(post_mod_scores, 1)) + if not ROWS_GUARANTEED_SAFE: + masked_out_rows = (m_ij == float("-inf")) + m_ij_masked = tl.where(masked_out_rows, 0, m_ij) + else: + m_ij_masked = m_ij + + alpha = tl.math.exp2(m_i - m_ij_masked) + p = tl.math.exp2(post_mod_scores - m_ij_masked[:, None]) + + # NB: l_i update is pulled up here since it's a bit faster + # NB: For headdim=256, it's faster to move it back down to after m_i = + # m_ij + l_i = l_i * alpha + tl.sum(p, 1) + # # -- scale and update acc -- + acc = acc * alpha[:, None] + {%- if USE_TMA %} + v = tl.load_tensor_descriptor( + desc_v, + [kv_start + kv_offset, 0], + ) + {%- else %} + v = load_checked_block(V_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + acc = tl.dot(p.to(MATMUL_PRECISION), v, acc, input_precision=FLOAT32_PRECISION) + + # -- update m_i + m_i = m_ij + + return acc, l_i, m_i + +@triton.jit +def forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets used as inputs to score_mod & mask_mod + # of size [BLOCK_M, BLOCK_N] or scalar. + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + # blocksparse data + kv_indices, kv_num_blocks, + # start kv and end kv block + block_n_start, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + # Redefines all kernel parameters (BLOCK_M, etc.) so we don't need to plumb them all through + {{gen_defines() | indent_except_first(1)}} + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + RCP_LN2: tl.constexpr = 1.44269504 + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + kv_offset = 0 + + # loop over k, v and update accumulator until block_n_end + for start_n in range(block_n_start, block_n_end): + # Here IS_DIVISIBLE acts are the start_n = tl.multiple_of(start_n, BLOCK_N) from triton_fused_attention. + if IS_DIVISIBLE: + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + else: + # Benchmark shows even we applied mod & mask to each block for non divisible seqlen, + # it's on par or slightly faster than only applying to the last block in fwd. + # However, we choose different strategy for bwd, where we only apply mod & mask + # to the last block because it's faster a lot. + acc, l_i, m_i = forward_block_mn( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, desc_k, desc_v, Q_LEN, KV_LEN, + # accumulated values + acc, l_i, m_i, + # Offsets + off_z, off_h, offs_m, offs_n, + # Offsets needed for TMA loads + kv_start, + kv_offset, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + + + + offset = get_offset_for_next_block( + start_n, kv_indices, kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N, BLOCKS_ARE_CONTIGUOUS + ) + + offs_n = offs_n + offset + kv_offset += offset + if not USE_TMA: + K_block_ptr = tl.advance(K_block_ptr, (0, offset)) + V_block_ptr = tl.advance(V_block_ptr, (offset, 0)) + + + return acc, l_i, m_i diff --git a/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja b/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja new file mode 100644 index 000000000000..79410fb50046 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_attention.py.jinja @@ -0,0 +1,248 @@ +{{def_kernel("Q", "K", "V", "LSE", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # M: Number of queries, N: Number of keys/values, D: Model dimension + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # + # OUTPUT_LOGSUMEXP: We only need to store the logsumexp if we require grad + # + # (Modifiable) Performance tuning options + # BLOCK_M: The thread block size across the seqlen dim of Q. + # BLOCK_N: Iterate over BLOCK_N across the seqlen dim of K/V in each thread block. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # BLOCKS_ARE_CONTIGUOUS: Is it guaranteed that all blocks in the mask are + # contiguous? If so, we don't need to do an indirect jump for every block + + tl.static_assert(SPARSE_Q_BLOCK_SIZE >= BLOCK_M and SPARSE_Q_BLOCK_SIZE % BLOCK_M == 0) + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + q_start = tl.program_id(0) + off_zq = tl.program_id(1) + off_hq = tl.program_id(2) + + # We support two cases for batch dimension. a) (ZKV == ZQ) where off_zkv = off_zq. + # b) (ZKV == 1 and ZQ > 1) where KV is broadcasted along the batch dimension and off_zkv=0. + off_zkv = off_zq % ZKV + off_hkv = off_hq // GQA_SHARED_HEADS + off_g = off_hq % GQA_SHARED_HEADS + + q_offset = off_zq * stride_qz + off_hq * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + Q = Q + q_offset + K = K + k_offset + V = V + v_offset + + # Setting up the TMA descriptors for Q, K, V + desc_q = None + desc_k = None + desc_v = None + {%- if USE_TMA %} + desc_q = tl.make_tensor_descriptor( + base=Q, + shape=[Q_LEN, QK_HEAD_DIM], + strides=[stride_qm, 1], + block_shape=[BLOCK_M, QK_HEAD_DIM_ROUNDED], + ) + + desc_k = tl.make_tensor_descriptor( + base=K, + shape=[KV_LEN, QK_HEAD_DIM], + strides=[stride_kn, 1], + block_shape=[BLOCK_N, QK_HEAD_DIM_ROUNDED], + ) + + desc_v = tl.make_tensor_descriptor( + base=V, + shape=[KV_LEN, V_HEAD_DIM], + strides=[stride_vn, 1], + block_shape=[BLOCK_N, V_HEAD_DIM_ROUNDED], + ) + {%- endif %} + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + sparse_idx_hq = off_hq % SPARSE_HQ + + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M) + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + offs_m = q_start * BLOCK_M + tl.arange(0, BLOCK_M) + + # KV_IDX and KV_NUM_BLKS are always contiguous. + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + q_start // SPARSE_Q_MULTIPLE + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + (q_start // SPARSE_Q_MULTIPLE) * stride_kv_idx_m # noqa: B950 + K_block_ptr = None + V_block_ptr = None + Q_block_ptr = None + + if not USE_TMA: + Q_block_ptr = tl.make_block_ptr( + base=Q , + shape=(Q_LEN, QK_HEAD_DIM), + strides=(stride_qm, stride_qk), + offsets=(q_start * BLOCK_M, 0), + block_shape=(BLOCK_M, QK_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + {%- if USE_TMA %} + q = tl.load_tensor_descriptor( + desc_q, + [(q_start * BLOCK_M).to(tl.int32), 0], + ) + {%- else %} + q = load_checked_block(Q_block_ptr, IS_DIVISIBLE, SAFE_HEAD_DIM) + {%- endif %} + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We don't know anything "special" about these blocks, so we need to apply + # both score_mod and mask_mod to it + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + + offs_n = kv_start + tl.arange(0, BLOCK_N) + + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + block_n_end = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + if not USE_TMA: + K_block_ptr = tl.make_block_ptr( + base=K, + shape=(QK_HEAD_DIM, KV_LEN), + strides=(stride_kk, stride_kn), + offsets=(0, kv_start), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(kv_start, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = kv_start + tl.arange(0, BLOCK_N) + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, + desc_k, desc_v, Q_LEN, KV_LEN, + acc, l_i, m_i, + off_zq, off_hq, offs_m[:, None], offs_n[None, :], + kv_start, + kv_indices, kv_num_blocks, + 0, block_n_end, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + + # [Note] Handle fully masked out rows: + # Li will be the sum(e^(-inf)) == 0.0 for masked out rows, mi will be -inf. + # We set Li to 1.0 which will result in lse/out = 0.0 | after the log(li) + mi(0.0) step + l_i = tl.where(l_i == 0.0, 1, l_i) + + acc = acc / l_i[:, None] + idx_zq = tl.program_id(1) + idx_hq = tl.program_id(2) + idx_m = offs_m[:, None] + idx_d = tl.arange(0, V_HEAD_DIM_ROUNDED)[None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + + {{store_output(("idx_zq", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} + + if OUTPUT_LOGSUMEXP: + off_hz = off_zq * HQ + off_hq + l_ptrs = LSE + off_hz * Q_LEN + offs_m + lse = m_i + tl.math.log2(l_i) + if IS_DIVISIBLE: + tl.store(l_ptrs, lse) + else: + tl.store(l_ptrs, lse, mask=offs_m < Q_LEN) diff --git a/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja b/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja new file mode 100644 index 000000000000..1775833b8e68 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_backwards.py.jinja @@ -0,0 +1,682 @@ +{{def_kernel("Q", "K", "V", "LSE", "DELTA", "DO", "DQ", "DV", "KV_NUM_BLKS", "KV_IDX", "Q_NUM_BLKS", "Q_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX", "FULL_Q_NUM_BLKS", "FULL_Q_IDX")}} + # Sub notation for this kernel: + # + # Q: Query, K: Key, V: Value + # LSE: logsumexp (logsumexp is always stored in fp32 regardless of the input dtype) + # DELTA: Precomputed sum(OUT*DO, axis=-1) + # DO: Derivative of Output, DQ: Derivative of Query, DV: Derivative of Value + # DK: Derivative of Key, is the written to via the store_output call due to some limitations with + # inductor codegen + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # z: Batch size, h: Number of heads, m: Number of queries or keys/values, d: Head dim + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # (Modifiable) Performance tuning options + # BLOCK_M1: when calculating DK & DV, iterate over BLOCK_M1 across the seqlen dim of Q in each thread block. + # BLOCK_N1: when calculating DK & DV, the thread block size across the seqlen dim of K/V. + # BLOCK_M2: when calculating DQ, the thread block size across the seqlen dim of Q. + # BLOCK_N2: when calculating DQ, iterate over BLOCK_N2 across the seqlen dim of K/V in each thread block. + # + # The following FULL_* and PARTIAL_* is defined in the block sparse mask grid, rather than the thread block grid. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # Q_NUM_BLKS: The number of Q blocks (that may or may not require masking) for each query. + # Q_IDX: The indices of Q blocks (that may or may not require masking) for each query. + # FULL_KV_NUM_BLKS: The number of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_KV_IDX: The indices of fully unmasked KV blocks (so we don't need masking) for each query. + # FULL_Q_NUM_BLKS: The number of fully unmasked Q blocks (so we don't need masking) for each query. + # FULL_Q_IDX: The indices of fully unmasked Q blocks (so we don't need masking) for each query. + + # The below are kernel options that can be applied for certain score_mods, + # or involve a numerics vs. perf tradeoff + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. Has + # about 20% more numerical error, but slightly faster. + + # Define strides of inputs + stride_qz, stride_qh, stride_qm, stride_qd = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kd = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vd = {{stride("V")}} + stride_doz, stride_doh, stride_dom, stride_dod = {{stride("DO")}} + + stride_dqz, stride_dqh, stride_dqm, stride_dqd = {{stride("DQ")}} + stride_dvz, stride_dvh, stride_dvm, stride_dvd = {{stride("DV")}} + + ZQ = {{size("Q", 0)}} + HQ = {{size("Q", 1)}} + HKV = {{size("K", 1)}} + Q_LEN = {{size("Q", 2)}} + ZKV = {{size("K", 0)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + pid = tl.program_id(0) + NUM_KV_BLOCKS = tl.cdiv(KV_LEN, BLOCK_N1) + NUM_Q_BLOCKS = tl.cdiv(Q_LEN, BLOCK_M2) + + off_zq = tl.program_id(1) # q batch idx + off_hkv = tl.program_id(2) # kv head idx + off_zkv = off_zq % ZKV # kv batch idx + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_zq % SPARSE_Z + + k_adj = (stride_kh * off_hkv + stride_kz * off_zkv).to(tl.int64) + v_adj = (stride_vh * off_hkv + stride_vz * off_zkv).to(tl.int64) + # first compute broadcasted dv of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dv of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + dv_adj = (stride_dvh * off_hkv + stride_dvz * off_zq).to(tl.int64) + + # offset K, V, DV pointers for batch/kv-head + K += k_adj + V += v_adj + DV += dv_adj + + RCP_LN2 = 1.44269504 + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + if pid >= NUM_KV_BLOCKS: + off_pid = pid - NUM_KV_BLOCKS + # THIS BLOCK DOES DQ + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M2) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + off_hq2 = off_pid // NUM_Q_BLOCKS + off_hkv * GQA_SHARED_HEADS + start_m2_block = off_pid % NUM_Q_BLOCKS + off_pid_mask = start_m2_block // SPARSE_Q_MULTIPLE + stride_kv_num_blks_h = {{stride("KV_NUM_BLKS", 1)}} + stride_kv_idx_h = {{stride("KV_IDX", 1)}} + stride_kv_idx_m = {{stride("KV_IDX", 2)}} + + sparse_idx_hq2 = off_hq2 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq2 + + sparse_kv_num_blks_offset = sparse_hz_offset * stride_kv_num_blks_h + off_pid_mask + sparse_kv_idx_offset = sparse_hz_offset * stride_kv_idx_h + off_pid_mask * stride_kv_idx_m # noqa: B950 + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj2 = (stride_qh * off_hq2 + stride_qz * off_zq).to(tl.int64) + do_adj2 = (stride_doh * off_hq2 + stride_doz * off_zq).to(tl.int64) + dq_adj2 = (stride_dqh * off_hq2 + stride_dqz * off_zq).to(tl.int64) + off_chz2 = ((off_zq * HQ + off_hq2) * Q_LEN).to(tl.int64) + + Q2 = Q + q_adj2 + DO2 = DO + do_adj2 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + DQ2 = DQ + dq_adj2 + LSE2 = LSE + off_chz2 + DELTA2 = DELTA + off_chz2 + + # dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM], dtype=tl.float32) + dq = tl.zeros([BLOCK_M2, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_m2 = start_m2_block * BLOCK_M2 + offs_m2 = start_m2 + tl.arange(0, BLOCK_M2) + + # load Q and do: they stay in SRAM throughout the inner loop. + q = load_checked_2d(Q2, offs_m2, offs_k, stride_qm, stride_qd, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, QK_HEAD_DIM) + do = load_checked_2d(DO2, offs_m2, offs_v, stride_dom, stride_dod, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + q = (q * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + if IS_DIVISIBLE: + Di = tl.load(DELTA2 + offs_m2) + lse = tl.load(LSE2 + offs_m2) + else: + Di = tl.load(DELTA2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.load(LSE2 + offs_m2, mask=offs_m2 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + lse = lse[:, None] + + # ~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # KV_IDX and KV_NUM_BLKS are always contiguous. + kv_indices = KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~ partial unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_KV_IDX and FULL_KV_NUM_BLKS are always contiguous. + kv_indices = FULL_KV_IDX + sparse_kv_idx_offset + kv_start = tl.load(kv_indices) * SPARSE_KV_BLOCK_SIZE # first kv block we're loading + sparse_kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_kv_num_blks_offset) + + offs_n2 = kv_start + tl.arange(0, BLOCK_N2) + dq = bwd_dq_inner( + {{gen_argdefs()}}, + K, V, + dq, q, do, Di, lse, + off_zq, off_hq2, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dQ. + dq_ptrs = DQ2 + offs_m2[:, None] * stride_dqm + offs_k[None, :] * stride_dqd + dq *= SM_SCALE + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dq_ptrs, dq) + else: + tl.store(dq_ptrs, dq, mask=(offs_m2[:, None] < Q_LEN) & (offs_k[None, :] < QK_HEAD_DIM)) + else: + # THIS BLOCK DOES DK & DV + SPARSE_Q_MULTIPLE = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + SPARSE_KV_MULTIPLE = (SPARSE_KV_BLOCK_SIZE // BLOCK_N1) + + pid_mask = pid // SPARSE_KV_MULTIPLE + + stride_q_num_blks_h = {{stride("Q_NUM_BLKS", 1)}} + stride_q_idx_h = {{stride("Q_IDX", 1)}} + stride_q_idx_n = {{stride("Q_IDX", 2)}} + + + dv = tl.zeros([BLOCK_N1, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + dk = tl.zeros([BLOCK_N1, QK_HEAD_DIM_ROUNDED], dtype=tl.float32) + + start_n1 = pid * BLOCK_N1 + offs_n1 = start_n1 + tl.arange(0, BLOCK_N1) + + # load K and V: they stay in SRAM throughout the inner loop. + k = load_checked_2d(K, offs_n1, offs_k, stride_kn, stride_kd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, QK_HEAD_DIM) + v = load_checked_2d(V, offs_n1, offs_v, stride_vn, stride_vd, IS_DIVISIBLE, SAFE_HEAD_DIM, KV_LEN, V_HEAD_DIM) + + if PRESCALE_QK: + k = (k * SM_SCALE * RCP_LN2).to(MATMUL_PRECISION) + + for off_g in range(0, GQA_SHARED_HEADS): + off_hq1 = off_hkv * GQA_SHARED_HEADS + off_g + + # Offset Q, DQ, DO, DELTA & LSE. These inputs are offsetted by query heads. + q_adj1 = (stride_qh * off_hq1 + stride_qz * off_zq).to(tl.int64) + do_adj1 = (stride_doh * off_hq1 + stride_doz * off_zq).to(tl.int64) + dq_adj1 = (stride_dqh * off_hq1 + stride_dqz * off_zq).to(tl.int64) + off_chz1 = ((off_zq * HQ + off_hq1) * Q_LEN).to(tl.int64) + + Q1 = Q + q_adj1 + DO1 = DO + do_adj1 + # TODO: This does not work if DQ is not the same layout as Q (for example, + # if Q is broadcasted) + LSE1 = LSE + off_chz1 + DELTA1 = DELTA + off_chz1 + + sparse_idx_hq1 = off_hq1 % SPARSE_HQ + sparse_hz_offset = sparse_idx_z * SPARSE_HQ + sparse_idx_hq1 + + sparse_q_num_blks_offset = sparse_hz_offset * stride_q_num_blks_h + pid_mask + sparse_q_idx_offset = sparse_hz_offset * stride_q_idx_h + pid_mask * stride_q_idx_n # noqa: B950 + + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Q_IDX and Q_NUM_BLKS are always contiguous. + q_indices = Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + if HAS_FULL_BLOCKS: + # ~~~~~~~~~~~~~~~ fully unmasked blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # FULL_Q_IDX and FULL_Q_NUM_BLKS are always contiguous. + q_indices = FULL_Q_IDX + sparse_q_idx_offset + q_start = tl.load(q_indices) * SPARSE_Q_BLOCK_SIZE # first q block we're loading + sparse_q_num_blocks = tl.load(FULL_Q_NUM_BLKS + sparse_q_num_blks_offset) + + offs_m1 = q_start + tl.arange(0, BLOCK_M1) + dk, dv = bwd_dkdv_inner( + {{gen_argdefs()}}, + Q1, DO1, DELTA1, LSE1, + dk, dv, k, v, + off_zq, off_hq1, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + # Write back dV and dK. + dv_ptrs = DV + offs_n1[:, None] * stride_dvm + offs_v[None, :] * stride_dvd + + index_n = offs_n1[:, None] + index_k = offs_k[None, :] + index_v = offs_v[None, :] + + if IS_DIVISIBLE and SAFE_HEAD_DIM: + tl.store(dv_ptrs, dv) + else: + tl.store(dv_ptrs, dv, mask=(index_n < KV_LEN) & (index_v < V_HEAD_DIM)) + + dk *= SM_SCALE + + if SAFE_HEAD_DIM: + mask = index_n < KV_LEN + else: + mask = (index_n < KV_LEN) & (index_k < QK_HEAD_DIM) + + # first compute broadcasted dk of shape [Bq, Hkv, KV_LEN, V_HEAD_DIM] + # then reduce to dk of shape [Bkv, Hkv, KV_LEN, V_HEAD_DIM] + {{store_output(("off_zq", "off_hkv", "index_n", "index_k"), "dk", "mask", indent_width=8)}} + +@triton.jit +def bwd_dq_inner( + {{gen_argdefs()}}, + K, V, # pointers + dq, q, do, Di, lse, + off_z, off_hq, offs_m2, offs_n2, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N2) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + kT_ptrs = K + offs_n2[None, :] * stride_kn + offs_k[:, None] * stride_kd + vT_ptrs = V + offs_n2[None, :] * stride_vn + offs_v[:, None] * stride_vd + # BLOCK_M2 must be a multiple of BLOCK_N2, otherwise the code wouldn't work. + tl.static_assert(BLOCK_M2 % BLOCK_N2 == 0) + + hi = tl.minimum(sparse_kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N2), 1)) + if not IS_DIVISIBLE: + if hi >= 1: + for start_n in range(0, hi - 1): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_n in range(0, hi): + dq = bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + + # Increment pointers. + offset = get_offset_for_next_block( + start_n, kv_indices, sparse_kv_num_blocks, + SPARSE_KV_BLOCK_SIZE, SPARSE_KV_MULTIPLE, BLOCK_N2, BLOCKS_ARE_CONTIGUOUS + ) + + kT_ptrs += offset * stride_kn + vT_ptrs += offset * stride_vn + + offs_n2 += offset + + return dq + + +@triton.jit +def bwd_dq_block_mn( + {{gen_argdefs()}}, + dq, q, kT_ptrs, vT_ptrs, do, Di, lse, Q_LEN, KV_LEN, + off_z, off_hq, offs_m2, offs_n2, offs_k, offs_v, + stride_kn, stride_kd, stride_vn, stride_vd, + kv_indices, sparse_kv_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1)}} + + # NB reversed order to since K is transposed + kT = load_checked_2d(kT_ptrs, offs_k, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, KV_LEN) + qk = tl.dot(q, kT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qk *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + pre_mod_scores = qk + n = get_bounded_indices(offs_n2[None, :], KV_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across N dim + # that the M reads out of bounds prior to the last loop + m = get_bounded_indices(offs_m2[:, None], Q_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qk" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n2[None, :] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qk", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # apply mask for partial masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + p = tl.math.exp2(post_mod_scores - lse) + # Compute dP and dS. + # NB reversed order to since V is transposed + vT = load_checked_2d(vT_ptrs, offs_v, offs_n2, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, V_HEAD_DIM, KV_LEN) + + dp = tl.dot(do, vT, input_precision=FLOAT32_PRECISION) + ds = p * (dp - Di[:, None]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(1) }} + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n2[None, :] < KV_LEN, grad_scores, 0.0) + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if WRITE_DQ: + scatter_mask = (offs_m2[:, None] < Q_LEN ) & (offs_n2[None, :] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="ds" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = grad_scores + + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n2[None, :] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + ds = tl.where(mask_mod_output, ds, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + ds = ds.to(MATMUL_PRECISION) + # Compute dQ. + dq += tl.dot(ds, tl.trans(kT), input_precision=FLOAT32_PRECISION) + + return dq + + +@triton.jit +def bwd_dkdv_inner( + {{gen_argdefs()}}, + Q, DO, DELTA, LSE, # pointers + dk, dv, k, v, + off_z, off_hq, offs_n1, offs_m1, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, + IS_FULL_BLOCKS, +): + {{gen_defines() | indent_except_first(1) }} + SPARSE_Q_MULTIPLE: tl.constexpr = (SPARSE_Q_BLOCK_SIZE // BLOCK_M1) + RCP_LN2: tl.constexpr = 1.44269504 + Q_LEN = {{size("Q", 2)}} + KV_LEN = {{size("K", 2)}} + + offs_k = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_v = tl.arange(0, V_HEAD_DIM_ROUNDED) + + qT_ptrs = Q + offs_m1[None, :] * stride_qm + offs_k[:, None] * stride_qd + do_ptrs = DO + offs_m1[:, None] * stride_dom + offs_v[None, :] * stride_dod + # BLOCK_N1 must be a multiple of BLOCK_M1, otherwise the code wouldn't work. + tl.static_assert(BLOCK_N1 % BLOCK_M1 == 0) + hi = tl.minimum(sparse_q_num_blocks * SPARSE_Q_MULTIPLE, tl.maximum(tl.cdiv(Q_LEN, BLOCK_M1), 1)) + + if not IS_DIVISIBLE: + if hi >= 1: + for start_m in range(0, hi - 1): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=True, + ) + else: + for start_m in range(0, hi): + dk, dv = bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, + ) + # Increment pointers. + offset = get_offset_for_next_block( + start_m, q_indices, sparse_q_num_blocks, + SPARSE_Q_BLOCK_SIZE, SPARSE_Q_MULTIPLE, BLOCK_M1, BLOCKS_ARE_CONTIGUOUS + ) + + qT_ptrs += offset * stride_qm + do_ptrs += offset * stride_dom + + offs_m1 += offset + + return dk, dv + + +@triton.jit +def bwd_dkdv_block_mn( + {{gen_argdefs()}}, + dk, dv, qT_ptrs, k, v, do_ptrs, DELTA, LSE, Q_LEN, KV_LEN, + off_z, off_hq, offs_n1, offs_m1, offs_k, offs_v, + stride_qm, stride_qd, stride_dom, stride_dod, + q_indices, sparse_q_num_blocks, + MATMUL_PRECISION, RCP_LN2, + IS_FULL_BLOCKS, CHECK_BLOCK_BOUNDARY=False, +): + {{gen_defines() | indent_except_first(1) }} + + # NB reversed order since Q is transposed + qT = load_checked_2d(qT_ptrs, offs_k, offs_m1, None, None, SAFE_HEAD_DIM, IS_DIVISIBLE, QK_HEAD_DIM, Q_LEN) + # Load LSE before computing qk to reduce pipeline stall. + if IS_DIVISIBLE: + lse = tl.load(LSE + offs_m1) + else: + lse = tl.load(LSE + offs_m1, mask=offs_m1 < Q_LEN) + lse = tl.where(lse == -float("inf"), 0.0, lse) + qkT = tl.dot(k, qT, input_precision=FLOAT32_PRECISION) + if not PRESCALE_QK: + qkT *= SM_SCALE + # ~~~~~~~~~~~~~~~~~~~ Apply score modification ~~~~~~~~~~~~~~~~~~~ + m = get_bounded_indices(offs_m1[None, :], Q_LEN if CHECK_BLOCK_BOUNDARY else None) + # The boundary check is done for the outer loop, but here it's possible since we're iterating across M dim + # that the n reads out of bounds prior to the last loop + n = get_bounded_indices(offs_n1[:, None], KV_LEN if (not IS_DIVISIBLE or CHECK_BLOCK_BOUNDARY) else None) + + pre_mod_scores = qkT + {{ modification( + subgraph_number=0, + output_name="post_mod_scores", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + out="qkT" + ) | indent_except_first(1) }} + + if CHECK_BLOCK_BOUNDARY: + # Mask out the elements that are out of the KV_LEN for non divisible seqlen. + post_mod_scores = tl.where(offs_n1[:, None] < KV_LEN, post_mod_scores, float("-inf")) + + if not IS_FULL_BLOCKS: + {{ modification( + subgraph_number=2, + output_name="mask_mod_output", + score="qkT", + b="off_z", + h="off_hq", + m="m", + n="n", + ) | indent_except_first(2) }} + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for fully masked block + post_mod_scores = tl.where(mask_mod_output, post_mod_scores, float("-inf")) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + if not PRESCALE_QK: + post_mod_scores *= RCP_LN2 + pT = tl.math.exp2(post_mod_scores - lse[None, :]) + do = load_checked_2d(do_ptrs, offs_m1, offs_v, None, None, IS_DIVISIBLE, SAFE_HEAD_DIM, Q_LEN, V_HEAD_DIM) + # Compute dV. + ppT = pT + dv += tl.dot(ppT.to(MATMUL_PRECISION), do, input_precision=FLOAT32_PRECISION) + if IS_DIVISIBLE: + Di = tl.load(DELTA + offs_m1) + else: + Di = tl.load(DELTA + offs_m1, mask=offs_m1 < Q_LEN) + # Compute dP and dS. + dpT = tl.dot(v, tl.trans(do), input_precision=FLOAT32_PRECISION) + dsT = pT * (dpT - Di[None, :]) + # ~~~~~~~~~~~~~~~~~~~ Apply joint modification ~~~~~~~~~~~~~~~~~~~ + {{ modification( + subgraph_number=1, + output_name = "grad_scores", + score="pre_mod_scores", + b="off_z", + h="off_hq", + m="m", + n="n", + grad_score_mod="dsT" + ) | indent_except_first(1) }} + + # ~~~~~~~~~~~~~~~~~~~ Apply other buffer grad writes ~~~~~~~~~~~~~ + if not WRITE_DQ: + idx_b = off_z + idx_h = off_hq + idx_m = m + idx_n = n + scatter_mask = (offs_m1[None, :] < Q_LEN) & (offs_n1[:, None] < KV_LEN) + {{ modification( + subgraph_number=3, + output_name=None, + mask="scatter_mask", + score="pre_mod_scores", + b="idx_b", + h="idx_h", + m="idx_m", + n="idx_n", + grad_score_mod="dsT" + ) | indent_except_first(2) }} + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + + if CHECK_BLOCK_BOUNDARY: + grad_scores = tl.where(offs_n1[:, None] < KV_LEN, grad_scores, 0.0) + + dsT = grad_scores + if not IS_FULL_BLOCKS: + if CHECK_BLOCK_BOUNDARY: + mask_mod_output = tl.where(offs_n1[:, None] < KV_LEN, mask_mod_output, False) + # (grads) apply mask for partially unmasked block + dsT = tl.where(mask_mod_output, dsT, 0.0) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + dk += tl.dot(dsT.to(MATMUL_PRECISION), tl.trans(qT), input_precision=FLOAT32_PRECISION) + + return dk, dv \ No newline at end of file diff --git a/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja new file mode 100644 index 000000000000..f4596070c833 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/flex_decode.py.jinja @@ -0,0 +1,252 @@ + {{def_kernel("Q", "K", "V", "M", "L", "KV_NUM_BLKS", "KV_IDX", "FULL_KV_NUM_BLKS", "FULL_KV_IDX")}} + # Sub notation for this kernel: + # Q: Query, K: Key, V: Value + # reduction buffers: M rowmax across local KV split, L local sumexp across local KV split + # M: Number of queries, N: Number of keys/values + # QK_HEAD_DIM: The dimension of the query and key embeddings + # V_HEAD_DIM: The dimension of the value embeddings + # BLOCK_M, QK_HEAD_DIM: M, and D dimemsion are always assigned to the same block + # z: Batch size, h: Number of heads, m: Number of queries per head, k: Number of keys per head t: Number of kv splits + # (Modifiable) Config options: + # SPLIT_KV: number of blocks K & V are split into + # TILE_KV: length of each local KV split + # BLOCK_M: block size that Q is padded along seqlen dim. + # BLOCK_N: block size of K & V along N dimension. + # GQA_SHARED_HEADS: number of query heads sharing one kv head in GQA setups. + # + # change of base out of the loop + # ROWS_GUARANTEED_SAFE: Is it guaranteed that at least one value in each row + # is not masked out? If so, we can skip an extra safety check + # SAFE_M_BOUNDARY: Is Q seqlen a multiple of BLOCK_M? If so, we can skip an extra boundary check for loading query. + # SAFE_N_BOUNDARY: Is KV seqlen a multiple of BLOCK_N? If so, we can skip an extra boundary check for loading key/value. + + # PRESCALE_QK: Whether to pre-scale QK by 1/sqrt(d) and change of base. + # + # SPARSE_KV_BLOCK_SIZE: sparse mask block size along KV seqlen dim. + # KV_NUM_BLKS: The number of KV blocks (that may or may not require masking) for each query. + # KV_IDX: The indices of KV blocks (that may or may not require masking) for each query. + # + # + # Output: ACC output accumulated across local KV split. + + tl.static_assert(SPARSE_KV_BLOCK_SIZE >= BLOCK_N and SPARSE_KV_BLOCK_SIZE % BLOCK_N == 0) + + # Define Q Strides + stride_qz, stride_qh, stride_qg, stride_qm, stride_qk = {{stride("Q")}} + stride_kz, stride_kh, stride_kn, stride_kk = {{stride("K")}} + stride_vz, stride_vh, stride_vn, stride_vk = {{stride("V")}} + stride_mz, stride_mt, stride_mh, stride_mm = {{stride("M")}} + stride_lz, stride_lt, stride_lh, stride_lm = {{stride("L")}} + + + Z = {{size("Q", 0)}} + ZKV = {{size("K", 0)}} + HKV = {{size("Q", 1)}} + G: tl.constexpr = GQA_SHARED_HEADS + HQ = HKV * G + Q_LEN = {{size("Q", 3)}} + KV_LEN = {{size("K", 2)}} + + MATMUL_PRECISION = Q.dtype.element_ty + + # Make sure each split is a multiple of BLOCK_N + TILE_KV_OG = tl.cdiv(KV_LEN, SPLIT_KV) + TILE_KV = tl.cdiv(TILE_KV_OG, BLOCK_N) * BLOCK_N + TILE_KV_MULTIPLE: tl.constexpr = (TILE_KV // BLOCK_N) + + off_z = tl.program_id(0) // HKV + off_zkv = off_z % ZKV + off_hkv = tl.program_id(0) % HKV + off_t = tl.program_id(1) + + q_offset = off_z * stride_qz + off_hkv * stride_qh + k_offset = off_zkv * stride_kz + off_hkv * stride_kh + v_offset = off_zkv * stride_vz + off_hkv * stride_vh + + SPARSE_Z = {{size("KV_NUM_BLKS", 0)}} + SPARSE_HQ = {{size("KV_NUM_BLKS", 1)}} + + sparse_idx_z = off_z % SPARSE_Z + sparse_idx_h = off_hkv % SPARSE_HQ + + SPARSE_KV_MULTIPLE: tl.constexpr = (SPARSE_KV_BLOCK_SIZE // BLOCK_N) + SPARSE_KV_BLOCK_CNT = tl.cdiv(KV_LEN, SPARSE_KV_BLOCK_SIZE) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, V_HEAD_DIM_ROUNDED], dtype=tl.float32) + + # initialize offsets + tl.device_assert(BLOCK_M % G == 0) + BLOCK_M_PER_HQ: tl.constexpr = BLOCK_M // G + off_g = tl.arange(0, G) # [G] + offs_g = tl.ravel(tl.broadcast_to(off_g[:, None], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_hq = offs_g + off_hkv * G + off_m = tl.arange(0, BLOCK_M_PER_HQ) # [BLOCK_M_PER_HQ] + offs_m = tl.ravel(tl.broadcast_to(off_m[None, :], [G, BLOCK_M_PER_HQ])) # [BLOCK_M] + offs_d = tl.arange(0, QK_HEAD_DIM_ROUNDED) + offs_vd = tl.arange(0, V_HEAD_DIM_ROUNDED) + + # Get HZ offsets for KV_NUM_BLKS and KV_IDX + stride_block_z, stride_block_h, stride_block_row = {{stride("KV_NUM_BLKS")}} + sparse_block_hz_offset = sparse_idx_z * stride_block_z + sparse_idx_h * stride_block_h + stride_kv_z, stride_kv_h, stride_kv_row, stride_kv_col = {{stride("KV_IDX")}} + sparse_idx_hz_offset = sparse_idx_z * stride_kv_z + sparse_idx_h * stride_kv_h + + # Calculate KV blocks that belong this CTA. + block_n_start = off_t * TILE_KV_MULTIPLE # n_offset inside sparse block + block_n_end = block_n_start + TILE_KV_MULTIPLE # end BLOCK_N + + q_range = stride_qg * off_g[:, None, None] + stride_qm * off_m[None, :, None] + stride_qk * offs_d[None, None, :] + + if not SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=(offs_d[None, None, :] < QK_HEAD_DIM) & (off_m[None, :, None] < Q_LEN)) + elif SAFE_M_BOUNDARY and not SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=offs_d[None, None, :] < QK_HEAD_DIM) + elif not SAFE_M_BOUNDARY and SAFE_HEAD_DIM: + q = tl.load(Q + q_offset + q_range, mask=off_m[None, :, None] < Q_LEN) + else: + q = tl.load(Q + q_offset + q_range) + + q = tl.reshape(q, [BLOCK_M, QK_HEAD_DIM_ROUNDED]) + + + # ~~~~~~~~~~~~~~ normal blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # Apply both score_mod and mask_mod + + # find first kv block we are loading and the number of blocks we are loading + # Offset the kv_indices tensor by the correct batch and head + kv_indices = KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset) + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + # first kv block we're loading + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=False, + ) + + + # ~~~~~~~~~~~~~~ "full" blocks ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + # We know these blocks are guaranteed to be "full", so we don't need to + # apply mask_mod to them - only score_mod + if HAS_FULL_BLOCKS: + kv_indices = FULL_KV_IDX + sparse_idx_hz_offset + kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset) + # Assign full block in a reverse order for off_t. Prioritize the last CTA. + block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE + block_n_end = block_n_start + TILE_KV_MULTIPLE + indices_idx = block_n_start // SPARSE_KV_MULTIPLE + off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE + off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N + + # last valid block according to sparse mask + block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1)) + + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(QK_HEAD_DIM, KV_LEN), # (d, N) + strides=(stride_kk, stride_kn), + offsets=(0, off_n), + block_shape=(QK_HEAD_DIM_ROUNDED, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(KV_LEN, V_HEAD_DIM), + strides=(stride_vn, stride_vk), + offsets=(off_n, 0), + block_shape=(BLOCK_N, V_HEAD_DIM_ROUNDED), + order=(1, 0) + ) + offs_n = tl.arange(0, BLOCK_N) + off_n + + acc, l_i, m_i = forward_inner( + {{gen_argdefs()}}, + q, K_block_ptr, V_block_ptr, None, None, Q_LEN, KV_LEN, + # accumulatd values + acc, l_i, m_i, + #offsets + off_z, offs_hq[:, None], offs_m[:, None], offs_n[None, :], + None, + #block sparse data + kv_indices, kv_num_blocks, + block_n_start, block_n_end if block_n_end <= block_n_last_valid else block_n_last_valid, + MATMUL_PRECISION, + IS_FULL_BLOCKS=True, + ) + + m_offset = off_t * stride_mt + off_z * stride_mz + l_offset = off_t * stride_lt + off_z * stride_lz + + M_block_ptr = tl.make_block_ptr( + base=M + m_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_mh, stride_mm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + L_block_ptr = tl.make_block_ptr( + base=L + l_offset, + shape=(G, Q_LEN), # (G, M) + strides=(stride_lh, stride_lm), + offsets=(off_hkv*G, 0), + block_shape=(G, BLOCK_M_PER_HQ), + order=(1, 0) + ) + + # Store output, logsumexp and rowmax for cross CTA reduction. (all in float32, even when input data are in fp16) + m_i = m_i.reshape(G, BLOCK_M_PER_HQ) + l_i = l_i.reshape(G, BLOCK_M_PER_HQ) + if SAFE_M_BOUNDARY: + tl.store(M_block_ptr, m_i) + tl.store(L_block_ptr, l_i) + else: + tl.store(M_block_ptr, m_i, boundary_check=(1,)) + tl.store(L_block_ptr, l_i, boundary_check=(1,)) + + # -- store output + idx_z = off_z + idx_t = off_t + idx_hq = off_hkv*G + off_g[:, None, None] + idx_m = off_m[None, :, None] + idx_d = offs_vd[None, None, :] + + mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM) + acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM) + {{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}} \ No newline at end of file diff --git a/torch/_inductor/kernel/flex/templates/utilities.py.jinja b/torch/_inductor/kernel/flex/templates/utilities.py.jinja new file mode 100644 index 000000000000..7e2367e4f269 --- /dev/null +++ b/torch/_inductor/kernel/flex/templates/utilities.py.jinja @@ -0,0 +1,59 @@ + + +# Utility triton funcs +@triton.jit +def get_offset_for_next_block( + loop_iter, col_indices, total_blocks, + SPARSE_BLOCK, SPARSE_BLOCK_MULTIPLE, BLOCK, + BLOCKS_ARE_CONTIGUOUS: tl.constexpr +): + if BLOCKS_ARE_CONTIGUOUS: + return BLOCK + cur_block_idx = loop_iter // SPARSE_BLOCK_MULTIPLE + cur_block = tl.load(col_indices + cur_block_idx, eviction_policy="evict_last") + next_block = tl.load(col_indices + cur_block_idx + 1, eviction_policy="evict_last", mask=cur_block_idx + 1 < total_blocks) + needs_jump = (loop_iter + 1) % SPARSE_BLOCK_MULTIPLE == 0 + jump_to_block = (next_block - cur_block ) * SPARSE_BLOCK - (SPARSE_BLOCK_MULTIPLE - 1) * BLOCK + offset = jump_to_block * needs_jump + (1 - needs_jump) * BLOCK + return offset + +@triton.jit +def get_bounded_indices(indices, max_len=None): + return indices % max_len if max_len is not None else indices + +@triton.jit +def load_checked_block(block_ptr, IS_DIVISIBLE: tl.constexpr, SAFE_HEAD_DIM: tl.constexpr): + if IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr) + elif IS_DIVISIBLE and not SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(1,), padding_option="zero") + elif not IS_DIVISIBLE and SAFE_HEAD_DIM: + return tl.load(block_ptr, boundary_check=(0,), padding_option="zero") + else: + return tl.load(block_ptr, boundary_check=(0, 1), padding_option="zero") + +@triton.jit +def load_checked_2d( + ptr, + offs_m, + offs_n, + stride_m, + stride_n, + IS_DIVISIBLE_M: tl.constexpr, + IS_DIVISIBLE_N: tl.constexpr, + M_LEN: tl.constexpr, + N_DIM: tl.constexpr, +): + # Calculate final pointer if strides are provided + if stride_m is not None and stride_n is not None: + ptr = ptr + offs_m[:, None] * stride_m + offs_n[None, :] * stride_n + + # Handle all masking cases + if not IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN) & (offs_n[None, :] < N_DIM), other=0.0) + elif IS_DIVISIBLE_M and not IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_n[None, :] < N_DIM), other=0.0) + elif not IS_DIVISIBLE_M and IS_DIVISIBLE_N: + return tl.load(ptr, mask=(offs_m[:, None] < M_LEN), other=0.0) + else: # Both divisible + return tl.load(ptr)