mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix][AMD] Compatible patch for AITER lib after 04/20 (#17912)
Signed-off-by: Qiang Li <qiang.li2@amd.com>
This commit is contained in:
@ -53,7 +53,7 @@ class AiterMLABackend(MLACommonBackend):
|
||||
|
||||
@dataclass
|
||||
class AiterMLAMetadata(MLACommonMetadata):
|
||||
# The following 4 tensors are for current version of AITER MLA
|
||||
# The following 5 tensors are for current version of AITER MLA
|
||||
block_table_bound: Optional[torch.Tensor] = None
|
||||
# The indptr of the paged kv cache, shape: [batch_size + 1]
|
||||
paged_kv_indptr: Optional[torch.Tensor] = None
|
||||
@ -63,6 +63,10 @@ class AiterMLAMetadata(MLACommonMetadata):
|
||||
# the paged kv cache, shape: [batch_size]
|
||||
paged_kv_last_page_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# This is just to make new AITER MLA API work
|
||||
# -- MTP support is not added yet.
|
||||
qo_indptr: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def prefill_metadata(self):
|
||||
prefill_metadata = super().prefill_metadata
|
||||
@ -74,6 +78,7 @@ class AiterMLAMetadata(MLACommonMetadata):
|
||||
prefill_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
prefill_metadata.block_table_bound = self.block_table_bound
|
||||
prefill_metadata.qo_indptr = self.qo_indptr
|
||||
|
||||
# update the cache
|
||||
self._cached_prefill_metadata = self.__class__(
|
||||
@ -93,6 +98,7 @@ class AiterMLAMetadata(MLACommonMetadata):
|
||||
decode_metadata\
|
||||
.paged_kv_last_page_lens = self.paged_kv_last_page_lens
|
||||
decode_metadata.block_table_bound = self.block_table_bound
|
||||
decode_metadata.qo_indptr = self.qo_indptr
|
||||
|
||||
# update the cache
|
||||
self._cached_decode_metadata = self.__class__(
|
||||
@ -136,6 +142,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
self.paged_kv_indptr: list[int] = [0]
|
||||
self.paged_kv_last_page_lens: list[int] = []
|
||||
self.total_blocks = 0
|
||||
self.qo_indptr: list[int] = [0]
|
||||
|
||||
def _add_seq_group(self, inter_data, chunked_prefill_enabled: bool,
|
||||
prefix_cache_hit: bool):
|
||||
@ -208,6 +215,7 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
self.paged_kv_indices.extend(block_table[:block_table_bound])
|
||||
self.paged_kv_indptr.append(self.paged_kv_indptr[-1] +
|
||||
block_table_bound)
|
||||
self.qo_indptr.append(self.qo_indptr[-1] + 1)
|
||||
|
||||
last_page_len = seq_len % self.block_size
|
||||
if last_page_len == 0:
|
||||
@ -226,6 +234,8 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
self.paged_kv_indptr.extend([last_paged_kv_indptr] *
|
||||
cuda_graph_pad_size)
|
||||
self.paged_kv_last_page_lens.extend([0] * cuda_graph_pad_size)
|
||||
last_qo_indptr = self.qo_indptr[-1]
|
||||
self.qo_indptr.extend([last_qo_indptr] * cuda_graph_pad_size)
|
||||
|
||||
# For current version of AITER MLA
|
||||
if len(self.paged_kv_indptr) > 0:
|
||||
@ -245,16 +255,22 @@ class AiterMLAMetadataBuilder(MLACommonMetadataBuilder[AiterMLAMetadata]):
|
||||
1,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
|
||||
qo_indptr = torch.tensor(self.qo_indptr,
|
||||
device=device,
|
||||
dtype=torch.int)
|
||||
else:
|
||||
paged_kv_indices_tensor = None
|
||||
paged_kv_indptr_tensor = None
|
||||
paged_kv_last_page_lens_tensor = None
|
||||
block_table_bound_tensor = None
|
||||
qo_indptr = None
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr_tensor
|
||||
metadata.paged_kv_indices = paged_kv_indices_tensor
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens_tensor
|
||||
metadata.block_table_bound = block_table_bound_tensor
|
||||
metadata.qo_indptr = qo_indptr
|
||||
|
||||
return metadata
|
||||
|
||||
@ -263,14 +279,17 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
|
||||
@contextmanager
|
||||
def graph_capture(self, max_batch_size: int):
|
||||
kv_indices, kv_indptr, last_page_lens = get_aiter_mla_metadata(
|
||||
max_batch_size=max_batch_size,
|
||||
block_size=self.runner.block_size,
|
||||
max_block_per_batch=self.runner.get_max_block_per_batch(),
|
||||
device=self.runner.device)
|
||||
kv_indices, kv_indptr, last_page_lens, qo_indptr = \
|
||||
get_aiter_mla_metadata(
|
||||
max_batch_size=max_batch_size,
|
||||
block_size=self.runner.block_size,
|
||||
max_block_per_batch=\
|
||||
self.runner.get_max_block_per_batch(),
|
||||
device=self.runner.device)
|
||||
self._paged_kv_indices_tensor = kv_indices
|
||||
self._paged_kv_indptr_tensor = kv_indptr
|
||||
self._paged_kv_last_page_lens_tensor = last_page_lens
|
||||
self._qo_indptr_tensor = qo_indptr
|
||||
|
||||
with super().graph_capture(max_batch_size):
|
||||
yield
|
||||
@ -278,6 +297,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
del self._paged_kv_indices_tensor
|
||||
del self._paged_kv_indptr_tensor
|
||||
del self._paged_kv_last_page_lens_tensor
|
||||
del self._qo_indptr_tensor
|
||||
|
||||
def graph_capture_get_metadata_for_batch(
|
||||
self,
|
||||
@ -291,10 +311,12 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
paged_kv_indices = self._paged_kv_indices_tensor
|
||||
paged_kv_last_page_lens = self._paged_kv_last_page_lens_tensor[:
|
||||
batch_size]
|
||||
qo_indptr = self._qo_indptr_tensor[:batch_size + 1]
|
||||
|
||||
metadata.paged_kv_indptr = paged_kv_indptr
|
||||
metadata.paged_kv_indices = paged_kv_indices
|
||||
metadata.paged_kv_last_page_lens = paged_kv_last_page_lens
|
||||
metadata.qo_indptr = qo_indptr
|
||||
|
||||
return metadata
|
||||
|
||||
@ -311,6 +333,7 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
input_buffers[
|
||||
"paged_kv_last_page_lens"] = attn_metadata.\
|
||||
decode_metadata.paged_kv_last_page_lens
|
||||
input_buffers['qo_indptr'] = attn_metadata.qo_indptr
|
||||
|
||||
return input_buffers
|
||||
|
||||
@ -330,6 +353,8 @@ class AiterMLAState(MLACommonState[AiterMLAMetadata]):
|
||||
input_buffers["paged_kv_last_page_lens"].copy_(
|
||||
attn_metadata.decode_metadata.paged_kv_last_page_lens,
|
||||
non_blocking=True)
|
||||
input_buffers["qo_indptr"].copy_(
|
||||
attn_metadata.decode_metadata.qo_indptr, non_blocking=True)
|
||||
|
||||
|
||||
class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
@ -370,11 +395,9 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
softmax_scale: float, return_softmax_lse: bool,
|
||||
**kwargs) -> Union[tuple[torch.Tensor, ...], torch.Tensor]:
|
||||
output = self.flash_attn_varlen_func(
|
||||
q=q,
|
||||
k=k,
|
||||
v=v,
|
||||
softmax_scale=softmax_scale,
|
||||
return_lse=return_softmax_lse,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -394,7 +417,7 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
B = q_nope.shape[0]
|
||||
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
o = torch.zeros(B,
|
||||
o = torch.empty(B,
|
||||
self.num_heads,
|
||||
self.kv_lora_rank,
|
||||
dtype=q.dtype,
|
||||
@ -403,6 +426,8 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
|
||||
kv_buffer = kv_c_and_k_pe_cache.unsqueeze(2)
|
||||
|
||||
aiter_mla_decode_fwd(q, kv_buffer, o, self.scale,
|
||||
attn_metadata.qo_indptr,
|
||||
attn_metadata.max_query_len,
|
||||
attn_metadata.paged_kv_indptr,
|
||||
attn_metadata.paged_kv_indices,
|
||||
attn_metadata.paged_kv_last_page_lens)
|
||||
|
@ -20,7 +20,8 @@ def get_aiter_mla_metadata(max_batch_size: int, block_size: int,
|
||||
paged_kv_last_page_lens = torch.full((max_batch_size, ),
|
||||
block_size,
|
||||
dtype=torch.int32)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens
|
||||
qo_indptr = torch.zeros(max_batch_size + 1, dtype=torch.int, device=device)
|
||||
return paged_kv_indices, paged_kv_indptr, paged_kv_last_page_lens, qo_indptr
|
||||
|
||||
|
||||
def aiter_mla_decode_fwd(
|
||||
@ -28,6 +29,8 @@ def aiter_mla_decode_fwd(
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
sm_scale: float,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
@ -38,6 +41,8 @@ def aiter_mla_decode_fwd(
|
||||
kv_buffer.view(
|
||||
-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
max_seqlen_qo,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
@ -49,6 +54,8 @@ def mla_decode_fwd_impl(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
@ -60,9 +67,11 @@ def mla_decode_fwd_impl(
|
||||
mla_decode_fwd(q,
|
||||
kv_buffer.view(-1, 1, 1, q.shape[-1]),
|
||||
o,
|
||||
qo_indptr,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
max_seqlen_qo,
|
||||
sm_scale=sm_scale,
|
||||
logit_cap=logit_cap)
|
||||
|
||||
@ -71,6 +80,8 @@ def mla_decode_fwd_fake(
|
||||
q: torch.Tensor,
|
||||
kv_buffer: torch.Tensor,
|
||||
o: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
max_seqlen_qo: int,
|
||||
kv_indptr: Optional[torch.Tensor] = None,
|
||||
kv_indices: Optional[torch.Tensor] = None,
|
||||
kv_last_page_lens: Optional[torch.Tensor] = None,
|
||||
|
@ -123,10 +123,11 @@ def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
|
||||
|
||||
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
|
||||
sorted_weight_buf, sorted_expert_ids,
|
||||
num_valid_ids, topk, w1_scale.view(local_E, -1),
|
||||
w2_scale.view(local_E, -1),
|
||||
a1_scale.t().contiguous(), *block_shape,
|
||||
smooth_scale)
|
||||
num_valid_ids, topk,
|
||||
a1_scale.t().contiguous(),
|
||||
w1_scale.view(local_E, -1),
|
||||
w2_scale.view(local_E,
|
||||
-1), *block_shape, smooth_scale)
|
||||
|
||||
return out_asm
|
||||
|
||||
|
Reference in New Issue
Block a user