[trainer, recipe] feat: fully async training recipe (#2981)

### What does this PR do?

To implement a purely asynchronous training workflow, we further split
the training process into a Trainer and a Rollouter based on the
existing one-step-off policy code, with samples transmitted via a
message queue.

We will continue to integrate partial rollout to mitigate the impact of
long-tail training.

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review. https://github.com/volcengine/verl/pull/2231
https://github.com/volcengine/verl/pull/2200

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: meituan-search <machi04@meituan.com>
Co-authored-by: wangshulin02 <wangshulin02@meituan.com>
Co-authored-by: arron <arron@MBP-2G17FXQ05P-2332.local>
Co-authored-by: wangshulin02 <953550366@qq.com>
Co-authored-by: hadoop-ai-search <hadoop-ai-search@set-zw04-mlp-codelab-pc1189.mt>
Co-authored-by: sl-1314 <82856253+sl-1314@users.noreply.github.com>
Co-authored-by: arron <arron@MBP-VH9RV7LTJC-1907.local>
Co-authored-by: arron <arron@MBP-JFQXPWR11F-1943.local>
This commit is contained in:
arron
2025-10-17 22:29:18 +08:00
committed by GitHub
parent dd8864f9ee
commit b25bb7d4f3
39 changed files with 6292 additions and 35 deletions

View File

@ -0,0 +1,149 @@
# # Tests layout
# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:
# - `tests/trainer` for testing functionality related to `verl/trainer`
# - `tests/models` for testing functionality related to `verl/models`
# - ...
# There are a few folders with `special_` prefix, created for special purposes:
# - `special_distributed`: unit tests that must run with multiple GPUs
# - `special_e2e`: end-to-end tests with training/generation scripts
# - `special_npu`: tests for NPUs
# - `special_sanity`: a suite of quick sanity tests
# - `special_standalone`: a set of test that are designed to run in dedicated environments
# Accelerators for tests
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.
# # Workflow layout
# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:
# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`
# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`
# 3. End-to-end tests: `e2e_*.yml`
# 4. Unit tests
# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`
# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.
# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when
# - new workflow yaml is added to `.github/workflows`
# - new tests are added to workflow mentioned in 2.
name: e2e_fully_async_policy
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
# For push, for now only anti-patterns are specified so it is more conservative
# and achieves higher coverage.
push:
branches:
- main
- v0.*
paths:
- "**/*.py"
- "!**/*.md"
- "!**/*.sh"
# Other entrypoints
- "!examples/*trainer*"
- "!tests/**"
- "!verl/trainer/main_*.py"
- "!verl/trainer/fsdp_sft_trainer.py"
- "!recipe/**"
- "recipe/fully_async_policy"
pull_request:
branches:
- main
- v0.*
paths:
- "**/*.py"
- "!**/*.md"
- "!**/*.sh"
# Other entrypoints
- "!examples/**"
- "!tests/**"
- "!verl/trainer/main_*.py"
- "!verl/trainer/fsdp_sft_trainer.py"
# Other recipes
- "!recipe/**"
# Home
- "recipe/fully_async_policy"
# Entrypoints
- ".github/workflows/e2e_fully_async_policy.yml"
- "examples/data_preprocess/gsm8k.py"
- "tests/special_e2e/run_fully_async_policy.sh"
# Cancel jobs on the same ref if a new one is triggered
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
# Declare permissions just read content.
permissions:
contents: read
env:
IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2"
DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner"
TRANSFORMERS_VERSION: "4.56.2"
jobs:
setup:
if: github.repository_owner == 'volcengine'
runs-on: ubuntu-latest
outputs:
runner-label: ${{ steps.create-runner.outputs.runner-label }}
mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }}
steps:
- uses: actions/checkout@v4
- id: create-runner
uses: volcengine/vemlp-github-runner@v1
with:
mode: "create"
faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}"
mlp-image: "${{ env.IMAGE }}"
# Test FSDP2 strategy
e2e_fully_async_policy_fsdp2:
needs: setup
runs-on: [ "${{ needs.setup.outputs.runner-label || 'L20x8' }}" ]
timeout-minutes: 10 # Increase timeout for async training
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
HF_ENDPOINT: "https://hf-mirror.com"
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
ACTOR_STRATEGY: "fsdp2"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install --no-deps -e .[test,gpu]
pip3 install transformers==$TRANSFORMERS_VERSION
- name: Prepare GSM8K dataset
run: |
python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k
- name: Running the E2E test with fully_async_policy algorithm (FSDP2)
run: |
ray stop --force
bash tests/special_e2e/run_fully_async_policy.sh
cleanup:
runs-on: ubuntu-latest
needs:
[
setup,
e2e_fully_async_policy_fsdp2
]
if: always()
steps:
- id: destroy-runner
uses: volcengine/vemlp-github-runner@v1
with:
mode: "destroy"
faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}"
mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}"

428
docs/advance/fully_async.md Normal file
View File

@ -0,0 +1,428 @@
# Recipe: Fully Async Policy Async Trainer
**Author:** `https://github.com/meituan-search`
Last updated: 10/17/2025.
This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,
supporting asynchronous sample generation and training.
Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs,
without significantly affecting the results.
## Introduction
### Background
The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more
flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training
efficiency caused by long-tail problems.
The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by
designing a separated architecture and performing asynchronous training between rollout and train for one round.
However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot
completely eliminate the impact of long-tail on training efficiency.
In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have
been implemented based on the separated architecture and have achieved gains.
We借鉴 their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial
rollout training.
By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy
can significantly improve training efficiency.
> Magistral https://arxiv.org/abs/2506.10910
>
> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language
> Reasoning https://arxiv.org/abs/2505.24298
>
> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream
> Generation https://arxiv.org/abs/2504.15930
>
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
>
### Core Contributions
* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to
specify the resources they occupy separately.
* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.
* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to
multiple steps, making the asynchronous solution more flexible.
* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter
and Trainer.
* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single
sample as the minimum transmission unit.
* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it
supports training with samples generated by old parameters.
* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter
synchronization, by adding `sleep() and resume()` logic, it
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
ongoing tasks to finish during parameter synchronization.
Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop.
## Design
The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four
parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.
![fully_async_policy_structure](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true)
1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the
production speed controlled by freshness.
2. MessageQueue is used to temporarily store samples generated by Rollouter.
3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size`
samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers
a parameter synchronization with Rollouter.
4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.
The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for
rollout cannot solve the idleness caused by long-tail samples.
After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources
are used),
but the overlap in their time consumption reduces the end-to-end time consumption.
![fully_async_policy_revenue](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true)
## Usage
### Parameter Description
| super params | implication |
|-----------------------------------------------|------------------------------------------------------------------------------------------------|
| `trainer.nnodes` | Number of nodes for Trainer |
| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
| `rollout.nnodes` | Number of nodes for Rollouter |
| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter |
| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) |
| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) |
| `rollout.total_rollout_steps` | Total number of rollout samples |
| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
| `async_training.staleness_threshold` | Freshness control |
| `async_training.partial_rollout` | Whether to perform partial_rollout |
| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
**Further Explanation:**
* `rollout.total_rollout_steps`
Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:
`rollout.total_rollout_steps = data.train_batch_size * step`.
* `async_training.trigger_parameter_sync_step`
In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches
`require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.
Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process
`trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.
To fairly compare speed with colocate, trigger_parameter_sync_step should be set to
`data.train_batch_size / (require_batches * ppo_mini_batch_size)`.
* `async_training.staleness_threshold`
In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.
* staleness_threshold=0, indicates synchronous training.
Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:
$$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
* staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous
calls.
Rollouter will generate at most the following number of samples between two parameter updates:
$$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
num_staleness_sample represents the number of stale samples generated in excess during the last rollout.
Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,
trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.
When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.
To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.
* `async_training.partial_rollout`
partial_rollout only actually takes effect when staleness_threshold>0.
* `async_training.use_rollout_log_probs`
In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
correctness. In the fully
async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
* `async_training.require_batches`
In streaming training, require_batches should be set to 1, indicating that training is performed after producing
enough ppo_mini_batch_size samples.
In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can
cause training instability and longer response lengths.
Here, we additionally provide require_batches for streaming distribution and control the number of samples
participating in training at once.
### Supported Modes
1. on policy pipeline:
1. **trigger_parameter_sync_step=1, staleness_threshold=0**
2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for
training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill
idle resources, causing some resource waste.
4. As shown in figure a;
2. stream off policy pipeline:
1. **trigger_parameter_sync_step>1, staleness_threshold=0**
2. Synchronous streaming training will be performed. Rollouter produces
`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local
training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training
trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
3. Compared to a, since more samples are generated at once, resource idleness will be lower.
4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,
train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter
update, rollout waits for training to complete.
5. As shown in figure b;
3. async stream pipeline with stale samples:
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**
2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number
of samples generated may be less than this value depending on rollout speed).
3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples
before parameter synchronization for immediate use by Trainer after synchronization.
When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete
and not add new tasks;
4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the
first batch rollout to finish, but will have the time to wait for active tasks to finish.
5. As shown in figure c;
4. async stream pipeline with partial rollout:
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**
2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will
interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be
generated after synchronization. This reduces the time to wait for active tasks to finish.
3. As shown in figure d;
![fully_async_policy_mode](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true)
### Key Metrics
| metrics | implication |
|------------------------------------------------|--------------------------------------------------------------------------------------------------------|
| `trainer/idle_ratio` | Trainer idle rate |
| `rollouter/idle_ratio` | Rollouter idle rate |
| `fully_async/count/stale_samples_processed` | Total number of old samples used in training |
| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) |
| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step |
| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step |
| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step |
### Parameter Tuning Recommendations
* Resource Allocation and Adjustment:
* Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource
allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire
training process,
avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource
allocation can be adjusted based on the idle time of rollout and train during actual training,
which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and
trainer/idle_ratio is low,
Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
* Key Parameters:
* staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It
is recommended to set it to less than 1.
* require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and
the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample
processing;
* trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent
parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in
low resource utilization.
The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
* rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at
different levels, suitable for tasks in different scenarios.
* For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed
requirements, the on policy pipeline mode (Mode 1) can be tried.
* For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy
pipeline mode can be tried. That is, by
setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization
mechanism (staleness_threshold=0) (Mode 2).
* For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and
staleness, setting staleness_threshold>
0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).
### Quick Start
```shell
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=10
staleness_threshold=0
trigger_parameter_sync_step=16
partial_rollout=False
python -m recipe.fully_async_policy.fully_async_main \
train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.partial_rollout="${partial_rollout}"
```
## Experiments
### Asynchronous Training on 7B Model
We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources.
Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,
64 cards, and 128 cards without significantly affecting experimental results.
* Machine: H20
* Model: Qwen2.5-Math-7B
* Rollout length: max_response_length FSDP2: 28K tokens;
* Algorithm: DAPO
* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
* Engine: vllm+FSDP2
* rollout.n: 16
* ppo_mini_batch_size: 32
* test_freq: 20
* colocate sync:
* step: 400
* train_batch_size: 512
* fully_async_policy
* total_rollout_steps: 512*400
* require_batches: 4
* trigger_parameter_sync_step: 4
* staleness_threshold: 0.3
* partial_rollout: True
| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313<br>last: 0.2448 |
| fully_async_policy | 16:16 | | | \ | | | | | | max: <br>last: |
| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365<br>last: 0.2333 |
| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m<br>(2.09x) | 10h 14m<br>(2.03x) | 16h 58m<br>(1.83x) | 21h 40m<br>(1.92x) | max: 0.3677<br>last: 0.3406 |
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m<br>(2.67x) | 6h 46m<br>(2.65x) | 10h 53m<br>(2.67x) | 17h 22m<br>(2.35x) | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
### 128-card 7B Asynchronous Mode Experiment
We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async.
We can see that the benefit brought by streaming is approximately 0.6x, and after combining staleness and
partial_rollout, the benefit reaches 2.35x.
| mode | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
| `async stream pipeline with stale samples`<br>(+staleness_threshold=0.5) | | | | | | | | | |
| `async stream pipeline with partial rollout`<br>(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
### 128-card Stale Ablation Experiment
Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training
efficiency.
We found that the larger the staleness, the more obvious the final gains.
We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps
increase, the response length changes significantly, causing training instability.
Further analysis and optimization are needed for this issue.
| staleness_threshold | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542<br>last: 0.2979 |
| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469<br>last: 0.2865 |
| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
### 128-card 7B require_batches Ablation Experiment
In multiple tests, we found that the number of samples issued each time in streaming affects the response length during
training, which in turn affects training time. We verified the impact on results by modifying
`async_training.require_batches`.
| require_batches | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | acc/mean@1 |
|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349<br>last: 0.326 |
| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351<br>last: 0.3406 |
| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521<br>last: 0.3521 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
### 30B Model Mode Experiment
TODO: The 30B experiment is still in progress.
* Machine: H20
* Model: Qwen2.5-32B~~~~
* Rollout length: max_response_length FSDP2: 20K tokens;
* Algorithm: DAPO
* Engine: vllm+FSDP2
* rollout.n: 16
* ppo_mini_batch_size: 32
* test_freq: 20
* colocate sync:
* step:200
* train_batch_size: 512
* fully_async_policy
* total_rollout_steps: 512*200
* trigger_parameter_sync_step: 512/32 = 16
* staleness_threshold: 0
* partial_rollout: False
| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |
|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|
| colocate sync | 128 | | | | | | | |
| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | |
| fully_async_policy | 64:64 | async stream pipeline with stale samples | | | | | | |
| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | |
## Future Plans
* GRPO experiments
* Megatron adaptation
* SGLang integration
* Transfer queue integration
* Asynchronous parameter synchronization
* AReaL asynchronous algorithm implementation
* TPPO algorithm implementation
* Multi-turn and Tool support

View File

@ -124,6 +124,7 @@ verl is fast with:
advance/rollout_is_migration.md
advance/one_step_off
advance/agent_loop
advance/fully_async
.. toctree::
:maxdepth: 1

View File

@ -0,0 +1,428 @@
# Recipe: Fully Async Policy Async Trainer
**Author:** `https://github.com/meituan-search`
Last updated: 10/17/2025.
This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,
supporting asynchronous sample generation and training.
Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs,
without significantly affecting the results.
## Introduction
### Background
The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more
flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training
efficiency caused by long-tail problems.
The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by
designing a separated architecture and performing asynchronous training between rollout and train for one round.
However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot
completely eliminate the impact of long-tail on training efficiency.
In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have
been implemented based on the separated architecture and have achieved gains.
We借鉴 their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial
rollout training.
By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy
can significantly improve training efficiency.
> Magistral https://arxiv.org/abs/2506.10910
>
> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language
> Reasoning https://arxiv.org/abs/2505.24298
>
> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream
> Generation https://arxiv.org/abs/2504.15930
>
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
>
### Core Contributions
* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to
specify the resources they occupy separately.
* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.
* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to
multiple steps, making the asynchronous solution more flexible.
* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter
and Trainer.
* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single
sample as the minimum transmission unit.
* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it
supports training with samples generated by old parameters.
* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter
synchronization, by adding `sleep() and resume()` logic, it
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
ongoing tasks to finish during parameter synchronization.
Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop.
## Design
The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four
parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.
![fully_async_policy_structure](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true)
1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the
production speed controlled by freshness.
2. MessageQueue is used to temporarily store samples generated by Rollouter.
3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size`
samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers
a parameter synchronization with Rollouter.
4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.
The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for
rollout cannot solve the idleness caused by long-tail samples.
After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources
are used),
but the overlap in their time consumption reduces the end-to-end time consumption.
![fully_async_policy_revenue](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true)
## Usage
### Parameter Description
| super params | implication |
|-----------------------------------------------|------------------------------------------------------------------------------------------------|
| `trainer.nnodes` | Number of nodes for Trainer |
| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
| `rollout.nnodes` | Number of nodes for Rollouter |
| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter |
| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) |
| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) |
| `rollout.total_rollout_steps` | Total number of rollout samples |
| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
| `async_training.staleness_threshold` | Freshness control |
| `async_training.partial_rollout` | Whether to perform partial_rollout |
| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
**Further Explanation:**
* `rollout.total_rollout_steps`
Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:
`rollout.total_rollout_steps = data.train_batch_size * step`.
* `async_training.trigger_parameter_sync_step`
In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches
`require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.
Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process
`trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.
To fairly compare speed with colocate, trigger_parameter_sync_step should be set to
`data.train_batch_size / (require_batches * ppo_mini_batch_size)`.
* `async_training.staleness_threshold`
In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.
* staleness_threshold=0, indicates synchronous training.
Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:
$$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
* staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous
calls.
Rollouter will generate at most the following number of samples between two parameter updates:
$$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
num_staleness_sample represents the number of stale samples generated in excess during the last rollout.
Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,
trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.
When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.
To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.
* `async_training.partial_rollout`
partial_rollout only actually takes effect when staleness_threshold>0.
* `async_training.use_rollout_log_probs`
In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
correctness. In the fully
async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
* `async_training.require_batches`
In streaming training, require_batches should be set to 1, indicating that training is performed after producing
enough ppo_mini_batch_size samples.
In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can
cause training instability and longer response lengths.
Here, we additionally provide require_batches for streaming distribution and control the number of samples
participating in training at once.
### Supported Modes
1. on policy pipeline:
1. **trigger_parameter_sync_step=1, staleness_threshold=0**
2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for
training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill
idle resources, causing some resource waste.
4. As shown in figure a;
2. stream off policy pipeline:
1. **trigger_parameter_sync_step>1, staleness_threshold=0**
2. Synchronous streaming training will be performed. Rollouter produces
`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local
training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training
trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
3. Compared to a, since more samples are generated at once, resource idleness will be lower.
4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,
train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter
update, rollout waits for training to complete.
5. As shown in figure b;
3. async stream pipeline with stale samples:
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**
2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number
of samples generated may be less than this value depending on rollout speed).
3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples
before parameter synchronization for immediate use by Trainer after synchronization.
When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete
and not add new tasks;
4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the
first batch rollout to finish, but will have the time to wait for active tasks to finish.
5. As shown in figure c;
4. async stream pipeline with partial rollout:
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**
2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will
interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be
generated after synchronization. This reduces the time to wait for active tasks to finish.
3. As shown in figure d;
![fully_async_policy_mode](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true)
### Key Metrics
| metrics | implication |
|------------------------------------------------|--------------------------------------------------------------------------------------------------------|
| `trainer/idle_ratio` | Trainer idle rate |
| `rollouter/idle_ratio` | Rollouter idle rate |
| `fully_async/count/stale_samples_processed` | Total number of old samples used in training |
| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) |
| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step |
| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step |
| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step |
### Parameter Tuning Recommendations
* Resource Allocation and Adjustment:
* Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource
allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire
training process,
avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource
allocation can be adjusted based on the idle time of rollout and train during actual training,
which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and
trainer/idle_ratio is low,
Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
* Key Parameters:
* staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It
is recommended to set it to less than 1.
* require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and
the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample
processing;
* trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent
parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in
low resource utilization.
The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
* rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at
different levels, suitable for tasks in different scenarios.
* For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed
requirements, the on policy pipeline mode (Mode 1) can be tried.
* For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy
pipeline mode can be tried. That is, by
setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization
mechanism (staleness_threshold=0) (Mode 2).
* For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and
staleness, setting staleness_threshold>
0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).
### Quick Start
```shell
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=10
staleness_threshold=0
trigger_parameter_sync_step=16
partial_rollout=False
python -m recipe.fully_async_policy.fully_async_main \
train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.partial_rollout="${partial_rollout}"
```
## Experiments
### Asynchronous Training on 7B Model
We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources.
Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,
64 cards, and 128 cards without significantly affecting experimental results.
* Machine: H20
* Model: Qwen2.5-Math-7B
* Rollout length: max_response_length FSDP2: 28K tokens;
* Algorithm: DAPO
* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
* Engine: vllm+FSDP2
* rollout.n: 16
* ppo_mini_batch_size: 32
* test_freq: 20
* colocate sync:
* step: 400
* train_batch_size: 512
* fully_async_policy
* total_rollout_steps: 512*400
* require_batches: 4
* trigger_parameter_sync_step: 4
* staleness_threshold: 0.3
* partial_rollout: True
| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313<br>last: 0.2448 |
| fully_async_policy | 16:16 | | | \ | | | | | | max: <br>last: |
| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365<br>last: 0.2333 |
| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m<br>(2.09x) | 10h 14m<br>(2.03x) | 16h 58m<br>(1.83x) | 21h 40m<br>(1.92x) | max: 0.3677<br>last: 0.3406 |
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m<br>(2.67x) | 6h 46m<br>(2.65x) | 10h 53m<br>(2.67x) | 17h 22m<br>(2.35x) | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
### 128-card 7B Asynchronous Mode Experiment
We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async.
We can see that the benefit brought by streaming is approximately 0.6x, and after combining staleness and
partial_rollout, the benefit reaches 2.35x.
| mode | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
| `async stream pipeline with stale samples`<br>(+staleness_threshold=0.5) | | | | | | | | | |
| `async stream pipeline with partial rollout`<br>(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
### 128-card Stale Ablation Experiment
Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training
efficiency.
We found that the larger the staleness, the more obvious the final gains.
We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps
increase, the response length changes significantly, causing training instability.
Further analysis and optimization are needed for this issue.
| staleness_threshold | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542<br>last: 0.2979 |
| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469<br>last: 0.2865 |
| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
### 128-card 7B require_batches Ablation Experiment
In multiple tests, we found that the number of samples issued each time in streaming affects the response length during
training, which in turn affects training time. We verified the impact on results by modifying
`async_training.require_batches`.
| require_batches | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | acc/mean@1 |
|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349<br>last: 0.326 |
| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351<br>last: 0.3406 |
| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521<br>last: 0.3521 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
### 30B Model Mode Experiment
TODO: The 30B experiment is still in progress.
* Machine: H20
* Model: Qwen2.5-32B~~~~
* Rollout length: max_response_length FSDP2: 20K tokens;
* Algorithm: DAPO
* Engine: vllm+FSDP2
* rollout.n: 16
* ppo_mini_batch_size: 32
* test_freq: 20
* colocate sync:
* step:200
* train_batch_size: 512
* fully_async_policy
* total_rollout_steps: 512*200
* trigger_parameter_sync_step: 512/32 = 16
* staleness_threshold: 0
* partial_rollout: False
| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |
|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|
| colocate sync | 128 | | | | | | | |
| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | |
| fully_async_policy | 64:64 | async stream pipeline with stale samples | | | | | | |
| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | |
## Future Plans
* GRPO experiments
* Megatron adaptation
* SGLang integration
* Transfer queue integration
* Asynchronous parameter synchronization
* AReaL asynchronous algorithm implementation
* TPPO algorithm implementation
* Multi-turn and Tool support

View File

@ -0,0 +1,373 @@
# Recipe: Fully Async Policy Async Trainer
**Author:** `https://github.com/meituan-search`
Last updated: 10/17/2025.
本文档介绍了完全异步PPO训练系统该系统实现了 Trainer 和 Rollouter 的完全解耦,支持异步样本生成和训练。
在该系统下我们使用128卡训练qwen2.5-7B模型取得了2.35x-2.67x的性能提升,同时效果没有显著受到影响。
## Introduction
### Background
rollout和train分离架构相较于colocate的架构能够更加灵活地分配资源设计更加灵活的训练逻辑从而处理长尾等问题带来的GPU利用率低训练效率低的问题。
one_step_off_policy通过分离架构的设计并进行rollout和train一轮异步的训练方法缓解了rollout时间过长的问题并在训练效率上取得了一些收益
但其强制使用一轮异步的数据存在不够灵活等问题而且并不能完全去除长尾对训练效率带来的的影响在其他框架如areal、Magistral、streamrl、asyncflow上
已经基于分离架构实现了异步训练、流式训练并取得了收益我们借鉴其方法在verl上进行了实现。fully_async_policy支持异步、流式、partial
rollout的训练 通过合理设置资源分配情况、参数同步频率等参数fully_async_policy能够显著提高训练效率。
> Magistral https://arxiv.org/abs/2506.10910
>
> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language
> Reasoning https://arxiv.org/abs/2505.24298
>
> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream
> Generation https://arxiv.org/abs/2504.15930
>
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
>
### 核心贡献
* **资源隔离**与使用hybrid_engine不同Rollouter和Trainer使用分离的计算资源需要分别指定所占用的资源。
* **生成与训练并行**Trainer在训练的同时Rollouter在生成新的样本。
* **多步异步**: 相比 one step off policy 支持0.x步到多步的异步设定异步方案更加灵活。
* **nccl参数同步**使用nccl通信原语进行Rollouter与Trainer参数的通信。
* **Stream推理与训练**Rollouter逐样本生成数据同时数据传输以单个sample为最小传输单位。
* **异步训练与新鲜度控制**通过设置参数async_training.staleness_threshold支持使用旧参数生成的样本进行训练。
* **PartialRollout**: Rollouter推理过程支持partial rollout逻辑通过参数同步时添加`sleep()``resume()`
逻辑保存进行中的rollout的样本并在下一次rollout中继续使用减少参数同步等待进行中的任务结束时间。
目前支持使用模式为 fsdp+vllm。vllm必须使用基于AgentLoop的server模式。
## 设计
fully_async_policy的整体架构如下图所示fully_async_policy主要由Rollouter、MessageQueue、Trainer、ParameterSynchronizer四部分组成。
![fully_async_policy_structure](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true)
1. Rollouter逐样本生成序列并将生成的sample放入MessageQueue中生产的速度受新鲜度控制。
2. MessageQueue用于暂存Rollouter生成的sample。
3. Trainer逐样本从MessageQueue中获取获取到`require_batches*ppo_mini_batch_size`
数量的样本后就会进行训练训练async_training.trigger_parameter_sync_step轮后触发与Rollouter的一次参数同步。
4. ParameterSynchronizer 实现了Nccl的同步参数同步能力。
当前方案对比base的收益来源在于colocate情况下rollout使用更多的资源无法解决长尾样本带来的空闲
当我们进行资源隔离后rollout的时间和train的时间都可能相较于之前更长因为使用的资源变少了
但是相互之间的耗时overlap端到端的耗时反而有所缩减。
![fully_async_policy_revenue](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true)
## 使用方式
### 参数说明
| super params | implication |
|-----------------------------------------------|-----------------------------------------------------------------|
| `trainer.nnodes` | Trainer的node数量 |
| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 |
| `rollout.nnodes` | Rollouter的node数量 |
| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 |
| `data.train_batch_size` | 在fully async策略中该值不生效默认设置为0 |
| `data.gen_batch_size` | 在fully async策略中使用流式的样本生产逻辑默认设置为1) |
| `rollout.total_rollout_steps` | 总的rollout的sample数量 |
| `rollout.test_freq` | Rollouter每更新多少次参数进行一次validation |
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 |
| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
| `async_training.staleness_threshold` | 新鲜度控制 |
| `async_training.partial_rollout` | 是否进行partial_rollout |
| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs |
**进一步的解释:**
* `rollout.total_rollout_steps`
与 colocate 相比,数量可以通过 train_batch_size 与 step 相乘对齐:
`rollout.total_rollout_steps = data.train_batch_size * step`
* `async_training.trigger_parameter_sync_step`
在fully async策略中表示Trainer进行多少次本地更新后也就是获取多少次`require_batches * ppo_mini_batch_size`数量样本),
与Rollouter之间进行一次参数同步。
每两次Rollouter和Trainer参数同步之间Trainer将会处理`trigger_parameter_sync_step* require_batches\
ppo_mini_batch_size`份sample。
如果为了与colocate在公平的情况下对比速度trigger_parameter_sync_step应该设置为 `data.train_batch_size / (
require_batches * ppo_mini_batch_size)`
* `async_training.staleness_threshold`
在fully async策略中表示最大允许使用的staleness样本的比例。
* staleness_threshold=0表示同步训练。
Rollouter两次参数更新之间将会生成固定数量的样本样本数为
$$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
* staleness_threshold>0表示异步训练 可以设置为小数,支持更灵活的异步调用。
Rollouter两次参数更新之间将会最多生成的样本数为
$$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
num_staleness_sample 表示上一次rollout多生成的陈旧样本数。
由于是流式系统rollout持续生成trainer持续消费。如果rollouter较慢trainer会更早触发参数同步rollouter并不会实际生产rollout_num个样本。
当rollout 足够快时staleness_threshold设置为1基本上等价于one_step_off policy。
为了避免过期样本太多影响训练精度建议该值设置小于1。
* `async_training.partial_rollout`
partial_rollout只会在staleness_threshold>0时才实际上起作用。
* `async_training.use_rollout_log_probs`
在强化学习算法中log_probs与参数版本token都存在隐性的相关性。由于PPO/GRPO/DAPO等算法的设定我们在计算重要性采样时
即 old_log_prob必须使用rollout参数及token所对应log_probs才能保证算法的正确性。在fully
async策略中我们默认old_log_prob是有rollout所计算的而不是由trainer所计算。
* `async_training.require_batches`
在流式训练中require_batches 应该设置为1表示生产够ppo_mini_batch_size样本后就进行训练。
在实际测试中我们发现如果单次下发的样本较少由于数据分发的顺序会导致训练不稳定response 长度变长。
在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。
### 模式支持
1. on policy pipeline:
1. **trigger_parameter_sync_step=1staleness_threshold=0**
2. Rollouter一次生产`require_batches*ppo_mini_batch_size`
的samplesTrainer获取这些samples后进行训练训练完后Trainer和Rollouter之间进行一次参数同步;
3. 在rollout阶段如果存在长尾的样本但是rollout样本数较少时较短的样本无法填充到空闲的资源中会造成一定的资源浪费。
4. 如图a所示
2. stream off policy pipeline:
1. **trigger_parameter_sync_step>1staleness_threshold=0**
2. 将会进行同步的流式训练Rollouter一次生产`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step`
的samplesTrainer每获取`require_batches*ppo_mini_batch_size`
就进行一次本地训练训练trigger_parameter_sync_step次后Trainer和Rollouter之间进行一次参数同步;
3. 相较于a由于一次生成的样本更多资源的空闲会更低。
4. 在一次step训练中会存在两次资源闲置的时间分别是在第一次获取样本时train等待`require_batches*ppo_mini_batch_size`
个样本生产以及最后一次参数更新时rollout等待训练完成。
5. 如图b所示
3. async stream pipeline with staleness samples:
1. **trigger_parameter_sync_step>=1staleness_threshold>0partial_rollout=Flase**
2. Rollouter在每次参数更新后将计划最多生产rollout_num个样本实际根据rollout速度生成的样本可能会少与这个值
3. 如果rollout过程比较快Rollouter将会在参数同步前额外生成一部分样本num_stale_samples用于参数同步后立即给Trainer使用。
触发参数同步时如果Rollouter有正在生产的任务将会等待任务完成同时不会添加新的任务
4. 相较于b除第一次step训练外后续的训练都不会有wait first batch rollout finish的时间但是会有wait active task
finish的时间。
5. 如图c所示
4. async stream pipeline with partial rollout:
1. **trigger_parameter_sync_step>=1staleness_threshold>0partial_rollout=True**
2. 相较于c触发参数同步时Rollouter如果有正在生产的sample会打断rollout过程并进行参数同步被中断的sample会在参数同步后继续生成。减少了wait
active task finish的时间。
3. 如图d所示
![fully_async_policy_mode](
https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true)
### 关键指标
| metrics | implication |
|------------------------------------------------|-----------------------------------------------------------|
| `trainer/idle_ratio` | Trainer闲置率 |
| `rollouter/idle_ratio` | Rollouter闲置率 |
| `fully_async/count/stale_samples_processed` | 训练使用的旧sample总数 |
| `fully_async/count/stale_trajectory_processed` | 训练使用的旧trajectory总数(一个sample会生产rollout.n条trajectory) |
| `fully_async/partial/total_partial_num` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本数 |
| `fully_async/partial/partial_ratio` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的比例 |
| `fully_async/partial/max_partial_span` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的最大参数跨度 |
### 调参建议
* 资源分配与调整:
* 合理的资源分配是获得好的训练效率的前提。理想的资源分配情况应该是使得Rollout的时间和Train的时间接近从而使得整个训练过程流水气泡最小
避免资源闲置同时Trainer不会使用旧样本。在真实训练场景下可以根据实际训练过程中rollout和train的空闲时间调整资源分配
可从rollouter/idle_ratio和trainer/idle_ratio获得如果rollouter/idle_ratio较高trainer/idle_ratio较低
应该增多Trainer的资源减少Rollouter的资源反之亦然。
* 关键参数:
* staleness_threshold: 设置太大会导致较多的旧样本使用影响模型效果建议设置小于1。
* require_batches越接近1越接近纯流式过程训练过程中bubble越小能够在速度上获得更快的加速效果但会对样本的处理顺序产生影响
* trigger_parameter_sync_step: 设置的越小越接近on policy但会导致频繁的参数同步长尾样本浪费的资源无法被短样本填充资源利用率低。
设置的越大有更高的计算效率但是精度上会受到off policy的影响。
* rollout.test_freq: 会占用Rollouter资源不建议设置太小。
* 模式选择通过调整不同的参数Fully Async架构支持不同程度上的优化加速适用于不同场景的任务。
* 对于小规模任务,需要保证训练的稳定性和 on-policy 性对速度要求不高的场景可以尝试使用on policy pipeline的模式模式1
* 对于需要提高训练吞吐量,但对 staleness 敏感的场景,可以尝试使用 stream off policy pipeline 的模式。即通过
设置trigger_parameter_sync_step>1 ,提高 训练效率,但仍保持同步机制 (staleness_threshold=0 )模式2
* 对于大规模任务,对训练速度有较高要求,且可以容忍一定 off-policy 程度、staleness的场景可以设置staleness_threshold>
0、partial_rollout=True提高训练效率使用 async stream pipeline 模式(模式 3 或 4
### 快速开始
```shell
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=10
staleness_threshold=0
trigger_parameter_sync_step=16
partial_rollout=False
python -m recipe.fully_async_policy.fully_async_main \
train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.partial_rollout="${partial_rollout}"
```
## 实验
### 在7B模型上进行异步训练
我们使用 Qwen2.5-Math-7B 验证 fully async 策略在长候选下,多种资源下的收益情况。
使用`async stream pipeline with staleness samples` 策略我们在32卡64卡128卡都取得2x左右的性能提升同时没有显著影响实验效果。
* 机器H20
* 模型Qwen2.5-Math-7B
* rollout长度max_response_length FSDP2: 28K tokens;
* 算法DAPO
* 数据集: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
* engine: vllm+FSDP2
* rollout.n: 16
* ppo_mini_batch_size: 32
* test_freq: 20
* colocate sync:
* step: 400
* train_batch_size: 512
* fully_async_policy
* total_rollout_steps: 512*400
* require_batches: 4
* trigger_parameter_sync_step: 4
* staleness_threshold: 0.3
* partial_rollout: True
| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313<br>last: 0.2448 |
| fully_async_policy | 16:16 | | | \ | | | | | | max: <br>last: |
| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365<br>last: 0.2333 |
| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m<br>(2.09x) | 10h 14m<br>(2.03x) | 16h 58m<br>(1.83x) | 21h 40m<br>(1.92x) | max: 0.3677<br>last: 0.3406 |
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m<br>(2.67x) | 6h 46m<br>(2.65x) | 10h 53m<br>(2.67x) | 17h 22m<br>(2.35x) | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
### 128卡 7B 异步模式实验
我们使用 Qwen2.5-Math-7B 验证 fully async 所支持的各个模式的效果。
我们可以看到 stream 带来的收益大约0.6x,叠加 staleness 和 partial_rollout 后收益为2.35x。
| mode | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
| `async stream pipeline with staleness samples`<br>(+staleness_threshold=0.5) | | | | | | | | | |
| `async stream pipeline with partial rollout`<br>(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
### 128卡 stale 消融实验
`async stream pipeline with partial rollout` 模式下,我们验证 staleness 的设置对于训练效率的影响。
我们可以发现staleness 越大,最终取得的收益越明显。
同时我们也注意到 staleness 取 0.3 和 0.5 的时间比较接近原因是随着训练步数的增量response 长度变化较大,训练出现了不稳定的问题。
后续还需要针对该问题进行进一步的分析和优化。
| staleness_threshold | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542<br>last: 0.2979 |
| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469<br>last: 0.2865 |
| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_stale?nw=nwuserhouzg
### 128卡 7B require_batches 消融实验
在多次测试下我们发现流式每次下发样本的数量会影响训练的response长度进而影响训练时长我们通过修改
`async_training.require_batches` 验证对与结果的影响。
| require_batches | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | acc/mean@1 |
|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349<br>last: 0.326 |
| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351<br>last: 0.3406 |
| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521<br>last: 0.3521 |
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
### 30B模型模式实验
TODO: 30B 的实验,还在完善中。
* 机器: H20
* 模型Qwen2.5-32B
* rollout长度max_response_length FSDP2: 20K tokens;
* 算法DAPO
* engine: vllm+FSDP2
* rollout.n: 16
* ppo_mini_batch_size: 32
* test_freq: 20
* colacate sync:
* step:200
* train_batch_size: 512
* fully_async_policy
* total_rollout_steps: 512*200
* trigger_parameter_sync_step: 512/32 = 16
* staleness_threshold: 0
* partial_rollout: False
| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |
|--------------------|---------------------|----------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|
| colocate sync | 128 | | | | | | | |
| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | |
| fully_async_policy | 64:64 | async stream pipeline with staleness samples | | | | | | |
| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | |
## 后续计划
* GRPO实验
* megatron 适配
* sglang 集成
* transfer queue 集成
* 异步参数同步
* Areal异步算法实现
* TPPO算法实现
* 多轮及Tool的支持

View File

@ -0,0 +1,19 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .agent_loop import FullyAsyncAgentLoopManager
from .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop
_ = [PartialSingleTurnAgentLoop]
__all__ = [FullyAsyncAgentLoopManager]

View File

@ -0,0 +1,275 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
from typing import Any, Optional
import hydra
import numpy as np
import ray
import torch
from omegaconf import DictConfig
from recipe.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica
from verl.experimental.agent_loop.agent_loop import (
AgentLoopManager,
AgentLoopOutput,
AgentLoopWorkerBase,
AsyncLLMServerManager,
BatchExecutor,
_agent_loop_registry,
_DummyConfig,
get_trajectory_info,
)
from verl.protocol import DataProto
from verl.single_controller.ray import RayWorkerGroup
from verl.utils.rollout_trace import rollout_trace_attr
from verl.workers.rollout.replica import TokenOutput
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class FullyAsyncLLMServerManager(AsyncLLMServerManager):
async def generate_for_partial(self, request_id, prompt_ids, sampling_params) -> TokenOutput:
"""Generate tokens from prompt ids. with partial rollout function"""
server = self._choose_server(request_id)
output = await server.generate_for_partial.remote(
request_id=request_id,
prompt_ids=prompt_ids,
sampling_params=sampling_params,
)
return output
class FullyAsyncAgentLoopOutput(AgentLoopOutput):
"""Agent loop output."""
is_cancel: bool = False
"""Indicates whether the request was interrupted"""
log_probs: list[float] = None
"""Response token log probs including LLM generated token, tool response token."""
param_version_start: int = 0
"""Indicate start parameter version when this response is generated"""
param_version_end: int = 0
"""Indicate end parameter version when this response is generated, used for partial rollout"""
@ray.remote
class FullyAsyncAgentLoopWorker(AgentLoopWorkerBase):
def __init__(
self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None
):
self.server_manager = FullyAsyncLLMServerManager(config, server_handles)
super().__init__(config, server_handles, rm_executor)
async def generate_sequences_no_post(
self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]]
) -> list[AgentLoopOutput]:
"""Generate sequences from agent loop.
Args:
batch (DataProto): Input batch.
partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result.
Returns:
list[FullyAsyncAgentLoopOutput]: List of agent loop outputs, one per sample in the batch.
"""
config = self.config.actor_rollout_ref.rollout
sampling_params = dict(
temperature=config.temperature,
top_p=config.top_p,
repetition_penalty=1.0,
logprobs=config.calculate_log_probs,
)
# override sampling params for validation
if batch.meta_info.get("validate", False):
sampling_params["top_p"] = config.val_kwargs.top_p
sampling_params["temperature"] = config.val_kwargs.temperature
# by default, we assume it's a single turn agent
if "agent_name" not in batch.non_tensor_batch:
batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object)
if "index" in batch.non_tensor_batch:
index = batch.non_tensor_batch["index"]
else:
index = np.arange(len(batch))
trajectory_info = await get_trajectory_info(
batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False)
)
if not partial_output_list:
partial_output_list = [None] * len(batch)
tasks = []
for i in range(len(batch)):
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
kwargs["output"] = partial_output_list[i]
tasks.append(
asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs))
)
return await asyncio.gather(*tasks)
async def _partial_run_agent_loop(
self,
sampling_params: dict[str, Any],
trajectory: dict[str, Any],
*,
agent_name: str,
**kwargs,
) -> AgentLoopOutput:
with rollout_trace_attr(
step=trajectory["step"],
sample_index=trajectory["sample_index"],
rollout_n=trajectory["rollout_n"],
validate=trajectory["validate"],
name="agent_loop",
):
assert agent_name in _agent_loop_registry, (
f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
)
agent_loop_config = _agent_loop_registry[agent_name]
agent_loop = hydra.utils.instantiate(
config=agent_loop_config,
trainer_config=_DummyConfig(config=self.config),
server_manager=self.server_manager,
tokenizer=self.tokenizer,
processor=self.processor,
)
return await agent_loop.run(sampling_params, **kwargs)
class FullyAsyncAgentLoopManager(AgentLoopManager):
def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):
self.config = config
self.worker_group = worker_group
self.rm_executor = None
self.rm_micro_batch_size = None
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
self.rollout_replica_class = FullyAsyncvLLMReplica
self.rm_wg = rm_wg
self.rollout_replicas = None
self.server_handles = None
self.server_addresses = None
self.agent_loop_workers = None
@classmethod
async def create(cls, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):
instance = cls(config, worker_group, rm_wg)
await instance._async_init()
return instance
async def _async_init(self):
if self.rm_wg:
def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]:
new_data_list = []
for data in data_list:
temp_non_tensor_batch = {"__num_turns__": data.non_tensor_batch["__num_turns__"]}
temp_data = DataProto(batch=data.batch, non_tensor_batch=temp_non_tensor_batch)
new_data_list.append(temp_data)
new_batch = DataProto.concat(new_data_list)
out_data = self.rm_wg.compute_rm_score(new_batch)
return out_data.split(1)
self.rm_executor = BatchExecutor.options(
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=ray.get_runtime_context().get_node_id(),
soft=False,
),
).remote(batch_fn, self.rm_wg.world_size)
self.rm_micro_batch_size = self.rm_wg.world_size
await self._initialize_llm_servers_async()
self._init_agent_loop_workers()
async def _initialize_llm_servers_async(self):
rollout_world_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
world_size = (
self.worker_group.world_size
if self.worker_group
else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
)
num_replicas = world_size // rollout_world_size
rollout_config = self.config.actor_rollout_ref.rollout
model_config = self.config.actor_rollout_ref.model
self.rollout_replicas = [
self.rollout_replica_class(
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
gpus_per_node=self.config.trainer.n_gpus_per_node,
)
for replica_rank in range(num_replicas)
]
if self.worker_group:
await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
else:
await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas])
self.server_handles = [server._server_handle for server in self.rollout_replicas]
self.server_addresses = [server._server_address for server in self.rollout_replicas]
async def generate_single_sample_async(
self,
sample: DataProto,
partial_output_list: Optional[list[AgentLoopOutput]],
) -> list[AgentLoopOutput]:
"""
Asynchronously process a single sample
Args:
sample: Single sample data
partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result.
Returns:
list[AgentLoopOutput]: Processing results
"""
worker = self._select_best_worker()
output_future = worker.generate_sequences_no_post.remote(sample, partial_output_list)
return await asyncio.wrap_future(output_future.future())
def _select_best_worker(self):
"""Select the best worker, simple round-robin load balancing"""
if not hasattr(self, "_worker_index"):
self._worker_index = 0
worker = self.agent_loop_workers[self._worker_index]
self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers)
return worker
async def cancel(self):
await asyncio.gather(*[replica.cancel() for replica in self.rollout_replicas])
async def resume(self):
await asyncio.gather(*[replica.resume() for replica in self.rollout_replicas])
async def wake_up(self):
await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas])
async def sleep(self):
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])
async def reset_prefix_cache(self):
await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])

View File

@ -0,0 +1,92 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Any, Optional
from uuid import uuid4
from recipe.fully_async_policy.agent_loop.agent_loop import AgentLoopOutput, FullyAsyncAgentLoopOutput
from verl.experimental.agent_loop import AgentLoopBase
from verl.experimental.agent_loop.agent_loop import register
from verl.utils.profiler import simple_timer
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
@register("partial_single_turn_agent")
class PartialSingleTurnAgentLoop(AgentLoopBase):
"""Naive agent loop that only do single turn chat completion."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
self.response_length = self.config.actor_rollout_ref.rollout.response_length
self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {})
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
output: Optional[FullyAsyncAgentLoopOutput] = kwargs.get("output", None)
messages = list(kwargs["raw_prompt"])
param_version = kwargs.get("param_version", 0)
metrics = {}
request_id = uuid4().hex
param_version_start = param_version
param_version_end = param_version
if not output:
prompt_ids = await self.loop.run_in_executor(
None,
lambda: self.tokenizer.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs
),
)
else:
if output.is_cancel:
# Resume the paused sample,
# add the result directly after prompt_ids,
# and reset generate_sequences metric
prompt_ids = output.prompt_ids + output.response_ids
metrics["generate_sequences"] = output.metrics.generate_sequences
param_version_start = output.param_version_start
else:
# In the same batch of samples,
# ome are canceled and some are not.
# The samples without partial rollout are returned directly.
return output
with simple_timer("generate_sequences", metrics):
response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial(
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
)
if not output:
response_mask = [1] * len(response_ids)
else:
# Pause the sample to be resumed, add the output result to response_ids, and reset response_mask
prompt_ids = output.prompt_ids
log_probs = output.log_probs + log_probs
response_ids = output.response_ids + response_ids
response_mask = [1] * len(response_ids)
return FullyAsyncAgentLoopOutput(
prompt_ids=prompt_ids,
response_ids=response_ids[: self.response_length],
response_mask=response_mask[: self.response_length],
num_turns=2,
metrics=metrics,
is_cancel=is_cancel,
log_probs=log_probs,
param_version_start=param_version_start,
param_version_end=param_version_end,
)

View File

@ -0,0 +1,55 @@
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
async_training:
# Maximum samples staleness threshold
staleness_threshold: 0.1
# Frequency of parameter synchronization between rollouter and trainer,
# One step means trainer obtains a batch of required samples
trigger_parameter_sync_step: 4
# The number of ppo_mini_batches that the FullyAsyncTrainer obtains once
require_batches: 1
# When synchronizing parameters, whether to interrupt rollouter and perform partial rollout
partial_rollout: True
# Whether to use rollout log probs for training
use_rollout_log_probs: True
# Rollout config
rollout:
# Number of nodes used in the rollout
nnodes: 1
# Number of GPUs per node
n_gpus_per_node: 8
# number of responses (i.e. num sample times). > 1 for grpo
n: 4
# total rollout samples # TODO rename to total_rollout_samples
total_rollout_steps: 100
# Number of epochs in training
total_epochs: 10
# Test frequency, how many times a parameter update triggers a validation
test_freq: 1
data:
# Number of samples generated, currently only support 1
gen_batch_size: 1
actor_rollout_ref:
actor:
# Whether to use rollout log probs for training
use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True}

View File

@ -0,0 +1,474 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from collections import defaultdict
from dataclasses import dataclass
from typing import Any, Optional
import numpy as np
import torch
from tensordict import TensorDict
from verl import DataProto
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
from verl.trainer.ppo.ray_trainer import compute_response_mask
def postprocess_agent_loop_outputs(inputs: list[AgentLoopOutput], tokenizer, config) -> DataProto:
"""Static method to postprocess a list of AgentLoopOutput into DataProto
Args:
inputs: List of AgentLoopOutput
tokenizer: Tokenizer instance
config: Configuration object
Returns:
DataProto: Processed batch data
"""
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
# input_ids: prompt + response
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
# prompts
tokenizer.padding_side = "left"
outputs = tokenizer.pad(
[{"input_ids": input.prompt_ids} for input in inputs],
padding="max_length",
max_length=config.actor_rollout_ref.rollout.prompt_length,
return_tensors="pt",
return_attention_mask=True,
)
prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"]
# responses
tokenizer.padding_side = "right"
outputs = tokenizer.pad(
[{"input_ids": input.response_ids} for input in inputs],
padding="max_length",
max_length=config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=True,
)
response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"]
# response_mask
outputs = tokenizer.pad(
[{"input_ids": input.response_mask} for input in inputs],
padding="max_length",
max_length=config.actor_rollout_ref.rollout.response_length,
return_tensors="pt",
return_attention_mask=False,
)
response_mask = outputs["input_ids"]
assert response_ids.shape == response_mask.shape, (
f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}"
)
response_mask = response_mask * response_attention_mask
input_ids = torch.cat([prompt_ids, response_ids], dim=1)
attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask
batch = TensorDict(
{
"prompts": prompt_ids, # [bsz, prompt_length]
"responses": response_ids, # [bsz, response_length]
"response_mask": response_mask, # [bsz, response_length]
"input_ids": input_ids, # [bsz, prompt_length + response_length]
"attention_mask": attention_mask, # [bsz, prompt_length + response_length]
"position_ids": position_ids, # [bsz, prompt_length + response_length]
},
batch_size=len(input_ids),
)
num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)
metrics = [input.metrics.model_dump() for input in inputs]
return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics})
@dataclass
class RolloutSample:
"""Enhanced rollout sample containing both original batch info and AgentLoopOutput"""
# Original batch information
full_batch: Any
# AgentLoopOutput from generation
agent_loop_output_list: list[Any] # AgentLoopOutput
# Metadata
sample_id: str
epoch: int
# Processing metadata
processing_times: list[float]
param_version: int
param_version_start: list[int]
param_version_end: list[int]
rollout_status: dict[str, Any]
@dataclass
class ValidateMetrics:
"""Metrics for validation"""
timing_raw: dict[str, Any]
metrics: Optional[dict[str, Any]] = None
global_steps: Optional[int] = None
param_version: Optional[int] = None
def prepare_single_generation_data(batch_dict, global_steps, rollout_n) -> DataProto:
"""
Similar to the logic of ray_trainer._prepare_generate_batch, but for a single sample.
Separate the data used for generation from the original data.
Returns:
tuple: (original_batch_dict, gen_data_for_single_sample)
"""
full_batch = DataProto.from_single_dict(batch_dict)
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
full_batch.pop(
batch_keys=batch_keys_to_pop,
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
)
# Setting agent - partial_single_turn_agent, that supports partial
full_batch.non_tensor_batch["agent_name"] = np.array(["partial_single_turn_agent"] * len(full_batch), dtype=object)
# Add global step count to generated data
full_batch = full_batch.repeat(repeat_times=rollout_n, interleave=True)
return full_batch
def process_rollout_log_probs(data_proto: DataProto, rollout_log_probs: list[list[float]]) -> torch.Tensor:
"""
Process rollout_log_probs according to the mask in DataProto
mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
Args:
data_proto: A DataProto object containing batch information
rollout_log_probs: A two-dimensional list, each sublist containing the log_probs of a sample
Returns:
torch.Tensor: The processed log_probs tensor, with shape: [bsz, response_length]
"""
batch = data_proto.batch
response_mask = batch["response_mask"]
rollout_log_probs_tensor = torch.zeros(response_mask.shape, dtype=torch.float32) - 1
for i, log_probs_seq in enumerate(rollout_log_probs):
# Get the effective length of the current sample (the number of positions with 1 in the mask)
valid_length = response_mask[i].sum().item()
# Ensure that the length of log_probs_seq does not exceed the valid length
actual_length = min(len(log_probs_seq), valid_length)
# Fill log_probs into the corresponding position
if actual_length > 0:
rollout_log_probs_tensor[i, :actual_length] = torch.tensor(log_probs_seq[:actual_length])
rollout_log_probs_tensor = rollout_log_probs_tensor.to(torch.float32)
return rollout_log_probs_tensor
def merge_rollout_sample(config, tokenizer, rs: RolloutSample):
"""
Supplement and refine the RolloutSample object,
"""
# Step 1: Create a DataProto from the AgentLoopOutput to generate the result
gen_batch_output = postprocess_agent_loop_outputs(rs.agent_loop_output_list, tokenizer, config)
rollout_log_probs = [x.log_probs for x in rs.agent_loop_output_list]
rollout_log_probs = process_rollout_log_probs(gen_batch_output, rollout_log_probs)
gen_batch_output.batch["rollout_log_probs"] = rollout_log_probs.to(torch.float32)
# Step 2: Add uid
rs.full_batch.non_tensor_batch["uid"] = np.array([f"uid_{rs.sample_id}"] * len(rs.full_batch), dtype=object)
# Step 2: Merge batches
# Merge the non_tensor_batch and meta_info of original_batch into final_batch
for key, value in rs.full_batch.non_tensor_batch.items():
gen_batch_output.non_tensor_batch[key] = value
gen_batch_output.meta_info.update(rs.full_batch.meta_info)
# Step 3, set full_batch
rs.full_batch = gen_batch_output
rs.processing_times = []
for agent_loop in rs.agent_loop_output_list:
rs.processing_times.append(agent_loop.metrics.generate_sequences)
rs.param_version_start = [agent_loop.param_version_start for agent_loop in rs.agent_loop_output_list]
rs.param_version_end = [agent_loop.param_version_end for agent_loop in rs.agent_loop_output_list]
# Step 4, clear agent_loop_output_list
rs.agent_loop_output_list = []
return rs
def assemble_batch_from_rollout_samples(
rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None
) -> DataProto:
"""
Assemble gen_batch_output from RolloutSample objects
Assembles batches from RolloutSample objects, similar to the _post_generate_batch logic in ray_trainer.
Args:
rollout_samples: List of RolloutSample objects
tokenizer: Tokenizer instance
config: Configuration object containing trainer settings
balance_batch: Whether to balance the batch (simplified version)
Returns:
DataProto: Assembled gen_batch_output
Raises:
ValueError: If rollout_samples is empty
"""
start_time = time.time()
if not rollout_samples:
raise ValueError("Empty rollout_samples provided for batch assembly")
print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects")
rollout_samples_batch = []
processing_times = []
rollout_status = rollout_samples[0].rollout_status
# Add a prefix to all rollout_status keys
rollout_status = {f"fully_async/{key}": value for key, value in rollout_status.items()}
for rs in rollout_samples:
rollout_samples_batch.append(rs.full_batch)
processing_times.extend(rs.processing_times)
final_batch = DataProto.concat(rollout_samples_batch)
# Calculate response_mask (if not present)
if "response_mask" not in final_batch.batch.keys():
final_batch.batch["response_mask"] = compute_response_mask(final_batch)
if balance_batch:
balance_batch(final_batch, metrics={})
# Calculate the global valid token number
if "attention_mask" in final_batch.batch:
final_batch.meta_info["global_token_num"] = torch.sum(final_batch.batch["attention_mask"], dim=-1).tolist()
# Collect statistics
param_versions = [rs.param_version for rs in rollout_samples]
trajectorys_param_versions = [version for rs in rollout_samples for version in rs.param_version_end]
processing_time_stats = {
"processing_time/avg": np.mean(processing_times),
"processing_time/max": np.max(processing_times),
"processing_time/min": np.min(processing_times),
"processing_time/tp50": np.percentile(processing_times, 50),
"processing_time/tp99": np.percentile(processing_times, 99),
"processing_time/tp95": np.percentile(processing_times, 95),
}
processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()}
param_version_diff = [abs(a - b) for a, b in zip(rs.param_version_end, rs.param_version_start, strict=False)]
num_diff0 = param_version_diff.count(0)
partial_stats = {
"fully_async/partial/total_partial_num": len(param_version_diff) - num_diff0,
"fully_async/partial/partial_ratio": (len(param_version_diff) - num_diff0) / len(param_version_diff),
"fully_async/partial/max_partial_span": max(param_version_diff),
}
# add meta_info
final_batch.meta_info.update(
{
"rollout_param_versions": param_versions,
"param_version_diversity": len(set(param_versions)) if param_versions else 0,
"trajectory_param_versions": trajectorys_param_versions,
**processing_time_stats,
**rollout_status,
**partial_stats,
}
)
print(f"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s")
return final_batch
class MetricsAggregator:
"""Metrics aggregator, used to combine metrics from multiple training steps"""
def __init__(self, total_gpus: int):
# Store all values for each metric
self.metric_values: dict[str, list[float]] = defaultdict(list)
# Store the number of samples at each step for weighted averaging
self.sample_counts: list[int] = []
# Store the timestamp of each step for time-related calculations
self.timestamps: list[float] = []
# Step Count
self.step_count = 0
# total num gpus used
self.total_gpus = total_gpus
# Metric aggregation rule configuration
self.aggregation_rules = self._init_aggregation_rules()
def _init_aggregation_rules(self) -> dict[str, dict[str, list[str]]]:
"""Initialize metrics aggregation rules"""
return {
# Time-Based metrics, can add metrics here
"time_sum": ["perf/time_per_step"],
"last": [
"fully_async/count/total_generated_samples",
"fully_async/count/stale_samples_processed",
"fully_async/count/stale_trajectory_processed",
"fully_async/count/current_param_version",
"fully_async/count/dropped_stale_samples",
"training/global_step", # TODO change name to: total_step
],
}
def add_step_metrics(self, metrics: dict[str, Any], sample_count: int, timestamp: float = None):
"""Adding a single-step metrics"""
if timestamp is None:
timestamp = time.time()
self.sample_counts.append(sample_count)
self.timestamps.append(timestamp)
self.step_count += 1
# Store all metrics values
for key, value in metrics.items():
if isinstance(value, int | float | np.number):
self.metric_values[key].append(float(value))
elif isinstance(value, torch.Tensor):
self.metric_values[key].append(float(value.item()))
def _get_aggregation_type(self, metric_name: str) -> str:
"""Determine the aggregation type based on the metric name"""
for agg_type, metric_list in self.aggregation_rules.items():
if metric_name in metric_list:
return agg_type
metric_lower = metric_name.lower()
if any(keyword in metric_lower for keyword in ["timing_s/"]):
return "time_sum"
if any(keyword in metric_lower for keyword in ["mean", "avg", "average"]):
return "avg"
if any(keyword in metric_lower for keyword in ["max", "maximum"]):
return "max"
if any(keyword in metric_lower for keyword in ["min", "minimum"]):
return "min"
if any(keyword in metric_lower for keyword in ["sum", "total"]):
return "sum"
if any(keyword in metric_lower for keyword in ["weighted_avg"]):
return "weighted_avg"
return "avg"
def _aggregate_single_metric(self, metric_name: str, values: list[float]) -> float:
"""Aggregating a single metric"""
if not values:
return 0.0
agg_type = self._get_aggregation_type(metric_name)
if agg_type == "last":
return values[-1]
elif agg_type == "weighted_avg":
# Weighted average
if len(values) != len(self.sample_counts):
# If the lengths do not match, use a simple average
return sum(values) / len(values)
total_samples = sum(self.sample_counts)
if total_samples == 0:
return sum(values) / len(values)
weighted_sum = sum(v * c for v, c in zip(values, self.sample_counts, strict=False))
return weighted_sum / total_samples
elif agg_type == "sum" or agg_type == "time_sum":
return sum(values)
elif agg_type == "avg":
return sum(values) / len(values)
elif agg_type == "max":
return max(values)
elif agg_type == "min":
return min(values)
else:
# Default average
return sum(values) / len(values)
def get_aggregated_metrics(self) -> dict[str, Any]:
"""aggregated metrics"""
t = time.time()
if self.step_count == 0:
return {}
aggregated = {}
# Aggregate all metrics
for metric_name, values in self.metric_values.items():
aggregated[metric_name] = self._aggregate_single_metric(metric_name, values)
# Aggregate special metrics
aggregated = self._special_metrics_aggergate(aggregated)
print(f"aggregated metrics done. cost {time.time() - t}")
return aggregated
def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> dict[str, Any]:
"""calculate special metrics"""
# global_seqlen/minmax_diff
if "global_seqlen/minmax_diff" in aggregated.keys():
aggregated["global_seqlen/minmax_diff"] = aggregated["global_seqlen/max"] - aggregated["global_seqlen/min"]
# perf/throughput
REQUIRED_PERF_KEYS = {"perf/throughput", "perf/total_num_tokens", "perf/time_per_step"}
if REQUIRED_PERF_KEYS.issubset(aggregated):
aggregated["perf/throughput"] = aggregated["perf/total_num_tokens"] / (
aggregated["perf/time_per_step"] * self.total_gpus
)
# trainer/idle_ratio
if "timing_s/gen" in aggregated.keys() and "timing_s/step" in aggregated.keys():
aggregated["trainer/idle_ratio"] = aggregated["timing_s/gen"] / aggregated["timing_s/step"]
return aggregated
def reset(self):
"""Reset Aggregator"""
self.metric_values.clear()
self.sample_counts.clear()
self.timestamps.clear()
self.step_count = 0
def get_current_stats(self) -> dict[str, Any]:
"""Get statistics about the current aggregation state (for debugging)"""
return {
"step_count": self.step_count,
"metric_count": len(self.metric_values),
"total_samples": sum(self.sample_counts),
"metric_names": list(self.metric_values.keys()),
}

View File

@ -0,0 +1,136 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import torch
import torch.distributed
from omegaconf import DictConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils.device import (
get_device_name,
get_torch_device,
)
from verl.utils.fsdp_utils import (
fsdp_version,
)
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
device_name = get_device_name()
__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"]
def get_inference_model(rollout):
"""
get models according to different types of inference_engine
Args:
rollout: rollout object
Returns:
model: model object
"""
inference_engine = rollout.inference_engine
if hasattr(inference_engine, "llm_engine"):
inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
elif hasattr(inference_engine, "worker"):
inference_model = inference_engine.worker.model_runner.model
else:
raise AttributeError(
f"Unsupported inference_engine type: {type(inference_engine)}. "
f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)."
)
return inference_model
class DetachNcclSync(AsyncActorRolloutRefWorker):
def _get_actor_params(self):
pass
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
def sync_rollout_weights(self):
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
assert hasattr(self, "_weights_info") and self._weights_info is not None
params = self._get_actor_params() if self._is_actor else None
if self._is_rollout:
inference_model = get_inference_model(self.rollout)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
patch_vllm_moe_model_weight_loader(inference_model)
for key, shape, dtype in self._weights_info:
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor:
assert key in params
origin_data = params[key]
if hasattr(origin_data, "full_tensor"):
origin_data = origin_data.full_tensor()
if torch.distributed.get_rank() == 0:
tensor.copy_(origin_data)
from ray.util.collective import collective
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
get_torch_device().empty_cache()
class DetachActorWorker(DetachNcclSync):
def _get_actor_params(self):
assert self._is_actor
params = self.actor_module_fsdp.state_dict()
from verl.utils.model import convert_weight_keys
params = convert_weight_keys(
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
)
return params
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
assert self._is_actor
if hasattr(self, "_weights_info"):
return self._weights_info
if fsdp_version(self.actor_module_fsdp) == 1:
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
FSDP.set_state_dict_type(
self.actor_module_fsdp,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)
params = self._get_actor_params()
ret = []
for key, tensor in params.items():
ret.append((key, tensor.size(), tensor.dtype))
self._weights_info = ret
return ret
class DetachAsyncRolloutWorker(DetachNcclSync):
def __init__(self, config: DictConfig, role: str):
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
ActorRolloutRefWorker.__init__(self, config, role)
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def set_actor_weights_info(self, weights_info):
assert self._is_rollout
self._weights_info = weights_info

View File

@ -0,0 +1,306 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import socket
import threading
from pprint import pprint
import hydra
import ray
from omegaconf import OmegaConf
from recipe.fully_async_policy.fully_async_rollouter import FullyAsyncRollouter
from recipe.fully_async_policy.fully_async_trainer import FullyAsyncTrainer
from recipe.fully_async_policy.message_queue import MessageQueue, MessageQueueClient
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.reward import load_reward_manager
from verl.trainer.ppo.utils import Role
from verl.utils.fs import copy_to_local
def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
"""
Create resource pool manager
Args:
config: Configuration object
roles: List of roles that need to create resource pools
Returns:
ResourcePoolManager: Resource pool manager
"""
resource_pool_spec = {}
mapping = {}
# Actor/Critic resource pool
if any(role in roles for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]):
assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0"
assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0"
trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes
resource_pool_spec["trainer_pool"] = trainer_pool
# Map training-related roles to the same resource pool
for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]:
if role in roles:
mapping[role] = "trainer_pool"
# Rollout resource pool
if Role.Rollout in roles:
assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0"
assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0"
rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
resource_pool_spec["rollout_pool"] = rollout_pool
mapping[Role.Rollout] = "rollout_pool"
return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
def create_role_worker_mapping(config):
"""
Create mapping from roles to worker classes
Args:
config: Configuration object
Returns:
dict: Mapping from roles to worker classes
"""
# Select worker class based on strategy
if config.actor_rollout_ref.actor.strategy == "fsdp2":
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
from recipe.fully_async_policy.fsdp_workers import (
CriticWorker,
DetachActorWorker,
DetachAsyncRolloutWorker,
)
from verl.single_controller.ray import RayWorkerGroup
ray_worker_group_cls = RayWorkerGroup
# TODO megatron support
else:
raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}")
role_worker_mapping = {
Role.Actor: ray.remote(DetachActorWorker),
Role.Rollout: ray.remote(DetachAsyncRolloutWorker),
Role.Critic: ray.remote(CriticWorker),
}
if config.reward_model.enable:
if config.reward_model.strategy == "fsdp2":
from verl.workers.fsdp_workers import RewardModelWorker
# TODO megatron support
else:
raise NotImplementedError(f"Unsupported reward model strategy: {config.reward_model.strategy}")
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
# Add reference policy (if KL loss or reward is required)
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker)
return role_worker_mapping, ray_worker_group_cls
@ray.remote(num_cpus=1)
class FullyAsyncTaskRunner:
"""
Ray remote class for executing distributed PPO training tasks.
"""
def __init__(self):
self.running = False
self.components = {}
self.shutdown_event = threading.Event()
def run(self, config):
print("[ASYNC MAIN] Starting fully async PPO training...")
self._initialize_components(config)
self._run_training_loop()
def _initialize_components(self, config) -> None:
print(f"[ASYNC MAIN] TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
pprint(OmegaConf.to_container(config, resolve=True))
OmegaConf.resolve(config)
print("[ASYNC MAIN] Initializing model and tokenizer...")
local_path = copy_to_local(
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
)
from verl.utils import hf_processor, hf_tokenizer
trust_remote_code = config.data.get("trust_remote_code", False)
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
# Used for multimodal LLM, could be None
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
self.components["tokenizer"] = tokenizer
self.components["processor"] = processor
self.components["config"] = config
print("[ASYNC MAIN] Creating worker mapping and resource pools...")
role_worker_mapping, ray_worker_group_cls = create_role_worker_mapping(config)
self.components["role_worker_mapping"] = role_worker_mapping
self.components["ray_worker_group_cls"] = ray_worker_group_cls
print("[ASYNC MAIN] Loading reward functions...")
reward_fn = load_reward_manager(
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
)
val_reward_fn = load_reward_manager(
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
)
self.components["reward_fn"] = reward_fn
self.components["val_reward_fn"] = val_reward_fn
print("[ASYNC MAIN] Creating FullyAsyncRollouter...")
self._create_rollouter(config)
print("[ASYNC MAIN] Creating FullyAsyncTrainer...")
self._create_trainer(config)
# sync total_train_steps between rollouter and trainer
total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote())
print(f"total_train_steps {total_train_steps}")
ray.get(self.components["trainer"].set_total_train_steps.remote(total_train_steps))
# max_queue_size
max_queue_size = ray.get(self.components["rollouter"].get_max_queue_size.remote())
print(f"[ASYNC MAIN] Creating MessageQueue... max_queue_size {max_queue_size}")
message_queue = MessageQueue.remote(config, max_queue_size)
message_queue_client = MessageQueueClient(message_queue)
self.components["message_queue"] = message_queue
self.components["message_queue_client"] = message_queue_client
ray.get(self.components["rollouter"].set_message_queue_client.remote(self.components["message_queue_client"]))
ray.get(self.components["trainer"].set_message_queue_client.remote(self.components["message_queue_client"]))
print("[ASYNC MAIN] Setting up parameter synchronization...")
from recipe.fully_async_policy.param_sync import ParameterSynchronizer
param_synchronizer = ParameterSynchronizer.remote(
config=config,
trainer=self.components["trainer"],
rollouter=self.components["rollouter"],
mq=self.components["message_queue_client"],
)
ray.get(self.components["trainer"].set_parameter_synchronizer.remote(param_synchronizer))
# load checkpoint and sync parameter before doing anything
val_before_train = val_reward_fn is not None and config.trainer.get("val_before_train", True)
ray.get(self.components["trainer"].load_checkpoint.remote())
ray.get(param_synchronizer.sync_weights.remote(version=0, validate=val_before_train))
self.components["param_synchronizer"] = param_synchronizer
print("[ASYNC MAIN] All components initialized successfully")
def _create_rollouter(self, config) -> None:
rollouter = FullyAsyncRollouter.remote(
config=config,
tokenizer=self.components["tokenizer"],
role_worker_mapping={Role.Rollout: self.components["role_worker_mapping"][Role.Rollout]},
resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]),
ray_worker_group_cls=self.components["ray_worker_group_cls"],
processor=self.components["processor"],
reward_fn=self.components["reward_fn"],
val_reward_fn=self.components["val_reward_fn"],
device_name=config.trainer.device,
)
ray.get(rollouter.init_workers.remote())
ray.get(rollouter.set_max_required_samples.remote())
self.components["rollouter"] = rollouter
print("[ASYNC MAIN] Rollouter created and initialized successfully")
def _create_trainer(self, config) -> None:
trainer_role_mapping = {
role: worker_cls
for role, worker_cls in self.components["role_worker_mapping"].items()
if role != Role.Rollout
}
trainer = FullyAsyncTrainer.remote(
config=config,
tokenizer=self.components["tokenizer"],
role_worker_mapping=trainer_role_mapping,
resource_pool_manager=create_resource_pool_manager(config, roles=list(trainer_role_mapping.keys())),
ray_worker_group_cls=self.components["ray_worker_group_cls"],
processor=self.components["processor"],
reward_fn=self.components["reward_fn"],
val_reward_fn=self.components["val_reward_fn"],
device_name=config.trainer.device,
)
ray.get(trainer.init_workers.remote())
self.components["trainer"] = trainer
print("[ASYNC MAIN] FullyAsyncTrainer created and initialized successfully")
def _run_training_loop(self):
self.running = True
print("[ASYNC MAIN] Starting Rollouter and Trainer...")
rollouter_future = self.components["rollouter"].fit.remote()
trainer_future = self.components["trainer"].fit.remote()
futures = [rollouter_future, trainer_future]
try:
while futures:
# Use ray.wait to monitor all futures and return when any one is completed.
done_futures, remaining_futures = ray.wait(futures, num_returns=1, timeout=None)
for future in done_futures:
try:
ray.get(future)
print("[ASYNC MAIN] One component completed successfully")
except Exception as e:
print(f"[ASYNC MAIN] Component failed with error: {e}")
for remaining_future in remaining_futures:
ray.cancel(remaining_future)
raise e
futures = remaining_futures
except Exception as e:
print(f"[ASYNC MAIN] Training failed: {e}")
for future in futures:
ray.cancel(future)
raise
finally:
self.components["message_queue_client"].clear_queue()
print("[ASYNC MAIN] Training completed or interrupted")
@hydra.main(config_path="config", config_name="fully_async_ppo_trainer", version_base=None)
def main(config):
from verl.trainer.main_ppo import run_ppo
# Ensure async training config exists
if not hasattr(config, "async_training"):
raise RuntimeError("must set async_training config")
from time import time
start_time = time()
run_ppo(config, task_runner_class=FullyAsyncTaskRunner)
print(f"total time: {time() - start_time:.2f} seconds")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,646 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import time
from pprint import pformat
import ray
from ray import ObjectRef
from recipe.fully_async_policy.detach_utils import (
RolloutSample,
ValidateMetrics,
merge_rollout_sample,
prepare_single_generation_data,
)
from recipe.fully_async_policy.message_queue import MessageQueueClient
from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.utils import Role, WorkerType
from verl.utils.profiler import marked_timer
from verl.utils.tracking import ValidationGenerationsLogger
@ray.remote(num_cpus=10, max_concurrency=100)
class FullyAsyncRollouter(FullyAsyncRayPPOTrainer):
"""
Asynchronous sample generator, responsible for continuously generating training samples
and putting them into MessageQueue
Based on the mature implementation improvements of OneStepOffRayTrainer
"""
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
device_name=None,
):
# Store the tokenizer for text processing
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert not self.hybrid_engine
assert self.config.data.train_batch_size == 0, "train_batch_size must be zero"
assert self.config.data.gen_batch_size == 1, "gen_batch_size must be one"
assert self.config.async_training.staleness_threshold >= 0, "staleness_threshold must larger than 0"
assert self.config.async_training.trigger_parameter_sync_step >= 1, (
"trigger_parameter_sync_step must larger than 1"
)
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
self.validation_generations_logger = ValidationGenerationsLogger(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
)
self.ref_in_actor = False
self.kl_ctrl_in_reward = False
self.use_critic = False
self.use_reference_policy = False
self.use_rm = False
print("[FullyAsyncRollouter] Creating datasets...")
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
from verl.utils.dataset.rl_dataset import collate_fn
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
train_sampler = create_rl_sampler(config.data, train_dataset)
self._validate_config()
print(f"[FullyAsyncRollouter] Rollouter _create_dataloader...\n{train_dataset}\n{val_dataset}")
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
# ==================== fully async config ====================
self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
if self.config.rollout.total_rollout_steps is not None:
self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps)
print(f"[FullyAsyncRollouter] Total rollout steps: {self.total_rollout_steps}")
self.total_train_steps = None
# Rollouter parameter configuration
self.message_queue_client = None
# Worker groups: rollout_wg is same to actor_rollout_wg
self.rollout_wg = None
self.actor_rollout_wg = None
self.async_rollout_manager = None
# Config
self.staleness_threshold: float = config.async_training.get("staleness_threshold", 1)
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
self.require_batches = config.async_training.require_batches
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
self.max_required_samples = None
self.max_concurrent_samples = None
# queue size
self.max_queue_size = None
# Statistics
self.current_param_version = 0
self.total_generated_samples = 0
self.staleness_samples = 0
self.dropped_stale_samples = 0
self.processed_sample_count = 0
self.global_steps = 0
self.idle_start_time = None
self.version_start_time = None
# Concurrency control
# Modified by self.pause() or self._should_pause_generation()
self.paused = False
self.running = True
self.monitor_loop_trigger = True
# Initialize async locks directly
self.lock = asyncio.Lock()
self.condition = asyncio.Condition(self.lock)
# Initialize async queues
self.pending_queue = asyncio.Queue(maxsize=128)
self.active_tasks = set()
self.result_queue = asyncio.Queue()
self.cancel_queue = asyncio.Queue()
async def set_message_queue_client(self, message_queue_client: MessageQueueClient):
"""Set message queue client"""
async with self.lock:
self.message_queue_client = message_queue_client
async def set_max_required_samples(self):
async with self.lock:
self.max_required_samples = int(
self.required_samples
* (self.staleness_threshold + 1)
* self.config.async_training.trigger_parameter_sync_step
)
self.total_train_steps = int(
self.total_rollout_steps
/ (self.required_samples * self.config.async_training.trigger_parameter_sync_step)
)
self.max_concurrent_samples = len(self.async_rollout_manager.server_handles) * 16
self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples)
self.max_queue_size = self.max_required_samples
print(
f"[FullyAsyncRollouter] required_samples : {self.required_samples} "
f"max_required_samples: {self.max_required_samples} "
f"max_queue_size: {self.max_queue_size} "
f"total_train_steps: {self.total_train_steps} "
f"total_rollout_steps: {self.total_rollout_steps} "
f"max_concurrent_samples: {self.max_concurrent_samples} "
)
def get_rollout_wg(self):
"""Get rollout worker group"""
return self.rollout_wg
def get_max_queue_size(self):
return self.max_queue_size
def get_total_train_steps(self):
return self.total_train_steps
async def update_param_version(self, version: int, validate: bool = False, global_steps: int = 0):
"""Update current parameter version"""
async with self.lock:
old_version = self.current_param_version
self.current_param_version = version
# every time param change, reset staleness_samples
self.staleness_samples = (
len(self.active_tasks)
+ self.result_queue.qsize()
+ self.cancel_queue.qsize()
+ await self.message_queue_client.get_queue_size()
)
timing_raw = {}
idle_ratio = None
if self.idle_start_time is not None and self.version_start_time is not None:
rollout_active_time = self.idle_start_time - self.version_start_time
rollout_version_time = time.time() - self.version_start_time
idle_ratio = 1 - rollout_active_time / rollout_version_time
timing_raw["rollouter/active_time"] = rollout_active_time
timing_raw["rollouter/version_time"] = rollout_version_time
timing_raw["rollouter/idle_ratio"] = idle_ratio
self.idle_start_time = None
print(
f"[FullyAsyncRollouter][Public][update_param_version] "
f"Parameter version updated from {old_version} to {version} "
f",reset staleness_samples to: {self.staleness_samples}"
f",idle_ratio: {idle_ratio}"
)
val_metrics = None
if (
self.val_reward_fn is not None
and self.config.rollout.test_freq > 0
and self.current_param_version % self.config.rollout.test_freq == 0
and self.current_param_version > 0 # don't test here in the initial parameter sync
) or (validate and self.val_reward_fn is not None):
with marked_timer("rollouter/validate_time", timing_raw, color="green"):
val_metrics: dict = self._validate()
data = ValidateMetrics(
timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version
)
await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))
self.version_start_time = time.time()
def _validate_config(self):
# Validate asynchronous training configuration
if not hasattr(self.config, "async_training"):
raise ValueError("[FullyAsyncRollouter] Missing async_training configuration")
assert self.config.actor_rollout_ref.rollout.calculate_log_probs, "must rollout calculate log_probs"
async def init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self._init_resource_pools()
self._create_worker_classes()
self._init_worker_groups()
self._init_models()
await self._init_async_rollout_manager()
def _create_actor_rollout_classes(self):
# only create rollout
for role in [Role.Rollout]:
resource_pool = self.resource_pool_manager.get_resource_pool(role)
role_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[role],
config=self.config.actor_rollout_ref,
role=str(role),
)
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
def _init_models(self):
self.rollout_wg = self.all_wg[str(Role.Rollout)]
self.rollout_wg.init_model()
self.actor_rollout_wg = self.rollout_wg
def _create_continuous_iterator(self):
"""
Create a continuous data iterator across epoch
"""
for epoch in range(self.config.rollout.total_epochs):
iterator = iter(self.train_dataloader)
for batch_dict in iterator:
yield epoch, batch_dict
async def _init_async_rollout_manager(self):
# create async rollout manager and request scheduler
assert self.config.actor_rollout_ref.rollout.mode == "async"
from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
self.async_rollout_mode = True
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
config=self.config,
worker_group=self.rollout_wg,
)
# Add samples to the pending_queue
async def _feed_samples(self):
continuous_iterator = self._create_continuous_iterator()
for epoch, batch_dict in continuous_iterator:
# Similar to _prepare_generate_batch: Separate data
full_batch = prepare_single_generation_data(
batch_dict, self.global_steps, self.config.actor_rollout_ref.rollout.n
)
sample_id = f"sample_{epoch}_{self.global_steps}"
rollout_sample = RolloutSample(
full_batch=full_batch,
agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n,
sample_id=sample_id,
epoch=epoch,
param_version=0,
param_version_start=[],
param_version_end=[],
processing_times=[],
rollout_status={},
)
await self.pending_queue.put(rollout_sample)
# Check if have reached the last step
if self.global_steps >= self.total_rollout_steps:
print(
f"[FullyAsyncRollouter][Feed] "
f"Maximum count has been reached, stop adding new samples"
f"{self.global_steps} >= {self.total_rollout_steps}"
)
break
self.global_steps += 1
# End signal
await self.pending_queue.put("DONE")
print(f"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added")
async def _processor_worker(self):
"""
Streaming worker coroutines, a sample is submitted for processing without waiting for batches
"""
while True:
if self.paused or await self._should_pause_generation():
print(
"[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return..."
)
async with self.lock:
self.paused = True
while self.active_tasks:
async with self.lock:
# After acquiring the lock, the number of active_tasks may change, need to be verified again
if self.active_tasks:
done_tasks, self.active_tasks = await asyncio.wait(
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done_tasks:
await task
async with self.lock:
while self.paused:
self.idle_start_time = time.time()
await self.condition.wait()
continue
simple_from_cancel_queue = False
if not self.cancel_queue.empty():
rollout_sample = await self.cancel_queue.get()
simple_from_cancel_queue = True
else:
rollout_sample = await self.pending_queue.get()
self.staleness_samples += 1
if rollout_sample == "DONE":
print(
"[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..."
)
while self.active_tasks:
async with self.lock:
if self.active_tasks:
done_tasks, self.active_tasks = await asyncio.wait(
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done_tasks:
await task
break
# Check whether the number of concurrent tasks exceeds the limit
while len(self.active_tasks) >= self.max_concurrent_samples:
async with self.lock:
if self.active_tasks:
done_tasks, self.active_tasks = await asyncio.wait(
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
)
for task in done_tasks:
await task
# Submit single sample processing
async with self.lock:
# After the pause is over, the lock is acquired and it is necessary
# to determine whether it is the pause phase, otherwise continue to wait
while self.paused:
await self.condition.wait()
task = asyncio.create_task(
self._process_single_sample_streaming(rollout_sample),
name=rollout_sample.sample_id,
)
self.active_tasks.add(task)
if simple_from_cancel_queue:
self.cancel_queue.task_done()
else:
self.pending_queue.task_done()
async def _process_single_sample_streaming(self, rollout_sample: RolloutSample):
"""Process a single sample streamingly"""
# Calling asynchronous generation methods
rollout_sample.full_batch.non_tensor_batch["param_version"] = [self.current_param_version] * len(
rollout_sample.full_batch
)
agent_loop_output_list = await self.async_rollout_manager.generate_single_sample_async(
rollout_sample.full_batch, rollout_sample.agent_loop_output_list
)
rollout_sample.agent_loop_output_list = agent_loop_output_list
is_cancel = False
for agent_loop in agent_loop_output_list:
if not is_cancel and agent_loop.is_cancel:
is_cancel = True
if is_cancel:
# Put in the cancel queue and wait for the generation to resume
await self.cancel_queue.put(rollout_sample)
else:
# put into the result_queue
rollout_sample.param_version = self.current_param_version
rollout_sample.rollout_status = await self.get_statistics()
await self.result_queue.put(rollout_sample)
self.processed_sample_count += 1
async def _consumer_worker(self):
"""
The consumer coroutine is responsible for obtaining the processing results
from the result queue and putting them into the message queue
"""
while True:
rollout_sample = await self.result_queue.get()
rollout_sample = merge_rollout_sample(self.config, self.tokenizer, rollout_sample)
# Put RolloutSample into the message queue
success = await self.message_queue_client.put_sample(
sample=ray.cloudpickle.dumps(rollout_sample),
param_version=rollout_sample.param_version,
)
if success:
self.total_generated_samples += 1
else:
self.dropped_stale_samples += 1
self.result_queue.task_done()
async def _streaming_generation_main(self):
"""The main entry method for stream processing"""
# we start from step 1
self.global_steps += 1
if self.async_rollout_manager is None:
await self._init_async_rollout_manager()
# Start the streaming loop
print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}")
# Start sample feed coroutine, streaming process coroutine and consumer coroutine
self.feed_task = asyncio.create_task(self._feed_samples())
self.processor_task = asyncio.create_task(self._processor_worker())
self.consumer_task = asyncio.create_task(self._consumer_worker())
try:
# Wait for sample feed to complete
await self.feed_task
print("[FullyAsyncRollouter] Sample feed completed")
# Wait for streaming to complete
await self.processor_task
print("[FullyAsyncRollouter] Streaming process completed")
# Waiting for the result queue to clear
await self.result_queue.join()
print("[FullyAsyncRollouter] Result queue cleared")
except Exception as e:
print(f"[FullyAsyncRollouter] Streaming process exception:{e}")
finally:
if self.processor_task:
self.processor_task.cancel()
if self.consumer_task:
self.consumer_task.cancel()
await asyncio.gather(self.processor_task, self.consumer_task, return_exceptions=True)
# Send a finish signal
await self.message_queue_client.put_sample(
sample=None,
param_version=self.current_param_version,
)
async with self.lock:
self.running = False
async def fit(self):
"""
Start the async rollouter - entry point that sets up and runs async tasks
Main async fit method that coordinates all coroutines
"""
print("[FullyAsyncRollouter] Starting FullyAsyncRollouter...")
if self.message_queue_client is None:
raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.")
# Set the running status flag
async with self.lock:
self.paused = False
self.running = True
# Create the main asynchronous task
generation_task = asyncio.create_task(self._streaming_generation_main())
monitor_task = asyncio.create_task(self._async_monitor_loop())
try:
# Run build and monitoring tasks concurrently
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)
except Exception as e:
print(f"[FullyAsyncRollouter] Asynchronous task execution error: {e}")
finally:
if not generation_task.done():
generation_task.cancel()
if not monitor_task.done():
monitor_task.cancel()
# Wait for the task to complete
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)
print("[FullyAsyncRollouter] Rollouter fit completed")
async def _async_monitor_loop(self):
"""
Async coroutine for monitoring:
Function 1: Log information output
Function 2: Trigger rollout recovery
"""
last_stats_time = time.time()
stats_interval = 60.0
check_interval = 10.0
while True:
async with self.lock:
if not self.running:
break
await asyncio.sleep(check_interval)
# Print statistics periodically
current_time = time.time()
if current_time - last_stats_time >= stats_interval:
stats = await self.get_statistics()
print(f"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}")
last_stats_time = current_time
# Trigger rollout recovery
if self.monitor_loop_trigger:
if not await self._should_pause_generation():
async with self.lock:
self.paused = False
self.condition.notify_all()
async def _should_pause_generation(self) -> bool:
"""Determine whether the build should be paused"""
queue_stats = self.message_queue_client.get_statistics_sync()
queue_size = queue_stats["queue_size"]
if queue_size >= self.max_queue_size:
if not self.paused:
print(
f"[FullyAsyncRollouter][ShouldPause] "
f"due to full queue: size={queue_size}, max={self.max_queue_size}"
)
return True
if self.staleness_samples >= self.max_required_samples:
if not self.paused:
print(
"[FullyAsyncRollouter][ShouldPause] "
f"due to "
f"staleness_samples {self.staleness_samples} >= max_required_samples {self.max_required_samples} "
)
return True
return False
async def pause(self):
"""pause rollout"""
print("[FullyAsyncRollouter][Public][Pause]")
async with self.lock:
self.paused = True
# Cancel all rollout tasks
if self.config.async_training.partial_rollout:
await self.async_rollout_manager.cancel()
if self.active_tasks:
await asyncio.gather(*self.active_tasks, return_exceptions=True)
self.active_tasks.clear()
print("[FullyAsyncRollouter][Public][Pause] All active tasks completed")
await self.async_rollout_manager.reset_prefix_cache()
self.monitor_loop_trigger = False
async def resume(self, dependency_ref: ObjectRef = None):
if dependency_ref is not None:
ray.get(dependency_ref)
print("[FullyAsyncRollouter][Public][Resume]")
async with self.lock:
self.paused = False
self.monitor_loop_trigger = True
self.condition.notify_all()
if self.config.async_training.partial_rollout:
await self.async_rollout_manager.resume()
async def get_statistics(self) -> dict:
queue_stats = self.message_queue_client.get_statistics_sync()
stats = {
# monitor stats
"monitor/active_tasks_size": len(self.active_tasks),
"monitor/queue/pending_queue_size": self.pending_queue.qsize(),
"monitor/queue/cancel_queue_size": self.cancel_queue.qsize(),
"monitor/queue/result_queue_size": self.result_queue.qsize(),
"monitor/queue/mq_queue_size": queue_stats["queue_size"],
# counting stats
"count/current_param_version": self.current_param_version,
"count/total_generated_samples": self.total_generated_samples,
"count/staleness_samples": self.staleness_samples,
"count/dropped_stale_samples": self.dropped_stale_samples,
# static stats
"static/max_required_samples": self.max_required_samples,
"static/required_samples": self.required_samples,
"static/staleness_threshold": self.staleness_threshold,
"static/max_queue_size": self.max_queue_size,
"static/max_concurrent_samples": self.max_concurrent_samples,
}
return stats

View File

@ -0,0 +1,360 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from datetime import datetime
from pprint import pprint
from typing import Any
import ray
from omegaconf import OmegaConf
from tqdm import tqdm
from recipe.fully_async_policy.detach_utils import (
MetricsAggregator,
ValidateMetrics,
assemble_batch_from_rollout_samples,
)
from recipe.fully_async_policy.message_queue import MessageQueueClient
from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
from verl.trainer.ppo import core_algos
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils.debug import marked_timer
@ray.remote(num_cpus=10)
class FullyAsyncTrainer(FullyAsyncRayPPOTrainer):
"""
A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training.
Based on an improved implementation of OneStepOffRayTrainer
"""
def __init__(
self,
config,
tokenizer,
role_worker_mapping: dict[Role, WorkerType],
resource_pool_manager: ResourcePoolManager,
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
processor=None,
reward_fn=None,
val_reward_fn=None,
device_name=None,
):
# Store the tokenizer for text processing
self.tokenizer = tokenizer
self.processor = processor
self.config = config
self.reward_fn = reward_fn
self.val_reward_fn = val_reward_fn
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
assert not self.hybrid_engine
self.role_worker_mapping = role_worker_mapping
self.resource_pool_manager = resource_pool_manager
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
self.use_rm = need_reward_model(self.role_worker_mapping)
self.use_critic = need_critic(self.config)
self.ray_worker_group_cls = ray_worker_group_cls
self.device_name = device_name if device_name else self.config.trainer.device
# if ref_in_actor is True, the reference policy will be actor without lora applied
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
# define in-reward KL control
# kl loss control currently not suppoorted
if self.config.algorithm.use_kl_in_reward:
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
# ==================== fully async config ====================
self.message_queue_client = None
self.param_synchronizer = None
# Statistics
# we start from step 1
self.global_steps = 1
self.local_trigger_step = 1
self.processed_samples = 0
self.stale_samples_processed = 0
self.stale_trajectory_processed = 0
self.current_param_version = 0
self.total_train_steps = None
self.progress_bar = None
self.trigger_parameter_sync_step = config.async_training.trigger_parameter_sync_step
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
self.require_batches = config.async_training.require_batches
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
total_gpus = (
config.trainer.nnodes * config.trainer.n_gpus_per_node
+ config.rollout.nnodes * config.rollout.n_gpus_per_node
)
self.metrics_aggregator = MetricsAggregator(total_gpus=total_gpus)
def set_message_queue_client(self, message_queue_client: MessageQueueClient):
"""Set message queue client"""
self.message_queue_client = message_queue_client
def set_parameter_synchronizer(self, param_synchronizer):
"""Set parameter synchronizer"""
self.param_synchronizer = param_synchronizer
def set_total_train_steps(self, total_train_steps):
self.total_train_steps = total_train_steps
self.progress_bar = tqdm(total=self.total_train_steps, initial=0, desc="Training Progress")
def get_actor_wg(self):
"""Get actor worker group"""
return self.actor_wg
def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]:
"""
Get samples from message queue and compose gen_batch_output
Uses a loop to continuously collect samples until enough are gathered
Returns:
tuple: (epoch, batch_dict, gen_batch_output)
"""
print(
f"[FullyAsyncTrainer] Requesting {self.required_samples} samples from queue",
flush=True,
)
# Collect samples using a simple loop calling get_sample
consumer_start = time.time()
queue_samples = []
queue_len = 0
while len(queue_samples) < self.required_samples:
# Get a single sample and wait until there is a sample or None is received
sample, queue_len = self.message_queue_client.get_sample_sync()
if sample is None:
print(
f"[FullyAsyncTrainer] Detected termination signal (None), stopping sample collection. "
f"Collected {len(queue_samples)}/{self.required_samples} samples"
)
break
queue_samples.append(sample)
if len(queue_samples) % 64 == 0:
print(
f"[FullyAsyncTrainer] Collected {len(queue_samples)}/{self.required_samples} samples. "
f"mq_len: {queue_len}"
)
consumer_end = time.time()
if not queue_samples or len(queue_samples) < self.required_samples:
print("[FullyAsyncTrainer] not enough samples collected after loop")
return None, None
total_wait_time = consumer_end - consumer_start
print(
f"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, "
f"total wait time: {total_wait_time:.2f} seconds."
f"mq_len: {queue_len}"
)
queue_samples = [ray.cloudpickle.loads(x) for x in queue_samples]
# Assemble batch - now working directly with RolloutSample objects
if self.config.trainer.balance_batch:
batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, self._balance_batch)
else:
batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, None)
batch.meta_info["fully_async/total_wait_time"] = total_wait_time
return 0, batch
def _create_actor_rollout_classes(self):
# create actor
for role in [Role.Actor]:
resource_pool = self.resource_pool_manager.get_resource_pool(role)
role_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[role],
config=self.config.actor_rollout_ref,
role=str(role),
)
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
def _init_models(self):
if self.use_critic:
self.critic_wg = self.all_wg[str(Role.Critic)]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = self.all_wg[str(Role.RewardModel)]
self.rm_wg.init_model()
self.actor_wg = self.all_wg[str(Role.Actor)]
self.actor_wg.init_model()
self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified
def _init_async_rollout_manager(self):
pass
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
print("[FullyAsyncTrainer] Starting FullyAsyncTrainer...")
if self.message_queue_client is None:
raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.")
if self.param_synchronizer is None:
raise ValueError("param_synchronizer client not set. Call set_parameter_synchronizer() first.")
from verl.utils.tracking import Tracking
self.logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.max_steps_duration = 0
# get validate data before training
if self.config.trainer.val_before_train and self.reward_fn is not None:
ray.get(self.param_synchronizer.wait_last_valid.remote())
val_data = self.message_queue_client.get_validate_sync()
if val_data:
val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)
if val_data.metrics:
self.logger.log(data=val_data.metrics, step=val_data.param_version)
pprint(f"[FullyAsyncTrainer] Initial validation metrics: {val_data.metrics}")
self.logger.log(data=val_data.timing_raw, step=val_data.param_version)
# Use queue mode, no need for traditional dataloader iterator
# Initialize to get the first batch of data
while True:
metrics = {}
timing_raw = {}
with marked_timer("step", timing_raw):
with marked_timer("gen", timing_raw, color="red"):
epoch, batch = self._get_samples_from_queue()
if batch is None:
break
self._collect_metrics_from_samples(batch, metrics)
batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)
self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
self._check_save_checkpoint(False, timing_raw)
self._collect_metrics(batch, 0, metrics, timing_raw)
self.metrics_aggregator.add_step_metrics(
metrics=metrics, sample_count=self.required_samples, timestamp=time.time()
)
# Trigger parameter synchronization after training step
time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3]
print(
f"[FullyAsyncTrainer] global_steps: {self.global_steps} "
f"local_trigger_step: {self.local_trigger_step} "
f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} "
f"{time_str}"
)
self._trigger_parameter_sync_after_step(global_steps=self.global_steps)
val_data = self.message_queue_client.get_validate_sync()
if val_data:
val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)
if val_data.metrics:
self.logger.log(data=val_data.metrics, step=val_data.param_version)
pprint(
f"[FullyAsyncTrainer] parameter version: {val_data.param_version} \
Validation metrics: {val_data.metrics}"
)
self.logger.log(data=val_data.timing_raw, step=val_data.param_version)
self.global_steps += 1
# final parameter sync and validate
if val_data is None or val_data.metrics is None:
self._trigger_parameter_sync_after_step(validate=True, global_steps=self.global_steps - 1)
ray.get(self.param_synchronizer.wait_last_valid.remote())
val_data = self.message_queue_client.get_validate_sync()
if val_data:
val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)
if val_data.metrics:
self.logger.log(data=val_data.metrics, step=val_data.param_version)
pprint(f"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}")
self.logger.log(data=val_data.timing_raw, step=val_data.param_version)
else:
pprint(f"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}")
self.progress_bar.close()
self._check_save_checkpoint(True, timing_raw) # TODO: check checkpoint
def load_checkpoint(self):
return self._load_checkpoint()
def _collect_metrics_from_samples(self, batch, metrics):
"""
Collect metrics from samples
"""
if hasattr(batch, "meta_info") and batch.meta_info:
samples_param_versions = batch.meta_info["rollout_param_versions"]
stale_count = sum(1 for v in samples_param_versions if self.current_param_version - v >= 1)
self.stale_samples_processed += stale_count
trajectory_param_versions = batch.meta_info["trajectory_param_versions"]
stale_traj_count = sum(1 for v in trajectory_param_versions if self.current_param_version - v >= 1)
self.stale_trajectory_processed += stale_traj_count
metrics.update(
{
"fully_async/count/stale_samples_processed": self.stale_samples_processed,
"fully_async/count/stale_trajectory_processed": self.stale_trajectory_processed,
"fully_async/count/current_param_version": self.current_param_version,
}
)
for key, value in batch.meta_info.items():
if key.startswith("fully_async"):
metrics[key] = value
def _trigger_parameter_sync_after_step(self, validate: bool = False, global_steps: int = None):
"""
Trigger parameter synchronization after training step
This ensures rollouter always uses the latest trained parameters
"""
if self.local_trigger_step < self.trigger_parameter_sync_step and not validate:
self.local_trigger_step += 1
return
self.current_param_version += 1
self.local_trigger_step = 1
self.logger.log(
data=self.metrics_aggregator.get_aggregated_metrics(),
step=self.current_param_version,
)
self.progress_bar.update(1)
self.metrics_aggregator.reset()
timing_param_sync = {}
with marked_timer("timing_s/wait_last_valid", timing_param_sync):
ray.get(self.param_synchronizer.wait_last_valid.remote())
with marked_timer("timing_s/param_sync", timing_param_sync):
ray.get(
self.param_synchronizer.sync_weights.remote(
self.current_param_version, validate=validate, global_steps=global_steps
)
)
self.logger.log(data=timing_param_sync, step=self.current_param_version)

View File

@ -0,0 +1,265 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from collections import deque
from typing import Any
import ray
from omegaconf import DictConfig
logger = logging.getLogger(__name__)
@ray.remote(num_cpus=2, max_concurrency=20)
class MessageQueue:
"""
Simplified Ray-based asynchronous message queue for communication between Rollouter and Trainer
"""
def __init__(self, config: DictConfig, max_queue_size: int = 1000):
self.config = config
if max_queue_size is None:
raise ValueError(f"max_queue_size cannot be None, got: {max_queue_size}")
self.max_queue_size = int(max_queue_size)
self.queue = deque(maxlen=self.max_queue_size)
self.current_param_version = 0
self.val_queue = deque()
try:
if hasattr(config, "async_training") and config.async_training is not None:
self.staleness_threshold = getattr(config.async_training, "staleness_threshold", 3)
else:
self.staleness_threshold = 3
except (AttributeError, RecursionError):
self.staleness_threshold = 3
# Asyncio for message handling
self.running = True
# async safe
self._lock = asyncio.Lock()
self._consumer_condition = asyncio.Condition(self._lock)
# statistic message
self.total_produced = 0
self.total_consumed = 0
self.dropped_samples = 0
print(
f"[MessageQueue] initialized with max_queue_size={max_queue_size},"
f"staleness_threshold={self.staleness_threshold}"
)
async def put_sample(self, sample: Any, param_version: int) -> bool:
"""
Put a batch sample into the queue
Args:
sample: Sample data
param_version: Parameter version number
Returns:
bool: Whether the sample was successfully put into the queue
"""
async with self._lock:
# If queue is full, remove the oldest sample (rarely happens)
is_drop = False
if len(self.queue) >= self.max_queue_size:
self.queue.popleft()
self.dropped_samples += 1
is_drop = True
logger.warning("Queue full, dropped sample")
self.queue.append(sample)
self.total_produced += 1
# Notify waiting consumers
self._consumer_condition.notify_all()
if self.total_produced % 100 == 0:
print(f"MessageQueue stats: produced={self.total_produced}, queue_size={len(self.queue)}")
if is_drop:
return False
return True
async def get_sample(self) -> Any | None:
"""
Get a single sample from the queue, wait until one is available
Returns:
Any: Single sample data or None if queue is closed
"""
async with self._lock:
while len(self.queue) == 0 and self.running:
await self._consumer_condition.wait()
# If queue is closed and empty, return None
if not self.running and len(self.queue) == 0:
return None
# Get one sample
data = self.queue.popleft()
self.total_consumed += 1
return data, len(self.queue)
async def update_param_version(self, version: int):
"""Update current parameter version"""
async with self._lock:
old_version = self.current_param_version
self.current_param_version = version
print(f"Parameter version updated from {old_version} to {version}")
async def get_queue_size(self) -> int:
"""Get current queue length"""
async with self._lock:
return len(self.queue)
async def get_statistics(self) -> dict[str, Any]:
"""Get queue statistics"""
async with self._lock:
return {
"queue_size": len(self.queue),
"total_produced": self.total_produced,
"total_consumed": self.total_consumed,
"dropped_samples": self.dropped_samples,
"current_param_version": self.current_param_version,
"staleness_threshold": self.staleness_threshold,
"max_queue_size": self.max_queue_size,
}
async def clear_queue(self):
"""Clear the queue"""
async with self._lock:
cleared_count = len(self.queue)
self.queue.clear()
logger.info(f"Cleared {cleared_count} samples from queue")
async def shutdown(self):
"""Shutdown the message queue"""
async with self._lock:
self.running = False
# Notify all waiting coroutines so they can exit
self._consumer_condition.notify_all()
logger.info("MessageQueue shutdown")
async def get_memory_usage(self) -> dict:
"""Get memory usage statistics"""
async with self._lock:
# Estimate memory usage of samples in queue
import sys
total_size = 0
sample_count = len(self.queue)
if sample_count > 0:
# Estimate size of a single sample (simplified estimation)
sample = list(self.queue)[0]
try:
sample_size = sys.getsizeof(sample)
# Since we now store RolloutSample directly, estimate based on its components
if hasattr(sample, "original_batch_dict") and sample.original_batch_dict:
# Estimate batch data size
batch_data = sample.original_batch_dict.get("batch", {})
sample_size += len(batch_data) * 1000 # Roughly estimate 1KB per batch entry
if hasattr(sample, "agent_loop_output"):
# Estimate AgentLoopOutput size
sample_size += 5000 # Roughly estimate 5KB for AgentLoopOutput
total_size = sample_size * sample_count
except Exception:
total_size = sample_count * 15000 # Roughly estimate 15KB per RolloutSample
return {
"queue_samples": sample_count,
"estimated_memory_bytes": total_size,
"estimated_memory_mb": total_size / (1024 * 1024),
}
async def put_validate(self, data):
async with self._lock:
self.val_queue.append(data)
async def get_validate(self):
async with self._lock:
if self.val_queue:
return self.val_queue.popleft()
else:
return None
class MessageQueueClient:
"""Asyncio-compatible MessageQueue client for communicating with MessageQueue Actor"""
def __init__(self, queue_actor: Any):
self.queue_actor = queue_actor
async def put_sample(self, sample: Any, param_version: int) -> bool:
"""Put batch into queue (async)"""
future = self.queue_actor.put_sample.remote(sample, param_version)
return await asyncio.wrap_future(future.future())
async def put_validate(self, data: Any) -> bool:
future = self.queue_actor.put_validate.remote(data)
return await asyncio.wrap_future(future.future())
def get_validate_sync(self) -> Any | None:
return ray.get(self.queue_actor.get_validate.remote())
async def get_sample(self) -> Any | None:
"""Get single sample from queue, wait until one is available (async)"""
future = self.queue_actor.get_sample.remote()
return await asyncio.wrap_future(future.future())
async def get_queue_size(self) -> int:
"""Get queue size (async)"""
future = self.queue_actor.get_queue_size.remote()
return await asyncio.wrap_future(future.future())
async def get_statistics(self) -> dict[str, Any]:
"""Get statistics (async)"""
future = self.queue_actor.get_statistics.remote()
return await asyncio.wrap_future(future.future())
async def clear_queue(self):
"""Clear queue (async)"""
future = self.queue_actor.clear_queue.remote()
await asyncio.wrap_future(future.future())
async def shutdown(self):
"""Shutdown queue (async)"""
future = self.queue_actor.shutdown.remote()
await asyncio.wrap_future(future.future())
async def get_memory_usage(self) -> dict:
"""Get memory usage statistics (async)"""
future = self.queue_actor.get_memory_usage.remote()
return await asyncio.wrap_future(future.future())
# Synchronous version of the method (deprecated)
def put_sample_sync(self, sample: Any, param_version: int) -> bool:
"""Put batch into queue (sync - deprecated, use put_sample instead)"""
return ray.get(self.queue_actor.put_sample.remote(sample, param_version))
def get_sample_sync(self) -> Any | None:
"""Get single sample from queue (sync - deprecated, use get_sample instead)"""
return ray.get(self.queue_actor.get_sample.remote())
def get_statistics_sync(self) -> dict[str, Any]:
"""Get statistics (sync - deprecated, use get_statistics instead)"""
return ray.get(self.queue_actor.get_statistics.remote())
def update_param_version_sync(self, version: int):
"""Update parameter version (async)"""
return ray.get(self.queue_actor.update_param_version.remote(version))

View File

@ -0,0 +1,105 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
import ray
from ray.util.collective import collective
logger = logging.getLogger(__name__)
@ray.remote
class ParameterSynchronizer:
"""
Unified parameter synchronizer, responsible for synchronizing model parameters between actor and rollout
Based on the mature synchronization mode implementation of one_step_off_policy
Merges the functions of the original multiple synchronizer classes
"""
def __init__(self, config, trainer, rollouter, mq):
self.config = config
self.trainer = trainer
self.rollouter = rollouter
self.mq_client = mq
self.actor_wg = ray.get(trainer.get_actor_wg.remote())
self.rollout_wg = ray.get(rollouter.get_rollout_wg.remote())
# Basic attributes
self.weights_info = None
self.sync_group_initialized = False
self.sync_group_name = "actor_rollout"
self.wait_last_update = None
self.wait_last_resume = None
# Statistics
self.current_version = 0
self._init_weights_info()
self._init_sync_group()
def get_current_param_version(self) -> int:
"""Get current parameter version number"""
return self.current_version
def get_weights_info(self):
"""Get weights info"""
return self.weights_info
def _init_weights_info(self):
self.weights_info = self.actor_wg.get_actor_weights_info()[0]
self.rollout_wg.set_actor_weights_info(self.weights_info)
def _init_sync_group(self):
print("[ParameterSynchronizer] Initializing parameter synchronization group...")
actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
collective.create_collective_group(
actor_rollout_workers,
len(actor_rollout_workers),
list(range(0, len(actor_rollout_workers))),
backend="nccl",
group_name=self.sync_group_name,
)
def sync_weights(self, version, validate=False, global_steps=0):
"""Sync weights between trainer and rollouter, and update parameter version"""
start_time = time.time()
self.current_version = version
print(f"[ParameterSynchronizer] Starting weight synchronization (version {self.current_version})...")
ray.get(self.rollouter.pause.remote())
# Update MQ version
self.mq_client.update_param_version_sync(version)
# sync weights
self.actor_wg.sync_rollout_weights()
ray.get(self.rollout_wg.sync_rollout_weights())
end_time = time.time()
print(f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds")
# Async Update rollout version & validation
self.wait_last_update = self.rollouter.update_param_version.remote(version, validate, global_steps)
self.wait_last_resume = self.rollouter.resume.remote(self.wait_last_update)
def wait_last_valid(self):
print("[ParameterSynchronizer] Waiting last sync and validate...")
start_time = time.time()
if self.wait_last_update:
ray.get(self.wait_last_update)
if self.wait_last_resume:
ray.get(self.wait_last_resume)
print(f"[ParameterSynchronizer] Wait last validate cost: {time.time() - start_time:.2f} seconds")

View File

@ -0,0 +1,528 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from copy import deepcopy
from pprint import pprint
import numpy as np
import ray
import torch
from omegaconf import OmegaConf
from tqdm import tqdm
from verl import DataProto
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
from verl.single_controller.ray import RayClassWithInitArgs
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
)
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role
from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
from verl.utils.config import omega_conf_to_dataclass
from verl.utils.debug import marked_timer
from verl.utils.metric import (
reduce_metrics,
)
from verl.utils.rollout_skip import RolloutSkip
class FullyAsyncRayPPOTrainer(RayPPOTrainer):
def init_workers(self):
"""Initialize distributed training workers using Ray backend.
Creates:
1. Ray resource pools from configuration
2. Worker groups for each role (actor, critic, etc.)
"""
self._init_resource_pools()
self._create_worker_classes()
self._init_worker_groups()
self._init_models()
self._init_async_rollout_manager()
def _init_resource_pools(self):
self.resource_pool_manager.create_resource_pool()
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
def _create_worker_classes(self):
self._create_actor_rollout_classes()
self._create_critic_class()
self._create_reference_policy_class()
self._create_reward_model_class()
def _create_actor_rollout_classes(self):
raise NotImplementedError
def _create_critic_class(self):
# create critic
if self.use_critic:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cfg = omega_conf_to_dataclass(self.config.critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
def _create_reference_policy_class(self):
# create reference policy if needed
if self.use_reference_policy:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role=str(Role.RefPolicy),
profile_option=self.config.trainer.npu_profile.options,
)
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
def _create_reward_model_class(self):
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
def _init_worker_groups(self):
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
# you should not use `create_colocated_worker_cls`.
# Instead, directly pass different resource pool to different worker groups.
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
all_wg = {}
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
if OmegaConf.select(self.config.global_profiler, "steps") is not None:
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
# Only require nsight worker options when tool is nsys
if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
assert (
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
is not None
), "worker_nsight_options must be set when using nsys with profile_steps"
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
)
wg_kwargs["device_name"] = self.device_name
for resource_pool, class_dict in self.resource_pool_to_cls.items():
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
wg_dict = self.ray_worker_group_cls(
resource_pool=resource_pool,
ray_cls_with_init=worker_dict_cls,
**wg_kwargs,
)
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
all_wg.update(spawn_wg)
self.all_wg = all_wg
def _init_models(self):
if self.use_critic:
self.critic_wg = self.all_wg[str(Role.Critic)]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]
self.ref_policy_wg.init_model()
if self.use_rm:
self.rm_wg = self.all_wg[str(Role.RewardModel)]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)]
self.actor_rollout_wg.init_model()
def _init_async_rollout_manager(self):
pass
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
rollout_skip.wrap_generate_sequences()
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
self.max_steps_duration = 0
prev_step_profile = False
curr_step_profile = (
self.global_steps in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
next_step_profile = False
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}
with marked_timer("start_profile", timing_raw):
self._start_profiling(
not prev_step_profile and curr_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
batch, gen_batch = self._prepare_generate_batch(batch_dict)
is_last_step = self.global_steps >= self.total_training_steps
with marked_timer("step", timing_raw):
# generate a batch
with marked_timer("gen", timing_raw, color="red"):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
timing_raw.update(gen_batch_output.meta_info["timing"])
gen_batch_output.meta_info.pop("timing", None)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
if self.reward_fn is None:
raise ValueError("A reward_fn is required for REMAX advantage estimation.")
with marked_timer("gen_max", timing_raw, color="purple"):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
if not self.async_rollout_mode:
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
else:
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
batch = batch.union(gen_baseline_output)
reward_baseline_tensor = self.reward_fn(batch)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
batch.batch["reward_baselines"] = reward_baseline_tensor
del gen_baseline_batch, gen_baseline_output
batch = self._post_generate_batch(batch, gen_batch_output, metrics)
batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)
self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
last_val_metrics = self._validate_metrics(is_last_step, last_val_metrics, metrics, timing_raw)
self._check_save_checkpoint(is_last_step, timing_raw)
with marked_timer("stop_profile", timing_raw):
next_step_profile = (
self.global_steps + 1 in self.config.global_profiler.steps
if self.config.global_profiler.steps is not None
else False
)
self._stop_profiling(
curr_step_profile and not next_step_profile
if self.config.global_profiler.profile_continuous_steps
else curr_step_profile
)
prev_step_profile = curr_step_profile
curr_step_profile = next_step_profile
self._collect_metrics(batch, epoch, metrics, timing_raw)
self._post_batch_processing(batch)
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
progress_bar.update(1)
self.global_steps += 1
if (
hasattr(self.config.actor_rollout_ref.actor, "profiler")
and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
):
self.actor_rollout_wg.dump_memory_snapshot(
tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
def _prepare_generate_batch(self, batch_dict):
batch: DataProto = DataProto.from_single_dict(batch_dict)
# add uid to batch
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
gen_batch = self._get_gen_batch(batch)
# pass global_steps to trace
gen_batch.meta_info["global_steps"] = self.global_steps
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
return batch, gen_batch
def _post_generate_batch(self, batch, gen_batch_output, metrics):
# repeat to align with repeated responses in rollout
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
batch = batch.union(gen_batch_output)
if "response_mask" not in batch.batch.keys():
batch.batch["response_mask"] = compute_response_mask(batch)
# Balance the number of valid tokens across DP ranks.
# NOTE: This usually changes the order of data in the `batch`,
# which won't affect the advantage calculation (since it's based on uid),
# but might affect the loss calculation (due to the change of mini-batching).
# TODO: Decouple the DP balancing and mini-batching.
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
return batch
def _process_batch_common(self, batch, metrics, timing_raw):
with marked_timer("reward", timing_raw, color="yellow"):
# compute reward model score
if self.use_rm:
reward_tensor = self.rm_wg.compute_rm_score(batch)
batch = batch.union(reward_tensor)
if self.config.reward_model.launch_reward_fn_async:
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
else:
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
# recompute old_log_probs
with marked_timer("old_log_prob", timing_raw, color="blue"):
async_training = self.config.get("async_training", None)
if async_training and async_training.use_rollout_log_probs:
batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"]
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
else:
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
entropys = old_log_prob.batch["entropys"]
response_masks = batch.batch["response_mask"]
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
metrics.update(old_log_prob_metrics)
old_log_prob.batch.pop("entropys")
batch = batch.union(old_log_prob)
if "rollout_log_probs" in batch.batch.keys():
# TODO: we may want to add diff of probs too.
from verl.utils.debug.metrics import calculate_debug_metrics
metrics.update(calculate_debug_metrics(batch))
if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with marked_timer("values", timing_raw, color="cyan"):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with marked_timer("adv", timing_raw, color="brown"):
# we combine with rule-based rm
reward_extra_infos_dict: dict[str, list]
if self.config.reward_model.launch_reward_fn_async:
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
batch.batch["token_level_scores"] = reward_tensor
if reward_extra_infos_dict:
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
batch, kl_metrics = apply_kl_penalty(
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(kl_metrics)
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
) # GRPO adv normalization factor
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
config=self.config.algorithm,
)
# update critic
if self.use_critic:
with marked_timer("update_critic", timing_raw, color="pink"):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with marked_timer("update_actor", timing_raw, color="red"):
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
return batch, reward_extra_infos_dict
def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw):
# Log rollout generations if enabled
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
if rollout_data_dir:
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch]
if "request_id" in batch.non_tensor_batch:
reward_extra_infos_dict.setdefault(
"request_id",
batch.non_tensor_batch["request_id"].tolist(),
)
self._dump_generations(
inputs=inputs,
outputs=outputs,
gts=sample_gts,
scores=scores,
reward_extra_infos_dict=reward_extra_infos_dict,
dump_path=rollout_data_dir,
)
def _validate_metrics(self, is_last_step, last_val_metrics, metrics, timing_raw):
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with marked_timer("testing", timing_raw, color="green"):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
return last_val_metrics
def _check_save_checkpoint(self, is_last_step, timing_raw):
# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
esi_close_to_expiration = should_save_ckpt_esi(
max_steps_duration=self.max_steps_duration,
redundant_time=self.config.trainer.esi_redundant_time,
)
# Check if the conditions for saving a checkpoint are met.
# The conditions include a mandatory condition (1) and
# one of the following optional conditions (2/3/4):
# 1. The save frequency is set to a positive value.
# 2. It's the last training step.
# 3. The current step number is a multiple of the save frequency.
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
):
if esi_close_to_expiration:
print("Force saving checkpoint: ESI instance expiration approaching.")
with marked_timer("save_checkpoint", timing_raw, color="green"):
self._save_checkpoint()
def _collect_metrics(self, batch, epoch, metrics, timing_raw):
steps_duration = timing_raw["step"]
self.max_steps_duration = max(self.max_steps_duration, steps_duration)
# training metrics
metrics.update(
{
"training/global_step": self.global_steps,
"training/epoch": epoch,
}
)
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
def _post_batch_processing(self, batch: DataProto):
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
self.train_dataloader.sampler.update(batch=batch)
# this is experimental and may be changed/removed in the future
# in favor of a general-purpose data buffer pool
if hasattr(self.train_dataset, "on_batch_end"):
# The dataset may be changed after each training batch
self.train_dataset.on_batch_end(batch=batch)

View File

@ -0,0 +1,162 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_16-16'
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 28))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=4
sp_size=4
fsdp_size=8
# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-2}
NNODES_TRAIN=${NNODES_TRAIN:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=20
staleness_threshold=0.1
trigger_parameter_sync_step=4
require_batches=4
partial_rollout=True
python -m recipe.fully_async_policy.fully_async_main \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True

View File

@ -0,0 +1,162 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_32-32'
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 28))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=4
sp_size=4
fsdp_size=8
# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-4}
NNODES_TRAIN=${NNODES_TRAIN:-4}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=20
staleness_threshold=0.1
trigger_parameter_sync_step=4
require_batches=4
partial_rollout=True
python -m recipe.fully_async_policy.fully_async_main \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True

View File

@ -0,0 +1,164 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-12'
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=1
sp_size=1
fsdp_size=2
# Fully async specific parameters
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*100)))
test_freq=10
staleness_threshold=0.1
trigger_parameter_sync_step=4
require_batches=4
partial_rollout=True
python -m recipe.fully_async_policy.fully_async_main \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.test_freq="${test_freq}" \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True

View File

@ -0,0 +1,164 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-4'
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=1
sp_size=1
fsdp_size=2
# Fully async specific parameters
NNODES=${NNODES:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=4
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*100)))
test_freq=10
staleness_threshold=0.1
trigger_parameter_sync_step=4
require_batches=4
partial_rollout=True
python -m recipe.fully_async_policy.fully_async_main \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=False \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True

View File

@ -0,0 +1,162 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_64-64'
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 28))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=4
sp_size=4
fsdp_size=8
# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-8}
NNODES_TRAIN=${NNODES_TRAIN:-8}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*400)))
test_freq=20
staleness_threshold=0.1
trigger_parameter_sync_step=4
require_batches=4
partial_rollout=True
python -m recipe.fully_async_policy.fully_async_main \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True

View File

@ -0,0 +1,162 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-8-8'
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=1
sp_size=1
fsdp_size=2
# Fully async specific parameters
NNODES_ROLLOUT=${NNODES_ROLLOUT:-1}
NNODES_TRAIN=${NNODES_TRAIN:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=32
total_rollout_steps=$(((512*100)))
test_freq=10
staleness_threshold=0.1
trigger_parameter_sync_step=4
require_batches=4
partial_rollout=True
python -m recipe.fully_async_policy.fully_async_main \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
data.gen_batch_size=${gen_prompt_bsz} \
data.return_raw_chat=${return_raw_chat} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.enable_chunked_prefill=True \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
actor_rollout_ref.rollout.name=${rollout_name} \
actor_rollout_ref.rollout.mode=${rollout_mode} \
actor_rollout_ref.rollout.calculate_log_probs=True \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.save_freq=-1 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.nnodes="${NNODES_TRAIN}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.nnodes="${NNODES_ROLLOUT}" \
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
rollout.total_rollout_steps="${total_rollout_steps}" \
rollout.total_epochs=10 \
rollout.test_freq="${test_freq}" \
async_training.staleness_threshold="${staleness_threshold}" \
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
async_training.require_batches="${require_batches}" \
async_training.partial_rollout="${partial_rollout}" \
async_training.use_rollout_log_probs=True

View File

@ -0,0 +1,4 @@
env_vars:
VLLM_USE_V1: "1"
NCCL_DEBUG: "INFO"
HYDRA_FULL_ERROR: "1"

View File

@ -0,0 +1,176 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import random
import time
class SimpleStreamingSystem:
"""Simplified streaming system demonstration"""
def __init__(self, max_concurrent_tasks: int = 4):
self.max_concurrent_tasks = max_concurrent_tasks
self.data_queue = asyncio.Queue()
self.result_queue = asyncio.Queue()
self.consumer_count = 0
# Data stream coroutine
async def data_stream(self):
# Add initial data
# Prepare test data
test_data = [{"id": f"task_{i}", "content": f"data_{i}"} for i in range(8)]
await self.add_data_stream(test_data)
# Simulate subsequent data stream
await asyncio.sleep(3)
print("\nAdding second batch of data...")
extra_data = [{"id": f"extra_{i}", "content": f"extra_data_{i}"} for i in range(5)]
await self.add_data_stream(extra_data)
# Send termination signal
await asyncio.sleep(1)
await self.data_queue.put("DONE")
print("Sending termination signal")
async def add_data_stream(self, data_list: list[dict]):
"""Simulate data stream"""
print("Starting to add data stream...")
for i, data_item in enumerate(data_list):
await self.data_queue.put(data_item)
print(f"Data {data_item['id']} added to pending queue")
# Simulate interval between data streams
if i < len(data_list) - 1: # Don't wait after the last item
await asyncio.sleep(0.8)
print("Initial data stream added successfully")
async def _process_data_async(self, data_item: dict):
"""Asynchronously process a single data item"""
data_id = data_item["id"]
content = data_item["content"]
# Simulate different processing times (1-3 seconds)
processing_time = random.uniform(1, 3)
print(f" Starting to process {data_id}, estimated time {processing_time:.1f}s")
# Asynchronously wait for processing completion
await asyncio.sleep(processing_time)
result = {
"id": data_id,
"processed_content": f"Processed {content}",
"processing_time": round(processing_time, 2),
"completed_at": time.time(),
}
# Immediately put into result queue
await self.result_queue.put(result)
print(f" {data_id} processing completed! (took {processing_time:.1f}s) -> Added to result queue")
async def _submit_worker(self):
"""Stream submission worker coroutine"""
active_tasks = set()
print("Stream submitter started...")
while True:
# Get data to process
data_item = await self.data_queue.get()
if data_item == "DONE":
print("Received termination signal, waiting for remaining tasks to complete...")
if active_tasks:
await asyncio.gather(*active_tasks, return_exceptions=True)
break
# Check concurrent limit
while len(active_tasks) >= self.max_concurrent_tasks:
print(f"Reached maximum concurrency {self.max_concurrent_tasks}, waiting for tasks to complete...")
done_tasks, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
# Clean up completed tasks
for task in done_tasks:
try:
await task
print(f"Task completed {task}")
except Exception as e:
print(f"Task execution failed: {e}")
# Immediately submit new task
task = asyncio.create_task(self._process_data_async(data_item), name=f"active {data_item}")
active_tasks.add(task)
print(f"Submitted task {data_item['id']}, current concurrency: {len(active_tasks)}")
async def _consumer_worker(self):
"""Result consumer coroutine"""
print("Consumer started...")
while True:
try:
# Get processing result from result queue
result = await asyncio.wait_for(self.result_queue.get(), timeout=2.0)
self.consumer_count += 1
print(
f"Consumed #{self.consumer_count}: {result['id']} "
f"(processing time {result['processing_time']}s) - {result['processed_content']}"
)
except asyncio.TimeoutError:
print(" Consumer waiting...")
await asyncio.sleep(0.5)
async def run_demo(self):
"""Run demonstration"""
print("=" * 60)
print(f"Maximum concurrency: {self.max_concurrent_tasks}")
print("=" * 60)
# Start core coroutines
stream_task = asyncio.create_task(self.data_stream())
submit_task = asyncio.create_task(self._submit_worker())
consumer_task = asyncio.create_task(self._consumer_worker())
try:
# Wait for data stream to complete
await stream_task
print("Data stream completed")
# Wait for processing to complete
await submit_task
print("All tasks processed")
finally:
# Cleanup
submit_task.cancel()
consumer_task.cancel()
await asyncio.gather(submit_task, consumer_task, return_exceptions=True)
print(f"\nFinal statistics: Consumed {self.consumer_count} results")
async def main():
"""Main function"""
system = SimpleStreamingSystem(max_concurrent_tasks=3)
await system.run_demo()
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,13 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,154 @@
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
from typing import Any, Optional, Sequence
import ray
from ray.actor import ActorHandle
from vllm import SamplingParams
from vllm.inputs import TokensPrompt
from vllm.outputs import RequestOutput
from verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig
from verl.workers.rollout.replica import RolloutMode
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
_qwen2_5_vl_dedup_image_tokens,
vLLMHttpServerBase,
vLLMReplica,
)
logger = logging.getLogger(__file__)
logger.setLevel(logging.INFO)
@ray.remote(num_cpus=1)
class vLLMHttpServerForPartial(vLLMHttpServerBase):
def __init__(
self,
config: RolloutConfig | RewardModelConfig,
model_config: HFModelConfig,
rollout_mode: RolloutMode,
workers: list[ActorHandle],
replica_rank: int,
node_rank: int,
gpus_per_node: int,
nnodes: int,
):
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
# for cancel LLMServer
self.paused = False
self.lock = asyncio.Lock()
self.cancel_event: dict[str, asyncio.Event] = {}
self.req_output: dict[str, Optional[RequestOutput]] = {}
async def _generate_step(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
):
max_tokens = self.config.max_model_len - len(prompt_ids)
sampling_params["logprobs"] = 1
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
prompt = TokensPrompt(
prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None
)
generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)
# Get final response
self.req_output[request_id]: Optional[RequestOutput] = None
async for output in generator:
self.req_output[request_id] = output
assert self.req_output[request_id] is not None
async def generate_for_partial(
self,
prompt_ids: list[int],
sampling_params: dict[str, Any],
request_id: str,
image_data: Optional[list[Any]] = None,
) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]:
async with self.lock:
if self.paused:
# After cancel, all tasks will return directly and wait for the next submission
return [], [], True
self.cancel_event[request_id] = asyncio.Event()
cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())
generation_handle = asyncio.create_task(
self._generate_step(prompt_ids, sampling_params, request_id, image_data)
)
done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED)
for task in done:
await task
for task in pend:
task.cancel()
async with self.lock:
token_ids = self.req_output[request_id].outputs[0].token_ids
log_probs: list[float] = []
for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs):
# In sampling_params, logprobs is set to 1, which should return 1,
# but in practice there are multiple. Take the log_prob corresponding to token_id
token_id = self.req_output[request_id].outputs[0].token_ids[i]
log_probs.append(x[token_id].logprob)
is_cancel = generation_handle not in done
self.cancel_event.pop(request_id, None)
self.req_output.pop(request_id, None)
return token_ids, log_probs, is_cancel
async def cancel(self):
async with self.lock:
self.paused = True
for request_id in self.cancel_event:
self.cancel_event[request_id].set()
async def resume(self):
async with self.lock:
self.paused = False
async def reset_prefix_cache(self):
async with self.lock:
await self.engine.reset_prefix_cache()
class FullyAsyncvLLMReplica(vLLMReplica):
def __init__(
self,
replica_rank: int,
config: RolloutConfig | RewardModelConfig,
model_config: HFModelConfig,
gpus_per_node: int = 8,
):
super().__init__(replica_rank, config, model_config, gpus_per_node)
self.server_class = vLLMHttpServerForPartial
async def cancel(self):
"""Cancel each rollout server."""
await asyncio.gather(*[server.cancel.remote() for server in self.servers])
async def resume(self):
"""Resume each rollout server."""
await asyncio.gather(*[server.resume.remote() for server in self.servers])
async def reset_prefix_cache(self):
"""reset kv cache in each rollout server."""
await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers])

View File

@ -0,0 +1,196 @@
#!/usr/bin/env bash
set -xeuo pipefail
# Test script for fully_async_policy E2E regression testing
# This script runs fully async PPO training with both FSDP2 and Megatron backends
# to ensure the asynchronous training mechanism works correctly
NUM_GPUS=${NUM_GPUS:-8}
ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp2"} # fsdp2 or megatron
# Download model if not exists
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
rollout_mode="async"
rollout_name="vllm" # sglang or vllm
if [ "$rollout_mode" = "async" ]; then
export VLLM_USE_V1=1
return_raw_chat="True"
fi
# Algorithm parameters
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
# Response length parameters
max_prompt_length=1024
max_response_length=2048
enable_overlong_buffer=True
overlong_buffer_len=128
overlong_penalty_factor=1.0
# Training parameters
loss_agg_mode="token-mean"
# Temperature parameters
temperature=1.0
top_p=1.0
top_k=-1
val_top_p=0.7
# Fully async specific parameters
n_gpus_rollout=4
n_gpus_training=4
train_prompt_bsz=0
gen_prompt_bsz=1
n_resp_per_prompt=16
train_prompt_mini_bsz=16
total_rollout_steps=$(((128)))
test_freq=-1
staleness_threshold=0.1
trigger_parameter_sync_step=4
partial_rollout=True
exp_name="$(basename "${MODEL_ID,,}")-fully-async-policy-${ACTOR_STRATEGY}-minimal"
echo "Running fully_async_policy with ${ACTOR_STRATEGY} strategy"
echo "Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}"
# Common parameters for both FSDP2 and Megatron
common_params=(
data.train_files="${HOME}/data/gsm8k/train.parquet"
data.val_files="${HOME}/data/gsm8k/test.parquet"
data.prompt_key=prompt
data.truncation='left'
data.max_prompt_length=${max_prompt_length}
data.max_response_length=${max_response_length}
data.train_batch_size=${train_prompt_bsz}
data.gen_batch_size=${gen_prompt_bsz}
data.return_raw_chat=${return_raw_chat}
actor_rollout_ref.rollout.n=${n_resp_per_prompt}
actor_rollout_ref.rollout.calculate_log_probs=True
algorithm.adv_estimator=${adv_estimator}
algorithm.use_kl_in_reward=${use_kl_in_reward}
algorithm.kl_ctrl.kl_coef=${kl_coef}
actor_rollout_ref.hybrid_engine=False
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}
actor_rollout_ref.actor.clip_ratio_c=10.0
actor_rollout_ref.model.path="${MODEL_PATH}"
actor_rollout_ref.model.enable_gradient_checkpointing=True
actor_rollout_ref.actor.optim.lr=1e-6
actor_rollout_ref.actor.optim.lr_warmup_steps=-1
actor_rollout_ref.actor.optim.weight_decay=0.1
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
actor_rollout_ref.actor.entropy_coeff=0
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}
actor_rollout_ref.rollout.gpu_memory_utilization=0.80
actor_rollout_ref.rollout.temperature=${temperature}
actor_rollout_ref.rollout.top_p=${top_p}
actor_rollout_ref.rollout.top_k=${top_k}
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}
actor_rollout_ref.rollout.val_kwargs.do_sample=True
actor_rollout_ref.rollout.val_kwargs.n=1
actor_rollout_ref.rollout.enable_chunked_prefill=True
actor_rollout_ref.rollout.name=${rollout_name}
actor_rollout_ref.rollout.mode=${rollout_mode}
reward_model.reward_manager=dapo
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False
+reward_model.reward_kwargs.max_resp_len=${max_response_length}
trainer.logger=['console']
trainer.project_name='verl-test-fully-async'
trainer.experiment_name="${exp_name}"
trainer.val_before_train=True
trainer.save_freq=-1
trainer.resume_mode=disable
trainer.nnodes=1
trainer.n_gpus_per_node=${n_gpus_training}
rollout.nnodes=1
rollout.n_gpus_per_node=${n_gpus_rollout}
rollout.total_rollout_steps=${total_rollout_steps}
rollout.total_epochs=2
rollout.test_freq=${test_freq}
# Fully async specific configurations
async_training.staleness_threshold=${staleness_threshold}
async_training.partial_rollout="${partial_rollout}"
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}"
)
if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then
echo "Running fully async training with FSDP2 strategy..."
# FSDP2 specific parameters
gen_tp=1
sp_size=1
fsdp_size=1
ref_offload=True
actor_offload=False
python3 -m recipe.fully_async_policy.fully_async_main \
"${common_params[@]}" \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@
elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then
echo "Running fully async training with Megatron strategy..."
# Megatron specific parameters
gen_tp=2
train_tp=1
train_pp=2
ref_offload=True
actor_offload=False
python3 -m recipe.fully_async_policy.fully_async_main \
--config-path=config \
--config-name='fully_async_ppo_megatron_trainer.yaml' \
"${common_params[@]}" \
actor_rollout_ref.actor.strategy=megatron \
critic.strategy=megatron \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \
actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@
else
echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'"
exit 1
fi
echo "Fully async policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy"

View File

@ -48,6 +48,7 @@ CUDA_KEYWORD_CHECK_WHITELIST = [
NCCL_KEYWORD_CHECK_WHITELIST = [
"verl/utils/device.py",
"verl/third_party/sglang/parallel_state.py", # appear in default backend
"verl/recipe/fully_async_policy/param_sync.py", # fully_async_policy in default backend
]
SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST

View File

@ -24,6 +24,7 @@ license_head_sglang = "Copyright 2023-2024 SGLang Team"
license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates"
license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates"
license_head_facebook = "Copyright (c) 2016- Facebook, Inc"
license_head_meituan = "Copyright 2025 Meituan Ltd. and/or its affiliates"
license_headers = [
license_head_bytedance,
license_head_bytedance_25,
@ -33,6 +34,7 @@ license_headers = [
license_head_modelbest,
license_head_amazon,
license_head_facebook,
license_head_meituan,
]

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .agent_loop import AgentLoopBase, AgentLoopManager, AsyncLLMServerManager
from .agent_loop import AgentLoopBase, AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager
from .single_turn_agent_loop import SingleTurnAgentLoop
from .tool_agent_loop import ToolAgentLoop
_ = [SingleTurnAgentLoop, ToolAgentLoop]
__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager"]
__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager", "AgentLoopWorker"]

View File

@ -370,8 +370,7 @@ class RewardManagerWorker:
return self.reward_manager(data, return_dict)
@ray.remote
class AgentLoopWorker:
class AgentLoopWorkerBase:
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
def __init__(
@ -384,7 +383,11 @@ class AgentLoopWorker:
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
"""
self.config = config
self.server_manager = AsyncLLMServerManager(config, server_handles)
# for recipe to change
if not hasattr(self, "server_manager"):
self.server_manager = AsyncLLMServerManager(config, server_handles)
self.rm_executor = rm_executor
model_path = config.actor_rollout_ref.model.path
@ -720,6 +723,22 @@ class AgentLoopWorker:
)
@ray.remote
class AgentLoopWorker(AgentLoopWorkerBase):
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
def __init__(
self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None
):
"""Initialize agent loop manager.
Args:
config (DictConfig): YAML config.
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
"""
super().__init__(config, server_handles, rm_executor)
async def get_trajectory_info(step, index, validate):
"""Get trajectory info.
@ -778,6 +797,12 @@ class AgentLoopManager:
self.rm_micro_batch_size = rm_wg.world_size
# for recipe to change
if not hasattr(self, "rollout_replica_class"):
self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
if not hasattr(self, "agent_loop_workers_class"):
self.agent_loop_workers_class = AgentLoopWorker
self._initialize_llm_servers()
self._init_agent_loop_workers()
@ -798,11 +823,10 @@ class AgentLoopManager:
)
num_replicas = world_size // rollout_world_size
rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
rollout_config = self.config.actor_rollout_ref.rollout
model_config = self.config.actor_rollout_ref.model
self.rollout_replicas = [
rollout_replica_class(
self.rollout_replica_class(
replica_rank=replica_rank,
config=rollout_config,
model_config=model_config,
@ -826,7 +850,7 @@ class AgentLoopManager:
# Round-robin scheduling over the all nodes
node_id = node_ids[i % len(node_ids)]
self.agent_loop_workers.append(
AgentLoopWorker.options(
self.agent_loop_workers_class.options(
name=f"agent_loop_worker_{i}",
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id, soft=True

View File

@ -39,4 +39,4 @@ entropy_from_logits_with_chunking: False
entropy_checkpointing: False
# Whether to remove padding tokens in inputs during training
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}

View File

@ -43,13 +43,14 @@ def main(config):
# Define a function to run the PPO-like training process
def run_ppo(config) -> None:
def run_ppo(config, task_runner_class=None) -> None:
"""Initialize Ray cluster and run distributed PPO training process.
Args:
config: Training configuration object containing all necessary parameters
for distributed PPO training including Ray initialization settings,
model paths, and training hyperparameters.
task_runner_class: For recipe to change TaskRunner.
"""
# Check if Ray is not initialized
if not ray.is_initialized():
@ -65,6 +66,9 @@ def run_ppo(config) -> None:
print(f"ray init kwargs: {ray_init_kwargs}")
ray.init(**OmegaConf.to_container(ray_init_kwargs))
if task_runner_class is None:
task_runner_class = TaskRunner
# Create a remote instance of the TaskRunner class, and
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
if (
@ -79,9 +83,9 @@ def run_ppo(config) -> None:
nsight_options = OmegaConf.to_container(
config.global_profiler.global_tool_config.nsys.controller_nsight_options
)
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
else:
runner = TaskRunner.remote()
runner = task_runner_class.remote()
ray.get(runner.run.remote(config))
# [Optional] get the path of the timeline trace file from the configuration, default to None

View File

@ -602,11 +602,9 @@ class RayPPOTrainer:
sample_scores.extend(scores)
reward_extra_infos_dict["reward"].extend(scores)
print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}")
if "reward_extra_info" in result:
for key, lst in result["reward_extra_info"].items():
reward_extra_infos_dict[key].extend(lst)
print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}")
# collect num_turns of each prompt
if "__num_turns__" in test_batch.non_tensor_batch:
@ -676,9 +674,9 @@ class RayPPOTrainer:
actor_rollout_cls = RayClassWithInitArgs(
cls=self.role_worker_mapping[Role.ActorRollout],
config=self.config.actor_rollout_ref,
role="actor_rollout",
role=str(Role.ActorRollout),
)
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls
else:
raise NotImplementedError
@ -687,7 +685,7 @@ class RayPPOTrainer:
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
critic_cfg = omega_conf_to_dataclass(self.config.critic)
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
# create reference policy if needed
if self.use_reference_policy:
@ -695,16 +693,16 @@ class RayPPOTrainer:
ref_policy_cls = RayClassWithInitArgs(
self.role_worker_mapping[Role.RefPolicy],
config=self.config.actor_rollout_ref,
role="ref",
role=str(Role.RefPolicy),
)
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
# create a reward model if reward_fn is None
if self.use_rm:
# we create a RM here
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
# initialize WorkerGroup
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
@ -739,20 +737,20 @@ class RayPPOTrainer:
all_wg.update(spawn_wg)
if self.use_critic:
self.critic_wg = all_wg["critic"]
self.critic_wg = all_wg[str(Role.Critic)]
self.critic_wg.init_model()
if self.use_reference_policy and not self.ref_in_actor:
self.ref_policy_wg = all_wg["ref"]
self.ref_policy_wg = all_wg[str(Role.RefPolicy)]
self.ref_policy_wg.init_model()
self.rm_wg = None
if self.use_rm:
self.rm_wg = all_wg["rm"]
self.rm_wg = all_wg[str(Role.RewardModel)]
self.rm_wg.init_model()
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
self.actor_rollout_wg = all_wg["actor_rollout"]
self.actor_rollout_wg = all_wg[str(Role.ActorRollout)]
self.actor_rollout_wg.init_model()
# create async rollout manager and request scheduler
@ -800,11 +798,13 @@ class RayPPOTrainer:
)
if self.use_critic:
critic_local_path = os.path.join(local_global_step_folder, "critic")
critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic))
critic_remote_path = (
None
if self.config.trainer.default_hdfs_dir is None
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
else os.path.join(
self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", str(Role.Critic)
)
)
self.critic_wg.save_checkpoint(
critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
@ -860,7 +860,7 @@ class RayPPOTrainer:
print(f"Resuming from {global_step_folder}")
actor_path = os.path.join(global_step_folder, "actor")
critic_path = os.path.join(global_step_folder, "critic")
critic_path = os.path.join(global_step_folder, str(Role.Critic))
# load actor
self.actor_rollout_wg.load_checkpoint(
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
@ -1129,7 +1129,7 @@ class RayPPOTrainer:
if self.use_reference_policy:
# compute reference log_prob
with marked_timer("ref", timing_raw, color="olive"):
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
if not self.ref_in_actor:
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
else:

View File

@ -36,6 +36,37 @@ class Role(Enum):
RewardModel = 5
ActorRolloutRef = 6
def __str__(self):
return self._get_role_string()
def _get_role_string(self):
role_mapping = {
Role.Actor: "actor",
Role.Rollout: "rollout",
Role.ActorRollout: "actor_rollout",
Role.Critic: "critic",
Role.RefPolicy: "ref",
Role.RewardModel: "rm",
Role.ActorRolloutRef: "actor_rollout_ref",
}
return role_mapping.get(self, self.name.lower())
@classmethod
def from_string(cls, name: str):
string_mapping = {
"actor": cls.Actor,
"rollout": cls.Rollout,
"actor_rollout": cls.ActorRollout,
"critic": cls.Critic,
"ref": cls.RefPolicy,
"rm": cls.RewardModel,
"actor_rollout_ref": cls.ActorRolloutRef,
}
role = string_mapping.get(name.lower())
if role is None:
raise ValueError(f"No Role found for string: {name}")
return role
def need_reference_policy(
role_worker_mapping: dict[Role, WorkerType],

View File

@ -78,7 +78,7 @@ class DataParallelPPOActor(BasePPOActor):
self.compute_entropy_from_logits = (
torch.compile(entropy_from_logits, dynamic=True)
if self.config.get("use_torch_compile", True) # use torch compile by default
if self.config.get("use_torch_compile", True) # use torch compile by default
else entropy_from_logits
)
self.device_name = get_device_name()
@ -427,10 +427,14 @@ class DataParallelPPOActor(BasePPOActor):
model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
)
if on_policy:
old_log_prob = log_prob.detach()
else:
# for fully_async_policy recipe
if hasattr(self.config, "use_rollout_log_probs") and self.config.use_rollout_log_probs:
old_log_prob = model_inputs["old_log_probs"]
else:
if on_policy:
old_log_prob = log_prob.detach()
else:
old_log_prob = model_inputs["old_log_probs"]
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla

View File

@ -231,6 +231,7 @@ class FSDPActorConfig(ActorConfig):
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
use_remove_padding: bool = False
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
use_rollout_log_probs: bool = False
def __post_init__(self):
"""Validate FSDP actor configuration parameters."""

View File

@ -110,8 +110,7 @@ class ExternalZeroMQDistributedExecutor(Executor):
return
@ray.remote(num_cpus=1)
class vLLMHttpServer:
class vLLMHttpServerBase:
"""vLLM http server in single node, this is equivalent to launch server with command line:
```
vllm serve --tensor-parallel-size=8 ...
@ -399,10 +398,42 @@ class vLLMHttpServer:
await self.engine.wait_for_requests_to_drain()
@ray.remote(num_cpus=1)
class vLLMHttpServer(vLLMHttpServerBase):
"""vLLM http server in single node, this is equivalent to launch server with command line:
```
vllm serve --tensor-parallel-size=8 ...
```
"""
def __init__(
self,
config: RolloutConfig | RewardModelConfig,
model_config: HFModelConfig,
rollout_mode: RolloutMode,
workers: list[ActorHandle],
replica_rank: int,
node_rank: int,
gpus_per_node: int,
nnodes: int,
):
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
_rollout_worker_actor_cls = ray.remote(vLLMAsyncRollout)
class vLLMReplica(RolloutReplica):
def __init__(
self,
replica_rank: int,
config: RolloutConfig | RewardModelConfig,
model_config: HFModelConfig,
gpus_per_node: int = 8,
):
super().__init__(replica_rank, config, model_config, gpus_per_node)
self.server_class = vLLMHttpServer
def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
"""Get rollout worker actor class for colocated and standalone mode."""
worker_dict_cls = RayClassWithInitArgs(
@ -437,7 +468,7 @@ class vLLMReplica(RolloutReplica):
for node_rank in range(nnodes):
workers = self.workers[node_rank * gpus_per_node : (node_rank + 1) * gpus_per_node]
node_id = worker_node_ids[node_rank * gpus_per_node]
server = vLLMHttpServer.options(
server = self.server_class.options(
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
node_id=node_id,
soft=False,