[Model] Add Qwen2 PRM model support (#12202)

Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
Isotr0py
2025-01-20 14:59:46 +08:00
committed by GitHub
parent 0974c9bc5c
commit 83609791d2
5 changed files with 45 additions and 13 deletions

View File

@ -470,6 +470,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
- `Qwen/Qwen2.5-Math-RM-72B`, etc. - `Qwen/Qwen2.5-Math-RM-72B`, etc.
- ✅︎ - ✅︎
- ✅︎ - ✅︎
* - `Qwen2ForProcessRewardModel`
- Qwen2-based
- `Qwen/Qwen2.5-Math-PRM-7B`, `Qwen/Qwen2.5-Math-PRM-72B`, etc.
- ✅︎
- ✅︎
``` ```
If your model is not in the above list, we will try to automatically convert the model using If your model is not in the above list, we will try to automatically convert the model using

View File

@ -17,14 +17,15 @@ from ..utils import check_embeddings_close
marks=[pytest.mark.core_model, pytest.mark.cpu_model]), marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("sentence-transformers/all-MiniLM-L12-v2"), pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
pytest.param("intfloat/multilingual-e5-large"), pytest.param("intfloat/multilingual-e5-large"),
# [Encoder-decoder] # [Decoder-only]
pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("BAAI/bge-multilingual-gemma2", pytest.param("BAAI/bge-multilingual-gemma2",
marks=[pytest.mark.core_model]), marks=[pytest.mark.core_model]),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("intfloat/e5-mistral-7b-instruct",
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
# [Encoder-decoder]
pytest.param("sentence-transformers/stsb-roberta-base-v2"), pytest.param("sentence-transformers/stsb-roberta-base-v2"),
], ],
) )

View File

@ -155,6 +155,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"), "MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"), "Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501

View File

@ -12,7 +12,7 @@ from vllm.attention import AttentionMetadata
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.pooler import Pooler, PoolingType from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.sequence import IntermediateTensors, PoolerOutput
@ -32,7 +32,7 @@ class ReLU(nn.Module):
return self.activation(input) return self.activation(input)
class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP): class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
@ -60,7 +60,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
config = vllm_config.model_config.hf_config config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config quant_config = vllm_config.quant_config
lora_config = vllm_config.lora_config lora_config = vllm_config.lora_config
pooler_config = vllm_config.model_config.pooler_config
self.config = config self.config = config
self.lora_config = lora_config self.lora_config = lora_config
@ -74,14 +73,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
config.hidden_size, config.hidden_size,
quant_config=quant_config), quant_config=quant_config),
ReLU(), ReLU(),
RowParallelLinear(config.hidden_size, 1, RowParallelLinear(config.hidden_size,
config.num_labels,
quant_config=quant_config), quant_config=quant_config),
) )
self._pooler = Pooler.from_config_with_defaults( self._pooler: SimplePooler
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
@ -115,3 +111,31 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
loader = AutoWeightsLoader(self, loader = AutoWeightsLoader(self,
ignore_unexpected_prefixes=["lm_head."]) ignore_unexpected_prefixes=["lm_head."])
return loader.load_weights(weights) return loader.load_weights(weights)
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config, prefix=""):
vllm_config.model_config.hf_config.num_labels = 1
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False)
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
def __init__(self, *, vllm_config, prefix=""):
vllm_config.model_config.hf_config.num_labels = 2
super().__init__(vllm_config=vllm_config, prefix=prefix)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.STEP,
normalize=False,
softmax=True,
step_tag_id=151651,
)

View File

@ -127,6 +127,7 @@ _EMBEDDING_MODELS = {
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"), "Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"), "TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
# [Multimodal] # [Multimodal]
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501