[Bugfix] Enable PP with AITER+V1 (#19822)
Signed-off-by: Qiang Li <qiang.li2@amd.com>
This commit is contained in:
@ -45,7 +45,6 @@ def fused_add_rms_norm(
|
||||
|
||||
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
|
||||
variance_epsilon: float) -> torch.Tensor:
|
||||
|
||||
import aiter as rocm_aiter
|
||||
if x.dim() > 2:
|
||||
x_original_shape = x.shape
|
||||
|
@ -201,16 +201,9 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
if self.num_heads == 16:
|
||||
# AITER MLA decode kernel only supports
|
||||
# max_seqlen_q=1 when using 16 heads.
|
||||
max_seqlen_qo = 1
|
||||
else:
|
||||
# AITER MLA decode Kernel handles arbitrary
|
||||
# max_seqlen_q values when using 128 heads.
|
||||
assert attn_metadata.prefill is not None
|
||||
max_seqlen_qo = attn_metadata.prefill.max_query_len
|
||||
|
||||
# max_seqlen_qo must be 1 except for MTP
|
||||
# TODO: Find the best value for MTP
|
||||
max_seqlen_qo = 1
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.decode.qo_indptr, max_seqlen_qo,
|
||||
attn_metadata.decode.paged_kv_indptr,
|
||||
|
Reference in New Issue
Block a user