[V1] [Hybrid] Move MiniMaxLinearAttention into layers/mamba (#23831)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
Co-authored-by: Cyrus Leung <tlleungac@connect.ust.hk>
This commit is contained in:
Thomas Parnell
2025-08-30 09:16:15 +02:00
committed by GitHub
parent f1bddbd852
commit 4071c76cf3
2 changed files with 448 additions and 410 deletions

View File

@ -0,0 +1,442 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import math
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
from typing import TYPE_CHECKING
import torch
import torch.distributed
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from vllm import envs
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, get_current_vllm_config
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.lightning_attn import (
lightning_attention, linear_decode_forward_triton)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
import torch
import torch.distributed
from vllm.model_executor.models.minimax_cache import MinimaxCacheParams
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size /
self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]
x = x.to(orig_dtype) * weight
return x
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: Optional[int] = None,
**kwargs) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(q,
k,
v,
slope_rate,
block_size=block_size,
kv_history=kv_history)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
assert self.model_config is not None
assert self.cache_config is not None
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, int, int], ...]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
tp_size=self.tp_size,
head_dim=self.head_dim)
def __init__(
self,
hidden_size: int,
hidden_inner_size: int,
num_heads: int,
head_dim: int,
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.BLOCK = block_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.total_num_heads = num_heads
self.hidden_inner_size = hidden_inner_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.total_num_heads % self.tp_size == 0
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.output_gate = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.output_gate",
)
self.out_proj = RowParallelLinear(
self.hidden_inner_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = MiniMaxText01RMSNormTP(
self.hidden_inner_size,
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (1 - layer_idx /
(num_hidden_layer - 1) + 1e-5)
self.tp_slope = self.slope_rate[self.tp_rank *
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
slopes = torch.tensor(get_slopes(n_attention_heads),
dtype=torch.float32).reshape(
n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
num_actual_tokens = hidden_states.shape[0]
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
assert kv_caches is not None
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
else:
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
output[:num_actual_tokens], _ = self.out_proj(hidden)
def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)
def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -1,45 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only MiniMaxText01 model."""
import math
from collections.abc import Iterable
from itertools import islice
from typing import TYPE_CHECKING, Optional, Union
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionBackend
pass
import regex as re
import torch
import torch.distributed
import torch.nn.functional as F
from einops import rearrange
from torch import nn
from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.lightning_attn import (
lightning_attention, linear_decode_forward_triton)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
MergedColumnParallelLinear,
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.linear_attn import (
MiniMaxText01LinearAttention)
from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
@ -50,10 +42,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
from .interfaces import HasInnerState, IsHybrid
from .minimax_cache import MinimaxCacheManager, MinimaxCacheParams
@ -87,66 +76,6 @@ def weight_loader_with_alias(alias: str):
return wrapper
class MiniMaxText01RMSNormTP(CustomOp):
name = "MiniMaxText01RMSNormTP"
def __init__(self, hidden_size: int, eps: float = 1e-6) -> None:
super().__init__()
self.tp_world = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
self.weight = nn.Parameter(torch.ones(int(hidden_size /
self.tp_world)))
self.weight.weight_loader = self.weight_loader
self.variance_epsilon = eps
return
@staticmethod
def weight_loader(
param: nn.Parameter,
loaded_weight: torch.Tensor,
) -> None:
tp_world = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
shard_size = loaded_weight.shape[0] // tp_world
shard = slice(tp_rank * shard_size, (tp_rank + 1) * shard_size)
param.data.copy_(loaded_weight[shard])
return
def _forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
orig_dtype = x.dtype
x = x.to(torch.float32)
variance = x.pow(2).mean(dim=-1, keepdim=True, dtype=torch.float32)
if self.tp_world > 1:
variance = tensor_model_parallel_all_reduce(
variance) / self.tp_world
x = x * torch.rsqrt(variance + self.variance_epsilon)
weight = self.weight
if x.size(-1) != self.weight.size(0):
if self.weight.size(0) < x.size(-1):
repeat_count = (x.size(-1) + self.weight.size(0)) // x.size(-1)
full_weight = self.weight.repeat(repeat_count)
weight = full_weight[:x.size(-1)]
else:
weight = self.weight[:x.size(-1)]
x = x.to(orig_dtype) * weight
return x
def forward(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
assert residual is None, "RMSNorm does not support residual connection."
return self._forward(x)
class MiniMaxText01MLP(nn.Module):
def __init__(
@ -253,307 +182,6 @@ class MiniMaxText01MoE(nn.Module):
return final_hidden
class MiniMaxText01LinearKernel:
@staticmethod
def jit_linear_forward_prefix(q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
kv_caches: torch.Tensor,
slope_rate: torch.Tensor,
block_size: int,
layer_idx: int = None,
**kwargs) -> torch.Tensor:
slope_rate = slope_rate.to(torch.float32)
should_pad_dim = q.dim() == 3
if should_pad_dim:
q = q.unsqueeze(0)
k = k.unsqueeze(0)
v = v.unsqueeze(0)
b, h, n, d = q.shape
e = d
kv_history = kv_caches.reshape(1, h, d, e).contiguous()
output, kv_history = lightning_attention(q,
k,
v,
slope_rate,
block_size=block_size,
kv_history=kv_history)
kv_caches.copy_(kv_history[:, :, -1, :, :].reshape(h, d, e))
assert output.shape[0] == 1, "batch size must be 1"
return rearrange(output.squeeze(0), "h n d -> n (h d)")
class MiniMaxText01LinearAttention(nn.Module, MambaBase):
@property
def mamba_type(self) -> str:
return "linear_attention"
def get_attn_backend(self) -> type["AttentionBackend"]:
from vllm.v1.attention.backends.linear_attn import (
LinearAttentionBackend)
return LinearAttentionBackend
def get_state_dtype(self) -> tuple[torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
self.model_config.dtype,
self.cache_config.mamba_cache_dtype,
)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=self.num_heads,
tp_size=self.tp_size,
head_dim=self.head_dim)
def __init__(
self,
hidden_size: int,
hidden_inner_size: int,
num_heads: int,
head_dim: int,
max_position: int,
block_size: int,
num_hidden_layer: int,
model_config: Optional[ModelConfig] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
layer_idx: int = 0,
linear_layer_idx: int = 0,
prefix: str = "linear_attn",
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.BLOCK = block_size
self.hidden_size = hidden_size
self.num_heads = num_heads
self.head_dim = head_dim
self.total_num_heads = num_heads
self.hidden_inner_size = hidden_inner_size
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
assert self.total_num_heads % self.tp_size == 0
self.tp_heads = self.total_num_heads // self.tp_size
self.qkv_size = self.num_heads * self.head_dim
self.tp_hidden = self.head_dim * self.tp_heads
self.model_config = model_config
self.cache_config = cache_config
self.prefix = prefix
self.qkv_proj = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size * 3,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.output_gate = ColumnParallelLinear(
hidden_size,
self.hidden_inner_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.output_gate",
)
self.out_proj = RowParallelLinear(
self.hidden_inner_size,
hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.out_proj",
)
self.norm = MiniMaxText01RMSNormTP(
self.hidden_inner_size,
eps=1e-5,
)
slope_rate = MiniMaxText01LinearAttention._build_slope_tensor(
self.num_heads)
if num_hidden_layer <= 1:
self.slope_rate = slope_rate * (1 + 1e-5)
else:
self.slope_rate = slope_rate * (1 - layer_idx /
(num_hidden_layer - 1) + 1e-5)
self.tp_slope = self.slope_rate[self.tp_rank *
self.tp_heads:(self.tp_rank + 1) *
self.tp_heads].contiguous()
if envs.VLLM_USE_V1:
compilation_config = get_current_vllm_config().compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
@staticmethod
def weight_direct_load(param: torch.Tensor,
loaded_weight: torch.Tensor) -> None:
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
return
@staticmethod
def _build_slope_tensor(n_attention_heads: int):
def get_slopes(n):
def get_slopes_power_of_2(n):
start = 2**(-(2**-(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2**math.floor(math.log2(n))
return (get_slopes_power_of_2(closest_power_of_2) + get_slopes(
2 * closest_power_of_2)[0::2][:n - closest_power_of_2])
slopes = torch.tensor(get_slopes(n_attention_heads),
dtype=torch.float32).reshape(
n_attention_heads, 1, 1)
return slopes
def _prefill_and_mix_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
hidden = []
for _prefill_idx in range(getattr(attn_metadata, "num_prefills", 0)):
if _prefill_idx >= len(attn_metadata.query_start_loc):
break
if _prefill_idx >= len(state_indices_tensor):
break
# prefills are packed at end of batch in V1
offset = attn_metadata.num_decode_tokens if envs.VLLM_USE_V1 else 0
_start = attn_metadata.query_start_loc[offset + _prefill_idx]
_end = attn_metadata.query_start_loc[offset + _prefill_idx + 1]
slot_id = state_indices_tensor[offset + _prefill_idx]
qs = q[_start:_end].transpose(0, 1).contiguous()
ks = k[_start:_end].transpose(0, 1).contiguous()
vs = v[_start:_end].transpose(0, 1).contiguous()
slice_layer_cache = kv_cache[slot_id, ...]
out_slice = MiniMaxText01LinearKernel.jit_linear_forward_prefix(
qs,
ks,
vs,
slice_layer_cache,
self.tp_slope,
self.BLOCK,
layer_idx=self.layer_idx)
hidden.append(out_slice.contiguous())
if attn_metadata.num_decode_tokens > 0:
hidden_decode = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
if envs.VLLM_USE_V1:
hidden.insert(0, hidden_decode)
else:
hidden.append(hidden_decode)
if not hidden:
return torch.empty((0, q.size(-1)), device=q.device, dtype=q.dtype)
hidden = torch.concat(hidden, dim=0).contiguous()
return hidden
def _decode_infer(self, q, k, v, kv_cache, state_indices_tensor,
attn_metadata):
if not envs.VLLM_USE_V1:
q = q[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
k = k[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
v = v[attn_metadata.num_prefill_tokens:].unsqueeze(2).contiguous()
num_prefills = getattr(attn_metadata, "num_prefills", 0)
slot_id = state_indices_tensor[num_prefills:]
else:
q = q[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
k = k[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
v = v[:attn_metadata.num_decode_tokens].unsqueeze(2).contiguous()
slot_id = state_indices_tensor[:attn_metadata.num_decodes]
hidden = linear_decode_forward_triton(q, k, v, kv_cache, self.tp_slope,
slot_id, 32)
return hidden
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
num_actual_tokens = hidden_states.shape[0]
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
num_prefills = getattr(attn_metadata, "num_prefills", 0)
if num_prefills > 0:
num_decode_tokens = getattr(attn_metadata,
"num_decode_tokens", 0)
for prefill_idx in range(num_prefills):
q_start = attn_metadata.query_start_loc[
num_decode_tokens + prefill_idx]
q_end = attn_metadata.query_start_loc[num_decode_tokens
+ prefill_idx +
1]
query_len = q_end - q_start
context_len = attn_metadata.seq_lens[
num_decode_tokens + prefill_idx] - query_len
if context_len == 0:
block_to_clear = state_indices_tensor[
num_decode_tokens + prefill_idx]
kv_cache[block_to_clear, ...] = 0
else:
kv_cache = kv_caches.minimax_cache
state_indices_tensor = kv_caches.state_indices_tensor
decode_only = getattr(attn_metadata, "num_prefills", 0) == 0
if attn_metadata is None:
hidden = torch.empty((q.shape[0], q.shape[1] * q.shape[2]),
device=q.device,
dtype=q.dtype)
else:
if not decode_only:
hidden = self._prefill_and_mix_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
else:
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
output[:num_actual_tokens], _ = self.out_proj(hidden)
class MiniMaxText01Attention(nn.Module):
def __init__(
@ -1397,35 +1025,3 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)
def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)
def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)