[Bugfix] fix tmp_out and exp_sums dimensions (#17438)

Signed-off-by: Hui Liu <96135754+hliuca@users.noreply.github.com>
This commit is contained in:
Hui Liu
2025-05-02 09:44:07 -07:00
committed by GitHub
parent cb234955df
commit 4c33d67321

View File

@ -289,7 +289,7 @@ def chunked_prefill_paged_decode(
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = query.shape[0]
total_num_seq = block_table.shape[0]
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),