[mcore] moonlight (small model with deepseekv3 arch) (#1284)

achieve 74.3 at gsm8k, while moonlight reported as 77.4

still WIP with the performance diff
This commit is contained in:
Yan Bai
2025-05-28 17:10:29 +08:00
committed by GitHub
parent 8fe4950061
commit be47ac44b2
13 changed files with 653 additions and 25 deletions

View File

@ -0,0 +1,109 @@
set -x
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
# 0. download the model
huggingface-cli download moonshotai/Moonlight-16B-A3B-Instruct
# 1. convert the model to mcore format
# change the HF_MODEL_PATH and DIST_CKPT_PATH to your own path
HF_MODEL_PATH=/data/models/moonshotai/Moonlight-16B-A3B-Instruct
DIST_CKPT_PATH=/data/mcore_ckpt/Moonlight-16B-A3B-Instruct
python scripts/converter_hf_to_mcore.py --hf_model_path $HF_MODEL_PATH --output_path $DIST_CKPT_PATH
# 2. run the script
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
train_files=$gsm8k_train_path
test_files=$gsm8k_test_path
ALL_OFFLOAD=${ALL_OFFLOAD:-False}
COMMON_PARAM_OFFLOAD=${COMMON_PARAM_OFFLOAD:-$ALL_OFFLOAD}
COMMON_GRAD_OFFLOAD=${COMMON_GRAD_OFFLOAD:-$ALL_OFFLOAD}
COMMON_OPTIMIZER_OFFLOAD=${COMMON_OPTIMIZER_OFFLOAD:-$ALL_OFFLOAD}
ACTOR_PARAM_OFFLOAD=${ACTOR_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
ACTOR_GRAD_OFFLOAD=${ACTOR_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
ACTOR_OPTIMIZER_OFFLOAD=${ACTOR_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
REF_PARAM_OFFLOAD=${REF_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
CRITIC_PARAM_OFFLOAD=${CRITIC_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
CRITIC_GRAD_OFFLOAD=${CRITIC_GRAD_OFFLOAD:-$COMMON_GRAD_OFFLOAD}
CRITIC_OPTIMIZER_OFFLOAD=${CRITIC_OPTIMIZER_OFFLOAD:-$COMMON_OPTIMIZER_OFFLOAD}
RM_PARAM_OFFLOAD=${RM_PARAM_OFFLOAD:-$COMMON_PARAM_OFFLOAD}
NODES=4
PP=2
TP=8
EP=8
ETP=1
VLLM_TP=4
# RAY_ADDRESS='auto' ray job submit --working-dir . --
python3 -m verl.trainer.main_ppo --config-path=./config --config-name='ppo_megatron_trainer'\
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.train_batch_size=1024 \
data.max_prompt_length=1024 \
data.max_response_length=512 \
data.filter_overlong_prompts=True \
data.truncation='error' \
+data.trust_remote_code=True \
actor_rollout_ref.model.path=$LLM \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
critic.optim.lr=1e-5 \
critic.model.path=$LLM \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size_per_gpu=4 \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb'] \
trainer.project_name='verl_megatron_gsm8k_examples' \
trainer.experiment_name='moonlight_16b_a3b_instruct_1node' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=$NODES \
trainer.save_freq=-1 \
trainer.test_freq=5 \
actor_rollout_ref.model.trust_remote_code=True \
critic.model.trust_remote_code=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=13 \
actor_rollout_ref.rollout.tensor_model_parallel_size=$VLLM_TP \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=$PP \
critic.megatron.pipeline_model_parallel_size=$PP \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=$TP \
critic.megatron.tensor_model_parallel_size=$TP \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.ref.megatron.expert_model_parallel_size=$EP \
critic.megatron.expert_model_parallel_size=$EP \
actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=$ETP \
critic.megatron.expert_tensor_parallel_size=$ETP \
actor_rollout_ref.actor.megatron.param_offload=${ACTOR_PARAM_OFFLOAD} \
actor_rollout_ref.actor.megatron.optimizer_offload=${ACTOR_OPTIMIZER_OFFLOAD} \
actor_rollout_ref.actor.megatron.grad_offload=${ACTOR_GRAD_OFFLOAD} \
actor_rollout_ref.ref.megatron.param_offload=${REF_PARAM_OFFLOAD} \
critic.megatron.param_offload=${CRITIC_PARAM_OFFLOAD} \
critic.megatron.optimizer_offload=${CRITIC_OPTIMIZER_OFFLOAD} \
critic.megatron.grad_offload=${CRITIC_GRAD_OFFLOAD} \
actor_rollout_ref.actor.megatron.use_dist_checkpointing=True \
actor_rollout_ref.ref.megatron.use_dist_checkpointing=True \
critic.megatron.use_dist_checkpointing=True \
actor_rollout_ref.actor.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
actor_rollout_ref.ref.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
critic.megatron.dist_checkpointing_path=$DIST_CKPT_PATH \
trainer.val_before_train=False \
trainer.total_epochs=100 $@

View File

@ -35,6 +35,7 @@ def _init_args():
parser.add_argument("--output_path", type=str, required=True, help="The path for the output mcore model")
parser.add_argument("--use_cpu_initialization", action="store_true", help="Whether to use cpu initialization")
parser.add_argument("--test", action="store_true", help="Whether to test the conversion")
parser.add_argument("--trust_remote_code", action="store_true", help="Whether to trust remote code")
args = parser.parse_args()
return args
@ -120,7 +121,7 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
v_bias = hf_layer.self_attn.v_proj.bias.view([num_key_value_heads, -1])
qkv_bias = torch.cat([q_bias, k_bias, v_bias], dim=1).view(-1).contiguous()
layer.self_attention.linear_qkv.bias.copy_(qkv_bias)
if hasattr(hf_layer.self_attn, "q_norm"):
layer.self_attention.q_layernorm.weight.copy_(hf_layer.self_attn.q_norm.weight.data)
layer.self_attention.k_layernorm.weight.copy_(hf_layer.self_attn.k_norm.weight.data)
@ -145,7 +146,72 @@ def convert_checkpoint_from_transformers_to_megatron(hf_model, model, hf_config)
model.output_layer.weight.copy_(hf_model.lm_head.weight)
def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False):
@torch.no_grad()
def convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model, hf_config, tfconfig):
warnings.warn("MTP model is not supported yet", stacklevel=2)
def safe_copy(
src_tensor: torch.Tensor,
dst_tensor: torch.Tensor,
skip_dtype_assert: bool = False,
):
if not skip_dtype_assert:
if src_tensor.dtype != dst_tensor.dtype:
raise ValueError(f"Get source dtype {src_tensor.dtype}, but target dtype {dst_tensor.dtype}")
assert src_tensor.shape == dst_tensor.shape
dst_tensor.data.copy_(src_tensor.data)
return src_tensor.numel()
model.embedding.word_embeddings.weight.copy_(hf_model.model.embed_tokens.weight)
for layer_idx, (layer, hf_layer) in enumerate(zip(model.decoder.layers, hf_model.model.layers)):
print(layer_idx)
layer.input_layernorm.weight.copy_(hf_layer.input_layernorm.weight)
if hf_config.q_lora_rank is None:
layer.self_attention.linear_q_proj.weight.copy_(hf_layer.self_attn.q_proj.weight)
else:
layer.self_attention.linear_q_down_proj.weight.copy_(hf_layer.self_attn.q_a_proj.weight)
layer.self_attention.linear_q_up_proj.weight.copy_(hf_layer.self_attn.q_b_proj.weight)
layer.self_attention.linear_q_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.q_a_layernorm.weight)
layer.self_attention.linear_kv_down_proj.weight.copy_(hf_layer.self_attn.kv_a_proj_with_mqa.weight)
layer.self_attention.linear_kv_up_proj.weight.copy_(hf_layer.self_attn.kv_b_proj.weight)
layer.self_attention.linear_kv_up_proj.layer_norm_weight.copy_(hf_layer.self_attn.kv_a_layernorm.weight)
layer.self_attention.linear_proj.weight.copy_(hf_layer.self_attn.o_proj.weight)
if not hasattr(layer.mlp, "router"):
layer.mlp.linear_fc1.layer_norm_weight.copy_(hf_layer.post_attention_layernorm.weight)
layer.mlp.linear_fc1.weight.copy_(torch.cat([hf_layer.mlp.gate_proj.weight, hf_layer.mlp.up_proj.weight]))
layer.mlp.linear_fc2.weight.copy_(hf_layer.mlp.down_proj.weight)
else:
layer.mlp.router.weight.copy_(hf_layer.mlp.gate.weight)
# NOTE: the e_score_correction_bias in mcore model will be initialized with bfloat16 and \
# recover to fp32 in the first forward. There is always a diff in the bias between two models (~0.3%)
safe_copy(hf_layer.mlp.gate.e_score_correction_bias, layer.mlp.router.expert_bias, skip_dtype_assert=True)
if tfconfig.moe_grouped_gemm:
for i, hf_expert in enumerate(hf_layer.mlp.experts):
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
linear_fc1_weighti = getattr(layer.mlp.experts.linear_fc1, "weight" + str(i))
linear_fc1_weighti.copy_(fc1_weight)
linear_fc2_weighti = getattr(layer.mlp.experts.linear_fc2, "weight" + str(i))
linear_fc2_weighti.copy_(hf_expert.down_proj.weight)
else:
for i, hf_expert in enumerate(hf_layer.mlp.experts):
expert = layer.mlp.experts.local_experts[i]
fc1_weight = torch.cat([hf_expert.gate_proj.weight, hf_expert.up_proj.weight])
expert.linear_fc1.weight.copy_(fc1_weight)
expert.linear_fc2.weight.copy_(hf_expert.down_proj.weight)
layer.pre_mlp_layernorm.weight.copy_(hf_layer.post_attention_layernorm.weight)
shared_fc1_weight = torch.cat([hf_layer.mlp.shared_experts.gate_proj.weight, hf_layer.mlp.shared_experts.up_proj.weight])
layer.mlp.shared_experts.linear_fc1.weight.copy_(shared_fc1_weight)
layer.mlp.shared_experts.linear_fc2.weight.copy_(hf_layer.mlp.shared_experts.down_proj.weight)
model.decoder.final_layernorm.weight.copy_(hf_model.model.norm.weight)
if not hf_config.tie_word_embeddings:
model.output_layer.weight.copy_(hf_model.lm_head.weight)
def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False, test=False, trust_remote_code=False):
os.makedirs(output_path, exist_ok=True)
if len(os.listdir(output_path)) > 0 and not test:
print(f"Output path {output_path} is not empty, skipping conversion")
@ -200,12 +266,14 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
warnings.simplefilter("ignore")
# init hf model
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16)
hf_model = AutoModelForCausalLM.from_pretrained(hf_model_path, torch_dtype=torch.bfloat16, trust_remote_code=trust_remote_code)
hf_state_dict = hf_model.state_dict()
# load hf state dict to megatron model
if "Qwen2MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
elif "DeepseekV3ForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron_dpskv3(hf_model, model[0].module, hf_config, tfconfig=tfconfig)
elif "Qwen3MoeForCausalLM" in hf_config.architectures:
convert_checkpoint_from_transformers_to_megatron(hf_model, model[0].module, hf_config)
else:
@ -232,4 +300,4 @@ def convert_hf_to_mcore(hf_model_path, output_path, use_cpu_initialization=False
if __name__ == "__main__":
args = _init_args()
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test)
convert_hf_to_mcore(args.hf_model_path, args.output_path, args.use_cpu_initialization, args.test, args.trust_remote_code)

View File

@ -23,7 +23,7 @@ from megatron.core.transformer import MLATransformerConfig, TransformerConfig
from transformers import PretrainedConfig
def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> dict:
"""
Create a base TransformerConfig with common parameters across different model architectures.
TODO: (ycl) use dataclass or converter config?
@ -82,7 +82,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype
base_config.update(override_transformer_config_kwargs)
print(f"Overridden TF init config: {base_config}")
return TransformerConfig(**base_config)
return base_config
def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
@ -90,11 +90,12 @@ def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **
qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False)
qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False
return _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs)
args = _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs)
return TransformerConfig(**args)
def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
return _get_base_transformer_config(
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
@ -121,10 +122,11 @@ def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype,
add_qkv_bias=True,
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)
def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
return _get_base_transformer_config(
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
@ -150,10 +152,11 @@ def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype,
bias_dropout_fusion=True,
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)
def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:
return _get_base_transformer_config(
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
@ -178,11 +181,87 @@ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype,
qk_layernorm=True,
**override_transformer_config_kwargs,
)
return TransformerConfig(**args)
def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig:
# DeepseekV3ForCausalLM
raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet")
from megatron.core.transformer.enums import AttnBackend
from .patch_v012 import apply_patch
apply_patch()
mla_rope_config = {
"beta_fast": 32,
"beta_slow": 1,
"factor": 1,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "rope",
}
if "rope_scaling" in hf_config and hf_config.rope_scaling is not None:
mla_rope_config.update(hf_config.rope_scaling)
moe_layer_freq = [1] * hf_config.num_hidden_layers
for i in range(hf_config.first_k_dense_replace):
moe_layer_freq[i] = 0
args = _get_base_transformer_config(
hf_config=hf_config,
dtype=dtype,
use_cpu_initialization=False,
add_bias_linear=False,
attention_backend=AttnBackend.fused,
bf16=dtype is torch.bfloat16,
layernorm_epsilon=hf_config.rms_norm_eps,
ffn_hidden_size=hf_config.intermediate_size,
qk_layernorm=True,
# moe specific
moe_ffn_hidden_size=hf_config.moe_intermediate_size,
moe_token_dispatcher_type="alltoall",
moe_router_bias_update_rate=0.001,
moe_router_enable_expert_bias=True,
moe_router_topk=hf_config.num_experts_per_tok,
num_moe_experts=hf_config.n_routed_experts,
moe_shared_expert_intermediate_size=hf_config.moe_intermediate_size * hf_config.n_shared_experts,
moe_aux_loss_coeff=getattr(hf_config, "aux_loss_alpha", 0.001),
moe_router_load_balancing_type="seq_aux_loss",
moe_shared_expert_overlap=True,
# moe_permute_fusion=True, # need TE 2.1+
moe_grouped_gemm=True,
moe_router_score_function="sigmoid",
moe_router_pre_softmax=True,
moe_router_topk_scaling_factor=hf_config.routed_scaling_factor,
moe_layer_freq=moe_layer_freq,
# MLA
q_lora_rank=hf_config.q_lora_rank,
kv_lora_rank=hf_config.kv_lora_rank,
qk_head_dim=hf_config.qk_nope_head_dim,
qk_pos_emb_head_dim=hf_config.qk_rope_head_dim,
v_head_dim=hf_config.v_head_dim,
rotary_base=hf_config.rope_theta,
rotary_scaling_factor=mla_rope_config["factor"],
rope_type=mla_rope_config["type"],
mscale=mla_rope_config["mscale"],
mscale_all_dim=mla_rope_config["mscale_all_dim"],
max_position_embeddings=mla_rope_config["original_max_position_embeddings"],
beta_fast=mla_rope_config["beta_fast"],
beta_slow=mla_rope_config["beta_slow"],
# mcore 0.12 moe
moe_router_dtype="fp64",
disable_bf16_reduced_precision_matmul=True,
# other
# deallocate_pipeline_outputs=True,
# gradient_accumulation_fusion=True,
persist_layer_norm=True,
bias_activation_fusion=True,
bias_dropout_fusion=True,
**override_transformer_config_kwargs,
)
transformer_config = MLATransformerConfig(**args)
return transformer_config
def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig:

View File

@ -156,6 +156,29 @@ class Qwen3MoEModel(BaseModelInitializer):
return model
class DeepseekV3Model(BaseModelInitializer):
"""Initializer for DeepseekV3 models."""
def get_transformer_layer_spec(self):
assert self.tfconfig.normalization == "RMSNorm", "only RMSNorm is supported for now"
transformer_layer_spec = get_gpt_decoder_block_spec(self.tfconfig, use_transformer_engine=True)
return transformer_layer_spec
def initialize(
self,
**kwargs,
):
freeze_moe_router = kwargs.get("freeze_moe_router", True)
if freeze_moe_router:
self.tfconfig.moe_router_load_balancing_type = "none"
model = super().initialize(**kwargs)
if freeze_moe_router:
for layer in model.decoder.layers:
if hasattr(layer.mlp, "router"):
layer.mlp.router.weight.requires_grad = False
return model
class Qwen25VLModel(BaseModelInitializer):
"""Initializer for Qwen2.5 VL models."""

View File

@ -0,0 +1,199 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# there is some bug in mcore 0.12, so we need to patch it
# 1. `get_query_key_value_tensors` in `multi_latent_attention.py` works wrong when packed_seq_params is not None
def apply_patch():
import torch
from megatron.core.transformer.multi_latent_attention import MLASelfAttention, apply_rotary_pos_emb, deprecate_inference_params, gather_from_sequence_parallel_region, gather_from_tensor_model_parallel_region, parallel_state, scatter_to_sequence_parallel_region, tensor_parallel
def patch_get_query_key_value_tensors(
self,
hidden_states,
key_value_states=None,
position_ids=None,
packed_seq_params=None,
inference_context=None,
*,
inference_params=None,
):
"""
Derives `query`, `key` and `value` tensors from `hidden_states`.
"""
# s = sequence length, b = batch size, h = hidden size, n = num attention heads
# Attention heads [s, b, n*h]
assert hidden_states.ndim == 3, f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D"
inference_context = deprecate_inference_params(inference_context, inference_params)
# =========================================
# Prepare RoPE and seqlen related params
# =========================================
rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len(inference_context, None, hidden_states, self.config, packed_seq_params)
# rotary_pos_emb:[s, b, 1, 64]
mscale = 1.0
if self.config.rope_type == "rope":
packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == "thd"
rotary_pos_emb = self.rotary_pos_emb(rotary_seq_len, packed_seq=packed_seq)
else:
rotary_pos_emb, mscale = self.rotary_pos_emb(rotary_seq_len)
# =========================================
# QKV down projection and layernorm
# =========================================
if self.config.q_lora_rank is not None:
# if linear_q_down_proj is ColumnParallelLinear:
# q_compressed: [s, b, q_lora_rank / TP]
# elif linear_q_down_proj is Linear:
# q_compressed: [s / TP, b, q_lora_rank]
q_compressed, _ = self.linear_q_down_proj(hidden_states)
# When output is sharded (ColumnParallelLinear), two things are needed to be
# identical to a normal Linear.
# 1. Manually gather output to restore output dim q_lora_rank;
# 2. Scatter sequence back to s / TP if sequence-parallel since it was
# gathered by ColumnParallelLinear.
if q_compressed.size(-1) != self.config.q_lora_rank:
q_compressed = gather_from_tensor_model_parallel_region(q_compressed)
if self.config.sequence_parallel:
q_compressed = scatter_to_sequence_parallel_region(q_compressed)
q_compressed = self.q_layernorm(q_compressed)
else:
q_compressed = hidden_states
# if linear_kv_down_proj is ColumnParallelLinear:
# kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim) / TP]
# elif linear_kv_down_proj is Linear:
# kv_combined: [s / TP, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined, _ = self.linear_kv_down_proj(hidden_states)
if kv_combined.size(-1) != self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim:
# kv_combined: [s, b, (kv_lora_rank + qk_pos_emb_head_dim)]
kv_combined = gather_from_tensor_model_parallel_region(kv_combined)
# kv_compressed:[s, b, kv_lora_rank], k_pos_emb: [s, b, qk_pos_emb_head_dim]
kv_compressed, k_pos_emb = torch.split(kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1)
if self.config.sequence_parallel:
# kv_compressed:[s / TP, b, kv_lora_rank]
kv_compressed = scatter_to_sequence_parallel_region(kv_compressed)
else:
# kv_compressed:[s / TP, b, kv_lora_rank], k_pos_emb: [s / TP, b, qk_pos_emb_head_dim]
kv_compressed, k_pos_emb = torch.split(kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1)
if parallel_state.get_tensor_model_parallel_world_size() > 1:
# k_pos_emb: [s, b, qk_pos_emb_head_dim]
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)
kv_compressed = self.kv_layernorm(kv_compressed)
# =========================================
# QKV up projection and RoPE apply
# =========================================
def qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb):
if self.config.q_lora_rank is not None:
q, _ = self.linear_q_up_proj(q_compressed)
else:
# hidden_states:[s, b, 2048], q: [s, b, n * 192]
q, _ = self.linear_q_proj(q_compressed)
q_len, bsz, _ = q.size()
# q: [s, b, n, 192]
q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)
# kv: [s, b, 2048]
kv, _ = self.linear_kv_up_proj(kv_compressed)
# kv: [s, b, n, 256]
kv = kv.view(
q_len,
bsz,
self.num_attention_heads_per_partition,
self.config.qk_head_dim + self.config.v_head_dim,
)
if inference_context is not None:
# add offset to the sequence start for inference
sequence_start = inference_context.sequence_len_offset
sequence_end = sequence_start + q_len
rotary_pos_emb = rotary_pos_emb[sequence_start:sequence_end]
else:
# Shorten rotary_pos_emb to the sequence length when inference_params
# is not provided. This makes sure we can run forward directly with
# any sequence length. During training, the sequence length is always
# the full rotary_pos_emb length.
rotary_pos_emb = rotary_pos_emb[0:q_len]
# [s, b, 64] -> [s, b, 1, 64]
k_pos_emb = torch.unsqueeze(k_pos_emb, 2)
# q: [s, b, n, 128], q_pos_emb: [s, b, n, 64]
q_no_pe, q_pos_emb = torch.split(q, [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], dim=-1)
# k_no_pe: [s, b, n, 128], value: [s, b, n, 128]
k_no_pe, value = torch.split(kv, [self.config.qk_head_dim, self.config.v_head_dim], dim=-1)
if packed_seq_params is not None:
cu_seqlens_q = packed_seq_params.cu_seqlens_q
cu_seqlens_kv = packed_seq_params.cu_seqlens_kv
q_pos_emb = q_pos_emb.squeeze(1)
k_pos_emb = k_pos_emb.squeeze(1)
q_no_pe = q_no_pe.squeeze(1)
k_no_pe = k_no_pe.squeeze(1)
value = value.squeeze(1)
else:
cu_seqlens_q = cu_seqlens_kv = None
# q_pos_emb: [s, b, n, 64], k_pos_emb:[s, b, 1, 64]
q_pos_emb = apply_rotary_pos_emb(
q_pos_emb,
rotary_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_q,
mscale=mscale,
)
k_pos_emb = apply_rotary_pos_emb(
k_pos_emb,
rotary_pos_emb,
config=self.config,
cu_seqlens=cu_seqlens_kv,
mscale=mscale,
)
# query: [s, b, n, 192]
query = torch.cat([q_no_pe, q_pos_emb], dim=-1)
if packed_seq_params is not None:
k_pos_emb = k_pos_emb.expand(-1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)
else:
# key: [s, b, n, 192]
k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
return query, key, value
if self.recompute_up_proj:
self.qkv_up_checkpoint = tensor_parallel.CheckpointWithoutOutput()
query, key, value = self.qkv_up_checkpoint.checkpoint(qkv_up_proj_and_rope_apply, q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)
else:
query, key, value = qkv_up_proj_and_rope_apply(q_compressed, kv_compressed, k_pos_emb, rotary_pos_emb)
return query, key, value
MLASelfAttention.get_query_key_value_tensors = patch_get_query_key_value_tensors

View File

@ -38,6 +38,7 @@ from .model_forward import (
)
from .model_initializer import (
BaseModelInitializer,
DeepseekV3Model,
DenseModel,
MixtralModel,
Qwen2MoEModel,
@ -46,6 +47,7 @@ from .model_initializer import (
)
from .weight_converter import (
McoreToHFWeightConverterDense,
McoreToHFWeightConverterDpskv3,
McoreToHFWeightConverterMixtral,
McoreToHFWeightConverterQwen2Moe,
McoreToHFWeightConverterQwen3Moe,
@ -83,7 +85,7 @@ MODEL_INITIALIZER_REGISTRY: Dict[SupportedModel, Type[BaseModelInitializer]] = {
SupportedModel.QWEN2: DenseModel,
SupportedModel.QWEN2_MOE: Qwen2MoEModel,
SupportedModel.MIXTRAL: MixtralModel,
SupportedModel.DEEPSEEK_V3: DenseModel,
SupportedModel.DEEPSEEK_V3: DeepseekV3Model,
SupportedModel.QWEN2_5_VL: Qwen25VLModel,
SupportedModel.LLAMA4: DenseModel,
SupportedModel.QWEN3: DenseModel,
@ -101,6 +103,7 @@ MODEL_FORWARD_REGISTRY: Dict[SupportedModel, Callable] = {
SupportedModel.LLAMA4: gptmodel_forward,
SupportedModel.QWEN3: gptmodel_forward,
SupportedModel.QWEN3_MOE: gptmodel_forward,
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
}
# Registry for model weight converters
@ -109,6 +112,7 @@ MODEL_WEIGHT_CONVERTER_REGISTRY: Dict[SupportedModel, Type] = {
SupportedModel.QWEN2: McoreToHFWeightConverterDense,
SupportedModel.QWEN2_MOE: McoreToHFWeightConverterQwen2Moe,
SupportedModel.MIXTRAL: McoreToHFWeightConverterMixtral,
SupportedModel.DEEPSEEK_V3: McoreToHFWeightConverterDpskv3,
SupportedModel.QWEN3: McoreToHFWeightConverterDense,
SupportedModel.QWEN3_MOE: McoreToHFWeightConverterQwen3Moe,
}

View File

@ -379,7 +379,7 @@ def merge_megatron_ckpt_gptmodel(wrapped_models, config, dtype, is_value_model=F
f"{layer_name}.self_attn.k_norm.weight",
src_pp_rank=src_pp_rank,
)
_broadcast_tp_shard_tensor_qkv(
sync_layer.self_attention.linear_qkv.weight,
f"{layer_name}.self_attn.q_proj.weight",
@ -467,5 +467,9 @@ def merge_megatron_ckpt_gptmodel_qwen_moe(wrapped_models, config, dtype, is_valu
raise NotImplementedError("merge_megatron_ckpt_gptmodel_qwen_moe is not implemented")
def merge_megatron_ckpt_gptmodel_dpskv3(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
raise NotImplementedError("merge_megatron_ckpt_gptmodel_dpskv3 is not implemented")
def merge_megatron_ckpt_gptmodel_mixtral(wrapped_models, config, dtype, is_value_model=False, tie_word_embeddings=False):
raise NotImplementedError("merge_megatron_ckpt_gptmodel_mixtral is not implemented")

View File

@ -147,6 +147,133 @@ class McoreToHFWeightConverterQwen2Moe(McoreToHFWeightConverterDense):
return convert_names, params
class McoreToHFWeightConverterDpskv3(McoreToHFWeightConverterBase):
def _convert_attention_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
# mcore
# 'decoder.layers.0.input_layernorm.weight'
# 'decoder.layers.0.self_attention.linear_proj.weight'
# 'decoder.layers.0.self_attention.linear_q_proj.weight'
# 'decoder.layers.0.self_attention.linear_kv_down_proj.weight'
# 'decoder.layers.0.self_attention.linear_kv_up_proj.layer_norm_weight'
# 'decoder.layers.0.self_attention.linear_kv_up_proj.weight'
# 'decoder.layers.0.self_attention.linear_q_down_proj.weight'
# 'decoder.layers.0.self_attention.linear_q_up_proj.weight'
# 'decoder.layers.0.self_attention.linear_q_up_proj.layer_norm_weight'
# hf
# 'model.layers.0.input_layernorm.weight'
# 'model.layers.0.self_attn.o_proj.weight'
# 'model.layers.0.self_attn.q_proj.weight'
# 'model.layers.0.self_attn.kv_a_proj_with_mqa.weight'
# 'model.layers.0.self_attn.kv_a_layernorm.weight'
# 'model.layers.0.self_attn.kv_b_proj.weight'
# 'model.layers.0.self_attn.q_a_proj.weight'
# 'model.layers.0.self_attn.q_b_proj.weight'
# 'model.layers.0.self_attn.q_a_layernorm.weight'
name_map_after_layer = {
"input_layernorm.weight": "input_layernorm.weight",
"self_attention.linear_proj.weight": "self_attn.o_proj.weight",
"self_attention.linear_q_proj.weight": "self_attn.q_proj.weight",
"self_attention.linear_kv_down_proj.weight": "self_attn.kv_a_proj_with_mqa.weight",
"self_attention.linear_kv_up_proj.layer_norm_weight": "self_attn.kv_a_layernorm.weight",
"self_attention.linear_kv_up_proj.weight": "self_attn.kv_b_proj.weight",
"self_attention.linear_q_down_proj.weight": "self_attn.q_a_proj.weight",
"self_attention.linear_q_up_proj.weight": "self_attn.q_b_proj.weight",
"self_attention.linear_q_up_proj.layer_norm_weight": "self_attn.q_a_layernorm.weight",
}
assert len(params) == 1
convert_names = []
layer_number = name.split(".")[2]
name_after_layer = name.split(f".{layer_number}.")[1]
convert_names.append(f"model.layers.{layer_number}.{name_map_after_layer[name_after_layer]}")
return convert_names, params
def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
# mcore dense
# 'decoder.layers.0.mlp.linear_fc1.layer_norm_weight'
# 'decoder.layers.0.mlp.linear_fc2.weight'
# 'decoder.layers.0.mlp.linear_fc1.weight'
# ---
# 'decoder.layers.1.mlp.shared_experts.linear_fc1.weight'
# ---
# 'decoder.layers.1.mlp.shared_experts.linear_fc2.weight'
# hf dense
# 'model.layers.0.post_attention_layernorm.weight'
# 'model.layers.0.mlp.down_proj.weight'
# 'model.layers.0.mlp.gate_proj.weight'
# 'model.layers.0.mlp.up_proj.weight'
# 'model.layers.1.mlp.shared_experts.gate_proj.weight'
# 'model.layers.1.mlp.shared_experts.up_proj.weight'
# 'model.layers.1.mlp.shared_experts.down_proj.weight'
# mcore moe
# 'decoder.layers.1.pre_mlp_layernorm.weight'
# 'decoder.layers.1.mlp.router.weight'
# 'decoder.layers.1.mlp.router.expert_bias'
# 'decoder.layers.1.mlp.experts.linear_fc1.weight0'
# ---
# 'decoder.layers.1.mlp.experts.linear_fc2.weight0'
# hf moe
# 'model.layers.1.post_attention_layernorm.weight'
# 'model.layers.1.mlp.gate.weight'
# 'model.layers.1.mlp.gate.e_score_correction_bias'
# 'model.layers.1.mlp.experts.0.gate_proj.weight'
# 'model.layers.1.mlp.experts.0.up_proj.weight'
# 'model.layers.1.mlp.experts.0.down_proj.weight'
name_map_after_layer = {
"mlp.linear_fc1.layer_norm_weight": "post_attention_layernorm.weight",
"mlp.linear_fc2.weight": "mlp.down_proj.weight",
"mlp.shared_experts.linear_fc2.weight": "mlp.shared_experts.down_proj.weight",
"mlp.linear_fc1.weight": ["mlp.gate_proj.weight", "mlp.up_proj.weight"],
"mlp.shared_experts.linear_fc1.weight": ["mlp.shared_experts.gate_proj.weight", "mlp.shared_experts.up_proj.weight"],
"pre_mlp_layernorm.weight": "post_attention_layernorm.weight",
"mlp.router.weight": "mlp.gate.weight",
"mlp.router.expert_bias": "mlp.gate.e_score_correction_bias",
}
convert_names = []
layer_number = name.split(".")[2]
name_after_layer = name.split(f".{layer_number}.")[1]
if name_after_layer in name_map_after_layer:
mapped_name = name_map_after_layer[name_after_layer]
if isinstance(mapped_name, list):
assert len(params) == len(mapped_name)
for one in mapped_name:
convert_names.append(f"model.layers.{layer_number}.{one}")
else:
assert len(params) == 1
convert_names.append(f"model.layers.{layer_number}.{mapped_name}")
else:
if "mlp.experts.linear_fc1.weight" in name:
expert_id = name.split("weight")[-1]
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight")
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight")
assert len(params) == 2
elif "mlp.experts.linear_fc2.weight" in name:
expert_id = name.split("weight")[-1]
convert_names.append(f"model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight")
assert len(params) == 1
else:
raise NotImplementedError(f"Unsupported parameter name: {name}")
return convert_names, params
def convert_param(self, name: str, params_one_group: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
direct_name_mapping = {
"embedding.word_embeddings.weight": "model.embed_tokens.weight",
"decoder.final_layernorm.weight": "model.norm.weight",
"output_layer.weight": "lm_head.weight",
}
if name in direct_name_mapping:
return [direct_name_mapping[name]], [params_one_group[0]]
if "self_attention" in name or "input_layernorm.weight" in name:
return self._convert_attention_param(name, params_one_group)
elif "mlp" in name:
return self._convert_mlp_param(name, params_one_group)
else:
raise NotImplementedError(f"Unsupported parameter name: {name}")
class McoreToHFWeightConverterMixtral(McoreToHFWeightConverterDense):
def _convert_mlp_param(self, name: str, params: list[torch.Tensor]) -> tuple[list[str], list[torch.Tensor]]:
# decoder.layers.0.mlp.router.weight

View File

@ -27,13 +27,14 @@ def get_weight_loader(arch: str):
def get_weight_saver(arch: str):
from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen_moe
from verl.models.mcore.saver import merge_megatron_ckpt_gptmodel, merge_megatron_ckpt_gptmodel_dpskv3, merge_megatron_ckpt_gptmodel_mixtral, merge_megatron_ckpt_gptmodel_qwen_moe
_MODEL_WEIGHT_MEGATRON_SAVER_REGISTRY = {
"LlamaForCausalLM": merge_megatron_ckpt_gptmodel,
"Qwen2ForCausalLM": merge_megatron_ckpt_gptmodel,
"MixtralForCausalLM": merge_megatron_ckpt_gptmodel_mixtral,
"Qwen2MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe,
"DeepseekV3ForCausalLM": merge_megatron_ckpt_gptmodel_dpskv3,
"Qwen3ForCausalLM": merge_megatron_ckpt_gptmodel,
"Qwen3MoeForCausalLM": merge_megatron_ckpt_gptmodel_qwen_moe,
}

View File

@ -39,7 +39,14 @@ class MegatronWorker(Worker):
info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank)
return info
def _init_hf_config_and_tf_config(self, model_path, dtype, override_model_config, override_transformer_config):
def _init_hf_config_and_tf_config(
self,
model_path,
dtype,
override_model_config,
override_transformer_config,
trust_remote_code=False,
):
from transformers import AutoConfig
from verl.models.mcore import hf_to_mcore_config
@ -49,10 +56,10 @@ class MegatronWorker(Worker):
# Step 1: initialize the tokenizer
self.local_path = copy_to_local(model_path)
self.tokenizer = hf_tokenizer(self.local_path)
self.tokenizer = hf_tokenizer(self.local_path, trust_remote_code=trust_remote_code)
# Step 2: get the hf
hf_config = AutoConfig.from_pretrained(self.local_path)
hf_config = AutoConfig.from_pretrained(self.local_path, trust_remote_code=trust_remote_code)
# Step 3: override the hf config
override_config_kwargs = {
@ -68,15 +75,19 @@ class MegatronWorker(Worker):
print(f"Model config after override: {hf_config}")
tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config)
def add_optimization_config_to_tf_config(tf_config, verl_model_config):
def add_optimization_config_to_tf_config(tf_config):
# add optimization config to tf_config, e.g. checkpointing
if verl_model_config.get("enable_gradient_checkpointing", False):
gradient_checkpointing_cfg = dict(verl_model_config.get("gradient_checkpointing_kwargs", dict()))
if self.config.model.get("enable_gradient_checkpointing", False):
gradient_checkpointing_cfg = dict(self.config.model.get("gradient_checkpointing_kwargs", dict()))
tf_config.recompute_method = gradient_checkpointing_cfg.get("activations_checkpoint_method", "full")
tf_config.recompute_granularity = gradient_checkpointing_cfg.get("activations_checkpoint_granularity", "full")
tf_config.recompute_num_layers = gradient_checkpointing_cfg.get("activations_checkpoint_num_layers", -1)
if megatron_config := self.config.get("megatron", {}):
if extra := megatron_config.get("extra", {}):
for k, v in extra.items():
setattr(tf_config, k, v)
add_optimization_config_to_tf_config(tf_config, self.config.model)
add_optimization_config_to_tf_config(tf_config)
print(f"TF config: {tf_config}")
self.hf_config = hf_config

View File

@ -37,6 +37,7 @@ actor_rollout_ref:
activations_checkpoint_granularity: null # 'selective' or 'full'
# 'full' will checkpoint the entire transformer layer and 'selective' only checkpoints memory intensive part of attention
activations_checkpoint_num_layers: null # not used with 'selective'
trust_remote_code: False
actor:
strategy: megatron # This is for backward-compatibility
ppo_mini_batch_size: 256
@ -190,6 +191,7 @@ critic:
moe_config:
freeze_moe_router: False
external_lib: ${actor_rollout_ref.model.external_lib}
trust_remote_code: False
enable_gradient_checkpointing: False
gradient_checkpointing_kwargs:
## Activation Checkpointing
@ -250,6 +252,7 @@ reward_model:
model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
trust_remote_code: False
external_lib: ${actor_rollout_ref.model.external_lib}
load_weight: True
micro_batch_size: null # will be deprecated, use micro_batch_size_per_gpu

View File

@ -794,7 +794,7 @@ def per_tensor_generator(actor_module, model_config, weight_converter, transform
else:
params = [param]
merge_params = default_tp_concat_fn(name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split)
merge_params = default_tp_concat_fn(layer_name_mapping, name, broad_pp_tensor, params, model_config, convert_qkv_gate_up_by_simple_split)
if not isinstance(merge_params, list):
merge_params = [merge_params]
converted_names, converted_params = weight_converter.convert_param(name, merge_params)

View File

@ -145,7 +145,7 @@ class ActorRolloutRefWorker(MegatronWorker):
from verl.utils.megatron_utils import get_model, init_megatron_optim_config
from verl.utils.model import get_generation_config, print_model_size
self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config)
self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False))
self.generation_config = get_generation_config(self.local_path)
def megatron_actor_model_provider(pre_process, post_process):
@ -629,7 +629,7 @@ class CriticWorker(MegatronWorker):
from verl.utils.megatron_utils import get_model, init_megatron_optim_config
from verl.utils.model import print_model_size
self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config)
self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False))
def megatron_critic_model_provider(pre_process, post_process):
from verl.models.mcore import init_mcore_model
@ -825,7 +825,7 @@ class RewardModelWorker(MegatronWorker):
from verl.utils.megatron_utils import get_model
self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config)
self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config, self.config.model.get("trust_remote_code", False))
def megatron_rm_model_provider(pre_process, post_process):
from verl.models.mcore import init_mcore_model