[Misc] fix olmoe model layer can't laod in tp gt 1 (#18828)

Signed-off-by: rongfu.leng <rongfu.leng@daocloud.io>
This commit is contained in:
rongfu.leng
2025-05-29 01:36:21 +08:00
committed by GitHub
parent fced756923
commit c68b5c63eb

View File

@ -13,6 +13,7 @@
# limitations under the License.
"""Inference-only OLMoE model compatible with HuggingFace weights."""
from collections.abc import Iterable
from functools import partial
from typing import Any, Optional, Union
import torch
@ -22,7 +23,10 @@ from transformers import PretrainedConfig
from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_gather)
from vllm.distributed.utils import split_tensor_along_last_dim
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
@ -140,8 +144,11 @@ class OlmoeAttention(nn.Module):
bias=False,
quant_config=quant_config,
)
self.q_norm = RMSNorm(hidden_size, eps=1e-5)
self.k_norm = RMSNorm(hidden_size, eps=1e-5)
self.tp_size = tp_size
self.tp_rank = get_tensor_model_parallel_rank()
self.q_norm = RMSNorm(self.total_num_heads * self.head_dim, eps=1e-5)
self.k_norm = RMSNorm(self.total_num_kv_heads * self.head_dim,
eps=1e-5)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
@ -165,6 +172,20 @@ class OlmoeAttention(nn.Module):
quant_config=quant_config,
prefix=f"{prefix}.attn")
def _apply_qk_norm(self, q: torch.Tensor,
k: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm(q)
k = self.k_norm(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k
def forward(
self,
positions: torch.Tensor,
@ -172,7 +193,7 @@ class OlmoeAttention(nn.Module):
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self._apply_qk_norm(q, k)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v)
output, _ = self.o_proj(attn_output)