Compare commits

...

3 Commits

Author SHA1 Message Date
02a752aa88 lint 2025-08-21 23:12:40 -07:00
f4424796b5 nit 2025-08-21 23:12:23 -07:00
773fd6b146 init 2025-08-20 10:35:24 -07:00
2 changed files with 33 additions and 4 deletions

View File

@ -1526,6 +1526,19 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1):
self.run_test(score_mod, device=device)
self.run_test_with_paged_attention(score_mod, device=device)
self.run_test_with_paged_attention(
score_mod=score_mod,
dtype=torch.bfloat16,
Q_B=4,
Q_H=1,
Q_S=1,
QK_D=16,
KV_B=4,
KV_H=1,
KV_S=64,
V_D=16,
device=device,
)
@supported_platform
@patch.object(torch._inductor.config, "max_autotune", True)

View File

@ -55,6 +55,10 @@ class PagedAttention:
# capacity: batch_idx -> allocated sequence length
self.capacity = torch.zeros(max_batch_size, dtype=torch.int64, device=device)
self.seq_lens: torch.Tensor = torch.zeros(
max_batch_size, dtype=torch.int64, device=device
)
# index of empty pages that is available for allocation
self.empty_pages = list(range(n_pages - 1, -1, -1))
@ -72,9 +76,11 @@ class PagedAttention:
batch_idx (Tensor): batch index to be reserved; shape :math:`(1)`.
seq_len (Tensor): minimum capacity for the given batch; shape :math:`(1)`.
"""
self.seq_lens[batch_idx] = torch.max(seq_len, self.seq_lens[batch_idx])
if seq_len <= self.capacity[batch_idx]:
return
self.seq_lens[batch_idx] = seq_len
num_pages_to_allocate = _cdiv(
seq_len - self.capacity[batch_idx], self.page_size
@ -297,9 +303,13 @@ class PagedAttention:
physical_kv_offset = physical_kv_idx % self.page_size
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
return torch.where(
logical_block_idx >= 0, mask_mod(b, h, q_idx, logical_kv_idx), False
)
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[b]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
return torch.where(is_valid, mask_mod(b, h, q_idx, logical_kv_idx), False)
return new_mask_mod
@ -327,8 +337,14 @@ class PagedAttention:
physical_kv_offset = physical_kv_idx % self.page_size
logical_block_idx = self.physical_to_logical[b, physical_kv_block]
logical_kv_idx = logical_block_idx * self.page_size + physical_kv_offset
live_block = logical_block_idx >= 0
within_upper_bound = logical_kv_idx < self.seq_lens[b]
within_lower_bound = logical_kv_idx >= 0
is_valid = live_block & within_upper_bound & within_lower_bound
return torch.where(
logical_block_idx >= 0,
is_valid,
score_mod(score, b, h, q_idx, logical_kv_idx),
float("-inf"),
)