451 lines
17 KiB
Python
451 lines
17 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
|
|
from collections.abc import Iterable
|
|
from functools import partial
|
|
from itertools import islice
|
|
from typing import Any, Optional, Union
|
|
|
|
import torch
|
|
from torch import nn
|
|
from transformers import PretrainedConfig
|
|
|
|
from vllm.attention import Attention
|
|
from vllm.compilation.decorators import support_torch_compile
|
|
from vllm.config import CacheConfig, VllmConfig
|
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
|
get_tensor_model_parallel_world_size,
|
|
split_tensor_along_last_dim,
|
|
tensor_model_parallel_all_gather)
|
|
from vllm.model_executor.layers.activation import SiluAndMul
|
|
from vllm.model_executor.layers.layernorm import RMSNorm
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
|
|
QKVParallelLinear,
|
|
RowParallelLinear)
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
|
from vllm.model_executor.layers.pooler import DispatchPooler, Pooler
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
ParallelLMHead, VocabParallelEmbedding)
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
|
from vllm.sequence import IntermediateTensors
|
|
|
|
from .interfaces import SupportsLoRA, SupportsPP
|
|
from .interfaces_base import default_pooling_type
|
|
from .utils import (is_pp_missing_parameter,
|
|
make_empty_intermediate_tensors_factory, make_layers,
|
|
maybe_prefix)
|
|
|
|
|
|
class InternLM2MLP(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
intermediate_size: int,
|
|
hidden_act: str,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
hidden_size,
|
|
[intermediate_size] * 2,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.gate_up_proj",
|
|
)
|
|
self.w2 = RowParallelLinear(
|
|
intermediate_size,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.w2",
|
|
)
|
|
if hidden_act != "silu":
|
|
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
|
"Only silu is supported for now.")
|
|
self.act_fn = SiluAndMul()
|
|
|
|
def forward(self, x):
|
|
gate_up, _ = self.gate_up_proj(x)
|
|
x = self.act_fn(gate_up)
|
|
x, _ = self.w2(x)
|
|
return x
|
|
|
|
|
|
class InternLM2Attention(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
hidden_size: int,
|
|
num_heads: int,
|
|
num_kv_heads: int,
|
|
rope_theta: float = 10000,
|
|
rope_scaling: Optional[dict[str, Any]] = None,
|
|
max_position_embeddings: int = 8192,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = hidden_size
|
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
self.tp_rank = get_tensor_model_parallel_rank()
|
|
self.total_num_heads = num_heads
|
|
assert self.total_num_heads % self.tp_size == 0
|
|
self.num_heads = self.total_num_heads // self.tp_size
|
|
self.total_num_kv_heads = num_kv_heads
|
|
if self.total_num_kv_heads >= self.tp_size:
|
|
# Number of KV heads is greater than TP size, so we partition
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.total_num_kv_heads % self.tp_size == 0
|
|
else:
|
|
# Number of KV heads is less than TP size, so we replicate
|
|
# the KV heads across multiple tensor parallel GPUs.
|
|
assert self.tp_size % self.total_num_kv_heads == 0
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size)
|
|
self.head_dim = hidden_size // self.total_num_heads
|
|
self.q_size = self.num_heads * self.head_dim
|
|
self.kv_size = self.num_kv_heads * self.head_dim
|
|
self.key_value_groups = int(self.num_heads / self.num_kv_heads)
|
|
self.scaling = self.head_dim**-0.5
|
|
self.rope_theta = rope_theta
|
|
self.max_position_embeddings = max_position_embeddings
|
|
|
|
self.wqkv = QKVParallelLinear(
|
|
hidden_size,
|
|
self.head_dim,
|
|
self.total_num_heads,
|
|
self.total_num_kv_heads,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.wqkv",
|
|
)
|
|
self.wo = RowParallelLinear(
|
|
self.total_num_heads * self.head_dim,
|
|
hidden_size,
|
|
bias=False,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.wo",
|
|
)
|
|
|
|
self.rotary_emb = get_rope(
|
|
self.head_dim,
|
|
rotary_dim=self.head_dim,
|
|
max_position=max_position_embeddings,
|
|
base=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
)
|
|
self.attn = Attention(
|
|
self.num_heads,
|
|
self.head_dim,
|
|
self.scaling,
|
|
num_kv_heads=self.num_kv_heads,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attn",
|
|
)
|
|
|
|
def split_qkv(self, qkv: torch.Tensor):
|
|
seq_len = qkv.shape[0]
|
|
if self.tp_size > 1:
|
|
qkv_map = [self.q_size, self.kv_size, self.kv_size] * self.tp_size
|
|
qkv = tensor_model_parallel_all_gather(qkv)
|
|
qkv = torch.split(qkv, qkv_map, dim=-1)
|
|
qkv = qkv[::3] + qkv[1::3] + qkv[2::3]
|
|
qkv = torch.cat(qkv, dim=-1)
|
|
|
|
qkv = qkv.view(seq_len, self.total_num_kv_heads,
|
|
self.key_value_groups + 2, self.head_dim)
|
|
q, k, v = torch.split(qkv, [self.key_value_groups, 1, 1], dim=-2)
|
|
q = q.reshape(seq_len, self.q_size * self.tp_size)
|
|
k = k.reshape(seq_len, self.kv_size * self.tp_size)
|
|
v = v.reshape(seq_len, self.kv_size * self.tp_size)
|
|
|
|
if self.tp_size > 1:
|
|
splitter = partial(split_tensor_along_last_dim,
|
|
num_partitions=self.tp_size)
|
|
q = splitter(q)[self.tp_rank]
|
|
k = splitter(k)[self.tp_rank]
|
|
v = splitter(v)[self.tp_rank]
|
|
return q, k, v
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
qkv, _ = self.wqkv(hidden_states)
|
|
q, k, v = self.split_qkv(qkv)
|
|
q, k = self.rotary_emb(positions, q, k)
|
|
attn_output = self.attn(q, k, v)
|
|
output, _ = self.wo(attn_output)
|
|
return output
|
|
|
|
|
|
class InternLMDecoderLayer(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
config: PretrainedConfig,
|
|
cache_config: Optional[CacheConfig] = None,
|
|
quant_config: Optional[QuantizationConfig] = None,
|
|
prefix: str = "",
|
|
) -> None:
|
|
super().__init__()
|
|
self.hidden_size = config.hidden_size
|
|
rope_theta = getattr(config, "rope_theta", 10000)
|
|
rope_scaling = getattr(config, "rope_scaling", None)
|
|
max_position_embeddings = getattr(config, "max_position_embeddings",
|
|
8192)
|
|
self.attention = InternLM2Attention(
|
|
hidden_size=self.hidden_size,
|
|
num_heads=config.num_attention_heads,
|
|
num_kv_heads=config.num_key_value_heads,
|
|
rope_theta=rope_theta,
|
|
rope_scaling=rope_scaling,
|
|
max_position_embeddings=max_position_embeddings,
|
|
cache_config=cache_config,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.attention",
|
|
)
|
|
self.feed_forward = InternLM2MLP(
|
|
hidden_size=self.hidden_size,
|
|
intermediate_size=config.intermediate_size,
|
|
hidden_act=config.hidden_act,
|
|
quant_config=quant_config,
|
|
prefix=f"{prefix}.feed_forward",
|
|
)
|
|
self.attention_norm = RMSNorm(config.hidden_size,
|
|
eps=config.rms_norm_eps)
|
|
self.ffn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
|
|
def forward(
|
|
self,
|
|
positions: torch.Tensor,
|
|
hidden_states: torch.Tensor,
|
|
residual: Optional[torch.Tensor],
|
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
# Self Attention
|
|
if residual is None:
|
|
residual = hidden_states
|
|
hidden_states = self.attention_norm(hidden_states)
|
|
else:
|
|
hidden_states, residual = self.attention_norm(
|
|
hidden_states, residual)
|
|
hidden_states = self.attention(
|
|
positions=positions,
|
|
hidden_states=hidden_states,
|
|
)
|
|
|
|
# Fully Connected
|
|
hidden_states, residual = self.ffn_norm(hidden_states, residual)
|
|
hidden_states = self.feed_forward(hidden_states)
|
|
return hidden_states, residual
|
|
|
|
|
|
@support_torch_compile
|
|
class InternLM2Model(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
layer_type: type[InternLMDecoderLayer] = InternLMDecoderLayer):
|
|
super().__init__()
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
cache_config = vllm_config.cache_config
|
|
quant_config = vllm_config.quant_config
|
|
|
|
self.config = config
|
|
self.vocab_size = config.vocab_size
|
|
self.tok_embeddings = VocabParallelEmbedding(
|
|
config.vocab_size,
|
|
config.hidden_size,
|
|
)
|
|
self.start_layer, self.end_layer, self.layers = make_layers(
|
|
config.num_hidden_layers,
|
|
lambda prefix: layer_type(
|
|
config, cache_config, quant_config, prefix=prefix),
|
|
prefix=f"{prefix}.layers")
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
self.make_empty_intermediate_tensors = (
|
|
make_empty_intermediate_tensors_factory(
|
|
["hidden_states", "residual"], config.hidden_size))
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.tok_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
if get_pp_group().is_first_rank:
|
|
if inputs_embeds is not None:
|
|
hidden_states = inputs_embeds
|
|
else:
|
|
hidden_states = self.get_input_embeddings(input_ids)
|
|
residual = None
|
|
else:
|
|
assert intermediate_tensors is not None
|
|
hidden_states = intermediate_tensors["hidden_states"]
|
|
residual = intermediate_tensors["residual"]
|
|
for layer in islice(self.layers, self.start_layer, self.end_layer):
|
|
hidden_states, residual = layer(positions, hidden_states, residual)
|
|
if not get_pp_group().is_last_rank:
|
|
return IntermediateTensors({
|
|
"hidden_states": hidden_states,
|
|
"residual": residual
|
|
})
|
|
hidden_states, _ = self.norm(hidden_states, residual)
|
|
return hidden_states
|
|
|
|
|
|
class InternLM2ForCausalLM(nn.Module, SupportsPP, SupportsLoRA):
|
|
packed_modules_mapping = {
|
|
"wqkv": ["wqkv"],
|
|
"gate_up_proj": ["w1", "w3"],
|
|
}
|
|
|
|
def __init__(self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
model_type: type[InternLM2Model] = InternLM2Model):
|
|
super().__init__()
|
|
config = vllm_config.model_config.hf_config
|
|
quant_config = vllm_config.quant_config
|
|
lora_config = vllm_config.lora_config
|
|
|
|
self.config = config
|
|
self.quant_config = quant_config
|
|
self.lora_config = lora_config
|
|
|
|
self.model = model_type(vllm_config=vllm_config,
|
|
prefix=maybe_prefix(prefix, "model"))
|
|
self.output = ParallelLMHead(config.vocab_size,
|
|
config.hidden_size,
|
|
quant_config=quant_config,
|
|
prefix=maybe_prefix(prefix, "output"))
|
|
if self.config.tie_word_embeddings:
|
|
self.output.weight = self.model.tok_embeddings.weight
|
|
self.logits_processor = LogitsProcessor(config.vocab_size)
|
|
self.make_empty_intermediate_tensors = (
|
|
self.model.make_empty_intermediate_tensors)
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
return self.model.get_input_embeddings(input_ids)
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors],
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
return hidden_states
|
|
|
|
def compute_logits(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> Optional[torch.Tensor]:
|
|
logits = self.logits_processor(self.output, hidden_states,
|
|
sampling_metadata)
|
|
return logits
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str,
|
|
torch.Tensor]]) -> set[str]:
|
|
stacked_params_mapping = [
|
|
# (param_name, shard_name, shard_id)
|
|
("gate_up_proj", "w1", 0),
|
|
("gate_up_proj", "w3", 1),
|
|
]
|
|
params_dict = dict(self.named_parameters())
|
|
loaded_params: set[str] = set()
|
|
for name, loaded_weight in weights:
|
|
if "rotary_emb.inv_freq" in name:
|
|
continue
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
|
if weight_name not in name:
|
|
continue
|
|
name = name.replace(weight_name, param_name)
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = param.weight_loader
|
|
weight_loader(param, loaded_weight, shard_id)
|
|
break
|
|
else:
|
|
# Skip loading extra bias for GPTQ models.
|
|
if name.endswith(".bias") and name not in params_dict:
|
|
continue
|
|
if is_pp_missing_parameter(name, self):
|
|
continue
|
|
param = params_dict[name]
|
|
weight_loader = getattr(param, "weight_loader",
|
|
default_weight_loader)
|
|
weight_loader(param, loaded_weight)
|
|
loaded_params.add(name)
|
|
return loaded_params
|
|
|
|
|
|
@default_pooling_type("ALL")
|
|
class InternLM2ForRewardModel(InternLM2ForCausalLM):
|
|
|
|
is_pooling_model = True
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
vllm_config: VllmConfig,
|
|
prefix: str = "",
|
|
model_type: type[InternLM2Model] = InternLM2Model,
|
|
):
|
|
super().__init__(vllm_config=vllm_config,
|
|
prefix=prefix,
|
|
model_type=model_type)
|
|
|
|
for attr in ("output", "logits_processor"):
|
|
delattr(self, attr)
|
|
|
|
config = vllm_config.model_config.hf_config
|
|
self.v_head = RowParallelLinear(
|
|
config.hidden_size,
|
|
1,
|
|
bias=False,
|
|
input_is_parallel=False,
|
|
prefix=maybe_prefix(prefix, "v_head"),
|
|
)
|
|
|
|
pooler_config = vllm_config.model_config.pooler_config
|
|
assert pooler_config is not None
|
|
|
|
self.pooler = DispatchPooler(
|
|
{"encode": Pooler.for_encode(pooler_config)}, )
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
positions: torch.Tensor,
|
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
|
inputs_embeds: Optional[torch.Tensor] = None,
|
|
) -> Union[torch.Tensor, IntermediateTensors]:
|
|
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
|
inputs_embeds)
|
|
logits, _ = self.v_head(hidden_states)
|
|
return logits
|