diff --git a/.github/workflows/e2e_ppo_trainer_megatron.yml b/.github/workflows/e2e_ppo_trainer_megatron.yml index 0605acf64..42ad40207 100644 --- a/.github/workflows/e2e_ppo_trainer_megatron.yml +++ b/.github/workflows/e2e_ppo_trainer_megatron.yml @@ -124,6 +124,48 @@ jobs: - name: clean up run: | rm -rf checkpoints + e2e_ppo_trainer_megatron-qwen3: + runs-on: [L20x8] + timeout-minutes: 30 # Increase this timeout value as needed + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + container: + image: whatcanyousee/verl:ngc-cu124-vllm0.8.5-sglang0.4.6-mcore0.12.0-te2.2 + options: --gpus all --shm-size=10g + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test] + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving + run: | + ray stop --force + VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming + run: | + ray stop --force + RESUME_MODE=auto MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) + run: | + exp_name="qwen3-0.6b-megatron-gsm8k-minimal" + python scripts/model_merger.py test --backend megatron --tie-word-embedding --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface + python scripts/model_merger.py test --backend megatron --is-value-model --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface + - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) + run: | + ray stop --force + ADV_ESTIMATOR=grpo MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + - name: clean up + run: | + rm -rf checkpoints e2e_ppo_trainer_megatron-different-train-infer-tp-qwen: runs-on: [L20x8] timeout-minutes: 30 # Increase this timeout value as needed @@ -190,7 +232,7 @@ jobs: - name: clean up run: | rm -rf checkpoints - e2e_ppo_trainer_megatron-different-train-infer-tp-deepseek: + e2e_ppo_trainer_megatron-qwen-override-transformer-config: runs-on: [L20x8] timeout-minutes: 30 # Increase this timeout value as needed env: @@ -212,18 +254,24 @@ jobs: - name: Prepare GSM8K dataset run: | python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp > infer tp + - name: Prepare dist_ckpt of Qwen2.5-0.5B, uneven layer distribution only supports dist_ckpt + run: | + python3 scripts/converter_hf_to_mcore.py --hf_model_path ${HOME}/models/Qwen/Qwen2.5-0.5B --output_path checkpoints/verl-test/qwen2.5-0.5b-megatron + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) run: | ray stop --force - VAL_BEFORE_TRAIN=True MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=2 INFER_TP=1 bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen) with train tp < infer tp + SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 SKIP_SAVE_HF_MODEL=1 bash tests/e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_first_pipeline_stage=8 +actor_rollout_ref.actor.megatron.override_transformer_config.num_layers_in_last_pipeline_stage=4 actor_rollout_ref.actor.megatron.use_dist_checkpointing=true actor_rollout_ref.actor.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron actor_rollout_ref.ref.megatron.use_dist_checkpointing=true actor_rollout_ref.ref.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron critic.megatron.use_dist_checkpointing=true critic.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron reward_model.megatron.use_dist_checkpointing=true reward_model.megatron.dist_checkpointing_path=checkpoints/verl-test/qwen2.5-0.5b-megatron + cp -r checkpoints checkpoints-dut + SAVE_FREQ=1 COMMON_PP=4 COMMON_VPP=null COMMON_CP=1 bash tests/e2e/run_ppo_trainer_megatron.sh + - name: Test Megatron checkpoints merging function (Qwen Actor and Critic) run: | - ray stop --force - VAL_BEFORE_TRAIN=True MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct TEST_FREQ=1 SAVE_FREQ=1 TRAIN_TP=1 INFER_TP=2 bash tests/e2e/run_ppo_trainer_megatron.sh + exp_name="qwen2.5-0.5b-megatron-gsm8k-minimal" + python scripts/model_merger.py test --backend megatron --tie-word-embedding --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path Qwen/Qwen2.5-0.5B + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints-dut/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path Qwen/Qwen2.5-0.5B - name: clean up run: | rm -rf checkpoints - e2e_ppo_trainer_megatron-qwen3: + e2e_ppo_trainer_megatron-deepseek-override-transformer-config: runs-on: [L20x8] timeout-minutes: 30 # Increase this timeout value as needed env: @@ -245,23 +293,16 @@ jobs: - name: Prepare GSM8K dataset run: | python3 examples/data_preprocess/gsm8k.py - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) with validation and saving + - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (DeepSeek) run: | ray stop --force - VAL_BEFORE_TRAIN=True TEST_FREQ=1 SAVE_FREQ=1 MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Running GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) after resuming + SAVE_FREQ=1 MODEL_ID=deepseek-ai/deepseek-coder-1.3b-instruct COMMON_PP=2 COMMON_VPP=null bash tests/e2e/run_ppo_trainer_megatron.sh +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_embedding_in_pipeline_split=true +actor_rollout_ref.actor.megatron.override_transformer_config.account_for_loss_in_pipeline_split=true + - name: Test Megatron checkpoints merging function (DeepSeek Actor and Critic) run: | - ray stop --force - RESUME_MODE=auto MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh - - name: Test Megatron checkpoints merging function (Qwen3 Actor and Critic) - run: | - exp_name="qwen3-0.6b-megatron-gsm8k-minimal" - python scripts/model_merger.py test --backend megatron --tie-word-embedding --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface - python scripts/model_merger.py test --backend megatron --is-value-model --hf_model_path Qwen/Qwen3-0.6B --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface - - name: Running GRPO GSM8K E2E training tests with 3D parallelism on 8 L20 GPUs with Megatron (Qwen3) - run: | - ray stop --force - ADV_ESTIMATOR=grpo MODEL_ID=Qwen/Qwen3-0.6B bash tests/e2e/run_ppo_trainer_megatron.sh + exp_name="deepseek-coder-1.3b-instruct-megatron-gsm8k-minimal" + python scripts/model_merger.py test --backend megatron --local_dir checkpoints/verl-test/${exp_name}/global_step_1/actor --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/actor/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct + python scripts/model_merger.py test --backend megatron --is-value-model --local_dir checkpoints/verl-test/${exp_name}/global_step_1/critic --test_hf_dir checkpoints/verl-test/${exp_name}/global_step_1/critic/huggingface --hf_model_path deepseek-ai/deepseek-coder-1.3b-instruct - name: clean up run: | rm -rf checkpoints + diff --git a/.vscode/settings.json b/.vscode/settings.json index c717de9c1..705533538 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -10,5 +10,6 @@ "string_view": "cpp", "initializer_list": "cpp", "utility": "cpp" - } + }, + "iis.configDir": "" } \ No newline at end of file diff --git a/scripts/model_merger.py b/scripts/model_merger.py index 2258d7647..aa0c2e5d2 100644 --- a/scripts/model_merger.py +++ b/scripts/model_merger.py @@ -524,7 +524,7 @@ class MegatronModelMerger(BaseModelMerger): raise RuntimeError(f"key: {name} not exist in state_dict") param = ref_state_dict[name] assert loaded_weight.dtype == param.dtype - torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4) + torch.testing.assert_close(loaded_weight, param, atol=1e-2, rtol=5e-2) def _replace_name(self, megatron_name: str, name_mapping: list[tuple[str, str]]) -> str: for m_name, v_name in name_mapping: diff --git a/tests/e2e/run_ppo_trainer_megatron.sh b/tests/e2e/run_ppo_trainer_megatron.sh index 25b2283d0..a70db50ad 100644 --- a/tests/e2e/run_ppo_trainer_megatron.sh +++ b/tests/e2e/run_ppo_trainer_megatron.sh @@ -55,6 +55,12 @@ RM_VPP=${RM_VPP:-$COMMON_VPP} RM_CP=${RM_CP:-$COMMON_CP} RM_TP=${RM_TP:-$TRAIN_TP} +CHECKPOINT_CONTENTS=['model','hf_model','optimizer','extra'] +SKIP_SAVE_HF_MODEL=${SKIP_SAVE_HF_MODEL:-0} +if [ $SKIP_SAVE_HF_MODEL -eq 1 ]; then + CHECKPOINT_CONTENTS=['model','optimizer','extra'] +fi + exp_name="$(basename "${MODEL_ID,,}")-megatron-gsm8k-minimal" python3 -m verl.trainer.main_ppo --config-path=config \ @@ -78,7 +84,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ actor_rollout_ref.actor.use_kl_loss=True \ actor_rollout_ref.actor.kl_loss_coef=0.001 \ actor_rollout_ref.actor.kl_loss_type=low_var_kl \ - actor_rollout_ref.actor.checkpoint.contents=['model','hf_model','optimizer','extra'] \ + actor_rollout_ref.actor.checkpoint.contents=$CHECKPOINT_CONTENTS \ actor_rollout_ref.rollout.name=vllm \ actor_rollout_ref.rollout.tensor_model_parallel_size=$ROLLOUT_TP \ actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ @@ -97,7 +103,7 @@ python3 -m verl.trainer.main_ppo --config-path=config \ critic.megatron.virtual_pipeline_model_parallel_size=$CRITIC_VPP \ critic.megatron.context_parallel_size=$CRITIC_CP \ critic.megatron.tensor_model_parallel_size=$CRITIC_TP \ - critic.checkpoint.contents=['model','hf_model','optimizer','extra'] \ + critic.checkpoint.contents=$CHECKPOINT_CONTENTS \ reward_model.enable=True \ reward_model.model.path="${MODEL_PATH}" \ reward_model.micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ diff --git a/tests/e2e/run_prime.sh b/tests/e2e/run_prime.sh index fcc4db150..da7664af3 100644 --- a/tests/e2e/run_prime.sh +++ b/tests/e2e/run_prime.sh @@ -34,6 +34,7 @@ python3 -m recipe.prime.main_prime \ actor_rollout_ref.model.path="${MODEL_PATH}" \ actor_rollout_ref.actor.optim.lr=5e-7 \ actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.model.use_fused_kernels=False \ actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${train_traj_micro_bsz_per_gpu} \ actor_rollout_ref.model.enable_gradient_checkpointing=False \ diff --git a/verl/models/mcore/config_converter.py b/verl/models/mcore/config_converter.py index 73bb5d8cb..70a3f7926 100644 --- a/verl/models/mcore/config_converter.py +++ b/verl/models/mcore/config_converter.py @@ -23,7 +23,7 @@ from megatron.core.transformer import MLATransformerConfig, TransformerConfig from transformers import PretrainedConfig -def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **kwargs) -> TransformerConfig: +def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: """ Create a base TransformerConfig with common parameters across different model architectures. TODO: (ycl) use dataclass or converter config? @@ -31,7 +31,7 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype Args: hf_config: HuggingFace model configuration dtype: Data type for the model - **kwargs: Additional parameters to override defaults + override_transformer_config_kwargs: Additional parameters to override defaults Returns: TransformerConfig with common parameters @@ -79,28 +79,21 @@ def _get_base_transformer_config(hf_config: PretrainedConfig, dtype: torch.dtype } # Update with any provided overrides - base_config.update(kwargs) + base_config.update(override_transformer_config_kwargs) print(f"Overridden TF init config: {base_config}") return TransformerConfig(**base_config) -def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config_dense(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: # for LlamaForCausalLM or Qwen2ForCausalLM qkv_bias = True if "Qwen2ForCausalLM" in hf_config.architectures else getattr(hf_config, "attention_bias", False) qk_layernorm = True if "Qwen3ForCausalLM" in hf_config.architectures else False - return _get_base_transformer_config( - hf_config=hf_config, - dtype=dtype, - use_cpu_initialization=False, - add_bias_linear=False, - add_qkv_bias=qkv_bias, - qk_layernorm=qk_layernorm, - ) + return _get_base_transformer_config(hf_config=hf_config, dtype=dtype, use_cpu_initialization=False, add_bias_linear=False, add_qkv_bias=qkv_bias, qk_layernorm=qk_layernorm, **override_transformer_config_kwargs) -def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: return _get_base_transformer_config( hf_config=hf_config, dtype=dtype, @@ -126,10 +119,11 @@ def hf_to_mcore_config_qwen2moe(hf_config: PretrainedConfig, dtype: torch.dtype) # Qwen specific moe_router_pre_softmax=True, add_qkv_bias=True, + **override_transformer_config_kwargs, ) -def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: return _get_base_transformer_config( hf_config=hf_config, dtype=dtype, @@ -154,10 +148,11 @@ def hf_to_mcore_config_mixtral(hf_config: PretrainedConfig, dtype: torch.dtype) apply_rope_fusion=True, bias_activation_fusion=True, bias_dropout_fusion=True, + **override_transformer_config_kwargs, ) -def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: return _get_base_transformer_config( hf_config=hf_config, dtype=dtype, @@ -181,19 +176,20 @@ def hf_to_mcore_config_qwen3moe(hf_config: PretrainedConfig, dtype: torch.dtype) # Qwen specific moe_router_pre_softmax=False, qk_layernorm=True, + **override_transformer_config_kwargs, ) -def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype) -> MLATransformerConfig: +def hf_to_mcore_config_dpskv3(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> MLATransformerConfig: # DeepseekV3ForCausalLM raise NotImplementedError("DeepseekV3ForCausalLM is not supported yet") -def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config_qwen2_5_vl(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: # Qwen2_5_VLForConditionalGeneration raise NotImplementedError("Qwen2_5_VLForConditionalGeneration is not supported yet") -def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config_llama4(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: # Llama4ForConditionalGeneration raise NotImplementedError("Llama4ForConditionalGeneration is not supported yet") diff --git a/verl/models/mcore/registry.py b/verl/models/mcore/registry.py index a3b987a91..9a3861024 100644 --- a/verl/models/mcore/registry.py +++ b/verl/models/mcore/registry.py @@ -122,10 +122,10 @@ def get_supported_model(model_type: str) -> SupportedModel: raise NotImplementedError(f"Model Type: {model_type} not supported. Supported models: {supported_models}") from err -def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype) -> TransformerConfig: +def hf_to_mcore_config(hf_config: PretrainedConfig, dtype: torch.dtype, **override_transformer_config_kwargs) -> TransformerConfig: assert len(hf_config.architectures) == 1, "Only one architecture is supported for now" model = get_supported_model(hf_config.architectures[0]) - return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype) + return MODEL_CONFIG_CONVERTER_REGISTRY[model](hf_config, dtype, **override_transformer_config_kwargs) def init_mcore_model( diff --git a/verl/single_controller/base/megatron/worker.py b/verl/single_controller/base/megatron/worker.py index 48365de08..3f7f5c2aa 100644 --- a/verl/single_controller/base/megatron/worker.py +++ b/verl/single_controller/base/megatron/worker.py @@ -39,7 +39,7 @@ class MegatronWorker(Worker): info = DistRankInfo(tp_rank=tp_rank, dp_rank=dp_rank, pp_rank=pp_rank, cp_rank=cp_rank) return info - def _init_hf_config_and_tf_config(self, model_path, dtype, override_model_config): + def _init_hf_config_and_tf_config(self, model_path, dtype, override_model_config, override_transformer_config): from transformers import AutoConfig from verl.models.mcore import hf_to_mcore_config @@ -66,7 +66,7 @@ class MegatronWorker(Worker): self.architectures = getattr(hf_config, "architectures", None) if self.rank == 0: print(f"Model config after override: {hf_config}") - tf_config = hf_to_mcore_config(hf_config, dtype) + tf_config = hf_to_mcore_config(hf_config, dtype, **override_transformer_config) def add_optimization_config_to_tf_config(tf_config, verl_model_config): # add optimization config to tf_config, e.g. checkpointing diff --git a/verl/trainer/config/ppo_megatron_trainer.yaml b/verl/trainer/config/ppo_megatron_trainer.yaml index e3cbaf9e3..a8e9b2501 100644 --- a/verl/trainer/config/ppo_megatron_trainer.yaml +++ b/verl/trainer/config/ppo_megatron_trainer.yaml @@ -82,6 +82,7 @@ actor_rollout_ref: use_dist_checkpointing: False dist_checkpointing_path: null seed: 1 + override_transformer_config: {} # additional transformer config like: num_layers_in_first(/last)_pipeline_stage profile: # profile the actor model in `update_policy` use_profile: False # open it when you want to profile the actor model profile_ranks: null # list, you can specify the ranks to profile @@ -107,6 +108,7 @@ actor_rollout_ref: use_dist_checkpointing: False dist_checkpointing_path: null seed: 1 + override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} profile: use_profile: False profile_ranks: null @@ -202,6 +204,7 @@ critic: use_dist_checkpointing: False dist_checkpointing_path: null seed: 1 + override_transformer_config: ${actor_rollout_ref.actor.megatron.override_transformer_config} load_weight: True ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: null # will be deprecated, use ppo_micro_batch_size_per_gpu @@ -235,6 +238,7 @@ reward_model: use_dist_checkpointing: False dist_checkpointing_path: null seed: 1 + override_transformer_config: {} model: input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical path: ~/models/FsfairX-LLaMA3-RM-v0.1 diff --git a/verl/utils/megatron_utils.py b/verl/utils/megatron_utils.py index 6980ad33c..508ce6608 100644 --- a/verl/utils/megatron_utils.py +++ b/verl/utils/megatron_utils.py @@ -785,3 +785,104 @@ def per_tensor_generator(actor_module, model_config, weight_converter, layer_nam converted_names, converted_params = weight_converter.convert_param(cur_name, infer_params) yield from zip(converted_names, converted_params) + + +def get_transformer_layer_offset(pipeline_rank, vp_rank, config: TransformerConfig): + ''' + Get the index offset of any pipeline stage, given the level of pipelining. + + Make pp_rank and vpp_rank as two arguments to make it more flexible, + which is able to fetch layer offset for any pipeline stage. + The original function only returns the layer offset for current pipeline stage. + + Extension to https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/transformer_layer.py::get_transformer_layer_offset""" + ''' + if config.pipeline_model_parallel_size > 1: + if config.num_layers_in_first_pipeline_stage is not None or config.num_layers_in_last_pipeline_stage is not None: + # Calculate number of pipeline stages to distribute the remaining Transformer + # layers after deducting the Transformer layers in the first or the last stages + middle_pipeline_stages = config.pipeline_model_parallel_size + middle_pipeline_stages -= sum( + [ + 1 if x is not None else 0 + for x in ( + config.num_layers_in_first_pipeline_stage, + config.num_layers_in_last_pipeline_stage, + ) + ] + ) + + # Calculate layers to distribute in each pipeline stage. If the + # num_layers_in_first_pipeline_stage and num_layers_in_last_pipeline_stage + # are not set, we will not enable uneven pipeline. All layers will be treated + # as middle layers. + num_layers_in_first_pipeline_stage = 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage + num_layers_in_last_pipeline_stage = 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage + + middle_num_layers = config.num_layers - num_layers_in_first_pipeline_stage - num_layers_in_last_pipeline_stage + + if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: + vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + # Calculate number of layers in each virtual model chunk + # If the num_layers_in_first_pipeline_stage and + # num_layers_in_last_pipeline_stage are not set, all pipeline stages + # will be treated as middle pipeline stages in the calculation + num_layers_per_virtual_model_chunk_in_first_pipeline_stage = 0 if config.num_layers_in_first_pipeline_stage is None else config.num_layers_in_first_pipeline_stage // vp_size + + num_layers_per_virtual_model_chunk_in_last_pipeline_stage = 0 if config.num_layers_in_last_pipeline_stage is None else config.num_layers_in_last_pipeline_stage // vp_size + + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage = middle_num_layers // vp_size + + # First stage + middle stage + last stage + total_virtual_chunks = num_layers_per_virtual_model_chunk_in_first_pipeline_stage + num_layers_per_vritual_model_chunk_in_middle_pipeline_stage + num_layers_per_virtual_model_chunk_in_last_pipeline_stage + + # Calculate the layer offset with interleaved uneven pipeline parallelism + if pipeline_rank == 0: + offset = vp_rank * total_virtual_chunks + else: + offset = vp_rank * total_virtual_chunks + num_layers_per_virtual_model_chunk_in_first_pipeline_stage + (pipeline_rank - 1) * (num_layers_per_vritual_model_chunk_in_middle_pipeline_stage // middle_pipeline_stages) + else: + if middle_pipeline_stages > 0: + num_layers_per_pipeline_rank = middle_num_layers // middle_pipeline_stages + else: + num_layers_per_pipeline_rank = 0 + + middle_pipeline_rank = pipeline_rank if config.num_layers_in_first_pipeline_stage is None else pipeline_rank - 1 + + if pipeline_rank == 0: + offset = 0 + else: + offset = (middle_pipeline_rank * num_layers_per_pipeline_rank) + num_layers_in_first_pipeline_stage + else: + num_layers = config.num_layers + + # Increase the number of layers by one if we include the embedding (loss) + # layer into pipeline parallelism partition and placement + if config.account_for_embedding_in_pipeline_split: + num_layers += 1 + + if config.account_for_loss_in_pipeline_split: + num_layers += 1 + + num_layers_per_pipeline_rank = num_layers // config.pipeline_model_parallel_size + + if mpu.get_virtual_pipeline_model_parallel_world_size() is not None: + vp_size = mpu.get_virtual_pipeline_model_parallel_world_size() + + num_layers_per_virtual_rank = num_layers_per_pipeline_rank // vp_size + total_virtual_chunks = num_layers // vp_size + offset = vp_rank * total_virtual_chunks + (pipeline_rank * num_layers_per_virtual_rank) + + # Reduce the offset of embedding layer from the total layer number + if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage(): + offset -= 1 + else: + offset = pipeline_rank * num_layers_per_pipeline_rank + + # Reduce the offset of embedding layer from the total layer number + if config.account_for_embedding_in_pipeline_split and not mpu.is_pipeline_first_stage(): + offset -= 1 + else: + offset = 0 + return offset diff --git a/verl/utils/model.py b/verl/utils/model.py index c59b79e8b..b964bc638 100644 --- a/verl/utils/model.py +++ b/verl/utils/model.py @@ -205,20 +205,13 @@ def compute_position_id_with_mask(mask): return torch.clip(torch.cumsum(mask, dim=-1) - 1, min=0, max=None) -def normalize_model_name(name, pp_rank, vpp_rank, pp_size, vpp_size, num_layers, layer_name="layers"): +def normalize_model_name(name, pp_rank, vpp_rank, transformer_config, layer_name="layers"): """ Transform the model name in each model_chunk in each pp stage into the name in inference engine """ - if vpp_size > 1: - # print(f'try to bind vpp params to inference engine...') - layers_per_pp = num_layers // pp_size - layers_per_vpp = layers_per_pp // vpp_size - pp_offset = layers_per_vpp * pp_rank - vpp_offset = (layers_per_vpp * pp_size) * vpp_rank - layer_offset = pp_offset + vpp_offset - else: - layers_per_pp = num_layers // pp_size - layer_offset = layers_per_pp * pp_rank + from verl.utils.megatron_utils import get_transformer_layer_offset + + layer_offset = get_transformer_layer_offset(pp_rank, vpp_rank, transformer_config) if layer_name in name: # belong to an intermediate layer split_name = name.split(".") diff --git a/verl/workers/megatron_workers.py b/verl/workers/megatron_workers.py index cb4466faf..8d89700d1 100644 --- a/verl/workers/megatron_workers.py +++ b/verl/workers/megatron_workers.py @@ -136,14 +136,14 @@ class ActorRolloutRefWorker(MegatronWorker): self.config.ref.ppo_micro_batch_size_per_gpu = self.config.ref.ppo_micro_batch_size self._ref_is_offload_param = self.config.ref.megatron.get("param_offload", False) - def _build_model_optimizer(self, model_path, optim_config, override_model_config): + def _build_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.megatron.optimizer import get_megatron_optimizer from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import get_generation_config, print_model_size - self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config) + self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config) self.generation_config = get_generation_config(self.local_path) def megatron_actor_model_provider(pre_process, post_process): @@ -248,6 +248,7 @@ class ActorRolloutRefWorker(MegatronWorker): sharding_manager = MegatronVLLMShardingManager( inference_engine=rollout.inference_engine, model_config=self.actor_model_config, + transformer_config=self.tf_config, layer_name_mapping=layer_name_mapping, actor_module=self.actor.actor_module, weight_converter=weight_converter, @@ -297,6 +298,12 @@ class ActorRolloutRefWorker(MegatronWorker): from verl.utils.torch_dtypes import PrecisionType override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + if self._is_actor: + override_transformer_config = OmegaConf.to_container(self.config.actor.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) + elif self._is_ref: + override_transformer_config = OmegaConf.to_container(self.config.ref.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) + else: + override_transformer_config = None self.param_dtype = torch.bfloat16 log_gpu_memory_usage("Before init actor model and optimizer", logger=logger) self.dtype = PrecisionType.to_dtype(self.param_dtype) @@ -307,6 +314,7 @@ class ActorRolloutRefWorker(MegatronWorker): model_path=self.config.model.path, optim_config=optim_config, override_model_config=override_model_config, + override_transformer_config=override_transformer_config, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.actor_module) @@ -335,6 +343,7 @@ class ActorRolloutRefWorker(MegatronWorker): model_path=self.config.model.path, optim_config=None, override_model_config=override_model_config, + override_transformer_config=override_transformer_config, ) log_gpu_memory_usage("After ref model init", logger=logger) self.ref_policy = MegatronPPOActor( @@ -545,14 +554,14 @@ class CriticWorker(MegatronWorker): # TODO(sgm): support critic model offload - def _build_critic_model_optimizer(self, model_path, optim_config, override_model_config): + def _build_critic_model_optimizer(self, model_path, optim_config, override_model_config, override_transformer_config): from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.megatron.optimizer import get_megatron_optimizer from verl.utils.megatron_utils import get_model, init_megatron_optim_config from verl.utils.model import print_model_size - self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config) + self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config) def megatron_critic_model_provider(pre_process, post_process): from verl.models.mcore import init_mcore_model @@ -603,12 +612,14 @@ class CriticWorker(MegatronWorker): importlib.import_module(self.config.model.external_lib) override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_transformer_config = OmegaConf.to_container(self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) self.param_dtype = torch.bfloat16 self.dtype = PrecisionType.to_dtype(self.param_dtype) self.critic_module, self.critic_optimizer, self.critic_model_config, critic_optimizer_config = self._build_critic_model_optimizer( model_path=self.config.model.path, optim_config=self.config.optim, override_model_config=override_model_config, + override_transformer_config=override_transformer_config, ) if self._is_offload_param: offload_megatron_model_to_cpu(self.critic_module) @@ -737,12 +748,12 @@ class RewardModelWorker(MegatronWorker): self.config.micro_batch_size //= mpu.get_data_parallel_world_size() self.config.micro_batch_size_per_gpu = self.config.micro_batch_size - def _build_rm_model(self, model_path, override_model_config): + def _build_rm_model(self, model_path, override_model_config, override_transformer_config): from megatron.core.models.gpt.gpt_model import ModelType from verl.utils.megatron_utils import get_model - self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config) + self._init_hf_config_and_tf_config(model_path, self.dtype, override_model_config, override_transformer_config) def megatron_rm_model_provider(pre_process, post_process): from verl.models.mcore import init_mcore_model @@ -792,6 +803,7 @@ class RewardModelWorker(MegatronWorker): importlib.import_module(self.config.model.external_lib) override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create())) + override_transformer_config = OmegaConf.to_container(self.config.megatron.get("override_transformer_config", OmegaConf.create()), resolve=True) sft_tokenizer_local_path = copy_to_local(self.config.model.input_tokenizer) sft_tokenizer = hf_tokenizer(sft_tokenizer_local_path) @@ -807,6 +819,7 @@ class RewardModelWorker(MegatronWorker): reward_model_module, reward_model_config = self._build_rm_model( model_path=self.config.model.path, override_model_config=override_model_config, + override_transformer_config=override_transformer_config, ) # FIXME(sgm): reward model param offload is implemented in MegatronRewardModel # should be implemented in workers diff --git a/verl/workers/sharding_manager/megatron_vllm.py b/verl/workers/sharding_manager/megatron_vllm.py index 1401c582f..3c2767123 100644 --- a/verl/workers/sharding_manager/megatron_vllm.py +++ b/verl/workers/sharding_manager/megatron_vllm.py @@ -274,6 +274,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): actor_module: nn.ModuleList, inference_engine: LLM, model_config, + transformer_config, layer_name_mapping, weight_converter: McoreToHFWeightConverterBase, module: AllGatherPPModel = None, @@ -283,6 +284,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): self.actor_module = actor_module self.inference_engine = inference_engine self.model_config = model_config + self.transformer_config = transformer_config self.layer_name_mapping = layer_name_mapping self.weight_converter = weight_converter self.module = module @@ -313,7 +315,6 @@ class MegatronVLLMShardingManager(BaseShardingManager): from megatron.core import parallel_state as mpu pp_rank = mpu.get_pipeline_model_parallel_rank() - pp_size = mpu.get_pipeline_model_parallel_world_size() vpp_size = len(self.actor_module) all_gather_group = self.train_tp_group @@ -347,7 +348,7 @@ class MegatronVLLMShardingManager(BaseShardingManager): cur_name, cur_tensor = next(gen_func) except StopIteration: cur_name, cur_tensor = None, None - cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, pp_size, vpp_size, self.model_config.num_hidden_layers) + cur_name = normalize_model_name(name, cur_pp_rank, scan_vpp_idx, self.transformer_config) else: cur_tensor, cur_name = None, None