mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[megatron] feat: support qwen3vl (#3763)
### What does this PR do? > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. support training qwen3vl with megatron 1. add an image with vllm0.11 and nemo's dedicated megatron that support gpt-oss with optimized fused kernels. 2. add a script of training qwen3vl-30b with megatron 3. necessary changes to support qwen3vl megatron. (just register forward functions, the modeling is through mbridge) ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. <img width="372" height="314" alt="image" src="https://github.com/user-attachments/assets/f1126e46-51a9-4e00-958f-5d034b8f94bd" /> ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
@ -36,6 +36,8 @@ For vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyo
|
||||
|
||||
For SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.
|
||||
|
||||
For latest vLLM with Megatron, please refer to [iseekyan/verl](https://hub.docker.com/r/iseekyan/verl) repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
|
||||
|
||||
See files under ``docker/`` for NGC-based image or if you want to build your own.
|
||||
|
||||
Note that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``
|
||||
|
@ -0,0 +1,15 @@
|
||||
FROM nvcr.io/nvidia/nemo:25.07.gpt_oss
|
||||
|
||||
RUN git clone -b v0.11.0 --depth 1 https://github.com/vllm-project/vllm.git /opt/vllm
|
||||
|
||||
RUN pip install setuptools_scm
|
||||
|
||||
RUN cd /opt/vllm && pip install --no-deps --no-build-isolation --no-cache-dir -e .
|
||||
|
||||
RUN pip install cbor2 setproctitle blake3 openai_harmony pybase64 msgspec partial_json_parser py-cpuinfo diskcache gguf
|
||||
|
||||
RUN pip install --upgrade transformers tokenizers
|
||||
|
||||
RUN pip install codetiming tensordict mathruler pylatexenc
|
||||
|
||||
RUN pip3 install --no-cache-dir mbridge
|
@ -79,7 +79,7 @@ For latest vLLM with FSDP, please refer to `hiyouga/verl <https://hub.docker.com
|
||||
|
||||
For latest SGLang with FSDP, please refer to `hebiaobuaa/verl <https://hub.docker.com/r/hebiaobuaa/verl>`_ repository and the latest version is ``hebiaobuaa/verl:app-verl0.5-sglang0.4.9.post6-mcore0.12.2-te2.2`` which is provided by SGLang RL Group.
|
||||
|
||||
For latest vLLM with Megatron, please refer to `iseekyan/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.15.0-te2.7`
|
||||
For latest vLLM with Megatron, please refer to `iseekyan/verl <https://hub.docker.com/r/iseekyan/verl>`_ repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
|
||||
|
||||
See files under ``docker/`` for NGC-based image or if you want to build your own.
|
||||
|
||||
|
79
examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh
Normal file
79
examples/grpo_trainer/run_qwen3_vl-30b-megatron.sh
Normal file
@ -0,0 +1,79 @@
|
||||
set -x
|
||||
ENGINE=${1:-vllm}
|
||||
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
|
||||
|
||||
# VLLM version >= 0.11.0 for qwen3-vl support, recommend to use container docker://iseekyan/verl:nemo.gptoss_vllm0.11.0
|
||||
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git # for latest mbridge
|
||||
# pip install -U transformers # for qwen3-vl support
|
||||
# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1 # for megatron-lm0.13.1
|
||||
|
||||
|
||||
export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP
|
||||
|
||||
|
||||
HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"}
|
||||
|
||||
|
||||
train_path=$HOME/data/geo3k/train.parquet
|
||||
test_path=$HOME/data/geo3k/test.parquet
|
||||
|
||||
python3 -m verl.trainer.main_ppo --config-path=config \
|
||||
--config-name='ppo_megatron_trainer.yaml'\
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files="$train_path" \
|
||||
data.val_files="$test_path" \
|
||||
data.train_batch_size=512 \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=2048 \
|
||||
data.filter_overlong_prompts=True \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.path=$HF_MODEL_PATH \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.actor.megatron.expert_model_parallel_size=8 \
|
||||
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
|
||||
actor_rollout_ref.actor.use_kl_loss=True \
|
||||
actor_rollout_ref.actor.kl_loss_coef=0.01 \
|
||||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \
|
||||
actor_rollout_ref.rollout.name=$ENGINE \
|
||||
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
|
||||
actor_rollout_ref.rollout.n=5 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
|
||||
actor_rollout_ref.actor.megatron.use_mbridge=True \
|
||||
actor_rollout_ref.actor.megatron.param_offload=True \
|
||||
actor_rollout_ref.actor.megatron.optimizer_offload=True \
|
||||
actor_rollout_ref.actor.megatron.grad_offload=True \
|
||||
actor_rollout_ref.ref.megatron.param_offload=True \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
|
||||
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
|
||||
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
|
||||
algorithm.use_kl_in_reward=False \
|
||||
trainer.critic_warmup=0 \
|
||||
trainer.logger='["console","wandb"]' \
|
||||
trainer.project_name='verl_grpo_example_geo3k' \
|
||||
trainer.experiment_name='qwen3_vl_30b_megatron' \
|
||||
trainer.n_gpus_per_node=8 \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=20 \
|
||||
trainer.test_freq=5 \
|
||||
trainer.total_epochs=15 $@
|
@ -113,7 +113,7 @@ def gptmodel_forward_qwen2_5_vl(
|
||||
output_orig = model(
|
||||
input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids,
|
||||
position_ids=None, # model will calculate position_ids
|
||||
packed_seq_params=packed_seq_params,
|
||||
pixel_values=pixel_values,
|
||||
image_grid_thw=image_grid_thw,
|
||||
|
@ -74,6 +74,7 @@ class SupportedModel(Enum):
|
||||
GLM4_MOE = "Glm4MoeForCausalLM"
|
||||
|
||||
QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"
|
||||
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
|
||||
|
||||
|
||||
# Registry for model configuration converters
|
||||
@ -118,6 +119,7 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.QWEN3: gptmodel_forward,
|
||||
SupportedModel.QWEN3_MOE: gptmodel_forward,
|
||||
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
|
||||
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_qwen2_5_vl,
|
||||
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
|
||||
SupportedModel.GLM4_MOE: gptmodel_forward,
|
||||
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
|
||||
@ -131,6 +133,7 @@ MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
|
||||
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding,
|
||||
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3: gptmodel_forward_no_padding,
|
||||
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
|
||||
@ -148,6 +151,7 @@ MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
|
||||
SupportedModel.MIXTRAL: fused_forward_gptmodel,
|
||||
SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
|
||||
SupportedModel.QWEN3_MOE_VL: fused_forward_qwen2_5_vl,
|
||||
SupportedModel.LLAMA4: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN3: fused_forward_gptmodel,
|
||||
SupportedModel.QWEN3_MOE: fused_forward_gptmodel,
|
||||
|
Reference in New Issue
Block a user