mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
### What does this PR do? This PR introduces a complete training recipe for [DeepEyes: Incentivizing "Thinking with Images" via Reinforcement Learning](https://arxiv.org/abs/2505.14362). The core feature is the support for multi-turn visual tools, specifically the `ImageZoomInTool`, integrated with a custom reward function based on the "LLM-as-a-Judge" pattern to evaluate model performance. Additionally, to better monitor and analyze the model's tool-use behavior, this PR adds functionality to track tool call counts during the training process and reports these metrics to logging systems like wandb. ### API and Usage Example The primary change is the new training recipe for DeepEyes. Users can start a training run by using the provided configuration file. 1. Preprocess the dataset. We need to add some tool-related extra_info: ```bash python recipe/deepeyes/deepeyes47k_preprocess.py --dataset_dir <path_to_raw_dataset> --save_dir <path_to_processed_data> ``` 2. Start the PPO training: ```bash bash recipe/deepeyes/run_deepeyes_grpo.sh ``` The training process will automatically load the ImageZoomInTool and the custom reward function as defined in the recipe. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes - **DeepEyes Recipe Integration**: Added a new recipe directory with data preprocessing, tool config, and a custom reward function for DeepEyes. - **Visual Tool Support**: Implemented `ImageZoomInTool` with robust bbox validation and resizing. - **Tool Call Statistics**: Modified the rollout and metrics code to track and log tool call counts per sample and per step. - **Bug Fixes**: Fixed image byte handling and ensured special tokens are preserved during decoding for tool call formatting. ### Checklist Before Submitting - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [ ] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [ ] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Co-authored-by: Maxwell-Jia <mr.minghui.jia@gamil.com> Co-authored-by: xieck13 <xieck13@gmail.com> Co-authored-by: Claude <noreply@anthropic.com>
74 lines
3.2 KiB
Bash
74 lines
3.2 KiB
Bash
#!/bin/bash
|
|
|
|
set -x
|
|
|
|
export LLM_AS_A_JUDGE_BASE="your llm-as-a-judge server/v1"
|
|
export WANDB_API_KEY="your wandb key"
|
|
|
|
PROJECT_NAME="your_project_name"
|
|
EXPERIMENT_NAME="your_experiment_name"
|
|
|
|
BASEDIR=base_dir
|
|
SAVE_CHECKPOINT_DIR=${BASEDIR}/verl_checkpoints
|
|
DATASET_TRAIN=${BASEDIR}/dataset/train.parquet
|
|
DATASET_VAL=${BASEDIR}/dataset/val.parquet
|
|
|
|
REF_MODEL_PATH=ref_model_path
|
|
|
|
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
|
|
--config-path=${BASEDIR}/recipe/deepeyes/configs \
|
|
--config-name='deepeyes_multiturn_grpo' \
|
|
data.train_files=${DATASET_TRAIN} \
|
|
data.val_files=[${DATASET_VAL}] \
|
|
data.train_batch_size=128 \
|
|
data.max_prompt_length=8192 \
|
|
data.max_response_length=16384 \
|
|
data.return_raw_chat=True \
|
|
data.filter_overlong_prompts=True \
|
|
algorithm.adv_estimator=grpo \
|
|
algorithm.kl_ctrl.kl_coef=0.0 \
|
|
actor_rollout_ref.model.path=${REF_MODEL_PATH} \
|
|
actor_rollout_ref.model.use_remove_padding=True \
|
|
actor_rollout_ref.model.use_fused_kernels=True \
|
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
|
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
|
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
|
|
actor_rollout_ref.actor.use_kl_loss=False \
|
|
actor_rollout_ref.actor.kl_loss_coef=0.0 \
|
|
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
|
|
actor_rollout_ref.actor.entropy_coeff=0.0 \
|
|
actor_rollout_ref.actor.checkpoint.save_contents=['model','hf_model','optimizer','extra'] \
|
|
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
|
|
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
|
|
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
|
actor_rollout_ref.rollout.name=sglang \
|
|
actor_rollout_ref.rollout.mode=async \
|
|
actor_rollout_ref.rollout.n=8 \
|
|
actor_rollout_ref.rollout.max_num_batched_tokens=32768 \
|
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
|
|
actor_rollout_ref.rollout.enforce_eager=True \
|
|
actor_rollout_ref.rollout.free_cache_engine=True \
|
|
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
|
actor_rollout_ref.actor.fsdp_config.param_offload=True \
|
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
|
|
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
|
|
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
|
actor_rollout_ref.rollout.multi_turn.enable=True \
|
|
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5 \
|
|
actor_rollout_ref.rollout.multi_turn.max_user_turns=5 \
|
|
actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \
|
|
actor_rollout_ref.rollout.multi_turn.tool_config_path=recipe/deepeyes/configs/image_zoom_in_tool_config.yaml \
|
|
trainer.critic_warmup=0 \
|
|
trainer.logger=['console','wandb','tensorboard'] \
|
|
trainer.val_before_train=False \
|
|
trainer.n_gpus_per_node=8 \
|
|
trainer.nnodes=1 \
|
|
trainer.save_freq=8 \
|
|
trainer.test_freq=80 \
|
|
trainer.project_name=${PROJECT_NAME} \
|
|
trainer.experiment_name=${EXPERIMENT_NAME} \
|
|
trainer.default_local_dir=${SAVE_CHECKPOINT_DIR}/${PROJECT_NAME}/${EXPERIMENT_NAME} \
|
|
+trainer.tensorboard_dir=${SAVE_CHECKPOINT_DIR}/logs/tensorboard \
|
|
+trainer.rl_logging_board_dir=${SAVE_CHECKPOINT_DIR}/logs/rl_logging_board \
|
|
trainer.total_epochs=1 2>&1 | tee ./logs/${EXPERIMENT_NAME}.log
|