mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Support for LlamaForSequenceClassification (#20807)
Signed-off-by: thechaos16 <thechaos16@gmail.com>
This commit is contained in:
@ -330,6 +330,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
|
||||
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
|
||||
"classifier_from_token": ["Yes"], # noqa: E501
|
||||
"method": "no_post_processing"}), # noqa: E501
|
||||
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
|
||||
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
|
||||
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
|
||||
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501
|
||||
|
@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
from .adapters import as_seq_cls_model
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
|
||||
is_pp_missing_parameter,
|
||||
@ -645,3 +646,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
name = name.replace(item, mapping[item])
|
||||
|
||||
return name, loaded_weight
|
||||
|
||||
|
||||
LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM)
|
||||
|
@ -183,7 +183,8 @@ _CROSS_ENCODER_MODELS = {
|
||||
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
|
||||
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
|
||||
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
|
||||
"LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501
|
||||
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
|
||||
}
|
||||
|
||||
_MULTIMODAL_MODELS = {
|
||||
|
Reference in New Issue
Block a user