mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[PERF] Add conv1d
metadata to GDN attn (#25105)
Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
@ -11,6 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
|
||||
PlaceholderAttentionMetadata)
|
||||
from vllm.attention.backends.utils import PAD_SLOT_ID
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
|
||||
from vllm.v1.attention.backends.mamba2_attn import (
|
||||
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
|
||||
|
||||
@ -45,8 +46,8 @@ class Mamba2Metadata:
|
||||
"""
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
|
||||
@ -117,7 +118,8 @@ def prepare_mamba2_metadata(
|
||||
|
||||
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
|
||||
mamba2_metadata: Union[Mamba2Metadata,
|
||||
Mamba2AttentionMetadata]):
|
||||
Mamba2AttentionMetadata,
|
||||
GDNAttentionMetadata]):
|
||||
"""
|
||||
this is triggered upon handling a new input at the first layer
|
||||
"""
|
||||
|
@ -35,6 +35,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.mamba.abstract import MambaBase
|
||||
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
|
||||
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
|
||||
mamba_v2_sharded_weight_loader)
|
||||
from vllm.model_executor.layers.mamba.mamba_utils import (
|
||||
@ -414,6 +415,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
assert isinstance(attn_metadata, dict)
|
||||
attn_metadata = attn_metadata[self.prefix]
|
||||
conv_metadata = attn_metadata
|
||||
assert isinstance(attn_metadata, GDNAttentionMetadata)
|
||||
has_initial_state = attn_metadata.has_initial_state
|
||||
spec_query_start_loc = attn_metadata.spec_query_start_loc
|
||||
@ -475,10 +477,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
|
||||
# 2.2: process the remaining part
|
||||
if attn_metadata.num_prefills > 0:
|
||||
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
|
||||
if conv_metadata.cu_seqlen is None:
|
||||
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
|
||||
non_spec_query_start_loc,
|
||||
conv_metadata)
|
||||
# - "cache_indices" updates the conv_state cache in positions
|
||||
# pointed to by "mamba_cache_params.state_indices_tensor"
|
||||
mixed_qkv_non_spec = causal_conv1d_fn(
|
||||
mixed_qkv_non_spec.transpose(0, 1),
|
||||
mixed_qkv_non_spec_T,
|
||||
conv_weights,
|
||||
self.conv1d.bias,
|
||||
activation=self.activation,
|
||||
@ -486,6 +493,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
|
||||
has_initial_state=has_initial_state,
|
||||
cache_indices=non_spec_state_indices_tensor,
|
||||
query_start_loc=non_spec_query_start_loc,
|
||||
metadata=conv_metadata,
|
||||
).transpose(0, 1)
|
||||
elif attn_metadata.num_decodes > 0:
|
||||
mixed_qkv_non_spec = causal_conv1d_update(
|
||||
|
@ -50,6 +50,12 @@ class GDNAttentionMetadata:
|
||||
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
|
||||
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
|
||||
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class GDNAttentionMetadataBuilder(
|
||||
AttentionMetadataBuilder[GDNAttentionMetadata]):
|
||||
|
@ -132,8 +132,8 @@ class Mamba2AttentionMetadata:
|
||||
# The following attributes are for triton implementation of causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class Mamba2AttentionMetadataBuilder(
|
||||
|
@ -34,8 +34,8 @@ class ShortConvAttentionMetadata:
|
||||
# For causal_conv1d
|
||||
nums_dict: Optional[dict] = None
|
||||
cu_seqlen: Optional[int] = None
|
||||
batch_ptr: Optional[torch.tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.tensor] = None
|
||||
batch_ptr: Optional[torch.Tensor] = None
|
||||
token_chunk_offset_ptr: Optional[torch.Tensor] = None
|
||||
|
||||
|
||||
class ShortConvAttentionMetadataBuilder(
|
||||
|
Reference in New Issue
Block a user