# SPDX-License-Identifier: Apache-2.0 import functools from abc import abstractmethod from dataclasses import dataclass from typing import Any, Dict, Generic, List, Optional, Tuple import torch from compressed_tensors.quantization import QuantizationStrategy from vllm import _custom_ops as ops from vllm import envs from vllm.attention.backends.abstract import (AttentionLayer, AttentionMetadata, MLAAttentionImpl, T) from vllm.attention.backends.utils import get_flash_attn_version from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.linear import (ColumnParallelLinear, LinearBase, RowParallelLinear, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 CompressedTensorsLinearMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( CompressedTensorsW8A8Fp8) from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod from vllm.model_executor.layers.quantization.utils.fp8_utils import ( apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8) from vllm.model_executor.layers.quantization.utils.quant_utils import ( scaled_quantize) from vllm.model_executor.layers.rotary_embedding import ( DeepseekScalingRotaryEmbedding, RotaryEmbedding) try: from vllm.vllm_flash_attn import flash_attn_varlen_func except ImportError: from flash_attn import flash_attn_varlen_func @dataclass class MLACommonMetadata(AttentionMetadata): # Input positions for rotrary embeddings since for MLA the rotary # position embeddings are applied inside the attention backend input_positions: torch.Tensor class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): """ Common class for implementing repeated parts Main reference: DeepseekV2 paper, and FlashInfer Implementation (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). Deepseek's MLA attention works the following way: * Use a single latent vector to represent the entire KV cache. * The attention "simulates" a multi-head attention, while the compute is similar to multi-query attention. * The dataflow is as follows, * B: batch/sequence length * H: hidden size * N: number of attention heads * Lq: latent dimension for Q * Lkv: latent dimension for K/V * P: nope dimension, P+R is the actual head_dim in common attention. * R: rope dimension, this slide of the head_dim goes through rope. * V: V head dim. * kv_c: latent/compressed KV * q_c: latent/compressed Q # # Outside the MLA attention backend # 1. The hidden states (B, H) are projected down into cq (B, Lq) and kv_c_k_pe (B, Lkv+R). 2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq and kv_c are normalized. # # Inside the MLA attention backend # * if prefill: 3. The q_c is then projected up into the multi-head version. * q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope (B, N, P) and q_pe (B, N, R). 4. q_pe, k_pe are then passed through rotary embeddings. 5. kv_c and k_pe are concatenated and inserted into the cache 6. The kv_c is then projected up into the multi-head version. * kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope dimensions for K and V, which is split into k_nope (B, N, P) and v (B, N, V). 7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from q_nope, q_pe, k_nope, k_pe. 8. Attention is computued with q, k, v. 9. The attention computation returns (B, N, V), which is projected back to (B, H) using out projection. * if decode: 3. Here's the change, we do not perform up the full up projection for q_c, and there is no up projection at all for kv_c. This is achieved by the technique of "weight absorption". The paper says "Fortunately, due to the associative law of matrix multiplication, we can absorb WUK into WUQ, and WUV into WO" * The q up projection turns (B, Lq) into (B, N, (P+R)), we split it into W_UQ (Lq, N, P) and W_QR (Lq, N, R). * The kv_c up projection turns (B, Lkv) into (B, N, (P+V)), we split it into W_UK (Lkv, N, P) and W_UV (Lkv, N, V). * The out projection shape W_O (N*V, H) turns (B, N, V) into (B, H). * We can precompute the product of W_UQ and W_UK into W_UQ_UK (Lq, N, Lkv), which is possible due to QK^T operation in attention. * We can precompute the product of W_UV and W_O into W_UV_O (N, Lkv, H), which is possible due to V@O as the "epilogue" of attention 4. We still need to compute q_pe (B, N, R) by applying W_QR to q_latent. 5. q_pe, k_pe are then passed through rotary embeddings. 6. kv_c and k_pe are concatenated and inserted into the cache 7. By applying W_UQ_UK to q_latent, we have the new q_nope of shape (B, N, Lkv). 8. q (B, N, (Lkv+R)), k (B, (Lkv+R)) are assembled from q_nope, q_pe, kv_a, k_pe. v (B, Lkv) is exactly the same vector as kv_a. 9. The attention is computed with q, k, v. Note that we just performed a MQA attention with (LKv+R) as our head dim. 10. The KV cache is updated using the new entries k (B, N, (Lkv+R)), which included the v and rope values. 11. The attention computation returns (B, N, Lkv), which is projected back to (B, H) using W_UV_O. From @tsu-bin's calculation, we only want to use the absorption technique for decode. The prefill algorithm should still use the up-projected MHA for less flops and memory usage. """ def __init__( self, num_heads: int, head_size: int, scale: float, num_kv_heads: int, alibi_slopes: Optional[List[float]], sliding_window: Optional[int], kv_cache_dtype: str, blocksparse_params: Optional[Dict[str, Any]], logits_soft_cap: Optional[float], attn_type: str, # MLA Specific Arguments q_lora_rank: Optional[int], kv_lora_rank: int, qk_nope_head_dim: int, qk_rope_head_dim: int, qk_head_dim: int, v_head_dim: int, rotary_emb: RotaryEmbedding, # q_proj should be q_b_proj if q_lora_rank is not None, but from an # attention backend perspective we rely on the layer to pass in the # correct matrix q_proj: ColumnParallelLinear, kv_b_proj: ColumnParallelLinear, o_proj: RowParallelLinear, ) -> None: self.num_heads = num_heads self.head_size = head_size self.scale = float(scale) self.num_kv_heads = num_kv_heads self.kv_cache_dtype = kv_cache_dtype self.q_lora_rank = q_lora_rank self.kv_lora_rank = kv_lora_rank self.qk_nope_head_dim = qk_nope_head_dim self.qk_rope_head_dim = qk_rope_head_dim self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.rotary_emb = rotary_emb self.use_yarn_rope = isinstance(rotary_emb, DeepseekScalingRotaryEmbedding) self.q_proj = q_proj self.kv_b_proj = kv_b_proj self.o_proj = o_proj self.vllm_flash_attn_version = get_flash_attn_version() # Handle the differences between the flash_attn_varlen from flash_attn # and the one from vllm_flash_attn. The former is used on RoCM and the # latter has an additional parameter to control FA2 vs FA3 self.flash_attn_varlen_func = flash_attn_varlen_func if self.vllm_flash_attn_version is not None: self.flash_attn_varlen_func = \ functools.partial(flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version) def _v_up_proj_and_o_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_UV_O): output_parallel = apply_fp8_linear_generic( x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, self.reqaunt_input_group_shape, self.reqaunt_weight_group_shape) else: output_parallel = torch.matmul(x.flatten(start_dim=1), self.W_UV_O) if self.tp_size > 1: output = tensor_model_parallel_all_reduce(output_parallel) else: output = output_parallel return output else: x = torch.einsum("bnl,lnv->bnv", x, self.W_UV) return self.o_proj(x.reshape(-1, self.num_heads * self.v_head_dim))[0] def _q_proj_and_k_up_proj(self, x): if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: if is_fp8(self.W_Q_UK): return apply_fp8_linear_generic( x, self.W_Q_UK, self.W_Q_UK_scales, self.reqaunt_input_group_shape, self.reqaunt_weight_group_shape).view( -1, self.num_heads, self.kv_lora_rank) return torch.matmul(x, self.W_Q_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) else: x = torch.matmul(x, self.W_Q)\ .view(-1, self.num_heads, self.qk_nope_head_dim) return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\ .view(-1, self.num_heads, self.kv_lora_rank) def process_weights_after_loading(self, act_dtype: torch.dtype): # TODO(lucas) This is very gross, we need a more wide scale refactor of # all the FP8 code with a more standard way of # defining schemes/group-shapes, we should also potentially force # quant_methods to support a decompress function # # returns input_group_shape, weight_group_shape def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \ Tuple[Tuple[int, int], Tuple[int, int]]: if isinstance(layer.quant_method, Fp8LinearMethod): if layer.quant_method.block_quant: weight_block_size = \ layer.quant_method.quant_config.weight_block_size # per-token-group (1, X), block-quantized (X, Y) return (1, weight_block_size[-1]), weight_block_size else: return (-1, -1), (-1, -1) # per-tensor, per-tensor elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\ and isinstance(layer.scheme, CompressedTensorsW8A8Fp8): # this is hacky but we always assume the for # CompressedTensorsW8A8Fp8 the input is dynamic per-token # we ignore if it is static-per-tensor since we are going to # requantize after later anyways strategy = layer.scheme.strategy if strategy == QuantizationStrategy.TENSOR: return (1, -1), (-1, -1) # per-token, per-tensor elif strategy == QuantizationStrategy.CHANNEL: return (1, -1), (-1, 1) # per-token, per-channel else: raise NotImplementedError( f"QuantizationStrategy.{strategy} is not supported for " "fp8 MLA, please run with VLLM_MLA_DISABLE=1") else: raise NotImplementedError( "Can't determine scale group shapes for " f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1" ) def get_layer_weight(layer): if hasattr(layer, "weight"): return layer.weight elif hasattr(layer, "qweight"): return layer.qweight else: raise AttributeError( f"Layer '{layer}' has neither weight nor qweight") def get_and_maybe_dequant_weights(layer: LinearBase): if not isinstance(layer.quant_method, UnquantizedLinearMethod): # NOTE: This should only be used offline, since it's O(N^3) eye = torch.eye(layer.input_size_per_partition, dtype=act_dtype, device=get_layer_weight(layer).device) dequant_weights = layer.quant_method.apply(layer, eye, bias=None) del eye # standardize to (output, input) return dequant_weights.T return layer.weight weight_dtype = get_layer_weight(self.kv_b_proj).dtype assert get_layer_weight(self.o_proj).dtype == weight_dtype assert get_layer_weight(self.q_proj).dtype == weight_dtype kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T assert kv_b_proj_weight.shape == ( self.kv_lora_rank, self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( f"{kv_b_proj_weight.shape=}, " f"{self.kv_lora_rank=}, " f"{self.num_heads=}, " f"{self.qk_nope_head_dim=}, " f"{self.v_head_dim=}") kv_b_proj_weight = kv_b_proj_weight.view( self.kv_lora_rank, self.num_heads, self.qk_nope_head_dim + self.v_head_dim, ) W_UK, W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ .view(-1, self.num_heads, self.qk_head_dim) # can be W_Q or W_UQ depending q_lora_rank, the former if # q_lora_rank is None, the latter otherwise. From the Attention backend # perspective though we call these both W_Q and rely on the layer # to pass in the correct matrix W_Q = q_proj_weight[..., :self.qk_nope_head_dim] self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\ .flatten(start_dim=1).contiguous() # W_QR is small so for simplicity we dont bother requantizing it self.W_QR = self.W_QR.to(act_dtype) if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION if is_fp8(weight_dtype) and requantization_enabled: # This assumes it wise to requantize using the same group shapes # (i.e. strategy, per-tensor, per-channel, block etc.) that the # weights were originally quantized requant_input_group_shape, requant_weight_group_shape = \ get_scale_group_shapes_for_fp8(self.q_proj) assert (requant_input_group_shape, requant_weight_group_shape)\ == get_scale_group_shapes_for_fp8(self.kv_b_proj) assert (requant_input_group_shape, requant_weight_group_shape)\ == get_scale_group_shapes_for_fp8(self.o_proj) self.reqaunt_input_group_shape = requant_input_group_shape self.reqaunt_weight_group_shape = requant_weight_group_shape # # Perform matrix-absorption following # https://github.com/flashinfer-ai/flashinfer/pull/551 # for decode, as a result we end up with absorbed weights for decode # and another copy of raw weights for prefill. # self.W_UK, self.W_UV = kv_b_proj_weight.split( [self.qk_nope_head_dim, self.v_head_dim], dim=-1) # We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK # depending q_lora_rank, the former if q_lora_rank is None, the # latter otherwise # basically if q_lora_rank is none we are absorbing into q_proj # instead of UQ W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\ .flatten(start_dim=1).contiguous() if is_fp8(weight_dtype) and requantization_enabled: W_Q_UK, W_Q_UK_scales = scaled_quantize( W_Q_UK, self.reqaunt_weight_group_shape, quant_dtype=current_platform_fp8_dtype) # For FP8 save the transpose so we can use # `apply_w8a8_block_fp8_linear` directly self.W_Q_UK = W_Q_UK.T.contiguous() self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous() else: self.W_Q_UK = W_Q_UK.to(act_dtype) W_O = get_and_maybe_dequant_weights(self.o_proj)\ .view(-1, self.num_heads, self.v_head_dim) W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\ .flatten(start_dim=0, end_dim=1).contiguous() if is_fp8(weight_dtype) and requantization_enabled: W_UV_O, W_UV_O_scales = scaled_quantize( W_UV_O, self.reqaunt_weight_group_shape, quant_dtype=current_platform_fp8_dtype) # For FP8 save the transpose so we can use # `apply_w8a8_block_fp8_linear` directly self.W_UV_O = W_UV_O.T.contiguous() self.W_UV_O_scales = W_UV_O_scales.T.contiguous() else: self.W_UV_O = W_UV_O.to(act_dtype) self.tp_size = get_tensor_model_parallel_world_size() else: if is_fp8(weight_dtype): raise NotImplementedError( "Currently fp8 requires matrix absorption") self.W_UV = W_UV self.W_UK = W_UK self.W_Q = W_Q.flatten(start_dim=1) @abstractmethod def _forward_prefill( self, q: torch.Tensor, kv_c_normed: torch.Tensor, k_pe: torch.Tensor, attn_metadata: T, ) -> torch.Tensor: raise NotImplementedError @abstractmethod def _forward_decode( self, q_nope: torch.Tensor, q_pe: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: T, ) -> torch.Tensor: raise NotImplementedError def forward( self, layer: AttentionLayer, hidden_states_or_q_c: torch.Tensor, # query in unified attn k_c_normed: torch.Tensor, # key in unified attn k_pe: torch.Tensor, # value in unified attn kv_cache: torch.Tensor, attn_metadata: T, output: Optional[torch.Tensor] = None, ) -> torch.Tensor: if output is not None: raise NotImplementedError( "output is not yet supported for MLAImplBase") is_decode = attn_metadata.decode_metadata is not None is_prefill = attn_metadata.prefill_metadata is not None if (is_decode and is_prefill): raise NotImplementedError( "chunked prefill is not supported for MLAImplBase") # Restore head dim (for rotary embedding) k_pe = k_pe.unsqueeze(1) assert hasattr(attn_metadata, "input_positions") if is_decode: q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c) q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\ .view(-1, self.num_heads, self.qk_rope_head_dim) q_pe, k_pe = self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe) else: assert is_prefill q = self.q_proj(hidden_states_or_q_c)[0]\ .view(-1, self.num_heads, self.qk_head_dim) # TODO(lucas): there must be a nicer way to write this line q[..., self.qk_nope_head_dim:], k_pe = \ self.rotary_emb( attn_metadata.input_positions, q[..., self.qk_nope_head_dim:], k_pe) # write the latent and rope to kv cache if kv_cache.numel() > 0: ops.concat_and_cache_mla( k_c_normed, k_pe.squeeze(1), kv_cache, attn_metadata.slot_mapping.flatten(), kv_cache_dtype=self.kv_cache_dtype, scale=layer._k_scale, ) if attn_metadata.prefill_metadata is not None: return self._forward_prefill(q, k_c_normed, k_pe, attn_metadata) if attn_metadata.decode_metadata is not None: return self._forward_decode(q_nope, q_pe, kv_cache, attn_metadata) # Optional common flash-attn based prefill def _forward_prefill_flash( self, q: torch.Tensor, k_c_normed: torch.Tensor, k_pe: torch.Tensor, seq_start_loc: torch.Tensor, max_prefill_seq_len: int, ) -> torch.Tensor: kv_nope = self.kv_b_proj(k_c_normed)[0]\ .view(-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = kv_nope\ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1) # For MLA the v head dim is smaller than qk head dim so we pad out # v with 0s to match the qk head dim v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]], value=0) attn_output = self.flash_attn_varlen_func( q=q, k=k, v=v_padded, cu_seqlens_q=seq_start_loc, cu_seqlens_k=seq_start_loc, max_seqlen_q=max_prefill_seq_len, max_seqlen_k=max_prefill_seq_len, softmax_scale=self.scale, causal=True, ) attn_output = attn_output\ .view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\ .reshape(-1, self.num_heads * v.shape[-1]) return self.o_proj(attn_output)[0]