[Perf] Improve MLA on V1 (#14540)

Signed-off-by: simon-mo <simon.mo@hey.com>
This commit is contained in:
Simon Mo
2025-03-10 12:06:58 -07:00
committed by GitHub
parent 92b0ce2ac7
commit fb0acb6c72

View File

@ -223,6 +223,7 @@ from vllm.model_executor.layers.quantization.utils.fp8_utils import (
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down
try:
@ -471,18 +472,23 @@ class MLACommonMetadataBuilder(Generic[M]):
common_prefix_len: int) -> M:
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
# function. We should avoid GPU -> CPU sync as much as possible because
# it blocks on all previous kernels.
device = self.runner.device
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
device, non_blocking=True)
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
seq_lens_cpu = self.runner.seq_lens_cpu[:num_reqs]
seq_lens = seq_lens_cpu.to(device, non_blocking=True)
max_query_len = seq_lens_cpu.max().item()
prefill_metadata = None
if self._num_prefills > 0:
reqs_start = self._num_decodes # prefill_start
@ -490,24 +496,22 @@ class MLACommonMetadataBuilder(Generic[M]):
context_lens_cpu = self.runner.input_batch.\
num_computed_tokens_cpu_tensor[reqs_start:num_reqs]
context_lens = context_lens_cpu.to(device, non_blocking=True)
max_context_len_cpu = context_lens_cpu.max().item()
num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
chunked_context_metadata = None
if self.chunked_prefill_enabled and self._num_prefills > 0 \
and context_lens.max() > 0:
and max_context_len_cpu > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section
# in the comment at the top of the file before trying to
# understand the following code
num_prefills_with_context = (context_lens > 0).sum().item()
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = \
self.chunked_prefill_workspace_size \
// num_prefills_with_context
max_context_chunk = (self.chunked_prefill_workspace_size //
num_prefills_with_context_cpu)
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
@ -516,30 +520,35 @@ class MLACommonMetadataBuilder(Generic[M]):
self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens.max(), max_context_chunk)
num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks
# like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
# Note(simon): this is done in CPU because of downstream's
# of `to_list`.
chunk_starts = \
torch.arange(num_chunks, device=device, dtype=torch.int32) \
torch.arange(num_chunks, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens.unsqueeze(0),
chunk_ends = torch.min(context_lens_cpu.unsqueeze(0),
chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
torch.int32)
zero = torch.zeros(num_chunks,
dtype=torch.int32,
device=device).unsqueeze(-1)
cu_seq_lens_cpu = torch.zeros(num_chunks,
self._num_prefills + 1,
dtype=torch.int32,
pin_memory=True)
torch.cumsum(chunk_seq_lens,
dim=1,
out=cu_seq_lens_cpu[:, 1:],
dtype=torch.int32)
chunked_context_metadata = \
MLACommonPrefillMetadata.ChunkedContextMetadata(
cu_seq_lens=torch.cat(
[zero, _chunk_cu_seq_lens], dim=1),
starts=chunk_starts,
cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
starts=chunk_starts.to(device, non_blocking=True),
seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
workspace=self.chunked_prefill_workspace,
@ -553,7 +562,7 @@ class MLACommonMetadataBuilder(Generic[M]):
block_table=block_table[reqs_start:, ...],
query_start_loc=query_start_loc[reqs_start:] -
query_start_loc[reqs_start],
max_query_len=seq_lens[reqs_start:].max().item(),
max_query_len=max_query_len,
chunked_context=chunked_context_metadata,
)
@ -629,7 +638,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly
# TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb._forward_method
self.rotary_emb = rotary_emb.forward_native
if current_platform.is_cuda():
self.rotary_emb = rotary_emb.forward_cuda
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
@ -1043,17 +1054,20 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe, decode_k_pe)
attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe)
if has_prefill:
assert attn_metadata.prefill is not None
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
attn_metadata.prefill.input_positions, prefill_q_pe,
prefill_k_pe)
attn_metadata.prefill.input_positions,
prefill_q_pe.contiguous(), prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0: