mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
### What does this PR do? I followed the instructions at https://verl.readthedocs.io/en/latest/start/quickstart.html to run the PPO example on my devbox, which uses zsh. However, I got the error zsh: no matches found: `trainer.logger=[console]` because `[]` is interpreted as a glob pattern in zsh. ``` (verl) ➜ verl git:(20250713-devbox-2-tmux0-verl-2) ✗ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ data.train_files=$HOME/data/gsm8k/train.parquet \ data.val_files=$HOME/data/gsm8k/test.parquet \ data.train_batch_size=256 \ data.max_prompt_length=512 \ data.max_response_length=256 \ actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ actor_rollout_ref.actor.optim.lr=1e-6 \ actor_rollout_ref.actor.ppo_mini_batch_size=64 \ actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ critic.optim.lr=1e-5 \ critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ critic.ppo_micro_batch_size_per_gpu=4 \ algorithm.kl_ctrl.kl_coef=0.001 \ trainer.logger=['console'] \ trainer.val_before_train=False \ trainer.n_gpus_per_node=1 \ trainer.nnodes=1 \ trainer.save_freq=10 \ trainer.test_freq=10 \ trainer.total_epochs=15 2>&1 | tee verl_demo.log zsh: no matches found: trainer.logger=[console] ``` This PR has 3 changes: * `trainer.logger=['console']` -> `trainer.logger=console` * `trainer.logger=['console','wandb']` -> `trainer.logger='["console","wandb"]'` * `trainer.logger=['console','tensorboard']` -> `trainer.logger='["console","tensorboard"]'` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test * `trainer.logger=console` (zsh) <img width="898" height="564" alt="image" src="https://github.com/user-attachments/assets/a957a493-75e6-462b-9974-6b1c4cdf5a80" /> * ``trainer.logger='["console","wandb"]'`` (zsh) <img width="870" height="565" alt="image" src="https://github.com/user-attachments/assets/e20613bf-2ccc-4653-b23f-90edc3d568d1" /> * `trainer.logger=console` (bash) ```bash ubuntu@ip-xxx-xx-x-xxx:~/verl$ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ > data.train_files=$HOME/data/gsm8k/train.parquet \ > data.val_files=$HOME/data/gsm8k/test.parquet \ > data.train_batch_size=256 \ > data.max_prompt_length=512 \ > data.max_response_length=256 \ > actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ > actor_rollout_ref.actor.optim.lr=1e-6 \ > actor_rollout_ref.actor.ppo_mini_batch_size=64 \ > actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ > actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ > actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ > actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ > actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ > critic.optim.lr=1e-5 \ > critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ > critic.ppo_micro_batch_size_per_gpu=4 \ > algorithm.kl_ctrl.kl_coef=0.001 \ > trainer.logger=console \ > trainer.val_before_train=False \ > trainer.n_gpus_per_node=1 \ > trainer.nnodes=1 \ > trainer.save_freq=10 \ > trainer.test_freq=10 \ > trainer.total_epochs=15 2>&1 | tee verl_demo.log 2025-07-14 02:52:27,669 INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 (TaskRunner pid=1799248) TaskRunner hostname: ip-172-31-9-244, PID: 1799248 (TaskRunner pid=1799248) {'actor_rollout_ref': {'actor': {'checkpoint': {'load_contents': ['model', (TaskRunner pid=1799248) 'optimizer', (TaskRunner pid=1799248) 'extra'], (TaskRunner pid=1799248) 'save_contents': ['model', (TaskRunner pid=1799248) 'optimizer', (TaskRunner pid=1799248) 'extra']}, ``` * `trainer.logger='["console","wandb"]'` (bash) ```bash ubuntu@ip-xxx-xx-x-xxx:~/verl$ PYTHONUNBUFFERED=1 python3 -m verl.trainer.main_ppo \ > data.train_files=$HOME/data/gsm8k/train.parquet \ > data.val_files=$HOME/data/gsm8k/test.parquet \ > data.train_batch_size=256 \ > data.max_prompt_length=512 \ > data.max_response_length=256 \ > actor_rollout_ref.model.path=Qwen/Qwen2.5-0.5B-Instruct \ > actor_rollout_ref.actor.optim.lr=1e-6 \ > actor_rollout_ref.actor.ppo_mini_batch_size=64 \ > actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=4 \ > actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=8 \ > actor_rollout_ref.rollout.tensor_model_parallel_size=1 \ > actor_rollout_ref.rollout.gpu_memory_utilization=0.4 \ > actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ > critic.optim.lr=1e-5 \ > critic.model.path=Qwen/Qwen2.5-0.5B-Instruct \ > critic.ppo_micro_batch_size_per_gpu=4 \ > algorithm.kl_ctrl.kl_coef=0.001 \ > trainer.logger='["console","wandb"]' \ > trainer.val_before_train=False \ > trainer.n_gpus_per_node=1 \ > trainer.nnodes=1 \ > trainer.save_freq=10 \ > trainer.test_freq=10 \ > trainer.total_epochs=15 2>&1 | tee verl_demo.log 2025-07-14 02:54:13,989 INFO worker.py:1908 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 (TaskRunner pid=1805000) TaskRunner hostname: ip-172-31-9-244, PID: 1805000 (TaskRunner pid=1805000) {'actor_rollout_ref': {'actor': {'checkpoint': {'load_contents': ['model', (TaskRunner pid=1805000) 'optimizer', (TaskRunner pid=1805000) 'extra'], (TaskRunner pid=1805000) 'save_contents': ['model', (TaskRunner pid=1805000) 'optimizer', (TaskRunner pid=1805000) 'extra']}, ``` ### API and Usage Example No ### Design & Code Changes No ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [ ] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). --------- Signed-off-by: Kai-Hsun Chen <kaihsun@anyscale.com>
61 lines
2.4 KiB
Bash
61 lines
2.4 KiB
Bash
set -x
|
|
|
|
gsm8k_train_path=$HOME/data/gsm8k/train.parquet
|
|
gsm8k_test_path=$HOME/data/gsm8k/test.parquet
|
|
math_train_path=$HOME/data/math/train.parquet
|
|
math_test_path=$HOME/data/math/test.parquet
|
|
|
|
train_files="['$gsm8k_train_path', '$math_train_path']"
|
|
test_files="['$gsm8k_test_path', '$math_test_path']"
|
|
|
|
# For async rollout mode, dataset should return raw chat.
|
|
rollout_mode="sync"
|
|
if [ "$rollout_mode" = "async" ]; then
|
|
return_raw_chat="True"
|
|
fi
|
|
|
|
python3 -m verl.trainer.main_ppo \
|
|
algorithm.adv_estimator=gae \
|
|
data.train_files="$train_files" \
|
|
data.val_files="$test_files" \
|
|
data.return_raw_chat=$return_raw_chat \
|
|
data.train_batch_size=4096 \
|
|
data.max_prompt_length=4096 \
|
|
data.max_response_length=4096 \
|
|
data.filter_overlong_prompts=True \
|
|
data.truncation='error' \
|
|
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
|
|
actor_rollout_ref.actor.optim.lr=1e-6 \
|
|
actor_rollout_ref.model.use_remove_padding=True \
|
|
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
|
actor_rollout_ref.actor.ppo_mini_batch_size=512 \
|
|
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
|
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=24000 \
|
|
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
|
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
|
actor_rollout_ref.actor.use_kl_loss=False \
|
|
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
|
|
actor_rollout_ref.rollout.name=vllm \
|
|
actor_rollout_ref.rollout.mode=$rollout_mode \
|
|
actor_rollout_ref.rollout.multi_turn.format=hermes \
|
|
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
|
|
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
|
|
critic.optim.lr=1e-5 \
|
|
critic.model.use_remove_padding=True \
|
|
critic.model.path=Qwen/Qwen2-7B-Instruct \
|
|
critic.model.enable_gradient_checkpointing=True \
|
|
critic.ppo_max_token_len_per_gpu=98304 \
|
|
critic.model.fsdp_config.param_offload=False \
|
|
critic.model.fsdp_config.optimizer_offload=False \
|
|
algorithm.use_kl_in_reward=False \
|
|
trainer.critic_warmup=0 \
|
|
trainer.logger='["console","wandb"]' \
|
|
trainer.project_name='verl_example_gsm8k' \
|
|
trainer.experiment_name='qwen2-7b_function_rm_bsz8k_p4k_r4k_seq_packing' \
|
|
trainer.n_gpus_per_node=8 \
|
|
trainer.val_before_train=False \
|
|
trainer.nnodes=1 \
|
|
trainer.save_freq=20 \
|
|
trainer.test_freq=5 \
|
|
trainer.total_epochs=15 $@
|