[Model] Support InternLM2 Reward models (#11571)

Signed-off-by: Isotr0py <2037008807@qq.com>
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
This commit is contained in:
Isotr0py
2024-12-28 14:14:10 +08:00
committed by GitHub
parent b5cbe8eeb3
commit d34be24bb1
4 changed files with 67 additions and 1 deletions

View File

@ -450,6 +450,11 @@ of the whole prompt are extracted from the normalized hidden state corresponding
- Example HF Models
- :ref:`LoRA <lora-adapter>`
- :ref:`PP <distributed-serving>`
* - :code:`InternLM2ForRewardModel`
- InternLM2-based
- :code:`internlm/internlm2-1_8b-reward`, :code:`internlm/internlm2-7b-reward`, etc.
- ✅︎
- ✅︎
* - :code:`LlamaForCausalLM`
- Llama-based
- :code:`peiyi9979/math-shepherd-mistral-7b-prm`, etc.

View File

@ -140,6 +140,8 @@ _EMBEDDING_EXAMPLE_MODELS = {
"BertModel": _HfExamplesInfo("BAAI/bge-base-en-v1.5"),
"Gemma2Model": _HfExamplesInfo("BAAI/bge-multilingual-gemma2"),
"GritLM": _HfExamplesInfo("parasail-ai/GritLM-7B-vllm"),
"InternLM2ForRewardModel": _HfExamplesInfo("internlm/internlm2-1_8b-reward",
trust_remote_code=True),
"JambaForSequenceClassification": _HfExamplesInfo("ai21labs/Jamba-tiny-reward-dev"), # noqa: E501
"LlamaModel": _HfExamplesInfo("llama", is_available_online=False),
"MistralModel": _HfExamplesInfo("intfloat/e5-mistral-7b-instruct"),

View File

@ -18,14 +18,16 @@ from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.pooler import Pooler, PoolingType
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors
from vllm.sequence import IntermediateTensors, PoolerOutput
from .interfaces import SupportsLoRA, SupportsPP
from .utils import (is_pp_missing_parameter,
@ -433,3 +435,59 @@ class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
weight_loader(param, loaded_weight)
loaded_params.add(name)
return loaded_params
class InternLM2ForRewardModel(InternLM2ForCausalLM):
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
model_type: Type[InternLM2Model] = InternLM2Model,
):
super().__init__(vllm_config=vllm_config,
prefix=prefix,
model_type=model_type)
for attr in ("output", "logits_processor", "sampler"):
delattr(self, attr)
config = vllm_config.model_config.hf_config
self.v_head = RowParallelLinear(
config.hidden_size,
1,
bias=False,
input_is_parallel=False,
prefix=maybe_prefix(prefix, "v_head"),
)
pooler_config = vllm_config.model_config.pooler_config
self._pooler = Pooler.from_config_with_defaults(
pooler_config,
pooling_type=PoolingType.ALL,
normalize=False,
softmax=False,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors] = None,
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
hidden_states = self.model(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
logits, _ = self.v_head(hidden_states)
return logits
def pooler(
self,
hidden_states: torch.Tensor,
pooling_metadata: PoolingMetadata,
) -> Optional[PoolerOutput]:
return self._pooler(hidden_states, pooling_metadata)

View File

@ -113,6 +113,7 @@ _EMBEDDING_MODELS = {
"Gemma2Model": ("gemma2", "Gemma2ForCausalLM"),
"GlmForCausalLM": ("glm", "GlmForCausalLM"),
"GritLM": ("gritlm", "GritLM"),
"InternLM2ForRewardModel": ("internlm2", "InternLM2ForRewardModel"),
"JambaForSequenceClassification": ("jamba", "JambaForSequenceClassification"), # noqa: E501
"LlamaModel": ("llama", "LlamaForCausalLM"),
**{