Files
verl/recipe/deepeyes/run_deepeyes_grpo.sh
Minghui Jia 9f4161e250 [recipe] feat: add deepeyes recipe (#2398)
### What does this PR do?

This PR introduces a complete training recipe for [DeepEyes:
Incentivizing "Thinking with Images" via Reinforcement
Learning](https://arxiv.org/abs/2505.14362).

The core feature is the support for multi-turn visual tools,
specifically the `ImageZoomInTool`, integrated with a custom reward
function based on the "LLM-as-a-Judge" pattern to evaluate model
performance.

Additionally, to better monitor and analyze the model's tool-use
behavior, this PR adds functionality to track tool call counts during
the training process and reports these metrics to logging systems like
wandb.

### API and Usage Example

The primary change is the new training recipe for DeepEyes. Users can
start a training run by using the provided configuration file.

1. Preprocess the dataset. We need to add some tool-related extra_info:
```bash
python recipe/deepeyes/deepeyes47k_preprocess.py --dataset_dir <path_to_raw_dataset> --save_dir <path_to_processed_data>
```
2. Start the PPO training:
```bash
bash recipe/deepeyes/run_deepeyes_grpo.sh
```
The training process will automatically load the ImageZoomInTool and the
custom reward function as defined in the recipe.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

- **DeepEyes Recipe Integration**: Added a new recipe directory with
data preprocessing, tool config, and a custom reward function for
DeepEyes.
- **Visual Tool Support**: Implemented `ImageZoomInTool` with robust
bbox validation and resizing.
- **Tool Call Statistics**: Modified the rollout and metrics code to
track and log tool call counts per sample and per step.
- **Bug Fixes**: Fixed image byte handling and ensured special tokens
are preserved during decoding for tool call formatting.

### Checklist Before Submitting

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Co-authored-by: Maxwell-Jia <mr.minghui.jia@gamil.com>
Co-authored-by: xieck13 <xieck13@gmail.com>
Co-authored-by: Claude <noreply@anthropic.com>
2025-08-12 09:51:58 +08:00

74 lines
3.2 KiB
Bash

#!/bin/bash
set -x
export LLM_AS_A_JUDGE_BASE="your llm-as-a-judge server/v1"
export WANDB_API_KEY="your wandb key"
PROJECT_NAME="your_project_name"
EXPERIMENT_NAME="your_experiment_name"
BASEDIR=base_dir
SAVE_CHECKPOINT_DIR=${BASEDIR}/verl_checkpoints
DATASET_TRAIN=${BASEDIR}/dataset/train.parquet
DATASET_VAL=${BASEDIR}/dataset/val.parquet
REF_MODEL_PATH=ref_model_path
PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \
--config-path=${BASEDIR}/recipe/deepeyes/configs \
--config-name='deepeyes_multiturn_grpo' \
data.train_files=${DATASET_TRAIN} \
data.val_files=[${DATASET_VAL}] \
data.train_batch_size=128 \
data.max_prompt_length=8192 \
data.max_response_length=16384 \
data.return_raw_chat=True \
data.filter_overlong_prompts=True \
algorithm.adv_estimator=grpo \
algorithm.kl_ctrl.kl_coef=0.0 \
actor_rollout_ref.model.path=${REF_MODEL_PATH} \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.model.use_fused_kernels=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.actor.kl_loss_coef=0.0 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0.0 \
actor_rollout_ref.actor.checkpoint.save_contents=['model','hf_model','optimizer','extra'] \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=1 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.mode=async \
actor_rollout_ref.rollout.n=8 \
actor_rollout_ref.rollout.max_num_batched_tokens=32768 \
actor_rollout_ref.rollout.gpu_memory_utilization=0.8 \
actor_rollout_ref.rollout.enforce_eager=True \
actor_rollout_ref.rollout.free_cache_engine=True \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.actor.fsdp_config.param_offload=True \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
actor_rollout_ref.rollout.multi_turn.enable=True \
actor_rollout_ref.rollout.multi_turn.max_assistant_turns=5 \
actor_rollout_ref.rollout.multi_turn.max_user_turns=5 \
actor_rollout_ref.rollout.multi_turn.max_parallel_calls=1 \
actor_rollout_ref.rollout.multi_turn.tool_config_path=recipe/deepeyes/configs/image_zoom_in_tool_config.yaml \
trainer.critic_warmup=0 \
trainer.logger=['console','wandb','tensorboard'] \
trainer.val_before_train=False \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=8 \
trainer.test_freq=80 \
trainer.project_name=${PROJECT_NAME} \
trainer.experiment_name=${EXPERIMENT_NAME} \
trainer.default_local_dir=${SAVE_CHECKPOINT_DIR}/${PROJECT_NAME}/${EXPERIMENT_NAME} \
+trainer.tensorboard_dir=${SAVE_CHECKPOINT_DIR}/logs/tensorboard \
+trainer.rl_logging_board_dir=${SAVE_CHECKPOINT_DIR}/logs/rl_logging_board \
trainer.total_epochs=1 2>&1 | tee ./logs/${EXPERIMENT_NAME}.log