mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Add Qwen2 PRM model support (#12202)
Signed-off-by: Isotr0py <2037008807@qq.com>
This commit is contained in:
@ -470,6 +470,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
|
||||
- `Qwen/Qwen2.5-Math-RM-72B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - `Qwen2ForProcessRewardModel`
|
||||
- Qwen2-based
|
||||
- `Qwen/Qwen2.5-Math-PRM-7B`, `Qwen/Qwen2.5-Math-PRM-72B`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
```
|
||||
|
||||
If your model is not in the above list, we will try to automatically convert the model using
|
||||
|
@ -17,14 +17,15 @@ from ..utils import check_embeddings_close
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("sentence-transformers/all-MiniLM-L12-v2"),
|
||||
pytest.param("intfloat/multilingual-e5-large"),
|
||||
# [Encoder-decoder]
|
||||
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
# [Decoder-only]
|
||||
pytest.param("BAAI/bge-multilingual-gemma2",
|
||||
marks=[pytest.mark.core_model]),
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
pytest.param("intfloat/e5-mistral-7b-instruct",
|
||||
marks=[pytest.mark.core_model, pytest.mark.cpu_model]),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"),
|
||||
pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"),
|
||||
pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
# [Encoder-decoder]
|
||||
pytest.param("sentence-transformers/stsb-roberta-base-v2"),
|
||||
],
|
||||
)
|
||||
|
@ -155,6 +155,7 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
"Qwen2Model": _HfExamplesInfo("ssmits/Qwen2-7B-Instruct-embed-base"),
|
||||
"Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"),
|
||||
"Qwen2ForProcessRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-PRM-7B"),
|
||||
"Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501
|
||||
"RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501
|
||||
"RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501
|
||||
|
@ -12,7 +12,7 @@ from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType, SimplePooler
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
@ -32,7 +32,7 @@ class ReLU(nn.Module):
|
||||
return self.activation(input)
|
||||
|
||||
|
||||
class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
class Qwen2RewardBaseModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
@ -60,7 +60,6 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
@ -74,14 +73,11 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
config.hidden_size,
|
||||
quant_config=quant_config),
|
||||
ReLU(),
|
||||
RowParallelLinear(config.hidden_size, 1,
|
||||
RowParallelLinear(config.hidden_size,
|
||||
config.num_labels,
|
||||
quant_config=quant_config),
|
||||
)
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
self._pooler: SimplePooler
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
@ -115,3 +111,31 @@ class Qwen2ForRewardModel(nn.Module, SupportsLoRA, SupportsPP):
|
||||
loader = AutoWeightsLoader(self,
|
||||
ignore_unexpected_prefixes=["lm_head."])
|
||||
return loader.load_weights(weights)
|
||||
|
||||
|
||||
class Qwen2ForRewardModel(Qwen2RewardBaseModel):
|
||||
|
||||
def __init__(self, *, vllm_config, prefix=""):
|
||||
vllm_config.model_config.hf_config.num_labels = 1
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False)
|
||||
|
||||
|
||||
class Qwen2ForProcessRewardModel(Qwen2RewardBaseModel):
|
||||
|
||||
def __init__(self, *, vllm_config, prefix=""):
|
||||
vllm_config.model_config.hf_config.num_labels = 2
|
||||
super().__init__(vllm_config=vllm_config, prefix=prefix)
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.STEP,
|
||||
normalize=False,
|
||||
softmax=True,
|
||||
step_tag_id=151651,
|
||||
)
|
||||
|
@ -127,6 +127,7 @@ _EMBEDDING_MODELS = {
|
||||
"Qwen2Model": ("qwen2", "Qwen2EmbeddingModel"),
|
||||
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
|
||||
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
|
||||
"Qwen2ForProcessRewardModel": ("qwen2_rm", "Qwen2ForProcessRewardModel"),
|
||||
"TeleChat2ForCausalLM": ("telechat2", "TeleChat2ForCausalLM"),
|
||||
# [Multimodal]
|
||||
"LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501
|
||||
|
Reference in New Issue
Block a user