[megatron] feat: support qwen3vl (#3763)

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

support training qwen3vl with megatron

1. add an image with vllm0.11 and nemo's dedicated megatron that support
gpt-oss with optimized fused kernels.
2. add a script of training qwen3vl-30b with megatron
3. necessary changes to support qwen3vl megatron. (just register forward
functions, the modeling is through mbridge)


### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
<img width="372" height="314" alt="image"
src="https://github.com/user-attachments/assets/f1126e46-51a9-4e00-958f-5d034b8f94bd"
/>

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
Yan Bai
2025-10-15 10:19:22 +08:00
committed by GitHub
parent 67f9a21b8e
commit 33eb86f54f
6 changed files with 102 additions and 2 deletions

View File

@ -36,6 +36,8 @@ For vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyo
For SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.
For latest vLLM with Megatron, please refer to [iseekyan/verl](https://hub.docker.com/r/iseekyan/verl) repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
See files under ``docker/`` for NGC-based image or if you want to build your own.
Note that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``

View File

@ -0,0 +1,15 @@
FROM nvcr.io/nvidia/nemo:25.07.gpt_oss
RUN git clone -b v0.11.0 --depth 1 https://github.com/vllm-project/vllm.git /opt/vllm
RUN pip install setuptools_scm
RUN cd /opt/vllm && pip install --no-deps --no-build-isolation --no-cache-dir -e .
RUN pip install cbor2 setproctitle blake3 openai_harmony pybase64 msgspec partial_json_parser py-cpuinfo diskcache gguf
RUN pip install --upgrade transformers tokenizers
RUN pip install codetiming tensordict mathruler pylatexenc
RUN pip3 install --no-cache-dir mbridge

View File

@ -79,7 +79,7 @@ For latest vLLM with FSDP, please refer to `hiyouga/verl <https://hub.docker.com
For latest SGLang with FSDP, please refer to `hebiaobuaa/verl <https://hub.docker.com/r/hebiaobuaa/verl>`_ repository and the latest version is ``hebiaobuaa/verl:app-verl0.5-sglang0.4.9.post6-mcore0.12.2-te2.2`` which is provided by SGLang RL Group.
For latest vLLM with Megatron, please refer to `iseekyan/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.15.0-te2.7`
For latest vLLM with Megatron, please refer to `iseekyan/verl <https://hub.docker.com/r/iseekyan/verl>`_ repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
See files under ``docker/`` for NGC-based image or if you want to build your own.

View File

@ -0,0 +1,79 @@
set -x
ENGINE=${1:-vllm}
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
# VLLM version >= 0.11.0 for qwen3-vl support, recommend to use container docker://iseekyan/verl:nemo.gptoss_vllm0.11.0
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git # for latest mbridge
# pip install -U transformers # for qwen3-vl support
# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1 # for megatron-lm0.13.1
export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP
HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"}
train_path=$HOME/data/geo3k/train.parquet
test_path=$HOME/data/geo3k/test.parquet
python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \
data.train_files="$train_path" \
data.val_files="$test_path" \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$HF_MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=8 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \
actor_rollout_ref.rollout.name=$ENGINE \
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.megatron.use_mbridge=True \
actor_rollout_ref.actor.megatron.param_offload=True \
actor_rollout_ref.actor.megatron.optimizer_offload=True \
actor_rollout_ref.actor.megatron.grad_offload=True \
actor_rollout_ref.ref.megatron.param_offload=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen3_vl_30b_megatron' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@

View File

@ -113,7 +113,7 @@ def gptmodel_forward_qwen2_5_vl(
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
position_ids=None, # model will calculate position_ids
packed_seq_params=packed_seq_params,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,

View File

@ -74,6 +74,7 @@ class SupportedModel(Enum):
GLM4_MOE = "Glm4MoeForCausalLM"
QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
# Registry for model configuration converters
@ -118,6 +119,7 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.QWEN3: gptmodel_forward,
SupportedModel.QWEN3_MOE: gptmodel_forward,
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
SupportedModel.GLM4_MOE: gptmodel_forward,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
@ -131,6 +133,7 @@ MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding,
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
SupportedModel.QWEN3: gptmodel_forward_no_padding,
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
@ -148,6 +151,7 @@ MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.MIXTRAL: fused_forward_gptmodel,
SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
SupportedModel.QWEN3_MOE_VL: fused_forward_qwen2_5_vl,
SupportedModel.LLAMA4: fused_forward_gptmodel,
SupportedModel.QWEN3: fused_forward_gptmodel,
SupportedModel.QWEN3_MOE: fused_forward_gptmodel,