[Bugfix] Enable PP with AITER+V1 (#19822)

Signed-off-by: Qiang Li <qiang.li2@amd.com>
This commit is contained in:
qli88
2025-06-19 23:43:20 -05:00
committed by GitHub
parent e41bf15cd0
commit e3a3e4db46
2 changed files with 3 additions and 11 deletions

View File

@ -45,7 +45,6 @@ def fused_add_rms_norm(
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter
if x.dim() > 2:
x_original_shape = x.shape

View File

@ -201,16 +201,9 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
if self.num_heads == 16:
# AITER MLA decode kernel only supports
# max_seqlen_q=1 when using 16 heads.
max_seqlen_qo = 1
else:
# AITER MLA decode Kernel handles arbitrary
# max_seqlen_q values when using 128 heads.
assert attn_metadata.prefill is not None
max_seqlen_qo = attn_metadata.prefill.max_query_len
# max_seqlen_qo must be 1 except for MTP
# TODO: Find the best value for MTP
max_seqlen_qo = 1
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
attn_metadata.decode.qo_indptr, max_seqlen_qo,
attn_metadata.decode.paged_kv_indptr,