mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Support InternLM2 Reward models (#11571)
Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
@ -450,6 +450,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
|
||||
- Example HF Models
|
||||
- :ref:`LoRA <lora-adapter>`
|
||||
- :ref:`PP <distributed-serving>`
|
||||
* - :code:`InternLM2ForRewardModel`
|
||||
- InternLM2-based
|
||||
- :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc.
|
||||
- ✅︎
|
||||
- ✅︎
|
||||
* - :code:`LlamaForCausalLM`
|
||||
- Llama-based
|
||||
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.
|
||||
|
@ -140,6 +140,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
|
||||
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
|
||||
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
|
||||
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
|
||||
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
|
||||
trust_remote_code=True),
|
||||
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
|
||||
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
|
||||
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),
|
||||
|
@ -18,14 +18,16 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.pooler import Pooler, PoolingType
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.pooling_metadata import PoolingMetadata
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.sequence import IntermediateTensors, PoolerOutput
|
||||
|
||||
from .interfaces import SupportsLoRA, SupportsPP
|
||||
from .utils import (is_pp_missing_parameter,
|
||||
@ -433,3 +435,59 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
||||
weight_loader(param, loaded_weight)
|
||||
loaded_params.add(name)
|
||||
return loaded_params
|
||||
|
||||
|
||||
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
vllm_config: VllmConfig,
|
||||
prefix: str = "",
|
||||
model_type: Type[InternLM2Model] = InternLM2Model,
|
||||
):
|
||||
super().__init__(vllm_config=vllm_config,
|
||||
prefix=prefix,
|
||||
model_type=model_type)
|
||||
|
||||
for attr in ("output", "logits_processor", "sampler"):
|
||||
delattr(self, attr)
|
||||
|
||||
config = vllm_config.model_config.hf_config
|
||||
self.v_head = RowParallelLinear(
|
||||
config.hidden_size,
|
||||
1,
|
||||
bias=False,
|
||||
input_is_parallel=False,
|
||||
prefix=maybe_prefix(prefix, "v_head"),
|
||||
)
|
||||
|
||||
pooler_config = vllm_config.model_config.pooler_config
|
||||
self._pooler = Pooler.from_config_with_defaults(
|
||||
pooler_config,
|
||||
pooling_type=PoolingType.ALL,
|
||||
normalize=False,
|
||||
softmax=False,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
logits, _ = self.v_head(hidden_states)
|
||||
return logits
|
||||
|
||||
def pooler(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
pooling_metadata: PoolingMetadata,
|
||||
) -> Optional[PoolerOutput]:
|
||||
return self._pooler(hidden_states, pooling_metadata)
|
||||
|
@ -113,6 +113,7 @@ _EMBEDDING_MODELS = {
|
||||
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
|
||||
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
|
||||
"GritLM": ("gritlm", "GritLM"),
|
||||
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
|
||||
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
|
||||
"LlamaModel": ("llama", "LlamaForCausalLM"),
|
||||
**{
|
||||
|
Reference in New Issue
Block a user