mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix ModernBert load & Enable sliding window attention for bidirectional attention. (#22637)
Signed-off-by: wang.yuqi <noooop@126.com> Signed-off-by: Max de Bayser <mbayser@br.ibm.com> Co-authored-by: Max de Bayser <mbayser@br.ibm.com>
This commit is contained in:
@ -4,10 +4,11 @@ from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from ...utils import (CLSPoolingEmbedModelInfo, EmbedModelInfo,
|
||||
LASTPoolingEmbedModelInfo, check_transformers_version)
|
||||
from ...utils import (CLSPoolingEmbedModelInfo, CLSPoolingRerankModelInfo,
|
||||
EmbedModelInfo, LASTPoolingEmbedModelInfo,
|
||||
RerankModelInfo, check_transformers_version)
|
||||
from .embed_utils import correctness_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models
|
||||
from .mteb_utils import mteb_test_embed_models, mteb_test_rerank_models
|
||||
|
||||
MODELS = [
|
||||
########## BertModel
|
||||
@ -58,6 +59,14 @@ MODELS = [
|
||||
enable_test=False),
|
||||
]
|
||||
|
||||
RERANK_MODELS = [
|
||||
# classifier_pooling: mean
|
||||
CLSPoolingRerankModelInfo(
|
||||
"Alibaba-NLP/gte-reranker-modernbert-base",
|
||||
architecture="ModernBertForSequenceClassification",
|
||||
enable_test=True),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", MODELS)
|
||||
def test_embed_models_mteb(hf_runner, vllm_runner,
|
||||
@ -88,3 +97,9 @@ def test_embed_models_correctness(hf_runner, vllm_runner,
|
||||
|
||||
correctness_test_embed_models(hf_runner, vllm_runner, model_info,
|
||||
example_prompts, vllm_extra_kwargs)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", RERANK_MODELS)
|
||||
def test_rerank_models_mteb(hf_runner, vllm_runner,
|
||||
model_info: RerankModelInfo) -> None:
|
||||
mteb_test_rerank_models(hf_runner, vllm_runner, model_info)
|
||||
|
@ -26,8 +26,7 @@ from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.tasks import PoolingTask
|
||||
|
||||
from .interfaces import (SupportsCrossEncoding, SupportsV0Only,
|
||||
default_pooling_type)
|
||||
from .interfaces import SupportsCrossEncoding, default_pooling_type
|
||||
from .utils import WeightsMapper, maybe_prefix
|
||||
|
||||
|
||||
@ -93,16 +92,14 @@ class ModernBertAttention(nn.Module):
|
||||
bias=config.attention_bias,
|
||||
)
|
||||
|
||||
sliding_window = None
|
||||
if layer_id % config.global_attn_every_n_layers != 0:
|
||||
self.local_attention = (config.local_attention // 2,
|
||||
config.local_attention // 2)
|
||||
sliding_window = config.local_attention // 2
|
||||
rope_theta = config.local_rope_theta if config.local_rope_theta \
|
||||
is not None else config.global_rope_theta
|
||||
else:
|
||||
self.local_attention = (-1, -1)
|
||||
rope_theta = config.global_rope_theta
|
||||
|
||||
rope_theta = config.global_rope_theta
|
||||
if self.local_attention != (
|
||||
-1, -1) and config.local_rope_theta is not None:
|
||||
rope_theta = config.local_rope_theta
|
||||
self.rotary_emb = ModernBertRotaryEmbedding(config=config,
|
||||
head_size=self.head_dim,
|
||||
dim=self.head_dim,
|
||||
@ -111,7 +108,8 @@ class ModernBertAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
prefix=f"{layer_id}.attn",
|
||||
attn_type=AttentionType.ENCODER_ONLY)
|
||||
attn_type=AttentionType.ENCODER_ONLY,
|
||||
per_layer_sliding_window=sliding_window)
|
||||
self.Wo = RowParallelLinear(config.hidden_size,
|
||||
config.hidden_size,
|
||||
bias=config.attention_bias)
|
||||
@ -278,6 +276,7 @@ class ModernBertPooler(Pooler):
|
||||
return self.pooling.get_pooling_updates(task)
|
||||
|
||||
def _head(self, pooled_output: torch.Tensor):
|
||||
pooled_output = pooled_output.to(self.dense.weight.dtype)
|
||||
return self.norm(self.act(self.dense(pooled_output)))
|
||||
|
||||
def forward(
|
||||
@ -296,8 +295,7 @@ class ModernBertPooler(Pooler):
|
||||
|
||||
|
||||
@default_pooling_type("CLS")
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
SupportsCrossEncoding):
|
||||
class ModernBertForSequenceClassification(nn.Module, SupportsCrossEncoding):
|
||||
|
||||
is_pooling_model = True
|
||||
|
||||
@ -308,6 +306,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
self.model = ModernBertModel(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "modernbert"))
|
||||
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
||||
self.pooling = ModernBertPooler(config)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
assert pooler_config is not None
|
||||
@ -317,14 +316,14 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
Pooler.for_encode(pooler_config),
|
||||
"classify":
|
||||
ClassifierPooler(
|
||||
pooling=ModernBertPooler(config),
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_seq_cls(
|
||||
vllm_config.model_config),
|
||||
),
|
||||
"score":
|
||||
ClassifierPooler(
|
||||
pooling=ModernBertPooler(config),
|
||||
pooling=self.pooling,
|
||||
classifier=self.classifier,
|
||||
act_fn=ClassifierPooler.act_fn_for_cross_encoder(
|
||||
vllm_config.model_config),
|
||||
@ -353,7 +352,7 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
if name.startswith("head"):
|
||||
param = params_dict["_pooler.pooler." + name[len("head") + 1:]]
|
||||
param = params_dict["pooling." + name[len("head") + 1:]]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
@ -368,5 +367,5 @@ class ModernBertForSequenceClassification(nn.Module, SupportsV0Only,
|
||||
return self.model(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
position_ids=positions,
|
||||
positions=positions,
|
||||
)
|
||||
|
@ -384,6 +384,8 @@ class FlashAttentionImpl(AttentionImpl):
|
||||
self.alibi_slopes = alibi_slopes
|
||||
if sliding_window is None:
|
||||
self.sliding_window = (-1, -1)
|
||||
elif attn_type == AttentionType.ENCODER_ONLY:
|
||||
self.sliding_window = (sliding_window - 1, sliding_window - 1)
|
||||
else:
|
||||
self.sliding_window = (sliding_window - 1, 0)
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
@ -826,7 +826,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Prepare encoder attention metadata separately
|
||||
# (encoder layers are not in KV cache groups)
|
||||
if self.is_encoder_only_model:
|
||||
common_attn_metadata, encoder_attn_metadata = \
|
||||
|
||||
per_layer_metadata = \
|
||||
self._build_encoder_only_attn_metadata(
|
||||
scheduler_output)
|
||||
|
||||
@ -835,6 +836,8 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
self.vllm_config, Attention)
|
||||
for layer_name, attn_module in attention_layers.items():
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
common_attn_metadata, encoder_attn_metadata =\
|
||||
per_layer_metadata[layer_name]
|
||||
attn_metadata[layer_name] = encoder_attn_metadata
|
||||
|
||||
# Prepare the attention metadata for each KV cache group and make layers
|
||||
@ -2683,30 +2686,41 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
# Check if model is encoder-only
|
||||
block_size = self.vllm_config.cache_config.block_size
|
||||
use_mla = self.vllm_config.model_config.use_mla
|
||||
attn_specs = list[AttentionSpec]()
|
||||
for attn_module in attn_layers.values():
|
||||
attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list)
|
||||
for layer_name, attn_module in attn_layers.items():
|
||||
|
||||
if attn_module.attn_type == AttentionType.ENCODER_ONLY:
|
||||
assert attn_module.sliding_window is None, "Sliding "
|
||||
"window attention is not supported for encoder-only models"
|
||||
if attn_module.sliding_window is None:
|
||||
attn_spec: AttentionSpec = FullAttentionSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla)
|
||||
else:
|
||||
attn_spec = SlidingWindowSpec(
|
||||
block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
sliding_window=attn_module.sliding_window,
|
||||
use_mla=use_mla)
|
||||
attn_specs[attn_spec].append(layer_name)
|
||||
|
||||
attn_specs.append(
|
||||
FullAttentionSpec(block_size=block_size,
|
||||
num_kv_heads=attn_module.num_kv_heads,
|
||||
head_size=attn_module.head_size,
|
||||
dtype=self.kv_cache_dtype,
|
||||
use_mla=use_mla))
|
||||
else:
|
||||
raise ValueError("Expected only encoder-only layers")
|
||||
|
||||
if len(attn_specs) > 0:
|
||||
assert len(attn_specs) == len(attn_layers), \
|
||||
total_layers = 0
|
||||
for attn_spec, layer_names in attn_specs.items():
|
||||
|
||||
attn_backends = get_attn_backends_for_layers(layer_names)
|
||||
total_layers += len(layer_names)
|
||||
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, attn_spec))
|
||||
assert total_layers == len(attn_layers), \
|
||||
"All or none of the layers are expected to be encoder-only"
|
||||
|
||||
attn_backends = get_attn_backends_for_layers(attn_layers.keys())
|
||||
|
||||
self.attn_groups.append(
|
||||
create_attn_groups(attn_backends, attn_specs[0]))
|
||||
self.is_encoder_only_model = True
|
||||
|
||||
def calculate_reorder_batch_threshold(self) -> None:
|
||||
@ -3071,7 +3085,7 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
|
||||
def _build_encoder_only_attn_metadata(
|
||||
self, scheduler_output: "SchedulerOutput") -> \
|
||||
tuple[CommonAttentionMetadata, Any]:
|
||||
dict[str, tuple[CommonAttentionMetadata, Any]]:
|
||||
"""Prepare encoder attention metadata for encoder-only models.
|
||||
|
||||
Args:
|
||||
@ -3088,10 +3102,6 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
tokens = [scheduler_output.num_scheduled_tokens[i] for i in req_ids]
|
||||
max_num_scheduled_tokens = max(tokens)
|
||||
|
||||
# Use the first attention metadata builder
|
||||
# to create encoder attention metadata
|
||||
builder = self.attn_groups[0][0].metadata_builder
|
||||
|
||||
dummy_block_table = torch.zeros((num_reqs, 1),
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
@ -3099,22 +3109,38 @@ class GPUModelRunner(LoRAModelRunnerMixin, KVConnectorModelRunnerMixin):
|
||||
dtype=torch.int32,
|
||||
device=self.device)
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
block_table_tensor=dummy_block_table,
|
||||
slot_mapping=dummy_slot_mapping,
|
||||
causal=False,
|
||||
)
|
||||
group_metadata = dict[str, tuple[CommonAttentionMetadata, Any]]()
|
||||
|
||||
return common_metadata, builder.build(
|
||||
common_prefix_len=0, # No cascade for encoder
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
for attn_group_list in self.attn_groups:
|
||||
|
||||
assert len(attn_group_list) == 1
|
||||
attn_group = attn_group_list[0]
|
||||
|
||||
# Use the first attention metadata builder
|
||||
# to create encoder attention metadata
|
||||
builder = attn_group.metadata_builder
|
||||
|
||||
common_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=self.query_start_loc[:num_reqs + 1],
|
||||
query_start_loc_cpu=self.query_start_loc_cpu[:num_reqs + 1],
|
||||
seq_lens=self.seq_lens[:num_reqs],
|
||||
seq_lens_cpu=self.seq_lens_cpu[:num_reqs],
|
||||
num_computed_tokens_cpu=self.input_batch.
|
||||
num_computed_tokens_cpu_tensor[:num_reqs],
|
||||
num_reqs=num_reqs,
|
||||
num_actual_tokens=total_num_scheduled_tokens,
|
||||
max_query_len=max_num_scheduled_tokens,
|
||||
block_table_tensor=dummy_block_table,
|
||||
slot_mapping=dummy_slot_mapping,
|
||||
causal=False,
|
||||
)
|
||||
|
||||
metadata = builder.build(
|
||||
common_prefix_len=0, # No cascade for encoder
|
||||
common_attn_metadata=common_metadata,
|
||||
)
|
||||
|
||||
for layer_name in attn_group.layer_names:
|
||||
group_metadata[layer_name] = (common_metadata, metadata)
|
||||
|
||||
return group_metadata
|
||||
|
Reference in New Issue
Block a user