Support for LlamaForSequenceClassification (#20807)

Signed-off-by: thechaos16 <thechaos16@gmail.com>
This commit is contained in:
Minkyu Kim
2025-07-13 16:09:34 +09:00
committed by GitHub
parent 99b4f080d8
commit bd4c1e6fdb
3 changed files with 7 additions and 1 deletions

View File

@ -330,6 +330,7 @@ _CROSS_ENCODER_EXAMPLE_MODELS = {
hf_overrides={"architectures": ["GemmaForSequenceClassification"], # noqa: E501
"classifier_from_token": ["Yes"], # noqa: E501
"method": "no_post_processing"}), # noqa: E501
"LlamaForSequenceClassification": _HfExamplesInfo("Skywork/Skywork-Reward-V2-Llama-3.2-1B"), # noqa: E501
"ModernBertForSequenceClassification": _HfExamplesInfo("Alibaba-NLP/gte-reranker-modernbert-base", v0_only=True), # noqa: E501
"RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base", v0_only=True), # noqa: E501
"XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3", v0_only=True), # noqa: E501

View File

@ -49,6 +49,7 @@ from vllm.model_executor.model_loader.weight_utils import (
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from .adapters import as_seq_cls_model
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (AutoWeightsLoader, PPMissingLayer, extract_layer_index,
is_pp_missing_parameter,
@ -645,3 +646,6 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
name = name.replace(item, mapping[item])
return name, loaded_weight
LlamaForSequenceClassification = as_seq_cls_model(LlamaForCausalLM)

View File

@ -183,7 +183,8 @@ _CROSS_ENCODER_MODELS = {
"GemmaForSequenceClassification": ("gemma", "GemmaForSequenceClassification"), # noqa: E501
"Qwen2ForSequenceClassification": ("qwen2", "Qwen2ForSequenceClassification"), # noqa: E501
"Qwen3ForSequenceClassification": ("qwen3", "Qwen3ForSequenceClassification"), # noqa: E501
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501
"LlamaForSequenceClassification": ("llama", "LlamaForSequenceClassification"), # noqa: E501
"JinaVLForRanking": ("jina_vl", "JinaVLForSequenceClassification"), # noqa: E501,
}
_MULTIMODAL_MODELS = {