mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] fix olmoe model layer can't laod in tp gt 1 (#18828)
Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -22,7 +23,10 @@ from transformers import PretrainedConfig
|
||||
from vllm.attention import Attention
|
||||
from vllm.compilation.decorators import support_torch_compile
|
||||
from vllm.config import CacheConfig, VllmConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_gather)
|
||||
from vllm.distributed.utils import split_tensor_along_last_dim
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
|
||||
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
|
||||
self.tp_size = tp_size
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
|
||||
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
|
||||
eps=1e-5)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
def _apply_qk_norm(self, q: torch.Tensor,
|
||||
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
if self.tp_size > 1:
|
||||
q = tensor_model_parallel_all_gather(q.contiguous())
|
||||
k = tensor_model_parallel_all_gather(k.contiguous())
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if self.tp_size > 1:
|
||||
splitter = partial(split_tensor_along_last_dim,
|
||||
num_partitions=self.tp_size)
|
||||
q = splitter(q)[self.tp_rank]
|
||||
k = splitter(k)[self.tp_rank]
|
||||
return q, k
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
|
||||
q, k = self._apply_qk_norm(q, k)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
|
Reference in New Issue
Block a user