Compare commits

...

13 Commits

6 changed files with 80 additions and 17 deletions

View File

@ -659,7 +659,10 @@ class TestFlexAttention(InductorTestCase):
paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
# convert block mask and score mod
converted_block_mask = paged_attention.convert_logical_block_mask(block_mask)
kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64)
converted_block_mask = paged_attention.convert_logical_block_mask(
block_mask, kv_len=kv_len_tensor
)
converted_score_mod = paged_attention.get_score_mod(score_mod)
return k_cache, v_cache, converted_block_mask, converted_score_mod
@ -2381,6 +2384,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test_with_paged_attention(
score_mod, dtype=torch.float16, device=device
)
self.run_test_with_paged_attention(
score_mod=score_mod,
dtype=torch.bfloat16,
KV_S=64,
device=device,
)
@supported_platform
@skip("TODO: Figure out why this is erroring")
@ -5129,7 +5138,12 @@ class TestPagedAttention(InductorTestCase):
block_mask = create_block_mask(
causal_mask, max_batch_size, 1, max_seq_len, max_seq_len, device=device
)
new_block_mask = paged_cache.convert_logical_block_mask(block_mask)
kv_len_tensor = torch.full(
(max_batch_size,), max_seq_len, device=device, dtype=torch.int64
)
new_block_mask = paged_cache.convert_logical_block_mask(
block_mask, kv_len=kv_len_tensor
)
zeros = [0, 0, 0, 0]
# Check that the new block mask is correct
@ -5404,7 +5418,9 @@ class TestPagedAttention(InductorTestCase):
)
paged_cache.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
new_block_mask = paged_cache.convert_logical_block_mask(block_mask)
new_block_mask = paged_cache.convert_logical_block_mask(
block_mask, kv_len=max_seq_len
)
compiled_sdpa = torch.compile(
create_attention(

View File

@ -540,7 +540,10 @@ class TestFlexDecoding(InductorTestCase):
paged_attention.assign(batch_idx, input_pos, k, v, k_cache, v_cache)
# convert block mask and score mod
converted_block_mask = paged_attention.convert_logical_block_mask(block_mask)
kv_len_tensor = torch.full((KV_B,), KV_S, device=device, dtype=torch.int64)
converted_block_mask = paged_attention.convert_logical_block_mask(
block_mask, kv_len=kv_len_tensor
)
converted_score_mod = paged_attention.get_score_mod(score_mod)
return k_cache, v_cache, converted_block_mask, converted_score_mod
@ -1526,6 +1529,19 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test(score_mod, device=device)
self.run_test_with_paged_attention(score_mod, device=device)
self.run_test_with_paged_attention(
score_mod=score_mod,
dtype=torch.bfloat16,
Q_B=4,
Q_H=1,
Q_S=1,
QK_D=16,
KV_B=4,
KV_H=1,
KV_S=64,
V_D=16,
device=device,
)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)
@ -1993,7 +2009,12 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
input_pos = torch.tensor(prefill_length, device=device, dtype=torch.int32).view(
max_batch_size, 1
)
new_block_mask = paged_cache.convert_logical_block_mask(block_mask)
kv_len_tensor = torch.full(
(max_batch_size,), max_seq_len, device=device, dtype=torch.int64
)
new_block_mask = paged_cache.convert_logical_block_mask(
block_mask, kv_len=kv_len_tensor
)
new_block_mask.seq_lengths = (1, new_block_mask.seq_lengths[1])
compiled_sdpa = torch.compile(
create_attention(

View File

@ -792,7 +792,7 @@ class CppFlexAttentionTemplate(CppTemplate):
return ""
if start_offset == -1:
start_offset = getattr(self, len_attr)
start_offset = self.len_score_other
length = getattr(self, len_attr)
for i in range(length):
@ -995,9 +995,9 @@ class CppFlexAttentionTemplate(CppTemplate):
value=value,
kv_num_blocks=self.input_nodes[3],
kv_indices=self.input_nodes[4],
full_kv_num_blocks=self.input_nodes[5]
if not self.no_full_kv_block
else None,
full_kv_num_blocks=(
self.input_nodes[5] if not self.no_full_kv_block else None
),
full_kv_indices=self.input_nodes[6] if not self.no_full_kv_block else None,
score_mod_other_buffers=self.score_mod_other_buffers,
mask_mod_other_buffers=self.mask_mod_other_buffers,

View File

@ -215,6 +215,11 @@ def create_flex_decoding_kernel(*args, **kwargs):
kernel_options.setdefault("SPLIT_KV", get_split_k(B, Hkv, seq_len_kv))
MAX_SPLIT_KV = kernel_options["SPLIT_KV"]
# Calculate the maximum valid KV index to prevent invalid memory access
# This is based on the actual number of KV blocks available
max_kv_blocks = V.graph.sizevars.size_hint(kv_indices.get_size()[-1], fallback=5)
kernel_options.setdefault("MAX_VALID_KV_IDX", max_kv_blocks - 1)
# create config dependent intermediate buffers
buf_ACC_shape = [B, MAX_SPLIT_KV, Hq, seq_len_q, v_head_dim]
buf_ML_shape = buf_ACC_shape[:-1]

View File

@ -119,7 +119,11 @@
# Offset the kv_indices tensor by the correct batch and head
kv_indices = KV_IDX + sparse_idx_hz_offset
kv_num_blocks = tl.load(KV_NUM_BLKS + sparse_block_hz_offset)
MAX_KV_IDX = {{size("KV_IDX", -1)}} - 1
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
# Early exit: skip CTA computation if indices_idx exceeds available KV blocks to prevent invalid memory access
if indices_idx > MAX_KV_IDX:
return
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
@ -166,13 +170,18 @@
# apply mask_mod to them - only score_mod
if HAS_FULL_BLOCKS:
kv_indices = FULL_KV_IDX + sparse_idx_hz_offset
MAX_KV_IDX = {{size("FULL_KV_IDX", -1)}} - 1
kv_num_blocks = tl.load(FULL_KV_NUM_BLKS + sparse_block_hz_offset)
# Assign full block in a reverse order for off_t. Prioritize the last CTA.
block_n_start = (SPLIT_KV - off_t - 1) * TILE_KV_MULTIPLE
block_n_end = block_n_start + TILE_KV_MULTIPLE
indices_idx = block_n_start // SPARSE_KV_MULTIPLE
# Early exit: skip CTA computation if indices_idx exceeds available KV blocks to prevent invalid memory access
if indices_idx > MAX_KV_IDX:
return
off_n_block_in_sparse = block_n_start % SPARSE_KV_MULTIPLE
off_n = tl.load(kv_indices + indices_idx) * SPARSE_KV_BLOCK_SIZE + off_n_block_in_sparse * BLOCK_N
# first kv block we're loading
# last valid block according to sparse mask
block_n_last_valid = tl.minimum(kv_num_blocks * SPARSE_KV_MULTIPLE, tl.maximum(tl.cdiv(KV_LEN, BLOCK_N), 1))
@ -210,6 +219,7 @@
IS_FULL_BLOCKS=True,
)
m_offset = off_t * stride_mt + off_z * stride_mz
l_offset = off_t * stride_lt + off_z * stride_lz
@ -249,4 +259,4 @@
mask = (idx_m < Q_LEN) & (idx_d < V_HEAD_DIM)
acc = acc.reshape(G, BLOCK_M_PER_HQ, V_HEAD_DIM)
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}
{{store_output(("idx_z", "idx_t", "idx_hq", "idx_m", "idx_d"), "acc", "mask")}}

View File

@ -116,7 +116,6 @@ class PagedAttention:
Args:
batch_idx (Tensor): batch index to be removed; shape :math:`(1)`.
"""
# find allocated pages
allocated_page_idx = self.page_table[batch_idx] != -1
allocated_pages = self.page_table[batch_idx][allocated_page_idx]
@ -182,12 +181,13 @@ class PagedAttention:
logical_block_offset = input_pos % self.page_size # [B, S]
physical_block_idx = torch.gather(
self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)
).to(torch.int32) # [B, S]
).to(
torch.int32
) # [B, S]
addr = (physical_block_idx * self.page_size + logical_block_offset).view(
-1
) # [B*S]
k_val = k_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, K_D)
v_val = v_val.permute(1, 0, 2, 3).contiguous().view(1, H, B * S, V_D)
@ -198,6 +198,7 @@ class PagedAttention:
self,
block_mask: BlockMask,
batch_idx: Optional[torch.Tensor] = None,
kv_len: Optional[torch.Tensor] = None,
) -> BlockMask:
"""
Converts a logical block mask by mapping its logical kv indices to the corresponding
@ -210,6 +211,8 @@ class PagedAttention:
batch dimension. This provides flexibility to convert a
block mask with smaller batch size than the page table;
shape :math:`(B)`.
kv_len (Optional[Tensor]): actual KV sequence length for upper bound check;
shape :math:`(B,)` to handle multiple batches.
"""
B, H, ROWS, MAX_BLOCKS_IN_COL = block_mask.kv_indices.shape
@ -261,7 +264,7 @@ class PagedAttention:
.to(torch.int32)
)
new_mask_mod = self.get_mask_mod(block_mask.mask_mod)
new_mask_mod = self.get_mask_mod(block_mask.mask_mod, kv_len)
seq_lengths = (block_mask.seq_lengths[0], self.n_pages * self.page_size)
return BlockMask.from_kv_blocks(
@ -275,7 +278,9 @@ class PagedAttention:
)
def get_mask_mod(
self, mask_mod: Optional[_mask_mod_signature]
self,
mask_mod: Optional[_mask_mod_signature],
kv_len: Optional[torch.Tensor] = None,
) -> _mask_mod_signature:
"""
Converts a mask_mod based on mapping from the physical block index to the logical
@ -283,6 +288,7 @@ class PagedAttention:
Args:
mask_mod (_mask_mod_signature): mask_mod based on the logical block index.
kv_len (Optional[torch.Tensor]): actual KV sequence length for upper bound check.
"""
if mask_mod is None:
mask_mod = noop_mask
@ -297,9 +303,14 @@ class PagedAttention:
physical_kv_offset = physical_kv_idx % self.page_size
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
return torch.where(
logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
live_block = logical_block_idx >= 0
within_upper_bound = (
logical_kv_idx < kv_len[b] if kv_len is not None else True
)
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
return torch.where(is_valid, mask_mod(b, h, q_idx, logical_kv_idx), False)
return new_mask_mod