[V1] [Hybrid] Enable piecewise CUDA Graph for mamba layers (#21194)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell
2025-07-19 21:27:21 +02:00
committed by GitHub
parent 9f414a12ad
commit 881e3cbe3b
10 changed files with 100 additions and 31 deletions

View File

@ -104,7 +104,6 @@ def test_models(
m.setenv("VLLM_ATTENTION_BACKEND", "FLASHINFER")
with vllm_runner(model,
max_num_seqs=MAX_NUM_SEQS,
enforce_eager=True,
enable_prefix_caching=False) as vllm_model:
vllm_v1_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

View File

@ -4312,6 +4312,7 @@ class CompilationConfig:
self.splitting_ops = [] if self.full_cuda_graph else [
"vllm.unified_attention",
"vllm.unified_attention_with_output",
"vllm.mamba_mixer2",
]

View File

@ -13,7 +13,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather,
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
@ -33,6 +33,8 @@ from vllm.model_executor.model_loader.weight_utils import (
LoaderFunction, composed_weight_loader, sharded_weight_loader)
from vllm.model_executor.models.mamba_cache import MambaCacheParams
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
from vllm.v1.attention.backends.mamba_attn import Mamba2AttentionMetadata
# Added by the IBM Team, 2024
@ -424,14 +426,36 @@ class MambaMixer2(MambaBase, CustomOp):
def forward_native(
self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
pass
def forward(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
):
if not envs.VLLM_USE_V1:
CustomOp.forward(self, hidden_states, output, mamba_cache_params,
mamba2_metadata, mup_vector)
else:
torch.ops.vllm.mamba_mixer2(
hidden_states,
output,
self.prefix,
mup_vector,
)
def forward_cuda(
self,
hidden_states: torch.Tensor,
output: torch.Tensor,
mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
@ -517,6 +541,7 @@ class MambaMixer2(MambaBase, CustomOp):
num_prefill_tokens = attn_metadata.num_prefill_tokens # token count
has_prefill = num_prefills > 0
has_decode = num_decodes > 0
num_actual_tokens = num_prefill_tokens + num_decodes
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
# Separate prefill and decode by splitting varlen input
@ -524,18 +549,18 @@ class MambaMixer2(MambaBase, CustomOp):
# NOTE: V0 put prefill before decode, v1 puts decode before prefill
if envs.VLLM_USE_V1:
hidden_states_B_C_d, hidden_states_B_C_p = torch.split(
hidden_states_B_C,
hidden_states_B_C[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
dt_d, dt_p = torch.split(
dt,
dt[:num_actual_tokens],
[num_decodes, num_prefill_tokens],
dim=0,
)
# Split along batch dimension
state_indices_tensor_d, state_indices_tensor_p = torch.split(
state_indices_tensor,
state_indices_tensor[:num_actual_tokens],
[num_decodes, num_prefills],
dim=0,
)
@ -696,11 +721,10 @@ class MambaMixer2(MambaBase, CustomOp):
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate)
hidden_states = self.norm(hidden_states, gate[:num_actual_tokens])
# 5. Final linear projection
out, _ = self.out_proj(hidden_states)
return out
output[:num_actual_tokens], _ = self.out_proj(hidden_states)
def get_state_shape(self) -> tuple[tuple[int, ...], tuple[int, ...]]:
return get_mamba_state_shape(
@ -712,3 +736,36 @@ class MambaMixer2(MambaBase, CustomOp):
state_size=self.ssm_state_size,
conv_kernel=self.conv_kernel_size,
)
def mamba_mixer2(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
self.forward_cuda(hidden_states=hidden_states,
output=output,
mamba_cache_params=None,
mamba2_metadata=None,
mup_vector=mup_vector)
def mamba_mixer2_fake(
hidden_states: torch.Tensor,
output: torch.Tensor,
layer_name: str,
mup_vector: Optional[torch.Tensor] = None,
) -> None:
return
direct_register_custom_op(
op_name="mamba_mixer2",
op_func=mamba_mixer2,
mutates_args=["output"],
fake_impl=mamba_mixer2_fake,
dispatch_key=current_platform.dispatch_key,
)

View File

@ -11,6 +11,7 @@ from transformers import BambaConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
@ -122,11 +123,10 @@ class BambaMixerDecoderLayer(nn.Module):
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states = self.mamba(hidden_states, mamba_cache_params,
mamba2_metadata)
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
# Fully Connected
hidden_states, residual = self.pre_ff_layernorm(
hidden_states, residual)
hidden_states, residual = self.pre_ff_layernorm(output, residual)
hidden_states = self.feed_forward(hidden_states)
return hidden_states, residual
@ -169,7 +169,7 @@ class BambaAttentionDecoderLayer(nn.Module):
self.max_position_embeddings = max_position_embeddings
if hasattr(config, "partial_rotary_factor"):
rotary_dim = self.head_dim * config.partial_rotary_factor
rotary_dim = int(self.head_dim * config.partial_rotary_factor)
elif hasattr(config, "attn_rotary_emb"):
rotary_dim = config.attn_rotary_emb # for backward compatibility
else:
@ -258,6 +258,7 @@ ALL_DECODER_LAYER_TYPES = {
}
@support_torch_compile
class BambaModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -10,6 +10,7 @@ from transformers import FalconH1Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
@ -179,13 +180,15 @@ class FalconH1SSMDecoderLayer(nn.Module):
mamba2_metadata: Mamba2Metadata,
**kwargs,
):
hidden_states = self.mamba(
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
mamba_cache_params,
mamba2_metadata=mamba2_metadata,
mup_vector=self.mup_vector,
)
return hidden_states, residual
return output, residual
class FalconH1AttentionDecoderLayer(nn.Module):
@ -398,6 +401,7 @@ class FalconH1ParallelHybrid(nn.Module):
return hidden_states
@support_torch_compile
class FalconH1Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -11,6 +11,7 @@ from transformers import GraniteMoeHybridConfig
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
@ -104,9 +105,9 @@ class GraniteMoeHybridMambaDecoderLayer(nn.Module):
):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.mamba(hidden_states, mamba_cache_params,
mamba2_metadata)
hidden_states = residual + hidden_states * self.residual_multiplier
output = torch.empty_like(hidden_states)
self.mamba(hidden_states, output, mamba_cache_params, mamba2_metadata)
hidden_states = residual + output * self.residual_multiplier
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
@ -307,6 +308,7 @@ ALL_DECODER_LAYER_TYPES = {
}
@support_torch_compile
class GraniteMoeHybridModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -10,6 +10,7 @@ from transformers import MambaConfig
from vllm import envs
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@ -79,11 +80,12 @@ class Mamba2DecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba2_metadata)
return hidden_states, residual
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
return output, residual
@support_torch_compile
class Mamba2Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -25,6 +25,7 @@ from torch import nn
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed.parallel_state import get_pp_group
@ -172,9 +173,9 @@ class NemotronHMambaDecoderLayer(nn.Module):
else:
hidden_states, residual = self.norm(hidden_states, residual)
hidden_states = self.mixer(hidden_states, mamba_cache_params,
mamba2_metadata)
return hidden_states, residual
output = torch.empty_like(hidden_states)
self.mixer(hidden_states, output, mamba_cache_params, mamba2_metadata)
return output, residual
class NemotronHAttention(nn.Module):
@ -292,6 +293,7 @@ ALL_DECODER_LAYER_TYPES = {
}
@support_torch_compile
class NemotronHModel(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

View File

@ -17,6 +17,7 @@ from transformers import Zamba2Config
from vllm import envs
from vllm.attention.layer import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.forward_context import get_forward_context
@ -548,14 +549,16 @@ class Zamba2MambaDecoderLayer(nn.Module):
hidden_states = self.input_layernorm(hidden_states)
# Process through Mamba mixer
hidden_states = self.mamba(
output = torch.empty_like(hidden_states)
self.mamba(
hidden_states,
output,
mamba_cache_params=mamba_cache_params,
mamba2_metadata=mamba2_metadata,
)
# residual connection after mamba
hidden_states = residual + hidden_states
hidden_states = residual + output
return hidden_states
@ -646,6 +649,7 @@ class Zamba2HybridLayer(nn.Module):
return layer_outputs
@support_torch_compile
class Zamba2Model(nn.Module):
"""Core Zamba2 model combining transformer and Mamba architectures.

View File

@ -2753,9 +2753,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
if self.vllm_config.speculative_config is not None:
raise NotImplementedError(
"Mamba with speculative decoding is not supported yet.")
if not self.vllm_config.model_config.enforce_eager:
raise NotImplementedError(
"Mamba with cuda graph is not supported yet.")
if self.vllm_config.cache_config.enable_prefix_caching:
raise NotImplementedError(
"Prefix caching is not supported for Mamba yet.")