[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models (#22589)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-08-27 19:05:16 +02:00
committed by GitHub
parent 52883ed084
commit dd58932280
2 changed files with 98 additions and 137 deletions

View File

@ -339,6 +339,7 @@ class CompilationConfig:
"vllm.mamba_mixer2",
"vllm.mamba_mixer",
"vllm.short_conv",
"vllm.linear_attention",
]
def compute_hash(self) -> str:

View File

@ -1,7 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Inference-only MiniMaxText01 model."""
import copy
import math
from collections.abc import Iterable
from typing import TYPE_CHECKING, Optional, Union
@ -19,13 +18,14 @@ from transformers import MiniMaxConfig
from vllm import envs
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
get_current_vllm_config)
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
@ -43,12 +43,15 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
MambaStateDtypeCalculator, MambaStateShapeCalculator)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.utils import maybe_prefix
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
from .interfaces import HasInnerState, IsHybrid
@ -143,61 +146,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
return self._forward(x)
class MiniMaxText01RotaryEmbedding(CustomOp):
name = "MiniMaxText01RotaryEmbedding"
def __init__(
self,
head_size: int,
rotary_dim: int,
max_position: int,
base: float,
is_neox_style: bool,
cache_dtype: torch.dtype,
) -> None:
super().__init__()
self.head_size = head_size
self.rotary_dim = rotary_dim
self.max_position_embeddings = max_position
self.base = base
self.is_neox_style = is_neox_style
self.cache_dtype = cache_dtype
cache = self._compute_cos_sin_cache().to(cache_dtype)
self.register_buffer("cos_sin_cache", cache, persistent=False)
def _compute_inv_freq(self, base: float) -> torch.Tensor:
"""Compute the inverse frequency."""
inv_freq = 1.0 / (base**(torch.arange(
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
return inv_freq
def _compute_cos_sin_cache(self) -> torch.Tensor:
"""Compute the cos and sin cache."""
inv_freq = self._compute_inv_freq(self.base)
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos()
sin = freqs.sin()
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
query_cast = query.to(self.cache_dtype)
key_cast = key.to(self.cache_dtype)
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
self.cos_sin_cache, self.is_neox_style)
query = query_cast.to(query.dtype)
key = key_cast.to(key.dtype)
return query, key
class MiniMaxText01MLP(nn.Module):
def __init__(
@ -526,20 +474,40 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
slot_id, 32)
return hidden
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: MinimaxCacheParams) -> None:
if not envs.VLLM_USE_V1:
self._forward(hidden_states, output, positions, kv_caches)
else:
torch.ops.vllm.linear_attention(
hidden_states,
output,
positions,
self.prefix,
)
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor,
kv_caches: Optional[MinimaxCacheParams]) -> None:
forward_context = get_forward_context()
attn_metadata: AttentionMetadata = forward_context.attn_metadata
if envs.VLLM_USE_V1 and attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
num_actual_tokens = attn_metadata.num_prefill_tokens + \
attn_metadata.num_decode_tokens
else:
num_actual_tokens = hidden_states.shape[0]
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
qkv32 = qkv.to(torch.float32)
qkvact = torch.nn.functional.silu(qkv32)
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
if envs.VLLM_USE_V1:
if attn_metadata is not None:
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
assert isinstance(attn_metadata, LinearAttentionMetadata)
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
state_indices_tensor = attn_metadata.state_indices_tensor
@ -578,13 +546,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
hidden = self._decode_infer(q, k, v, kv_cache,
state_indices_tensor,
attn_metadata)
hidden = self.norm._forward(hidden)
gate, _ = self.output_gate(hidden_states)
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
hidden = F.sigmoid(gate) * hidden
hidden = hidden.to(hidden_states.dtype)
hidden, _ = self.out_proj(hidden)
return hidden
output[:num_actual_tokens], _ = self.out_proj(hidden)
class MiniMaxText01Attention(nn.Module):
@ -652,23 +618,23 @@ class MiniMaxText01Attention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn",
)
self.rotary_emb = get_rope(
head_size=self.head_dim,
rotary_dim=rotary_dim,
max_position=max_position,
base=int(rope_theta),
is_neox_style=True,
dtype=torch.float32,
)
return
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
**kwargs) -> torch.Tensor:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
positions: torch.Tensor, **kwargs) -> None:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if envs.VLLM_USE_V1:
if attn_metadata is not None:
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
positions, q, k)
else:
q, k = attn_metadata.rotary_emb(positions, q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)
return output
output[:], _ = self.o_proj(attn_output)
class MiniMaxText01DecoderLayer(nn.Module):
@ -816,16 +782,15 @@ class MiniMaxText01DecoderLayer(nn.Module):
is_warmup: bool = False,
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
forward_context = get_forward_context()
attn_metadata = forward_context.attn_metadata
layernorm_input = hidden_states
layernorm_output = self.input_layernorm(layernorm_input)
residual = layernorm_output if self.postnorm else layernorm_input
self_attention_output = self.self_attn(
self_attention_output = torch.empty_like(layernorm_output)
self.self_attn(
hidden_states=layernorm_output,
output=self_attention_output,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
)
residual = residual * self.layernorm_attention_alpha
@ -839,8 +804,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
if self.expert_num == 1:
hidden_states = self.mlp(layernorm_output)
else:
moe_hidden_states = self.block_sparse_moe(
copy.deepcopy(layernorm_output))
moe_layernorm_output = layernorm_output.clone()
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
if self.shared_moe:
before_moe_dtype = layernorm_output.dtype
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
@ -878,18 +843,16 @@ class MiniMaxText01DecoderLayer(nn.Module):
return
@support_torch_compile
class MiniMaxText01Model(nn.Module):
def __init__(
self,
config: MiniMaxConfig,
model_config: Optional[ModelConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
cache_config: Optional[CacheConfig] = None,
scheduler_config=None,
prefix: str = "",
) -> None:
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config: MiniMaxConfig = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
quant_config = vllm_config.quant_config
cache_config = vllm_config.cache_config
scheduler_config = vllm_config.scheduler_config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
@ -976,24 +939,6 @@ class MiniMaxText01Model(nn.Module):
self.minimax_cache = MinimaxCacheManager(
dtype=torch.float32, cache_shape=self.cache_shape)
rope_theta = getattr(config, "rope_theta", 10000)
head_dim = getattr(config, "head_dim", None)
if head_dim is None:
head_dim = config.hidden_size // config.num_attention_heads
if hasattr(config, "max_model_len") and isinstance(
config.max_model_len, int):
max_position_embeddings = min(config.max_position_embeddings,
config.max_model_len)
self.rotary_emb = MiniMaxText01RotaryEmbedding(
head_dim,
rotary_dim=config.rotary_dim
if hasattr(config, "rotary_dim") else head_dim,
max_position=max_position_embeddings,
base=int(rope_theta),
is_neox_style=True,
cache_dtype=torch.float32,
)
norm_kwargs = {}
if hasattr(config, "rms_norm_eps"):
norm_kwargs["eps"] = config.rms_norm_eps
@ -1043,12 +988,11 @@ class MiniMaxText01Model(nn.Module):
attn_metadata = forward_context.attn_metadata
if not envs.VLLM_USE_V1 and attn_metadata is None:
return None
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
if not envs.VLLM_USE_V1:
if "request_ids_to_seq_ids" not in kwargs:
kwargs["request_ids_to_seq_ids"] = {}
if "finished_requests_ids" not in kwargs:
kwargs["finished_requests_ids"] = []
(
minimax_cache_tensors,
state_indices_tensor,
@ -1077,16 +1021,6 @@ class MiniMaxText01Model(nn.Module):
for i in range(self.start_layer, self.end_layer):
layer = self.layers[i]
if attn_metadata is not None:
# TODO (tdoublep): this whole thing with the rotary_emb is
# weird. we shouldn't be passing it via attn_metadata imo.
if envs.VLLM_USE_V1:
if isinstance(layer.self_attn, MiniMaxText01Attention):
attn_metadata[layer.prefix +
".attn"].rotary_emb = self.rotary_emb
else:
attn_metadata.rotary_emb = self.rotary_emb
_caches = None
if not envs.VLLM_USE_V1 and isinstance(
layer.self_attn, MiniMaxText01LinearAttention):
@ -1120,7 +1054,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config
self.config = config
self.lora_config = lora_config
@ -1133,13 +1066,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
self.unpadded_vocab_size = self.config.vocab_size
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
self.config,
model_config=vllm_config.model_config,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
scheduler_config=vllm_config.scheduler_config,
prefix=maybe_prefix(prefix, "model"))
self.model = MiniMaxText01Model(vllm_config=vllm_config,
prefix=maybe_prefix(prefix, "model"))
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
self.unpadded_vocab_size,
@ -1469,3 +1397,35 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)
def linear_attention(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self._forward(hidden_states=hidden_states,
output=output,
positions=positions,
kv_caches=None)
def linear_attention_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
positions: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="linear_attention",
op_func=linear_attention,
mutates_args=["output"],
fake_impl=linear_attention_fake,
dispatch_key=current_platform.dispatch_key,
)