[V1] [Hybrid] Enable compile and piecewise CUDA graph for MiniMax-Text models (#22589)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@ -339,6 +339,7 @@ class CompilationConfig:
|
||||
"vllm.mamba_mixer2",
|
||||
"vllm.mamba_mixer",
|
||||
"vllm.short_conv",
|
||||
"vllm.linear_attention",
|
||||
]
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
|
@ -1,7 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Inference-only MiniMaxText01 model."""
|
||||
import copy
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
@ -19,13 +18,14 @@ from transformers import MiniMaxConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import (CacheConfig, ModelConfig, VllmConfig,
|
||||
get_current_vllm_config)
|
||||
from vllm.distributed.communication_op import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
@ -43,12 +43,15 @@ from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
MambaStateDtypeCalculator, MambaStateShapeCalculator)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.linear_attn import LinearAttentionMetadata
|
||||
|
||||
from .interfaces import HasInnerState, IsHybrid
|
||||
@ -143,61 +146,6 @@ class MiniMaxText01RMSNormTP(CustomOp):
|
||||
return self._forward(x)
|
||||
|
||||
|
||||
class MiniMaxText01RotaryEmbedding(CustomOp):
|
||||
name = "MiniMaxText01RotaryEmbedding"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
max_position: int,
|
||||
base: float,
|
||||
is_neox_style: bool,
|
||||
cache_dtype: torch.dtype,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.head_size = head_size
|
||||
self.rotary_dim = rotary_dim
|
||||
self.max_position_embeddings = max_position
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.cache_dtype = cache_dtype
|
||||
cache = self._compute_cos_sin_cache().to(cache_dtype)
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
def _compute_inv_freq(self, base: float) -> torch.Tensor:
|
||||
"""Compute the inverse frequency."""
|
||||
inv_freq = 1.0 / (base**(torch.arange(
|
||||
0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim))
|
||||
return inv_freq
|
||||
|
||||
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
||||
"""Compute the cos and sin cache."""
|
||||
inv_freq = self._compute_inv_freq(self.base)
|
||||
t = torch.arange(self.max_position_embeddings, dtype=torch.float)
|
||||
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
||||
cos = freqs.cos()
|
||||
sin = freqs.sin()
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
from vllm import _custom_ops as ops
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device)
|
||||
query_cast = query.to(self.cache_dtype)
|
||||
key_cast = key.to(self.cache_dtype)
|
||||
ops.rotary_embedding(positions, query_cast, key_cast, self.head_size,
|
||||
self.cos_sin_cache, self.is_neox_style)
|
||||
query = query_cast.to(query.dtype)
|
||||
key = key_cast.to(key.dtype)
|
||||
return query, key
|
||||
|
||||
|
||||
class MiniMaxText01MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@ -526,20 +474,40 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
slot_id, 32)
|
||||
return hidden
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
|
||||
kv_caches: MinimaxCacheParams, **kwargs) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: MinimaxCacheParams) -> None:
|
||||
if not envs.VLLM_USE_V1:
|
||||
self._forward(hidden_states, output, positions, kv_caches)
|
||||
else:
|
||||
torch.ops.vllm.linear_attention(
|
||||
hidden_states,
|
||||
output,
|
||||
positions,
|
||||
self.prefix,
|
||||
)
|
||||
|
||||
def _forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: Optional[MinimaxCacheParams]) -> None:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata: AttentionMetadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1 and attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
num_actual_tokens = attn_metadata.num_prefill_tokens + \
|
||||
attn_metadata.num_decode_tokens
|
||||
else:
|
||||
num_actual_tokens = hidden_states.shape[0]
|
||||
|
||||
qkv, _ = self.qkv_proj(hidden_states[:num_actual_tokens])
|
||||
qkv32 = qkv.to(torch.float32)
|
||||
qkvact = torch.nn.functional.silu(qkv32)
|
||||
qkvact = qkvact.view((qkv.shape[0], self.tp_heads, -1))
|
||||
q, k, v = torch.split(qkvact, [self.head_dim] * 3, dim=-1)
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
assert isinstance(attn_metadata, LinearAttentionMetadata)
|
||||
kv_cache = self.kv_cache[forward_context.virtual_engine][0]
|
||||
state_indices_tensor = attn_metadata.state_indices_tensor
|
||||
|
||||
@ -578,13 +546,11 @@ class MiniMaxText01LinearAttention(nn.Module, MambaBase):
|
||||
hidden = self._decode_infer(q, k, v, kv_cache,
|
||||
state_indices_tensor,
|
||||
attn_metadata)
|
||||
|
||||
hidden = self.norm._forward(hidden)
|
||||
gate, _ = self.output_gate(hidden_states)
|
||||
gate, _ = self.output_gate(hidden_states[:num_actual_tokens])
|
||||
hidden = F.sigmoid(gate) * hidden
|
||||
hidden = hidden.to(hidden_states.dtype)
|
||||
hidden, _ = self.out_proj(hidden)
|
||||
return hidden
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden)
|
||||
|
||||
|
||||
class MiniMaxText01Attention(nn.Module):
|
||||
@ -652,23 +618,23 @@ class MiniMaxText01Attention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn",
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
head_size=self.head_dim,
|
||||
rotary_dim=rotary_dim,
|
||||
max_position=max_position,
|
||||
base=int(rope_theta),
|
||||
is_neox_style=True,
|
||||
dtype=torch.float32,
|
||||
)
|
||||
return
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, positions: torch.Tensor,
|
||||
**kwargs) -> torch.Tensor:
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
def forward(self, hidden_states: torch.Tensor, output: torch.Tensor,
|
||||
positions: torch.Tensor, **kwargs) -> None:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if envs.VLLM_USE_V1:
|
||||
if attn_metadata is not None:
|
||||
q, k = attn_metadata[f"{self.prefix}.attn"].rotary_emb(
|
||||
positions, q, k)
|
||||
else:
|
||||
q, k = attn_metadata.rotary_emb(positions, q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
output[:], _ = self.o_proj(attn_output)
|
||||
|
||||
|
||||
class MiniMaxText01DecoderLayer(nn.Module):
|
||||
@ -816,16 +782,15 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
is_warmup: bool = False,
|
||||
**kwargs) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
forward_context = get_forward_context()
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
layernorm_input = hidden_states
|
||||
layernorm_output = self.input_layernorm(layernorm_input)
|
||||
residual = layernorm_output if self.postnorm else layernorm_input
|
||||
self_attention_output = self.self_attn(
|
||||
self_attention_output = torch.empty_like(layernorm_output)
|
||||
self.self_attn(
|
||||
hidden_states=layernorm_output,
|
||||
output=self_attention_output,
|
||||
positions=positions,
|
||||
kv_caches=kv_caches,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
residual = residual * self.layernorm_attention_alpha
|
||||
@ -839,8 +804,8 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
if self.expert_num == 1:
|
||||
hidden_states = self.mlp(layernorm_output)
|
||||
else:
|
||||
moe_hidden_states = self.block_sparse_moe(
|
||||
copy.deepcopy(layernorm_output))
|
||||
moe_layernorm_output = layernorm_output.clone()
|
||||
moe_hidden_states = self.block_sparse_moe(moe_layernorm_output)
|
||||
if self.shared_moe:
|
||||
before_moe_dtype = layernorm_output.dtype
|
||||
moe_hidden_fp32 = moe_hidden_states.to(torch.float32)
|
||||
@ -878,18 +843,16 @@ class MiniMaxText01DecoderLayer(nn.Module):
|
||||
return
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class MiniMaxText01Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MiniMaxConfig,
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
scheduler_config=None,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config: MiniMaxConfig = vllm_config.model_config.hf_config
|
||||
model_config = vllm_config.model_config
|
||||
quant_config = vllm_config.quant_config
|
||||
cache_config = vllm_config.cache_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
@ -976,24 +939,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
self.minimax_cache = MinimaxCacheManager(
|
||||
dtype=torch.float32, cache_shape=self.cache_shape)
|
||||
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
head_dim = getattr(config, "head_dim", None)
|
||||
if head_dim is None:
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
if hasattr(config, "max_model_len") and isinstance(
|
||||
config.max_model_len, int):
|
||||
max_position_embeddings = min(config.max_position_embeddings,
|
||||
config.max_model_len)
|
||||
self.rotary_emb = MiniMaxText01RotaryEmbedding(
|
||||
head_dim,
|
||||
rotary_dim=config.rotary_dim
|
||||
if hasattr(config, "rotary_dim") else head_dim,
|
||||
max_position=max_position_embeddings,
|
||||
base=int(rope_theta),
|
||||
is_neox_style=True,
|
||||
cache_dtype=torch.float32,
|
||||
)
|
||||
|
||||
norm_kwargs = {}
|
||||
if hasattr(config, "rms_norm_eps"):
|
||||
norm_kwargs["eps"] = config.rms_norm_eps
|
||||
@ -1043,12 +988,11 @@ class MiniMaxText01Model(nn.Module):
|
||||
attn_metadata = forward_context.attn_metadata
|
||||
if not envs.VLLM_USE_V1 and attn_metadata is None:
|
||||
return None
|
||||
if "request_ids_to_seq_ids" not in kwargs:
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
if "request_ids_to_seq_ids" not in kwargs:
|
||||
kwargs["request_ids_to_seq_ids"] = {}
|
||||
if "finished_requests_ids" not in kwargs:
|
||||
kwargs["finished_requests_ids"] = []
|
||||
(
|
||||
minimax_cache_tensors,
|
||||
state_indices_tensor,
|
||||
@ -1077,16 +1021,6 @@ class MiniMaxText01Model(nn.Module):
|
||||
|
||||
for i in range(self.start_layer, self.end_layer):
|
||||
layer = self.layers[i]
|
||||
if attn_metadata is not None:
|
||||
# TODO (tdoublep): this whole thing with the rotary_emb is
|
||||
# weird. we shouldn't be passing it via attn_metadata imo.
|
||||
if envs.VLLM_USE_V1:
|
||||
if isinstance(layer.self_attn, MiniMaxText01Attention):
|
||||
attn_metadata[layer.prefix +
|
||||
".attn"].rotary_emb = self.rotary_emb
|
||||
else:
|
||||
attn_metadata.rotary_emb = self.rotary_emb
|
||||
|
||||
_caches = None
|
||||
if not envs.VLLM_USE_V1 and isinstance(
|
||||
layer.self_attn, MiniMaxText01LinearAttention):
|
||||
@ -1120,7 +1054,6 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
@ -1133,13 +1066,8 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
self.unpadded_vocab_size = self.config.vocab_size
|
||||
if hasattr(vllm_config.model_config, "max_model_len"):
|
||||
self.config.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.model = MiniMaxText01Model(
|
||||
self.config,
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
quant_config=quant_config,
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
self.model = MiniMaxText01Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
if get_pp_group().is_last_rank:
|
||||
self.lm_head = ParallelLMHead(
|
||||
self.unpadded_vocab_size,
|
||||
@ -1469,3 +1397,35 @@ class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
|
||||
tp_size=parallel_config.tensor_parallel_size,
|
||||
head_dim=hf_config.head_dim,
|
||||
)
|
||||
|
||||
|
||||
def linear_attention(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self._forward(hidden_states=hidden_states,
|
||||
output=output,
|
||||
positions=positions,
|
||||
kv_caches=None)
|
||||
|
||||
|
||||
def linear_attention_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
layer_name: str,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="linear_attention",
|
||||
op_func=linear_attention,
|
||||
mutates_args=["output"],
|
||||
fake_impl=linear_attention_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
Reference in New Issue
Block a user