Files
MindSpeed-RL/mindspeed_rl/workers/resharding/vllm_weight_container.py
xiecheng c722ee7795 !587 更新32b长序列配置
Merge pull request !587 from xiecheng/master
2025-09-11 01:34:40 +00:00

745 lines
38 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright (c) 2025, HUAWEI CORPORATION. All rights reserved.
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains a Megatron style Hybrid Model that shares the weights of the actor with the inference engine.
"""
import hashlib
import re
from functools import partial
from typing import List
import torch
import torch.distributed as dist
import numpy as np
from torch.distributed import new_group
import vllm.distributed.parallel_state as ps
from mindspeed_rl.workers.resharding.memory_buffer import build_model_weight_buffer, calc_padded_numel
import mindspeed_rl.workers.resharding.utils
from mindspeed_rl.workers.resharding.utils import get_tensor_parallel_partition_dim, \
_build_infer_param_dict, get_tp_allgather_group, \
get_tp_allgather_world_size, is_tensor_parallel_param, get_tp_group, is_fake_tp_param
from mindspeed_rl.utils.loggers import Loggers
from mindspeed_rl.utils.utils import is_multimodal
logger = Loggers(__name__)
class MegatronStyleVllmWeightContainer:
def __init__(self, megatron_model, vllm_model, model_config, infer_tensor_parallel_size,
infer_pipeline_parallel_size,
infer_expert_parallel_size,
num_layer_list,
moe_tp_extend_ep=False,
parallel_state=None,
weight_adaptor=None,
noop_layers=None,
eplb_map=None,
global_redundant_expert_num=0,
infer_local_num_experts=-1,) -> None:
""" Megatron style vllm weight container.
Arguments:
megatron_model (nn.Module or nn.ModuleList): Megatron model used for training
vllm_model (nn.Module or nn.ModuleList): VLLM model used for inference
model_config (MegatronConfig): Model configuration
infer_tensor_parallel_size (int): Inference tensor parallel size
infer_pipeline_parallel_size (int): Inference pipeline parallel size
infer_expert_parallel_size (int): Inference expert parallel size
num_layer_list (str): a list of number of layers, seperated by comma; e.g., 4,4,4,4.
moe_tp_extend_ep (bool): Controls whether expert model parameters are split across multiple GPUs.
parallel_state (ModuleType): Megatron parallel state of the model.
weight_adaptor (WeightAdaptor): Provides a set of tools to transfer from training weight to inference weight.
"""
self.vllm_model = vllm_model
self.model_config = model_config
self.megatron_model = megatron_model
self.parallel_state = parallel_state
self.weight_adaptor = weight_adaptor
# 通过tokenier路径下的config.json获取hf的模型
self._num_hidden_layers = self.model_config.num_hidden_layers
self._noop_layers = None
if hasattr(self.model_config, "first_k_dense_replace"):
self._first_k_dense = self.model_config.first_k_dense_replace
else:
self._first_k_dense = 0
#真正的moe层数
self._moe_layers = self._num_hidden_layers - self._first_k_dense
if noop_layers is not None:
self._noop_layers = [int(layer_idx) for layer_idx in noop_layers.split(',')]
self._num_hidden_layers += len(self._noop_layers)
# EPLB
self.eplb_map = eplb_map
self.global_redundant_expert_num = global_redundant_expert_num
self.infer_local_num_experts = infer_local_num_experts
self.per_rank_redundant_expert_num = self.global_redundant_expert_num // infer_expert_parallel_size
# self.global_redundant_expert_num总的冗余专家个数 self.infer_local_num_experts一个卡上的专家数带冗余 每张卡上的荣誉专家数
# pp configs
self._pp_rank = self.parallel_state.get_pipeline_model_parallel_rank()
self._pp_group = self.parallel_state.get_pipeline_model_parallel_group()
self._pp_size = self.parallel_state.get_pipeline_model_parallel_world_size()
self._world_size = dist.get_world_size()
## vpp
self._num_layer_list = self._build_num_layer_list(num_layer_list)
self._vpp_rank = self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK if self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK else 0
self._vpp_size = self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE if self.parallel_state._VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE else 1
self._vpp_layer_list = self._build_vpp_layer_list(self._num_layer_list)
## _noop_layers
self._global2local_map = self._build_global2local_map(self._vpp_layer_list, self._vpp_size, self._noop_layers) if self._noop_layers is not None else None
# tp configs
self._tp_size = self.parallel_state.get_tensor_model_parallel_world_size()
self._tp_group = self.parallel_state.get_tensor_model_parallel_group()
# ep configs
self._ep_size = self.parallel_state.get_expert_model_parallel_world_size()
if moe_tp_extend_ep:
self._ep_group = self.parallel_state.get_tensor_and_expert_parallel_group()
self._ep_size = self._tp_size * self._ep_size
else:
self._ep_group = self.parallel_state.get_expert_model_parallel_group()
if hasattr(self.model_config, "n_routed_experts"):
self.num_experts = self.model_config.n_routed_experts
self.num_local_experts = self.num_experts // self._ep_size
elif hasattr(self.model_config, "num_experts"):
self.num_experts = self.model_config.num_experts
self.num_local_experts = self.num_experts // self._ep_size
# infer configs
self._infer_tp_size = infer_tensor_parallel_size
self._infer_pp_size = infer_pipeline_parallel_size
self._infer_ep_size = infer_expert_parallel_size
self._infer_ep_size_raw = infer_expert_parallel_size
self.moe_tp_extend_ep = moe_tp_extend_ep
# TODO: infer_expert_tensor_parallel_size and num_process is fixed.
self.infer_expert_tensor_parallel_size = 1
self.num_process = 1
self._infer_ep_size = self._infer_ep_size * self._infer_tp_size
self.experts_memory_expand_N = self._infer_ep_size // self._ep_size
# validate parallel configs
self._validate_parallel_config()
self._rank = dist.get_rank()
self._init_tensor_model_parallel_allgather_group()
self._init_pipeline_model_parallel_allgather_group()
self._init_tensor_model_parallel_split_group()
self._init_weight_buffers()
def _validate_parallel_config(self):
if self._infer_pp_size != 1:
raise ValueError("infer_pp_size != 1 not supported yet")
if self._infer_ep_size % self._ep_size != 0:
raise ValueError("The training expert size should be divisibled by the inference expert size.")
if self._ep_size > 1 and not self.moe_tp_extend_ep:
raise ValueError("To enable training EP, you need to enable moe_tp_extend_ep and use GroupedMLP.")
if self._pp_size < self._infer_pp_size:
raise ValueError(
"The training pipeline parallel size should be greater than or equal to the inference pipeline "
"parallel size.")
if self._pp_size % self._infer_pp_size != 0:
raise ValueError(
"The training pipeline parallel size should be an integer multiple of the inference pipeline parallel "
"size.")
if self._tp_size > self._infer_tp_size and self._tp_size % self._infer_tp_size != 0:
raise ValueError(
"The training tensor parallel size should be an integer multiple of the inference tensor parallel size.")
# For tp increase, train_tp * dp >= infer_tp, train_tp * dp % infer_tp == 0
if self._tp_size < self._infer_tp_size:
if (self._world_size // self._pp_size < self._infer_tp_size or
(self._world_size // self._pp_size) % self._infer_tp_size != 0):
raise ValueError(
f"Do not support split train tp size {self._tp_size} to infer tp size {self._infer_tp_size} "
f"with train dp size {(self._world_size // (self._tp_size * self._pp_size))}.")
def get_infer_params(self):
"""
return the whole weight state dict for vllm, but in megatron style and names,
needs megatron weight loader to further transfer for vllm
"""
self._update_weight_buffers_intra_pp()
self._update_weight_buffers_inter_pp()
# 执行_update_weight_buffers_ep+_send_receive_experts的前提条件
if(self.moe_tp_extend_ep and self._infer_ep_size >= self._ep_size):
self._update_weight_buffers_ep()
self._send_receive_experts()
# 复用TP-EP建组部分
# 获得初始化映射表 eplb_map_initial
if self.eplb_map is not None:
tensor_model_parallel_size = self._infer_tp_size
expert_model_parallel_size = self._infer_ep_size_raw
tensor_and_expert_group_size = tensor_model_parallel_size * expert_model_parallel_size
# npu数量总专家数一个卡上的专家数专家层数
eplb_map_initial = self.create_initial_map(self._world_size, self.num_experts, int(self.num_experts / tensor_and_expert_group_size), self._moe_layers)
cur_rank = dist.get_rank()
eplb_map_initial = eplb_map_initial.to(self.eplb_map.device)
self._send_receive_redundancy_experts(eplb_map_initial, self.eplb_map, tensor_and_expert_group_size)
params = self._get_all_params()
params = _build_infer_param_dict(params=params)
return params
def create_initial_map(self, world_size: int, num_experts: int, EP_group_num: int, moe_layer: int):
# 创建设备列表
eplb_map_initial = torch.zeros((moe_layer, world_size, EP_group_num), dtype=torch.int32)
# 计算每个设备的专家分配(所有层共享相同的分配)
per_device_experts = []
for device_id in range(world_size):
start_expert = (device_id * EP_group_num) % num_experts
experts = list(range(start_expert, start_expert + EP_group_num))
per_device_experts.append(experts)
for layer_idx in range(moe_layer):
for device_id in range(world_size):
# 将专家列表赋值到张量的相应位置
eplb_map_initial[layer_idx, device_id] = torch.tensor(per_device_experts[device_id], dtype=torch.int32)
return eplb_map_initial
def _send_receive_redundancy_experts(self, eplb_map_initial, eplb_map, tensor_and_expert_group_size):
# 当前进程的 rank全局编号用于决定收发信息的节点。 作为目标节点
cur_rank = dist.get_rank()
# 一个npu上有几个ep
ep_group = int(self.num_experts / tensor_and_expert_group_size)
for layer_id in range(eplb_map.size(0)):
for cur_pp_rank in range(self._pp_size):
for memory_buffer in self.weight_buffers[cur_pp_rank].memory_buffers.values():
for name in sorted(memory_buffer.tensor_indices.keys()):
op_list = []
layer_id_replace = layer_id + self._first_k_dense
# 筛选专家权重
if "mlp.experts" in name and f"layers.{layer_id_replace}." in name:
tensor_to_send = memory_buffer.get_by_name(name)
tensor_to_replace = torch.empty_like(tensor_to_send)
ex_group_initial_rank = cur_rank // tensor_and_expert_group_size * tensor_and_expert_group_size
#======收=======
for idx, target_expert in enumerate(eplb_map[layer_id, cur_rank, :]):
# 前面这部分是目标节点所在tp-ep建组的第一个npu节点 该专家的源节点rank
src_rank = ex_group_initial_rank + target_expert // ep_group
# 目标节点需要的专家 在源节点的位置
src_idx = target_expert % ep_group
if src_rank == cur_rank:
# 同一个rank不需要通信直接复制
tensor_to_replace[idx] = tensor_to_send[src_idx]
else:
recv_op = dist.P2POp(dist.irecv, tensor_to_replace[idx], src_rank, tag=idx)
op_list.append(recv_op)
#=======发======
# 遍历所在建组,即所在通信组
for dst_rank in range(ex_group_initial_rank, ex_group_initial_rank + tensor_and_expert_group_size):
# 所在通信组的目标专家表
for idx, target_expert in enumerate(eplb_map[layer_id, dst_rank, :]):
if target_expert in eplb_map_initial[layer_id, cur_rank, :]:
if dst_rank != cur_rank:
send_op = dist.P2POp(dist.isend, tensor_to_send[target_expert % ep_group], dst_rank, tag=idx)
op_list.append(send_op)
# 通信
if op_list:
reqs = dist.batch_isend_irecv(op_list)
for req in reqs:
req.wait()
memory_buffer.copy_by_name(name, tensor_to_replace)
def _build_num_layer_list(self, num_layer_list):
if num_layer_list:
# multimodal num_layer_list is a list of lists, including vit and llm layers
if isinstance(num_layer_list[0], List):
return num_layer_list
return [int(num_layers) for num_layers in num_layer_list.split(',')]
if self._num_hidden_layers % self._pp_size != 0:
raise ValueError("num_layers % pp_size == 0, please specify num_layer_list")
return [self._num_hidden_layers // self._pp_size for _ in range(self._pp_size)]
def _build_vpp_layer_list(self, num_layer_list):
if self._vpp_size <= 1:
return num_layer_list
for layers_in_pp_rank in num_layer_list:
if layers_in_pp_rank % self._vpp_size != 0:
raise ValueError("num_layers_per_pp % vpp_size != 0, please specify pp_size and vpp_size")
return [int(layers_in_pp_rank / self._vpp_size) for layers_in_pp_rank in num_layer_list]
def _build_global2local_map(self, layer_list, vpp_size, noop_layers):
stage_layers_num = sum(layer_list)
glb2local_map = []
for vpp_rank in range(vpp_size):
start_layer = vpp_rank * stage_layers_num
for _, layers_in_vpp_rank in enumerate(layer_list):
layer_idx_list = [
layer_idx for layer_idx in range(start_layer, start_layer + layers_in_vpp_rank)
if layer_idx not in noop_layers
]
glb2local_map += [layer_idx % layers_in_vpp_rank for layer_idx in layer_idx_list]
start_layer += layers_in_vpp_rank
return glb2local_map
def _unwrap_megatron_model(self, model):
"""
Remove consecutive 'module.' prefixes from the model based on the state_dict's first key.
This method only removes 'module.' from the beginning of the key and ignores other occurrences.
"""
unwraped_model = []
for model_chunk in model:
first_key = list(dict(model_chunk.named_parameters()).keys())[0]
while first_key.startswith("module."):
model_chunk = model_chunk.module
first_key = first_key[len("module."):]
unwraped_model.append(model_chunk)
return unwraped_model
def _init_weight_buffers(self):
"""
Build buffers from vllm state dict. Totally build train pp_size buffers, each buffer corresponds to a pack of megatron weight.
Return a list of buffers, and a reference dict megatron_param_name->buffer.
"""
# 获取每个pp内部的weights name
vllm_names = list(dict(self.vllm_model.named_parameters()).keys())
if is_multimodal():
layers_num = [sum(num_layer_list) for num_layer_list in self._num_layer_list]
else:
layers_num = sum(self._num_layer_list)
self.weight_names_per_pp = self.weight_adaptor.get_weight_names_per_pp(self._vpp_layer_list, vllm_names,
layers_num, self._vpp_size, self._noop_layers)
self.weight_buffers = build_model_weight_buffer(self.vllm_model, self.weight_names_per_pp,
self.weight_adaptor.get_weight_buffer_meta
)
def trans_ep_params_to_tp(self, megatron_param, name):
"""
Transfer a GroupedMLP from EP to TP. Currently, assert EP==TP.
e.g. EP=2 -> TP=2
Assume we have 4 experts in total.
We here note e0 for expert 0, and [a0, b0] for the tensor parallel weight for expert 0,
so we can denote first half weights for all the 4 experts as a0-4 .
For EP to TP transfer, what we actually need to do is:
[[e0-1], [e2-3]] -> [[a0-4], [b0-4]]
We first build a matrix, each column is a rank before transfer, and each row is a rank after transfer.
ep0 ep1
[
a0-4 a0-1, a2-3,
b0-4 b0-1, b2-3,
]
When we get this matrix, we only need to do All2All to transfer EP to TP on the EP group.
So, for ep_rank 0 we need to build [a0-1, b0-1] from [e0-1], i.e.
[e0-1] <=> [e0, e1] <=> [a0, a1, b0, b1] -> [a0, a1, b0, b1] <=> [a0-1, b0-1]
For DSv3 model, this function only handles decoder.layers.x.mlp.experts.weight1 and
decoder.layers.x.mlp.experts.weight2.
In which, weight 1 is cut by column and contains both gate and up;
and weight 2 is cut by row.
"""
# the true ep size, equal to ep_size * tp_size when tp_extend_ep
if self._ep_size == 1:
return megatron_param
tp_size = self._infer_tp_size
num_experts = self.num_local_experts
# 1. build ep_tp matrix buffer
# For megatron param [e0, e1], we make it [a0, a1, b0, b1], in which e0 == [a0, b0]
# weight1: column cut, be like [g0, u0, g1, u1, ...]
if 'weight1' in name:
hidden_size = megatron_param.shape[0]
megatron_param = torch.cat(megatron_param.view(num_experts, hidden_size, -1).unbind(0), dim=1)
# We can treat both the gate and the up weight as 2 independent experts.
num_experts *= 2
# weight2: row cut, be like [ d0, d1, d2, ...]^T
elif 'weight2' in name:
hidden_size = megatron_param.shape[1]
megatron_param = torch.cat(megatron_param.view(num_experts, -1, hidden_size).unbind(0), dim=0)
# transpose params to handle uniformly with column cut
megatron_param = megatron_param.t()
else:
return megatron_param
# chunk to tp * ep parts
chunks = torch.chunk(megatron_param, tp_size * num_experts, dim=1)
# re-select by tp-ep order
# e.g. TP=2 num_experts=4, old order [1,2,3,4,5,6,7,8], new order [1,3,5,7,2,4,6,8]
new_order = []
for i in range(tp_size):
for j in range(num_experts):
new_order.append(chunks[i + j * tp_size])
reordered_x = torch.cat(new_order, dim=1)
final_chunks = torch.chunk(reordered_x, tp_size, dim=1)
# 2. do AlltoAll communication
input_tensor_list = [chunk.contiguous() for chunk in final_chunks]
output_tensor_list = [torch.empty_like(chunk) for chunk in input_tensor_list]
torch.distributed.all_to_all(
output_tensor_list,
input_tensor_list,
group=self._ep_group,
async_op=False
)
total_experts = self.num_local_experts * tp_size
res = torch.cat(output_tensor_list, dim=1).reshape(hidden_size, total_experts, -1)
if 'weight2' in name:
return res.permute(1, 2, 0).contiguous()
return res.permute(1, 0, 2).contiguous()
def _update_weight_buffers_intra_pp(self):
"""
Here, we only update the current training pp_rank's buffer.
"""
def _transfer_from_megatron_division(megatron_param, name):
"""
Deal with the tp_param form train_tp to infer_tp.
"""
infer_param = self.allgather_tp_param(megatron_param, name)
infer_param = self.split_tp_params(infer_param, name)
infer_param = self.trans_ep_params_to_tp(infer_param, name)
return infer_param
pp_rank = self._pp_rank
weight_names = self.weight_names_per_pp[pp_rank]
weight_names_meta = self.weight_adaptor.convert_weight_name_meta(weight_names)
true_megatron_model = self._unwrap_megatron_model(self.megatron_model)
normal_layer_func = partial(self.weight_adaptor.global2local_layer, num_layer_list=self._vpp_layer_list, global2local_map=self._global2local_map)
name_pairs = sorted(list(set([(name, vpp_rank, self.weight_adaptor.replace_name_i2t(normal_layer_func(name, vpp_rank=vpp_rank)))
for vpp_rank, names_per_vpp in enumerate(weight_names_meta) for name in names_per_vpp])))
# 检查 linear_fc1 和 linear_fc2 权重形状是否符合特定关系fc1 包含门控和扩展参数,因此大小是 fc2 的两倍)。不符合条件的模型不被支持。
for _, vpp_rank, megatron_name in name_pairs:
if not megatron_name.startswith("image_encoder") and megatron_name.endswith("linear_fc1.weight"):
fc2_name = megatron_name.replace("linear_fc1", "linear_fc2")
megatron_param_fc1 = dict(true_megatron_model[vpp_rank].named_parameters())[megatron_name]
megatron_param_fc2 = dict(true_megatron_model[vpp_rank].named_parameters())[fc2_name]
if megatron_param_fc1.shape[0] * megatron_param_fc1.shape[1] != megatron_param_fc2.shape[0] * \
megatron_param_fc2.shape[1] * 2:
raise ValueError("Only implemented for Llama model which linear_fc1 contains gate and up params.")
weight_buffer = self.weight_buffers[pp_rank]
megatron_params_dict = {}
for vpp_rank in range(self._vpp_size):
megatron_params_dict.update({vpp_rank: dict(true_megatron_model[vpp_rank].named_buffers())})
megatron_params_dict[vpp_rank].update(true_megatron_model[vpp_rank].named_parameters())
megatron_params_dict[vpp_rank] = self.weight_adaptor.adjust_megatron_param_dict(megatron_params_dict[vpp_rank], self._tp_size)
for hf_name, vpp_rank, megatron_name in name_pairs:
if((self._infer_ep_size > 1 or self._ep_size > 1) and "mlp.experts" in megatron_name):
pass
else:
megatron_param = megatron_params_dict[vpp_rank][megatron_name]
param = _transfer_from_megatron_division(megatron_param, megatron_name)
weight_buffer.copy_by_name(hf_name, param)
def _update_weight_buffers_ep(self):
# 构造临时的experts_memory_buffers
for cur_pp_rank in range(self._pp_size):
pp_rank = self._pp_rank
from mindspeed_rl.workers.resharding.memory_buffer import build_experts_memory_buffer, get_weight_buffer_meta_from_buffer
# Step1 在当前的PP_rank中设置一个临时的exprts_buffer
combined_names_per_pp = []
vpp_stages = self.weight_names_per_pp[cur_pp_rank]
for weight_names_per_stage in vpp_stages:
combined_names_per_pp.extend(weight_names_per_stage)
self.weight_buffer_meta = self.weight_adaptor.get_weight_buffer_meta(self.vllm_model, combined_names_per_pp)
self.experts_weight_buffer_meta = get_weight_buffer_meta_from_buffer(self.weight_buffer_meta)
self.experts_memory_buffers = build_experts_memory_buffer(self.experts_weight_buffer_meta, self.experts_memory_expand_N)
expert_pernode = int(self.num_experts / self._infer_ep_size_raw / self._infer_tp_size)
# Step2 将weights_buffer上对应的权重放到experts_buffer中
if(cur_pp_rank == pp_rank):
weight_names = self.weight_names_per_pp[pp_rank]
weight_names_meta = self.weight_adaptor.convert_weight_name_meta(weight_names)
normal_layer_func = partial(self.weight_adaptor.global2local_layer, num_layer_list=self._vpp_layer_list, global2local_map=self._global2local_map)
name_pairs = sorted(list(set([(name, vpp_rank, self.weight_adaptor.replace_name_i2t(normal_layer_func(name, vpp_rank=vpp_rank)))
for vpp_rank, names_per_vpp in enumerate(weight_names_meta) for name in names_per_vpp])))
true_megatron_model = self._unwrap_megatron_model(self.megatron_model)
megatron_params_dict = {}
# 拿到当前pp的所有权重
for vpp_rank in range(self._vpp_size):
megatron_params_dict.update({vpp_rank: dict(true_megatron_model[vpp_rank].named_buffers())})
megatron_params_dict[vpp_rank].update(true_megatron_model[vpp_rank].named_parameters())
megatron_params_dict[vpp_rank] = self.weight_adaptor.adjust_megatron_param_dict(megatron_params_dict[vpp_rank], self._tp_size)
for hf_name, vpp_rank, megatron_name in name_pairs:
if((self._infer_ep_size > 1 or self._ep_size > 1) and "mlp.experts" in megatron_name):
megatron_param = megatron_params_dict[vpp_rank][megatron_name]
dtype = self.experts_weight_buffer_meta[hf_name]['dtype']
self.experts_memory_buffers[dtype].copy_by_name_smallershape(hf_name, megatron_param, expert_pernode)
# Step3 后续的操作可以复用
global_src = dist.get_global_rank(group=self._pp_group, group_rank=cur_pp_rank)
# broadcast专家权重experts memory buffer中的
for dtype, experts_memory_buffer in self.experts_memory_buffers.items():
dist.broadcast(tensor=experts_memory_buffer.data, src=global_src, group=self._pp_group, async_op=False)
ep_expand_rank = self._rank // self._ep_size
# 获取对应的dtype
for name, tensor_indices_value in sorted(experts_memory_buffer.tensor_indices.items()):
shape = tensor_indices_value[1] # 是*N的
index = ep_expand_rank % self.experts_memory_expand_N
experts_tensor = experts_memory_buffer.get_by_name(name)
experts_tensor_reshape = experts_tensor.view(shape)
weight_tensor_infer = experts_tensor_reshape[index]
self.weight_buffers[cur_pp_rank].copy_by_name(name, weight_tensor_infer)
# 卸载专家的buffer
experts_memory_buffer = None
self.experts_memory_buffers[dtype] = None
for memory_buffer in self.experts_memory_buffers.values():
memory_buffer = None
self.experts_memory_buffers = None
def _update_weight_buffers_inter_pp(self):
"""
Update weight buffers by gathering weights from other pp stage.
"""
for cur_pp_rank in range(self._pp_size):
global_src = dist.get_global_rank(group=self._pp_group, group_rank=cur_pp_rank)
for memory_buffer in self.weight_buffers[cur_pp_rank].memory_buffers.values():
dist.broadcast(tensor=memory_buffer.data, src=global_src, group=self._pp_group, async_op=False)
def get_expert_router(self, cur_rank, train_tp_ep_size, infer_tp_ep_size, world_size):
for tp_ep_group_id in range(world_size // infer_tp_ep_size):
tp_ep_group = [i for i in range(tp_ep_group_id * infer_tp_ep_size, (tp_ep_group_id + 1) * infer_tp_ep_size)]
if cur_rank in tp_ep_group:
self.INFER_TP_EP_GROUP = tp_ep_group
stride = infer_tp_ep_size // train_tp_ep_size
dev_array = np.array(self.INFER_TP_EP_GROUP).reshape(stride, train_tp_ep_size)
src_router = np.squeeze(dev_array.transpose().reshape(1, infer_tp_ep_size)).tolist()
src = src_router[cur_rank % infer_tp_ep_size]
dst = self.INFER_TP_EP_GROUP[src_router.index(cur_rank)]
return src, dst
def _send_receive_experts(self):
cur_rank = dist.get_rank()
src_rank, dst_rank = self.get_expert_router(cur_rank, self._ep_size, self._infer_ep_size, self._world_size)
for cur_pp_rank in range(self._pp_size):
for memory_buffer in self.weight_buffers[cur_pp_rank].memory_buffers.values():
for name in sorted(memory_buffer.tensor_indices.keys()):
if "mlp.experts" in name:
# 做收发
tensor_to_send = memory_buffer.get_by_name(name)
tensor_to_replace = torch.empty_like(tensor_to_send)
send_op = dist.P2POp(dist.isend, tensor_to_send, dst_rank)
recv_op = dist.P2POp(dist.irecv, tensor_to_replace, src_rank)
reqs = dist.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
memory_buffer.copy_by_name(name, tensor_to_replace)
def _get_all_params(self):
"""Get all the parameters of the models in all pp ranks
Returns:
params: List[List[Dict[str, Tensor]]]: a list of parameters in all pp, where each is a list of dict
tensors of each model chunk
"""
params = []
for pp_rank in range(self._pp_size):
params.append([])
params[pp_rank].append({})
model_chunk_idx = 0
weight_buffer = self.weight_buffers[pp_rank]
for name in weight_buffer.weight_names:
if 'lora' in name:
raise ValueError("not support lora now")
params[pp_rank][model_chunk_idx][name] = weight_buffer[name]
return params
def _init_tensor_model_parallel_allgather_group(self):
if self._tp_size < self._infer_tp_size:
return
if self._tp_size % self._infer_tp_size != 0:
raise ValueError("self._tp_size must be divisible by self._infer_tp_size")
tp_allgather_size = self._tp_size
if mindspeed_rl.workers.resharding.utils._TP_ALLGATHER_GROUP is not None:
raise RuntimeError("Group for allgather tensor model parallel weight is already initialized")
num_groups = self._world_size // tp_allgather_size
for i in range(num_groups):
ranks = range(i * tp_allgather_size, (i + 1) * tp_allgather_size)
group = new_group(ranks=ranks)
if self._rank in ranks:
mindspeed_rl.workers.resharding.utils._TP_ALLGATHER_GROUP = group
def _init_pipeline_model_parallel_allgather_group(self):
if self._pp_size < self._infer_pp_size:
raise NotImplementedError("Not implemented for infer_pp > train_pp")
if self._pp_size % self._infer_pp_size != 0:
raise ValueError(
"Pipeline model parallel size must be a multiple of inference pipeline model parallel size")
pp_allgather_size = self._pp_size // self._infer_pp_size
if mindspeed_rl.workers.resharding.utils._PP_ALLGATHER_GROUP is not None:
raise RuntimeError("Group for allgather pipeline model parallel weight is already initialized")
global_pp_group_ranks_list = []
for pp_group_index in range(self._world_size // self._pp_size):
self_pp_group_ranks_list = []
for ranks in range(pp_group_index, pp_group_index + self._world_size, self._world_size // self._pp_size):
self_pp_group_ranks_list.append(ranks)
global_pp_group_ranks_list.append(self_pp_group_ranks_list)
for pp_group_ranks in global_pp_group_ranks_list:
splited_pp_group_ranks = np.array_split(pp_group_ranks, self._infer_pp_size)
for ranks in splited_pp_group_ranks:
cur_group = new_group(ranks=ranks)
if self._rank in ranks:
mindspeed_rl.workers.resharding.utils._PP_ALLGATHER_GROUP = cur_group
def _init_tensor_model_parallel_split_group(self):
if self._tp_size >= self._infer_tp_size:
return
if self._infer_tp_size % self._tp_size != 0:
raise ValueError("self._infer_tp_size must be a multiple of self._tp_size")
if mindspeed_rl.workers.resharding.utils._TP_GROUP is not None:
raise RuntimeError("Group for tensor model parallel weight is already initialized")
if self._infer_tp_size > self._tp_size:
mindspeed_rl.workers.resharding.utils._TP_GROUP = self.parallel_state.get_tensor_model_parallel_group()
def _default_tp_concat_fn(self, name, param, infer_params):
"""
name: name of the parameter
param: training_utils parameters
infer_params (List[torch.Tensor]): a list of parameters all-gathered from micro_dp_group
definition so that it is model-agnostic. If the model doesn't implement this function,
we can throw an error to force user disable TP HybridEngine.
"""
if 'projector' not in name and 'linear_fc1' in name:
# if the tensor is gate and proj
gate_lst = []
up_lst = []
for infer_param in infer_params:
gate, up = infer_param.chunk(2)
gate_lst.append(gate)
up_lst.append(up)
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
infer_params = torch.cat((gate, up), dim=0)
else:
# concat tensor
infer_params = torch.cat(infer_params, dim=get_tensor_parallel_partition_dim(param))
return infer_params
def split_tp_params(self, param, name):
"""
name: name of the parameter
param: training_utils parameters
1. get full train params through allgather
2. split train_tp params into groups (size: infer_tp_size)
3. return the corresponding param from group based on infer tp rank
"""
if self._infer_tp_size <= self._tp_size or is_fake_tp_param(name, self.moe_tp_extend_ep):
return param
tp_group = get_tp_group()
if is_tensor_parallel_param(param):
if self._tp_size > 1:
# allocate a new tensor with proper size
infer_params = [torch.empty_like(param) for _ in range(self._tp_size)]
torch.distributed.all_gather(infer_params, param, group=tp_group)
else:
infer_params = [param]
if "linear_fc1.weight" in name:
# if the tensor is gate and proj
gate_lst = []
up_lst = []
for infer_param in infer_params:
gate, up = infer_param.chunk(2)
gate_lst.append(gate)
up_lst.append(up)
gate = torch.cat(gate_lst, dim=0)
up = torch.cat(up_lst, dim=0)
gate_splits = torch.chunk(gate, self._infer_tp_size, dim=0)
up_splits = torch.chunk(up, self._infer_tp_size, dim=0)
new_params_list = [
torch.cat([gate_splits[i], up_splits[i]], dim=0)
for i in range(self._infer_tp_size)
]
else:
partition_dim = get_tensor_parallel_partition_dim(param)
infer_params = torch.cat(infer_params, dim=partition_dim)
split_params = torch.chunk(infer_params, self._infer_tp_size, dim=partition_dim)
new_params_list = list(split_params)
# make_list
param_list = new_params_list
else:
param_list = [param] * self._infer_tp_size
global_rank = self._rank
infer_tp_rank_in_group = global_rank % self._infer_tp_size
return param_list[infer_tp_rank_in_group]
def allgather_tp_param(self, param, name):
if self._tp_size <= self._infer_tp_size:
return param
tp_allgather_size = get_tp_allgather_world_size()
tp_allgather_group = get_tp_allgather_group()
infer_param = param
if tp_allgather_size <= 1 or is_fake_tp_param(name, self.moe_tp_extend_ep):
return infer_param
if is_tensor_parallel_param(param):
# allocate a new tensor with proper size
infer_param = [torch.empty_like(param) for _ in range(tp_allgather_size)]
torch.distributed.all_gather(infer_param, param, group=tp_allgather_group)
part_len = len(infer_param) // self._infer_tp_size
start = self._rank % self._infer_tp_size
part_param = infer_param[part_len * start:part_len * (start + 1)]
infer_param = self._default_tp_concat_fn(name, param, part_param)
return infer_param