mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers (#21194)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
@ -104,7 +104,6 @@ def test_models(
|
||||
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
|
||||
with vllm_runner(model,
|
||||
max_num_seqs=MAX_NUM_SEQS,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=False) as vllm_model:
|
||||
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
@ -4312,6 +4312,7 @@ class CompilationConfig:
|
||||
self.splitting_ops = [] if self.full_cuda_graph else [
|
||||
"vllm.unified_attention",
|
||||
"vllm.unified_attention_with_output",
|
||||
"vllm.mamba_mixer2",
|
||||
]
|
||||
|
||||
|
||||
|
@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.forward_context import ForwardContext, get_forward_context
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
LoaderFunction, composed_weight_loader, sharded_weight_loader)
|
||||
from vllm.model_executor.models.mamba_cache import MambaCacheParams
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
|
||||
|
||||
# Added by the IBM Team, 2024
|
||||
@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
conv_state: torch.Tensor,
|
||||
ssm_state: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
pass
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not envs.VLLM_USE_V1:
|
||||
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
|
||||
mamba2_metadata, mup_vector)
|
||||
else:
|
||||
torch.ops.vllm.mamba_mixer2(
|
||||
hidden_states,
|
||||
output,
|
||||
self.prefix,
|
||||
mup_vector,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
mamba_cache_params: MambaCacheParams,
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
|
||||
has_prefill = num_prefills > 0
|
||||
has_decode = num_decodes > 0
|
||||
num_actual_tokens = num_prefill_tokens + num_decodes
|
||||
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
# Separate prefill and decode by splitting varlen input
|
||||
@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
|
||||
if envs.VLLM_USE_V1:
|
||||
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
|
||||
hidden_states_B_C,
|
||||
hidden_states_B_C[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
dt_d, dt_p = torch.split(
|
||||
dt,
|
||||
dt[:num_actual_tokens],
|
||||
[num_decodes, num_prefill_tokens],
|
||||
dim=0,
|
||||
)
|
||||
# Split along batch dimension
|
||||
state_indices_tensor_d, state_indices_tensor_p = torch.split(
|
||||
state_indices_tensor,
|
||||
state_indices_tensor[:num_actual_tokens],
|
||||
[num_decodes, num_prefills],
|
||||
dim=0,
|
||||
)
|
||||
@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
# GatedRMSNorm internally applying SiLU to the gate
|
||||
# SiLU is applied internally before normalization, unlike standard
|
||||
# norm usage
|
||||
hidden_states = self.norm(hidden_states, gate)
|
||||
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
|
||||
|
||||
# 5. Final linear projection
|
||||
out, _ = self.out_proj(hidden_states)
|
||||
return out
|
||||
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
|
||||
|
||||
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
|
||||
return get_mamba_state_shape(
|
||||
@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp):
|
||||
state_size=self.ssm_state_size,
|
||||
conv_kernel=self.conv_kernel_size,
|
||||
)
|
||||
|
||||
|
||||
def mamba_mixer2(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
forward_context: ForwardContext = get_forward_context()
|
||||
self = forward_context.no_compile_layers[layer_name]
|
||||
self.forward_cuda(hidden_states=hidden_states,
|
||||
output=output,
|
||||
mamba_cache_params=None,
|
||||
mamba2_metadata=None,
|
||||
mup_vector=mup_vector)
|
||||
|
||||
|
||||
def mamba_mixer2_fake(
|
||||
hidden_states: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
layer_name: str,
|
||||
mup_vector: Optional[torch.Tensor] = None,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
direct_register_custom_op(
|
||||
op_name="mamba_mixer2",
|
||||
op_func=mamba_mixer2,
|
||||
mutates_args=["output"],
|
||||
fake_impl=mamba_mixer2_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ from transformers import BambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -122,11 +123,10 @@ class BambaMixerDecoderLayer(nn.Module):
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.pre_ff_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states, residual = self.pre_ff_layernorm(output, residual)
|
||||
hidden_states = self.feed_forward(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
@ -169,7 +169,7 @@ class BambaAttentionDecoderLayer(nn.Module):
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
if hasattr(config, "partial_rotary_factor"):
|
||||
rotary_dim = self.head_dim * config.partial_rotary_factor
|
||||
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
|
||||
elif hasattr(config, "attn_rotary_emb"):
|
||||
rotary_dim = config.attn_rotary_emb # for backward compatibility
|
||||
else:
|
||||
@ -258,6 +258,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class BambaModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -10,6 +10,7 @@ from transformers import FalconH1Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -179,13 +180,15 @@ class FalconH1SSMDecoderLayer(nn.Module):
|
||||
mamba2_metadata: Mamba2Metadata,
|
||||
**kwargs,
|
||||
):
|
||||
hidden_states = self.mamba(
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
mup_vector=self.mup_vector,
|
||||
)
|
||||
return hidden_states, residual
|
||||
return output, residual
|
||||
|
||||
|
||||
class FalconH1AttentionDecoderLayer(nn.Module):
|
||||
@ -398,6 +401,7 @@ class FalconH1ParallelHybrid(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class FalconH1Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -11,6 +11,7 @@ from transformers import GraniteMoeHybridConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -104,9 +105,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
|
||||
):
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
hidden_states = self.mamba(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
hidden_states = residual + hidden_states * self.residual_multiplier
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
hidden_states = residual + output * self.residual_multiplier
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
@ -307,6 +308,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class GraniteMoeHybridModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -10,6 +10,7 @@ from transformers import MambaConfig
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -79,11 +80,12 @@ class Mamba2DecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
return hidden_states, residual
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
return output, residual
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Mamba2Model(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -25,6 +25,7 @@ from torch import nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@ -172,9 +173,9 @@ class NemotronHMambaDecoderLayer(nn.Module):
|
||||
else:
|
||||
hidden_states, residual = self.norm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.mixer(hidden_states, mamba_cache_params,
|
||||
mamba2_metadata)
|
||||
return hidden_states, residual
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
|
||||
return output, residual
|
||||
|
||||
|
||||
class NemotronHAttention(nn.Module):
|
||||
@ -292,6 +293,7 @@ ALL_DECODER_LAYER_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class NemotronHModel(nn.Module):
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
|
@ -17,6 +17,7 @@ from transformers import Zamba2Config
|
||||
|
||||
from vllm import envs
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.forward_context import get_forward_context
|
||||
@ -548,14 +549,16 @@ class Zamba2MambaDecoderLayer(nn.Module):
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Process through Mamba mixer
|
||||
hidden_states = self.mamba(
|
||||
output = torch.empty_like(hidden_states)
|
||||
self.mamba(
|
||||
hidden_states,
|
||||
output,
|
||||
mamba_cache_params=mamba_cache_params,
|
||||
mamba2_metadata=mamba2_metadata,
|
||||
)
|
||||
|
||||
# residual connection after mamba
|
||||
hidden_states = residual + hidden_states
|
||||
hidden_states = residual + output
|
||||
|
||||
return hidden_states
|
||||
|
||||
@ -646,6 +649,7 @@ class Zamba2HybridLayer(nn.Module):
|
||||
return layer_outputs
|
||||
|
||||
|
||||
@support_torch_compile
|
||||
class Zamba2Model(nn.Module):
|
||||
"""Core Zamba2 model combining transformer and Mamba architectures.
|
||||
|
||||
|
@ -2753,9 +2753,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.vllm_config.speculative_config is not None:
|
||||
raise NotImplementedError(
|
||||
"Mamba with speculative decoding is not supported yet.")
|
||||
if not self.vllm_config.model_config.enforce_eager:
|
||||
raise NotImplementedError(
|
||||
"Mamba with cuda graph is not supported yet.")
|
||||
if self.vllm_config.cache_config.enable_prefix_caching:
|
||||
raise NotImplementedError(
|
||||
"Prefix caching is not supported for Mamba yet.")
|
||||
|
Reference in New Issue
Block a user