[PERF] Add conv1d metadata to GDN attn (#25105)

Signed-off-by: Vadim Gimpelson <vadim.gimpelson@gmail.com>
This commit is contained in:
Vadim Gimpelson
2025-09-18 18:27:49 +04:00
committed by GitHub
parent 01a583fea4
commit 072d7e53e5
5 changed files with 24 additions and 8 deletions

View File

@ -11,6 +11,7 @@ from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.platforms import current_platform
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadata
from vllm.v1.attention.backends.mamba2_attn import (
Mamba2AttentionMetadata, _query_start_loc_to_chunk_indices_offsets)
@ -45,8 +46,8 @@ class Mamba2Metadata:
"""
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
@ -117,7 +118,8 @@ def prepare_mamba2_metadata(
def update_metadata(x: torch.Tensor, query_start_loc: torch.Tensor,
mamba2_metadata: Union[Mamba2Metadata,
Mamba2AttentionMetadata]):
Mamba2AttentionMetadata,
GDNAttentionMetadata]):
"""
this is triggered upon handling a new input at the first layer
"""

View File

@ -35,6 +35,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.mamba.abstract import MambaBase
from vllm.model_executor.layers.mamba.mamba2_metadata import update_metadata
from vllm.model_executor.layers.mamba.mamba_mixer2 import (
mamba_v2_sharded_weight_loader)
from vllm.model_executor.layers.mamba.mamba_utils import (
@ -414,6 +415,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
assert isinstance(attn_metadata, dict)
attn_metadata = attn_metadata[self.prefix]
conv_metadata = attn_metadata
assert isinstance(attn_metadata, GDNAttentionMetadata)
has_initial_state = attn_metadata.has_initial_state
spec_query_start_loc = attn_metadata.spec_query_start_loc
@ -475,10 +477,15 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
# 2.2: process the remaining part
if attn_metadata.num_prefills > 0:
mixed_qkv_non_spec_T = mixed_qkv_non_spec.transpose(0, 1)
if conv_metadata.cu_seqlen is None:
conv_metadata = update_metadata(mixed_qkv_non_spec_T,
non_spec_query_start_loc,
conv_metadata)
# - "cache_indices" updates the conv_state cache in positions
# pointed to by "mamba_cache_params.state_indices_tensor"
mixed_qkv_non_spec = causal_conv1d_fn(
mixed_qkv_non_spec.transpose(0, 1),
mixed_qkv_non_spec_T,
conv_weights,
self.conv1d.bias,
activation=self.activation,
@ -486,6 +493,7 @@ class Qwen3NextGatedDeltaNet(nn.Module, MambaBase):
has_initial_state=has_initial_state,
cache_indices=non_spec_state_indices_tensor,
query_start_loc=non_spec_query_start_loc,
metadata=conv_metadata,
).transpose(0, 1)
elif attn_metadata.num_decodes > 0:
mixed_qkv_non_spec = causal_conv1d_update(

View File

@ -50,6 +50,12 @@ class GDNAttentionMetadata:
Tensor] = None # shape: [num_prefill_tokens + num_decode_tokens,]
num_accepted_tokens: Optional[torch.Tensor] = None # shape: [batch,]
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class GDNAttentionMetadataBuilder(
AttentionMetadataBuilder[GDNAttentionMetadata]):

View File

@ -132,8 +132,8 @@ class Mamba2AttentionMetadata:
# The following attributes are for triton implementation of causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class Mamba2AttentionMetadataBuilder(

View File

@ -34,8 +34,8 @@ class ShortConvAttentionMetadata:
# For causal_conv1d
nums_dict: Optional[dict] = None
cu_seqlen: Optional[int] = None
batch_ptr: Optional[torch.tensor] = None
token_chunk_offset_ptr: Optional[torch.tensor] = None
batch_ptr: Optional[torch.Tensor] = None
token_chunk_offset_ptr: Optional[torch.Tensor] = None
class ShortConvAttentionMetadataBuilder(