mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 21:53:50 +08:00
[trainer, recipe] feat: fully async training recipe (#2981)
### What does this PR do? To implement a purely asynchronous training workflow, we further split the training process into a Trainer and a Rollouter based on the existing one-step-off policy code, with samples transmitted via a message queue. We will continue to integrate partial rollout to mitigate the impact of long-tail training. > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. https://github.com/volcengine/verl/pull/2231 https://github.com/volcengine/verl/pull/2200 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: meituan-search <machi04@meituan.com> Co-authored-by: wangshulin02 <wangshulin02@meituan.com> Co-authored-by: arron <arron@MBP-2G17FXQ05P-2332.local> Co-authored-by: wangshulin02 <953550366@qq.com> Co-authored-by: hadoop-ai-search <hadoop-ai-search@set-zw04-mlp-codelab-pc1189.mt> Co-authored-by: sl-1314 <82856253+sl-1314@users.noreply.github.com> Co-authored-by: arron <arron@MBP-VH9RV7LTJC-1907.local> Co-authored-by: arron <arron@MBP-JFQXPWR11F-1943.local>
This commit is contained in:
149
.github/workflows/e2e_fully_async_policy.yml
vendored
Normal file
149
.github/workflows/e2e_fully_async_policy.yml
vendored
Normal file
@ -0,0 +1,149 @@
|
||||
# # Tests layout
|
||||
|
||||
# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance:
|
||||
# - `tests/trainer` for testing functionality related to `verl/trainer`
|
||||
# - `tests/models` for testing functionality related to `verl/models`
|
||||
# - ...
|
||||
|
||||
# There are a few folders with `special_` prefix, created for special purposes:
|
||||
# - `special_distributed`: unit tests that must run with multiple GPUs
|
||||
# - `special_e2e`: end-to-end tests with training/generation scripts
|
||||
# - `special_npu`: tests for NPUs
|
||||
# - `special_sanity`: a suite of quick sanity tests
|
||||
# - `special_standalone`: a set of test that are designed to run in dedicated environments
|
||||
|
||||
# Accelerators for tests
|
||||
# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`.
|
||||
# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment.
|
||||
|
||||
# # Workflow layout
|
||||
|
||||
# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs:
|
||||
# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml`
|
||||
# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml`
|
||||
# 3. End-to-end tests: `e2e_*.yml`
|
||||
# 4. Unit tests
|
||||
# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py`
|
||||
# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix.
|
||||
# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when
|
||||
# - new workflow yaml is added to `.github/workflows`
|
||||
# - new tests are added to workflow mentioned in 2.
|
||||
|
||||
|
||||
name: e2e_fully_async_policy
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
# For push, for now only anti-patterns are specified so it is more conservative
|
||||
# and achieves higher coverage.
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
- v0.*
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "!**/*.md"
|
||||
- "!**/*.sh"
|
||||
# Other entrypoints
|
||||
- "!examples/*trainer*"
|
||||
- "!tests/**"
|
||||
- "!verl/trainer/main_*.py"
|
||||
- "!verl/trainer/fsdp_sft_trainer.py"
|
||||
- "!recipe/**"
|
||||
- "recipe/fully_async_policy"
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- v0.*
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- "!**/*.md"
|
||||
- "!**/*.sh"
|
||||
# Other entrypoints
|
||||
- "!examples/**"
|
||||
- "!tests/**"
|
||||
- "!verl/trainer/main_*.py"
|
||||
- "!verl/trainer/fsdp_sft_trainer.py"
|
||||
# Other recipes
|
||||
- "!recipe/**"
|
||||
# Home
|
||||
- "recipe/fully_async_policy"
|
||||
# Entrypoints
|
||||
- ".github/workflows/e2e_fully_async_policy.yml"
|
||||
- "examples/data_preprocess/gsm8k.py"
|
||||
- "tests/special_e2e/run_fully_async_policy.sh"
|
||||
|
||||
# Cancel jobs on the same ref if a new one is triggered
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: ${{ github.ref != 'refs/heads/main' }}
|
||||
|
||||
# Declare permissions just read content.
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
env:
|
||||
IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2"
|
||||
DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner"
|
||||
TRANSFORMERS_VERSION: "4.56.2"
|
||||
|
||||
jobs:
|
||||
setup:
|
||||
if: github.repository_owner == 'volcengine'
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
runner-label: ${{ steps.create-runner.outputs.runner-label }}
|
||||
mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- id: create-runner
|
||||
uses: volcengine/vemlp-github-runner@v1
|
||||
with:
|
||||
mode: "create"
|
||||
faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}"
|
||||
mlp-image: "${{ env.IMAGE }}"
|
||||
|
||||
# Test FSDP2 strategy
|
||||
e2e_fully_async_policy_fsdp2:
|
||||
needs: setup
|
||||
runs-on: [ "${{ needs.setup.outputs.runner-label || 'L20x8' }}" ]
|
||||
timeout-minutes: 10 # Increase timeout for async training
|
||||
env:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
NO_PROXY: "localhost,127.0.0.1,hf-mirror.com"
|
||||
HF_ENDPOINT: "https://hf-mirror.com"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable
|
||||
ACTOR_STRATEGY: "fsdp2"
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install the current repository
|
||||
run: |
|
||||
pip3 install --no-deps -e .[test,gpu]
|
||||
pip3 install transformers==$TRANSFORMERS_VERSION
|
||||
- name: Prepare GSM8K dataset
|
||||
run: |
|
||||
python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k
|
||||
- name: Running the E2E test with fully_async_policy algorithm (FSDP2)
|
||||
run: |
|
||||
ray stop --force
|
||||
bash tests/special_e2e/run_fully_async_policy.sh
|
||||
|
||||
cleanup:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
[
|
||||
setup,
|
||||
e2e_fully_async_policy_fsdp2
|
||||
]
|
||||
if: always()
|
||||
steps:
|
||||
- id: destroy-runner
|
||||
uses: volcengine/vemlp-github-runner@v1
|
||||
with:
|
||||
mode: "destroy"
|
||||
faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}"
|
||||
mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}"
|
428
docs/advance/fully_async.md
Normal file
428
docs/advance/fully_async.md
Normal file
@ -0,0 +1,428 @@
|
||||
# Recipe: Fully Async Policy Async Trainer
|
||||
|
||||
**Author:** `https://github.com/meituan-search`
|
||||
|
||||
Last updated: 10/17/2025.
|
||||
|
||||
This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,
|
||||
supporting asynchronous sample generation and training.
|
||||
Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs,
|
||||
without significantly affecting the results.
|
||||
|
||||
## Introduction
|
||||
|
||||
### Background
|
||||
|
||||
The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more
|
||||
flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training
|
||||
efficiency caused by long-tail problems.
|
||||
The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by
|
||||
designing a separated architecture and performing asynchronous training between rollout and train for one round.
|
||||
However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot
|
||||
completely eliminate the impact of long-tail on training efficiency.
|
||||
In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have
|
||||
been implemented based on the separated architecture and have achieved gains.
|
||||
We借鉴 their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial
|
||||
rollout training.
|
||||
By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy
|
||||
can significantly improve training efficiency.
|
||||
|
||||
> Magistral https://arxiv.org/abs/2506.10910
|
||||
>
|
||||
> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language
|
||||
> Reasoning https://arxiv.org/abs/2505.24298
|
||||
>
|
||||
> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream
|
||||
> Generation https://arxiv.org/abs/2504.15930
|
||||
>
|
||||
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
|
||||
>
|
||||
|
||||
### Core Contributions
|
||||
|
||||
* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to
|
||||
specify the resources they occupy separately.
|
||||
* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.
|
||||
* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to
|
||||
multiple steps, making the asynchronous solution more flexible.
|
||||
* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter
|
||||
and Trainer.
|
||||
* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single
|
||||
sample as the minimum transmission unit.
|
||||
* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it
|
||||
supports training with samples generated by old parameters.
|
||||
* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter
|
||||
synchronization, by adding `sleep() and resume()` logic, it
|
||||
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
|
||||
ongoing tasks to finish during parameter synchronization.
|
||||
|
||||
Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop.
|
||||
|
||||
## Design
|
||||
|
||||
The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four
|
||||
parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.
|
||||
|
||||

|
||||
|
||||
1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the
|
||||
production speed controlled by freshness.
|
||||
2. MessageQueue is used to temporarily store samples generated by Rollouter.
|
||||
3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size`
|
||||
samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers
|
||||
a parameter synchronization with Rollouter.
|
||||
4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.
|
||||
|
||||
The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for
|
||||
rollout cannot solve the idleness caused by long-tail samples.
|
||||
After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources
|
||||
are used),
|
||||
but the overlap in their time consumption reduces the end-to-end time consumption.
|
||||
|
||||

|
||||
|
||||
## Usage
|
||||
|
||||
### Parameter Description
|
||||
|
||||
| super params | implication |
|
||||
|-----------------------------------------------|------------------------------------------------------------------------------------------------|
|
||||
| `trainer.nnodes` | Number of nodes for Trainer |
|
||||
| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
|
||||
| `rollout.nnodes` | Number of nodes for Rollouter |
|
||||
| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter |
|
||||
| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) |
|
||||
| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) |
|
||||
| `rollout.total_rollout_steps` | Total number of rollout samples |
|
||||
| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
|
||||
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
|
||||
| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
|
||||
| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
|
||||
| `async_training.staleness_threshold` | Freshness control |
|
||||
| `async_training.partial_rollout` | Whether to perform partial_rollout |
|
||||
| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
|
||||
|
||||
**Further Explanation:**
|
||||
|
||||
* `rollout.total_rollout_steps`
|
||||
|
||||
Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:
|
||||
`rollout.total_rollout_steps = data.train_batch_size * step`.
|
||||
|
||||
* `async_training.trigger_parameter_sync_step`
|
||||
|
||||
In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches
|
||||
`require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.
|
||||
Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process
|
||||
`trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.
|
||||
To fairly compare speed with colocate, trigger_parameter_sync_step should be set to
|
||||
`data.train_batch_size / (require_batches * ppo_mini_batch_size)`.
|
||||
|
||||
* `async_training.staleness_threshold`
|
||||
|
||||
In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.
|
||||
|
||||
* staleness_threshold=0, indicates synchronous training.
|
||||
Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:
|
||||
$$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
|
||||
* staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous
|
||||
calls.
|
||||
Rollouter will generate at most the following number of samples between two parameter updates:
|
||||
$$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
|
||||
|
||||
num_staleness_sample represents the number of stale samples generated in excess during the last rollout.
|
||||
|
||||
Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,
|
||||
trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.
|
||||
When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.
|
||||
To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.
|
||||
|
||||
* `async_training.partial_rollout`
|
||||
|
||||
partial_rollout only actually takes effect when staleness_threshold>0.
|
||||
|
||||
* `async_training.use_rollout_log_probs`
|
||||
|
||||
In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
|
||||
the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
|
||||
old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
|
||||
correctness. In the fully
|
||||
async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
|
||||
|
||||
* `async_training.require_batches`
|
||||
|
||||
In streaming training, require_batches should be set to 1, indicating that training is performed after producing
|
||||
enough ppo_mini_batch_size samples.
|
||||
In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can
|
||||
cause training instability and longer response lengths.
|
||||
Here, we additionally provide require_batches for streaming distribution and control the number of samples
|
||||
participating in training at once.
|
||||
|
||||
### Supported Modes
|
||||
|
||||
1. on policy pipeline:
|
||||
1. **trigger_parameter_sync_step=1, staleness_threshold=0**
|
||||
2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for
|
||||
training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
|
||||
3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill
|
||||
idle resources, causing some resource waste.
|
||||
4. As shown in figure a;
|
||||
|
||||
2. stream off policy pipeline:
|
||||
1. **trigger_parameter_sync_step>1, staleness_threshold=0**
|
||||
2. Synchronous streaming training will be performed. Rollouter produces
|
||||
`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local
|
||||
training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training
|
||||
trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
|
||||
3. Compared to a, since more samples are generated at once, resource idleness will be lower.
|
||||
4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,
|
||||
train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter
|
||||
update, rollout waits for training to complete.
|
||||
5. As shown in figure b;
|
||||
|
||||
3. async stream pipeline with stale samples:
|
||||
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**
|
||||
2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number
|
||||
of samples generated may be less than this value depending on rollout speed).
|
||||
3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples
|
||||
before parameter synchronization for immediate use by Trainer after synchronization.
|
||||
When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete
|
||||
and not add new tasks;
|
||||
4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the
|
||||
first batch rollout to finish, but will have the time to wait for active tasks to finish.
|
||||
5. As shown in figure c;
|
||||
|
||||
4. async stream pipeline with partial rollout:
|
||||
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**
|
||||
2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will
|
||||
interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be
|
||||
generated after synchronization. This reduces the time to wait for active tasks to finish.
|
||||
3. As shown in figure d;
|
||||
|
||||

|
||||
|
||||
### Key Metrics
|
||||
|
||||
| metrics | implication |
|
||||
|------------------------------------------------|--------------------------------------------------------------------------------------------------------|
|
||||
| `trainer/idle_ratio` | Trainer idle rate |
|
||||
| `rollouter/idle_ratio` | Rollouter idle rate |
|
||||
| `fully_async/count/stale_samples_processed` | Total number of old samples used in training |
|
||||
| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) |
|
||||
| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step |
|
||||
| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step |
|
||||
| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step |
|
||||
|
||||
### Parameter Tuning Recommendations
|
||||
|
||||
* Resource Allocation and Adjustment:
|
||||
* Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource
|
||||
allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire
|
||||
training process,
|
||||
avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource
|
||||
allocation can be adjusted based on the idle time of rollout and train during actual training,
|
||||
which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and
|
||||
trainer/idle_ratio is low,
|
||||
Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
|
||||
|
||||
* Key Parameters:
|
||||
* staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It
|
||||
is recommended to set it to less than 1.
|
||||
* require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and
|
||||
the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample
|
||||
processing;
|
||||
* trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent
|
||||
parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in
|
||||
low resource utilization.
|
||||
The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
|
||||
* rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
|
||||
|
||||
* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at
|
||||
different levels, suitable for tasks in different scenarios.
|
||||
* For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed
|
||||
requirements, the on policy pipeline mode (Mode 1) can be tried.
|
||||
* For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy
|
||||
pipeline mode can be tried. That is, by
|
||||
setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization
|
||||
mechanism (staleness_threshold=0) (Mode 2).
|
||||
* For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and
|
||||
staleness, setting staleness_threshold>
|
||||
0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).
|
||||
|
||||
### Quick Start
|
||||
|
||||
```shell
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*400)))
|
||||
test_freq=10
|
||||
staleness_threshold=0
|
||||
trigger_parameter_sync_step=16
|
||||
partial_rollout=False
|
||||
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.partial_rollout="${partial_rollout}"
|
||||
```
|
||||
|
||||
## Experiments
|
||||
|
||||
### Asynchronous Training on 7B Model
|
||||
|
||||
We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources.
|
||||
Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,
|
||||
64 cards, and 128 cards without significantly affecting experimental results.
|
||||
|
||||
* Machine: H20
|
||||
* Model: Qwen2.5-Math-7B
|
||||
* Rollout length: max_response_length FSDP2: 28K tokens;
|
||||
* Algorithm: DAPO
|
||||
* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
|
||||
* Engine: vllm+FSDP2
|
||||
* rollout.n: 16
|
||||
* ppo_mini_batch_size: 32
|
||||
* test_freq: 20
|
||||
|
||||
* colocate sync:
|
||||
* step: 400
|
||||
* train_batch_size: 512
|
||||
|
||||
* fully_async_policy
|
||||
* total_rollout_steps: 512*400
|
||||
* require_batches: 4
|
||||
* trigger_parameter_sync_step: 4
|
||||
* staleness_threshold: 0.3
|
||||
* partial_rollout: True
|
||||
|
||||
| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
|
||||
| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313<br>last: 0.2448 |
|
||||
| fully_async_policy | 16:16 | | | \ | | | | | | max: <br>last: |
|
||||
| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365<br>last: 0.2333 |
|
||||
| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m<br>(2.09x) | 10h 14m<br>(2.03x) | 16h 58m<br>(1.83x) | 21h 40m<br>(1.92x) | max: 0.3677<br>last: 0.3406 |
|
||||
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
|
||||
| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m<br>(2.67x) | 6h 46m<br>(2.65x) | 10h 53m<br>(2.67x) | 17h 22m<br>(2.35x) | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
|
||||
|
||||
### 128-card 7B Asynchronous Mode Experiment
|
||||
|
||||
We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async.
|
||||
We can see that the benefit brought by streaming is approximately 0.6x, and after combining staleness and
|
||||
partial_rollout, the benefit reaches 2.35x.
|
||||
|
||||
| mode | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
|
||||
| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
|
||||
| `async stream pipeline with stale samples`<br>(+staleness_threshold=0.5) | | | | | | | | | |
|
||||
| `async stream pipeline with partial rollout`<br>(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
|
||||
|
||||
### 128-card Stale Ablation Experiment
|
||||
|
||||
Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training
|
||||
efficiency.
|
||||
We found that the larger the staleness, the more obvious the final gains.
|
||||
We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps
|
||||
increase, the response length changes significantly, causing training instability.
|
||||
Further analysis and optimization are needed for this issue.
|
||||
|
||||
| staleness_threshold | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
|
||||
| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542<br>last: 0.2979 |
|
||||
| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469<br>last: 0.2865 |
|
||||
| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
|
||||
|
||||
### 128-card 7B require_batches Ablation Experiment
|
||||
|
||||
In multiple tests, we found that the number of samples issued each time in streaming affects the response length during
|
||||
training, which in turn affects training time. We verified the impact on results by modifying
|
||||
`async_training.require_batches`.
|
||||
|
||||
| require_batches | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | acc/mean@1 |
|
||||
|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349<br>last: 0.326 |
|
||||
| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351<br>last: 0.3406 |
|
||||
| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521<br>last: 0.3521 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
|
||||
|
||||
### 30B Model Mode Experiment
|
||||
|
||||
TODO: The 30B experiment is still in progress.
|
||||
|
||||
* Machine: H20
|
||||
* Model: Qwen2.5-32B~~~~
|
||||
* Rollout length: max_response_length FSDP2: 20K tokens;
|
||||
* Algorithm: DAPO
|
||||
* Engine: vllm+FSDP2
|
||||
* rollout.n: 16
|
||||
* ppo_mini_batch_size: 32
|
||||
* test_freq: 20
|
||||
|
||||
* colocate sync:
|
||||
* step:200
|
||||
* train_batch_size: 512
|
||||
|
||||
* fully_async_policy
|
||||
* total_rollout_steps: 512*200
|
||||
* trigger_parameter_sync_step: 512/32 = 16
|
||||
* staleness_threshold: 0
|
||||
* partial_rollout: False
|
||||
|
||||
| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |
|
||||
|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|
|
||||
| colocate sync | 128 | | | | | | | |
|
||||
| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | |
|
||||
| fully_async_policy | 64:64 | async stream pipeline with stale samples | | | | | | |
|
||||
| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | |
|
||||
|
||||
|
||||
## Future Plans
|
||||
|
||||
* GRPO experiments
|
||||
* Megatron adaptation
|
||||
* SGLang integration
|
||||
* Transfer queue integration
|
||||
* Asynchronous parameter synchronization
|
||||
* AReaL asynchronous algorithm implementation
|
||||
* TPPO algorithm implementation
|
||||
* Multi-turn and Tool support
|
@ -124,6 +124,7 @@ verl is fast with:
|
||||
advance/rollout_is_migration.md
|
||||
advance/one_step_off
|
||||
advance/agent_loop
|
||||
advance/fully_async
|
||||
|
||||
.. toctree::
|
||||
:maxdepth: 1
|
||||
|
428
recipe/fully_async_policy/README.md
Normal file
428
recipe/fully_async_policy/README.md
Normal file
@ -0,0 +1,428 @@
|
||||
# Recipe: Fully Async Policy Async Trainer
|
||||
|
||||
**Author:** `https://github.com/meituan-search`
|
||||
|
||||
Last updated: 10/17/2025.
|
||||
|
||||
This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter,
|
||||
supporting asynchronous sample generation and training.
|
||||
Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs,
|
||||
without significantly affecting the results.
|
||||
|
||||
## Introduction
|
||||
|
||||
### Background
|
||||
|
||||
The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more
|
||||
flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training
|
||||
efficiency caused by long-tail problems.
|
||||
The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by
|
||||
designing a separated architecture and performing asynchronous training between rollout and train for one round.
|
||||
However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot
|
||||
completely eliminate the impact of long-tail on training efficiency.
|
||||
In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have
|
||||
been implemented based on the separated architecture and have achieved gains.
|
||||
We借鉴 their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial
|
||||
rollout training.
|
||||
By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy
|
||||
can significantly improve training efficiency.
|
||||
|
||||
> Magistral https://arxiv.org/abs/2506.10910
|
||||
>
|
||||
> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language
|
||||
> Reasoning https://arxiv.org/abs/2505.24298
|
||||
>
|
||||
> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream
|
||||
> Generation https://arxiv.org/abs/2504.15930
|
||||
>
|
||||
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
|
||||
>
|
||||
|
||||
### Core Contributions
|
||||
|
||||
* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to
|
||||
specify the resources they occupy separately.
|
||||
* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples.
|
||||
* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to
|
||||
multiple steps, making the asynchronous solution more flexible.
|
||||
* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter
|
||||
and Trainer.
|
||||
* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single
|
||||
sample as the minimum transmission unit.
|
||||
* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it
|
||||
supports training with samples generated by old parameters.
|
||||
* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter
|
||||
synchronization, by adding `sleep() and resume()` logic, it
|
||||
saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for
|
||||
ongoing tasks to finish during parameter synchronization.
|
||||
|
||||
Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop.
|
||||
|
||||
## Design
|
||||
|
||||
The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four
|
||||
parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer.
|
||||
|
||||

|
||||
|
||||
1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the
|
||||
production speed controlled by freshness.
|
||||
2. MessageQueue is used to temporarily store samples generated by Rollouter.
|
||||
3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size`
|
||||
samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers
|
||||
a parameter synchronization with Rollouter.
|
||||
4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability.
|
||||
|
||||
The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for
|
||||
rollout cannot solve the idleness caused by long-tail samples.
|
||||
After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources
|
||||
are used),
|
||||
but the overlap in their time consumption reduces the end-to-end time consumption.
|
||||
|
||||

|
||||
|
||||
## Usage
|
||||
|
||||
### Parameter Description
|
||||
|
||||
| super params | implication |
|
||||
|-----------------------------------------------|------------------------------------------------------------------------------------------------|
|
||||
| `trainer.nnodes` | Number of nodes for Trainer |
|
||||
| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer |
|
||||
| `rollout.nnodes` | Number of nodes for Rollouter |
|
||||
| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter |
|
||||
| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) |
|
||||
| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) |
|
||||
| `rollout.total_rollout_steps` | Total number of rollout samples |
|
||||
| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation |
|
||||
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
|
||||
| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once |
|
||||
| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization |
|
||||
| `async_training.staleness_threshold` | Freshness control |
|
||||
| `async_training.partial_rollout` | Whether to perform partial_rollout |
|
||||
| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout |
|
||||
|
||||
**Further Explanation:**
|
||||
|
||||
* `rollout.total_rollout_steps`
|
||||
|
||||
Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step:
|
||||
`rollout.total_rollout_steps = data.train_batch_size * step`.
|
||||
|
||||
* `async_training.trigger_parameter_sync_step`
|
||||
|
||||
In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches
|
||||
`require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter.
|
||||
Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process
|
||||
`trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples.
|
||||
To fairly compare speed with colocate, trigger_parameter_sync_step should be set to
|
||||
`data.train_batch_size / (require_batches * ppo_mini_batch_size)`.
|
||||
|
||||
* `async_training.staleness_threshold`
|
||||
|
||||
In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used.
|
||||
|
||||
* staleness_threshold=0, indicates synchronous training.
|
||||
Rollouter will generate a fixed number of samples between two parameter updates, the sample count is:
|
||||
$$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
|
||||
* staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous
|
||||
calls.
|
||||
Rollouter will generate at most the following number of samples between two parameter updates:
|
||||
$$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
|
||||
|
||||
num_staleness_sample represents the number of stale samples generated in excess during the last rollout.
|
||||
|
||||
Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower,
|
||||
trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples.
|
||||
When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy.
|
||||
To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1.
|
||||
|
||||
* `async_training.partial_rollout`
|
||||
|
||||
partial_rollout only actually takes effect when staleness_threshold>0.
|
||||
|
||||
* `async_training.use_rollout_log_probs`
|
||||
|
||||
In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to
|
||||
the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling,
|
||||
old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm
|
||||
correctness. In the fully
|
||||
async strategy, we default to old_log_prob being calculated by rollout rather than by trainer.
|
||||
|
||||
* `async_training.require_batches`
|
||||
|
||||
In streaming training, require_batches should be set to 1, indicating that training is performed after producing
|
||||
enough ppo_mini_batch_size samples.
|
||||
In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can
|
||||
cause training instability and longer response lengths.
|
||||
Here, we additionally provide require_batches for streaming distribution and control the number of samples
|
||||
participating in training at once.
|
||||
|
||||
### Supported Modes
|
||||
|
||||
1. on policy pipeline:
|
||||
1. **trigger_parameter_sync_step=1, staleness_threshold=0**
|
||||
2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for
|
||||
training, and after training completes, Trainer and Rollouter perform a parameter synchronization;
|
||||
3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill
|
||||
idle resources, causing some resource waste.
|
||||
4. As shown in figure a;
|
||||
|
||||
2. stream off policy pipeline:
|
||||
1. **trigger_parameter_sync_step>1, staleness_threshold=0**
|
||||
2. Synchronous streaming training will be performed. Rollouter produces
|
||||
`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local
|
||||
training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training
|
||||
trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization;
|
||||
3. Compared to a, since more samples are generated at once, resource idleness will be lower.
|
||||
4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples,
|
||||
train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter
|
||||
update, rollout waits for training to complete.
|
||||
5. As shown in figure b;
|
||||
|
||||
3. async stream pipeline with stale samples:
|
||||
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False**
|
||||
2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number
|
||||
of samples generated may be less than this value depending on rollout speed).
|
||||
3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples
|
||||
before parameter synchronization for immediate use by Trainer after synchronization.
|
||||
When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete
|
||||
and not add new tasks;
|
||||
4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the
|
||||
first batch rollout to finish, but will have the time to wait for active tasks to finish.
|
||||
5. As shown in figure c;
|
||||
|
||||
4. async stream pipeline with partial rollout:
|
||||
1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True**
|
||||
2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will
|
||||
interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be
|
||||
generated after synchronization. This reduces the time to wait for active tasks to finish.
|
||||
3. As shown in figure d;
|
||||
|
||||

|
||||
|
||||
### Key Metrics
|
||||
|
||||
| metrics | implication |
|
||||
|------------------------------------------------|--------------------------------------------------------------------------------------------------------|
|
||||
| `trainer/idle_ratio` | Trainer idle rate |
|
||||
| `rollouter/idle_ratio` | Rollouter idle rate |
|
||||
| `fully_async/count/stale_samples_processed` | Total number of old samples used in training |
|
||||
| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) |
|
||||
| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step |
|
||||
| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step |
|
||||
| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step |
|
||||
|
||||
### Parameter Tuning Recommendations
|
||||
|
||||
* Resource Allocation and Adjustment:
|
||||
* Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource
|
||||
allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire
|
||||
training process,
|
||||
avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource
|
||||
allocation can be adjusted based on the idle time of rollout and train during actual training,
|
||||
which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and
|
||||
trainer/idle_ratio is low,
|
||||
Trainer resources should be increased and Rollouter resources should be reduced, and vice versa.
|
||||
|
||||
* Key Parameters:
|
||||
* staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It
|
||||
is recommended to set it to less than 1.
|
||||
* require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and
|
||||
the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample
|
||||
processing;
|
||||
* trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent
|
||||
parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in
|
||||
low resource utilization.
|
||||
The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy.
|
||||
* rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small.
|
||||
|
||||
* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at
|
||||
different levels, suitable for tasks in different scenarios.
|
||||
* For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed
|
||||
requirements, the on policy pipeline mode (Mode 1) can be tried.
|
||||
* For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy
|
||||
pipeline mode can be tried. That is, by
|
||||
setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization
|
||||
mechanism (staleness_threshold=0) (Mode 2).
|
||||
* For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and
|
||||
staleness, setting staleness_threshold>
|
||||
0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4).
|
||||
|
||||
### Quick Start
|
||||
|
||||
```shell
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*400)))
|
||||
test_freq=10
|
||||
staleness_threshold=0
|
||||
trigger_parameter_sync_step=16
|
||||
partial_rollout=False
|
||||
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.partial_rollout="${partial_rollout}"
|
||||
```
|
||||
|
||||
## Experiments
|
||||
|
||||
### Asynchronous Training on 7B Model
|
||||
|
||||
We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources.
|
||||
Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards,
|
||||
64 cards, and 128 cards without significantly affecting experimental results.
|
||||
|
||||
* Machine: H20
|
||||
* Model: Qwen2.5-Math-7B
|
||||
* Rollout length: max_response_length FSDP2: 28K tokens;
|
||||
* Algorithm: DAPO
|
||||
* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
|
||||
* Engine: vllm+FSDP2
|
||||
* rollout.n: 16
|
||||
* ppo_mini_batch_size: 32
|
||||
* test_freq: 20
|
||||
|
||||
* colocate sync:
|
||||
* step: 400
|
||||
* train_batch_size: 512
|
||||
|
||||
* fully_async_policy
|
||||
* total_rollout_steps: 512*400
|
||||
* require_batches: 4
|
||||
* trigger_parameter_sync_step: 4
|
||||
* staleness_threshold: 0.3
|
||||
* partial_rollout: True
|
||||
|
||||
| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
|
||||
| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313<br>last: 0.2448 |
|
||||
| fully_async_policy | 16:16 | | | \ | | | | | | max: <br>last: |
|
||||
| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365<br>last: 0.2333 |
|
||||
| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m<br>(2.09x) | 10h 14m<br>(2.03x) | 16h 58m<br>(1.83x) | 21h 40m<br>(1.92x) | max: 0.3677<br>last: 0.3406 |
|
||||
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
|
||||
| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m<br>(2.67x) | 6h 46m<br>(2.65x) | 10h 53m<br>(2.67x) | 17h 22m<br>(2.35x) | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
|
||||
|
||||
### 128-card 7B Asynchronous Mode Experiment
|
||||
|
||||
We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async.
|
||||
We can see that the benefit brought by streaming is approximately 0.6x, and after combining staleness and
|
||||
partial_rollout, the benefit reaches 2.35x.
|
||||
|
||||
| mode | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
|
||||
| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
|
||||
| `async stream pipeline with stale samples`<br>(+staleness_threshold=0.5) | | | | | | | | | |
|
||||
| `async stream pipeline with partial rollout`<br>(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
|
||||
|
||||
### 128-card Stale Ablation Experiment
|
||||
|
||||
Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training
|
||||
efficiency.
|
||||
We found that the larger the staleness, the more obvious the final gains.
|
||||
We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps
|
||||
increase, the response length changes significantly, causing training instability.
|
||||
Further analysis and optimization are needed for this issue.
|
||||
|
||||
| staleness_threshold | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
|
||||
| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542<br>last: 0.2979 |
|
||||
| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469<br>last: 0.2865 |
|
||||
| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
|
||||
|
||||
### 128-card 7B require_batches Ablation Experiment
|
||||
|
||||
In multiple tests, we found that the number of samples issued each time in streaming affects the response length during
|
||||
training, which in turn affects training time. We verified the impact on results by modifying
|
||||
`async_training.require_batches`.
|
||||
|
||||
| require_batches | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | acc/mean@1 |
|
||||
|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349<br>last: 0.326 |
|
||||
| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351<br>last: 0.3406 |
|
||||
| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521<br>last: 0.3521 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
|
||||
|
||||
### 30B Model Mode Experiment
|
||||
|
||||
TODO: The 30B experiment is still in progress.
|
||||
|
||||
* Machine: H20
|
||||
* Model: Qwen2.5-32B~~~~
|
||||
* Rollout length: max_response_length FSDP2: 20K tokens;
|
||||
* Algorithm: DAPO
|
||||
* Engine: vllm+FSDP2
|
||||
* rollout.n: 16
|
||||
* ppo_mini_batch_size: 32
|
||||
* test_freq: 20
|
||||
|
||||
* colocate sync:
|
||||
* step:200
|
||||
* train_batch_size: 512
|
||||
|
||||
* fully_async_policy
|
||||
* total_rollout_steps: 512*200
|
||||
* trigger_parameter_sync_step: 512/32 = 16
|
||||
* staleness_threshold: 0
|
||||
* partial_rollout: False
|
||||
|
||||
| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |
|
||||
|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|
|
||||
| colocate sync | 128 | | | | | | | |
|
||||
| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | |
|
||||
| fully_async_policy | 64:64 | async stream pipeline with stale samples | | | | | | |
|
||||
| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | |
|
||||
|
||||
|
||||
## Future Plans
|
||||
|
||||
* GRPO experiments
|
||||
* Megatron adaptation
|
||||
* SGLang integration
|
||||
* Transfer queue integration
|
||||
* Asynchronous parameter synchronization
|
||||
* AReaL asynchronous algorithm implementation
|
||||
* TPPO algorithm implementation
|
||||
* Multi-turn and Tool support
|
373
recipe/fully_async_policy/README_zh.md
Normal file
373
recipe/fully_async_policy/README_zh.md
Normal file
@ -0,0 +1,373 @@
|
||||
# Recipe: Fully Async Policy Async Trainer
|
||||
|
||||
**Author:** `https://github.com/meituan-search`
|
||||
|
||||
Last updated: 10/17/2025.
|
||||
|
||||
本文档介绍了完全异步PPO训练系统,该系统实现了 Trainer 和 Rollouter 的完全解耦,支持异步样本生成和训练。
|
||||
在该系统下,我们使用128卡训练qwen2.5-7B模型取得了2.35x-2.67x的性能提升,同时效果没有显著受到影响。
|
||||
|
||||
## Introduction
|
||||
|
||||
### Background
|
||||
|
||||
rollout和train分离架构相较于colocate的架构能够更加灵活地分配资源,设计更加灵活的训练逻辑,从而处理长尾等问题带来的GPU利用率低,训练效率低的问题。
|
||||
one_step_off_policy通过分离架构的设计并进行rollout和train一轮异步的训练方法,缓解了rollout时间过长的问题,并在训练效率上取得了一些收益,
|
||||
但其强制使用一轮异步的数据,存在不够灵活等问题,而且并不能完全去除长尾对训练效率带来的的影响;在其他框架如areal、Magistral、streamrl、asyncflow上,
|
||||
已经基于分离架构实现了异步训练、流式训练,并取得了收益;我们借鉴其方法,在verl上进行了实现。fully_async_policy支持异步、流式、partial
|
||||
rollout的训练, 通过合理设置资源分配情况、参数同步频率等参数,fully_async_policy能够显著提高训练效率。
|
||||
|
||||
> Magistral https://arxiv.org/abs/2506.10910
|
||||
>
|
||||
> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language
|
||||
> Reasoning https://arxiv.org/abs/2505.24298
|
||||
>
|
||||
> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream
|
||||
> Generation https://arxiv.org/abs/2504.15930
|
||||
>
|
||||
> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663
|
||||
>
|
||||
|
||||
### 核心贡献
|
||||
|
||||
* **资源隔离**:与使用hybrid_engine不同,Rollouter和Trainer使用分离的计算资源,需要分别指定所占用的资源。
|
||||
* **生成与训练并行**:Trainer在训练的同时,Rollouter在生成新的样本。
|
||||
* **多步异步**: 相比 one step off policy 支持0.x步到多步的异步设定,异步方案更加灵活。
|
||||
* **nccl参数同步**:使用nccl通信原语进行Rollouter与Trainer参数的通信。
|
||||
* **Stream推理与训练**:Rollouter逐样本生成数据,同时数据传输以单个sample为最小传输单位。
|
||||
* **异步训练与新鲜度控制**:通过设置参数async_training.staleness_threshold,支持使用旧参数生成的样本进行训练。
|
||||
* **PartialRollout**: Rollouter推理过程支持partial rollout逻辑,通过参数同步时,添加`sleep()`和`resume()`
|
||||
逻辑,保存进行中的rollout的样本,并在下一次rollout中继续使用,减少参数同步等待进行中的任务结束时间。
|
||||
|
||||
目前支持使用模式为 fsdp+vllm。vllm必须使用基于AgentLoop的server模式。
|
||||
|
||||
## 设计
|
||||
|
||||
fully_async_policy的整体架构如下图所示,fully_async_policy主要由Rollouter、MessageQueue、Trainer、ParameterSynchronizer四部分组成。
|
||||
|
||||

|
||||
|
||||
1. Rollouter逐样本生成序列,并将生成的sample放入MessageQueue中,生产的速度受新鲜度控制。
|
||||
2. MessageQueue用于暂存Rollouter生成的sample。
|
||||
3. Trainer逐样本从MessageQueue中获取,获取到`require_batches*ppo_mini_batch_size`
|
||||
数量的样本后,就会进行训练,训练async_training.trigger_parameter_sync_step轮后,触发与Rollouter的一次参数同步。
|
||||
4. ParameterSynchronizer 实现了Nccl的同步参数同步能力。
|
||||
|
||||
当前方案对比base的收益来源,在于colocate情况下,rollout使用更多的资源无法解决长尾样本带来的空闲,
|
||||
当我们进行资源隔离后,rollout的时间和train的时间都可能相较于之前更长(因为使用的资源变少了),
|
||||
但是相互之间的耗时overlap,端到端的耗时反而有所缩减。
|
||||
|
||||

|
||||
|
||||
## 使用方式
|
||||
|
||||
### 参数说明
|
||||
|
||||
| super params | implication |
|
||||
|-----------------------------------------------|-----------------------------------------------------------------|
|
||||
| `trainer.nnodes` | Trainer的node数量 |
|
||||
| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 |
|
||||
| `rollout.nnodes` | Rollouter的node数量 |
|
||||
| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 |
|
||||
| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) |
|
||||
| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) |
|
||||
| `rollout.total_rollout_steps` | 总的rollout的sample数量 |
|
||||
| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation |
|
||||
| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus |
|
||||
| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 |
|
||||
| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 |
|
||||
| `async_training.staleness_threshold` | 新鲜度控制 |
|
||||
| `async_training.partial_rollout` | 是否进行partial_rollout |
|
||||
| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs |
|
||||
|
||||
**进一步的解释:**
|
||||
|
||||
* `rollout.total_rollout_steps`
|
||||
|
||||
与 colocate 相比,数量可以通过 train_batch_size 与 step 相乘对齐:
|
||||
`rollout.total_rollout_steps = data.train_batch_size * step`。
|
||||
|
||||
* `async_training.trigger_parameter_sync_step`
|
||||
|
||||
在fully async策略中,表示Trainer进行多少次本地更新后(也就是获取多少次`require_batches * ppo_mini_batch_size`数量样本),
|
||||
与Rollouter之间进行一次参数同步。
|
||||
每两次Rollouter和Trainer参数同步之间,Trainer将会处理`trigger_parameter_sync_step* require_batches\
|
||||
ppo_mini_batch_size`份sample。
|
||||
如果为了与colocate在公平的情况下对比速度,trigger_parameter_sync_step应该设置为 `data.train_batch_size / (
|
||||
require_batches * ppo_mini_batch_size)`。
|
||||
|
||||
* `async_training.staleness_threshold`
|
||||
|
||||
在fully async策略中,表示最大允许使用的staleness样本的比例。
|
||||
|
||||
* staleness_threshold=0,表示同步训练。
|
||||
Rollouter两次参数更新之间将会生成固定数量的样本,样本数为:
|
||||
$$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$
|
||||
* staleness_threshold>0,表示异步训练, 可以设置为小数,支持更灵活的异步调用。
|
||||
Rollouter两次参数更新之间将会最多生成的样本数为:
|
||||
$$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$
|
||||
|
||||
num_staleness_sample 表示上一次rollout多生成的陈旧样本数。
|
||||
|
||||
由于是流式系统,rollout持续生成,trainer持续消费。如果rollouter较慢,trainer会更早触发参数同步,rollouter并不会实际生产rollout_num个样本。
|
||||
当rollout 足够快时,staleness_threshold设置为1,基本上等价于one_step_off policy。
|
||||
为了避免过期样本太多影响训练精度,建议该值设置小于1。
|
||||
|
||||
* `async_training.partial_rollout`
|
||||
|
||||
partial_rollout只会在staleness_threshold>0时才实际上起作用。
|
||||
|
||||
* `async_training.use_rollout_log_probs`
|
||||
|
||||
在强化学习算法中,log_probs与参数版本,token都存在隐性的相关性。由于PPO/GRPO/DAPO等算法的设定,我们在计算重要性采样时,
|
||||
即 old_log_prob必须使用rollout参数及token所对应log_probs,才能保证算法的正确性。在fully
|
||||
async策略中,我们默认old_log_prob是有rollout所计算的,而不是由trainer所计算。
|
||||
|
||||
* `async_training.require_batches`
|
||||
|
||||
在流式训练中,require_batches 应该设置为1,表示生产够ppo_mini_batch_size样本后,就进行训练。
|
||||
在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。
|
||||
在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。
|
||||
|
||||
### 模式支持
|
||||
|
||||
1. on policy pipeline:
|
||||
1. **trigger_parameter_sync_step=1,staleness_threshold=0**
|
||||
2. Rollouter一次生产`require_batches*ppo_mini_batch_size`
|
||||
的samples,Trainer获取这些samples后进行训练,训练完后Trainer和Rollouter之间进行一次参数同步;
|
||||
3. 在rollout阶段,如果存在长尾的样本,但是rollout样本数较少时,较短的样本无法填充到空闲的资源中,会造成一定的资源浪费。
|
||||
4. 如图a所示;
|
||||
|
||||
2. stream off policy pipeline:
|
||||
1. **trigger_parameter_sync_step>1,staleness_threshold=0**
|
||||
2. 将会进行同步的流式训练,Rollouter一次生产`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step`
|
||||
的samples,Trainer每获取`require_batches*ppo_mini_batch_size`
|
||||
就进行一次本地训练,训练trigger_parameter_sync_step次后,Trainer和Rollouter之间进行一次参数同步;
|
||||
3. 相较于a,由于一次生成的样本更多,资源的空闲会更低。
|
||||
4. 在一次step训练中,会存在两次资源闲置的时间,分别是在第一次获取样本时,train等待`require_batches*ppo_mini_batch_size`
|
||||
个样本生产,以及最后一次参数更新时,rollout等待训练完成。
|
||||
5. 如图b所示;
|
||||
|
||||
3. async stream pipeline with staleness samples:
|
||||
1. **trigger_parameter_sync_step>=1,staleness_threshold>0,partial_rollout=Flase**
|
||||
2. Rollouter在每次参数更新后将计划最多生产rollout_num个样本(实际根据rollout速度,生成的样本可能会少与这个值)。
|
||||
3. 如果rollout过程比较快,Rollouter将会在参数同步前额外生成一部分样本num_stale_samples,用于参数同步后立即给Trainer使用。
|
||||
触发参数同步时,如果Rollouter有正在生产的任务,将会等待任务完成,同时不会添加新的任务;
|
||||
4. 相较于b,除第一次step训练外,后续的训练都不会有wait first batch rollout finish的时间,但是会有wait active task
|
||||
finish的时间。
|
||||
5. 如图c所示;
|
||||
|
||||
4. async stream pipeline with partial rollout:
|
||||
1. **trigger_parameter_sync_step>=1,staleness_threshold>0,partial_rollout=True**
|
||||
2. 相较于c,触发参数同步时,Rollouter如果有正在生产的sample,会打断rollout过程并进行参数同步,被中断的sample会在参数同步后继续生成。减少了wait
|
||||
active task finish的时间。
|
||||
3. 如图d所示;
|
||||
|
||||

|
||||
|
||||
### 关键指标
|
||||
|
||||
| metrics | implication |
|
||||
|------------------------------------------------|-----------------------------------------------------------|
|
||||
| `trainer/idle_ratio` | Trainer闲置率 |
|
||||
| `rollouter/idle_ratio` | Rollouter闲置率 |
|
||||
| `fully_async/count/stale_samples_processed` | 训练使用的旧sample总数 |
|
||||
| `fully_async/count/stale_trajectory_processed` | 训练使用的旧trajectory总数(一个sample会生产rollout.n条trajectory) |
|
||||
| `fully_async/partial/total_partial_num` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本数 |
|
||||
| `fully_async/partial/partial_ratio` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的比例 |
|
||||
| `fully_async/partial/max_partial_span` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的最大参数跨度 |
|
||||
|
||||
### 调参建议
|
||||
|
||||
* 资源分配与调整:
|
||||
* 合理的资源分配是获得好的训练效率的前提。理想的资源分配情况应该是使得Rollout的时间和Train的时间接近,从而使得整个训练过程流水气泡最小,
|
||||
避免资源闲置,同时Trainer不会使用旧样本。在真实训练场景下,可以根据实际训练过程中rollout和train的空闲时间调整资源分配,
|
||||
可从rollouter/idle_ratio和trainer/idle_ratio获得,如果rollouter/idle_ratio较高trainer/idle_ratio较低,
|
||||
应该增多Trainer的资源减少Rollouter的资源,反之亦然。
|
||||
|
||||
* 关键参数:
|
||||
* staleness_threshold: 设置太大会导致较多的旧样本使用,影响模型效果,建议设置小于1。
|
||||
* require_batches:越接近1,越接近纯流式过程,训练过程中bubble越小,能够在速度上获得更快的加速效果,但会对样本的处理顺序产生影响;
|
||||
* trigger_parameter_sync_step: 设置的越小越接近on policy,但会导致频繁的参数同步,长尾样本浪费的资源无法被短样本填充,资源利用率低。
|
||||
设置的越大有更高的计算效率,但是精度上会受到off policy的影响。
|
||||
* rollout.test_freq: 会占用Rollouter资源,不建议设置太小。
|
||||
|
||||
* 模式选择:通过调整不同的参数,Fully Async架构支持不同程度上的优化加速,适用于不同场景的任务。
|
||||
* 对于小规模任务,需要保证训练的稳定性和 on-policy 性,对速度要求不高的场景,可以尝试使用on policy pipeline的模式(模式1)。
|
||||
* 对于需要提高训练吞吐量,但对 staleness 敏感的场景,可以尝试使用 stream off policy pipeline 的模式。即通过
|
||||
设置trigger_parameter_sync_step>1 ,提高 训练效率,但仍保持同步机制 (staleness_threshold=0 )(模式2)。
|
||||
* 对于大规模任务,对训练速度有较高要求,且可以容忍一定 off-policy 程度、staleness的场景,可以设置staleness_threshold>
|
||||
0、partial_rollout=True提高训练效率,使用 async stream pipeline 模式(模式 3 或 4)。
|
||||
|
||||
### 快速开始
|
||||
|
||||
```shell
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*400)))
|
||||
test_freq=10
|
||||
staleness_threshold=0
|
||||
trigger_parameter_sync_step=16
|
||||
partial_rollout=False
|
||||
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.partial_rollout="${partial_rollout}"
|
||||
```
|
||||
|
||||
## 实验
|
||||
|
||||
### 在7B模型上进行异步训练
|
||||
|
||||
我们使用 Qwen2.5-Math-7B 验证 fully async 策略在长候选下,多种资源下的收益情况。
|
||||
使用`async stream pipeline with staleness samples` 策略,我们在32卡,64卡,128卡都取得2x左右的性能提升,同时没有显著影响实验效果。
|
||||
|
||||
* 机器:H20
|
||||
* 模型:Qwen2.5-Math-7B
|
||||
* rollout长度:max_response_length FSDP2: 28K tokens;
|
||||
* 算法:DAPO
|
||||
* 数据集: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet
|
||||
* engine: vllm+FSDP2
|
||||
* rollout.n: 16
|
||||
* ppo_mini_batch_size: 32
|
||||
* test_freq: 20
|
||||
|
||||
* colocate sync:
|
||||
* step: 400
|
||||
* train_batch_size: 512
|
||||
|
||||
* fully_async_policy
|
||||
* total_rollout_steps: 512*400
|
||||
* require_batches: 4
|
||||
* trigger_parameter_sync_step: 4
|
||||
* staleness_threshold: 0.3
|
||||
* partial_rollout: True
|
||||
|
||||
| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:|
|
||||
| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313<br>last: 0.2448 |
|
||||
| fully_async_policy | 16:16 | | | \ | | | | | | max: <br>last: |
|
||||
| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365<br>last: 0.2333 |
|
||||
| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m<br>(2.09x) | 10h 14m<br>(2.03x) | 16h 58m<br>(1.83x) | 21h 40m<br>(1.92x) | max: 0.3677<br>last: 0.3406 |
|
||||
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
|
||||
| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m<br>(2.67x) | 6h 46m<br>(2.65x) | 10h 53m<br>(2.67x) | 17h 22m<br>(2.35x) | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg
|
||||
|
||||
### 128卡 7B 异步模式实验
|
||||
|
||||
我们使用 Qwen2.5-Math-7B 验证 fully async 所支持的各个模式的效果。
|
||||
我们可以看到 stream 带来的收益大约0.6x,叠加 staleness 和 partial_rollout 后,收益为2.35x。
|
||||
|
||||
| mode | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573<br>last: 0.2958 |
|
||||
| `stream off policy pipeline`<br>(+fully async: trigger_parameter_sync_step= 4,<br>require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
|
||||
| `async stream pipeline with staleness samples`<br>(+staleness_threshold=0.5) | | | | | | | | | |
|
||||
| `async stream pipeline with partial rollout`<br>(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg
|
||||
|
||||
### 128卡 stale 消融实验
|
||||
|
||||
在 `async stream pipeline with partial rollout` 模式下,我们验证 staleness 的设置对于训练效率的影响。
|
||||
我们可以发现,staleness 越大,最终取得的收益越明显。
|
||||
同时我们也注意到 staleness 取 0.3 和 0.5 的时间比较接近,原因是随着训练步数的增量,response 长度变化较大,训练出现了不稳定的问题。
|
||||
后续还需要针对该问题进行进一步的分析和优化。
|
||||
|
||||
| staleness_threshold | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | total time<br>400 step | acc/mean@1 |
|
||||
|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844<br>last: 0.2604 |
|
||||
| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542<br>last: 0.2979 |
|
||||
| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469<br>last: 0.2865 |
|
||||
| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521<br>last: 0.3094 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_stale?nw=nwuserhouzg
|
||||
|
||||
### 128卡 7B require_batches 消融实验
|
||||
|
||||
在多次测试下,我们发现流式每次下发样本的数量会影响训练的response长度,进而影响训练时长,我们通过修改
|
||||
`async_training.require_batches` 验证对与结果的影响。
|
||||
|
||||
| require_batches | step | gen | old_log_prob | update_actor | total time<br>100 step | total time<br>200 step | total time<br>300 step | acc/mean@1 |
|
||||
|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:|
|
||||
| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349<br>last: 0.326 |
|
||||
| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351<br>last: 0.3406 |
|
||||
| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521<br>last: 0.3521 |
|
||||
|
||||
> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg
|
||||
|
||||
### 30B模型模式实验
|
||||
|
||||
TODO: 30B 的实验,还在完善中。
|
||||
|
||||
* 机器: H20
|
||||
* 模型:Qwen2.5-32B
|
||||
* rollout长度:max_response_length FSDP2: 20K tokens;
|
||||
* 算法:DAPO
|
||||
* engine: vllm+FSDP2
|
||||
* rollout.n: 16
|
||||
* ppo_mini_batch_size: 32
|
||||
* test_freq: 20
|
||||
|
||||
* colacate sync:
|
||||
* step:200
|
||||
* train_batch_size: 512
|
||||
|
||||
* fully_async_policy
|
||||
* total_rollout_steps: 512*200
|
||||
* trigger_parameter_sync_step: 512/32 = 16
|
||||
* staleness_threshold: 0
|
||||
* partial_rollout: False
|
||||
|
||||
| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean |
|
||||
|--------------------|---------------------|----------------------------------------------|------|--------------------|--------------|--------------|------------|------------------|
|
||||
| colocate sync | 128 | | | | | | | |
|
||||
| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | |
|
||||
| fully_async_policy | 64:64 | async stream pipeline with staleness samples | | | | | | |
|
||||
| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | |
|
||||
|
||||
|
||||
## 后续计划
|
||||
|
||||
* GRPO实验
|
||||
* megatron 适配
|
||||
* sglang 集成
|
||||
* transfer queue 集成
|
||||
* 异步参数同步
|
||||
* Areal异步算法实现
|
||||
* TPPO算法实现
|
||||
* 多轮及Tool的支持
|
19
recipe/fully_async_policy/agent_loop/__init__.py
Normal file
19
recipe/fully_async_policy/agent_loop/__init__.py
Normal file
@ -0,0 +1,19 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .agent_loop import FullyAsyncAgentLoopManager
|
||||
from .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop
|
||||
|
||||
_ = [PartialSingleTurnAgentLoop]
|
||||
__all__ = [FullyAsyncAgentLoopManager]
|
275
recipe/fully_async_policy/agent_loop/agent_loop.py
Normal file
275
recipe/fully_async_policy/agent_loop/agent_loop.py
Normal file
@ -0,0 +1,275 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import hydra
|
||||
import numpy as np
|
||||
import ray
|
||||
import torch
|
||||
from omegaconf import DictConfig
|
||||
|
||||
from recipe.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica
|
||||
from verl.experimental.agent_loop.agent_loop import (
|
||||
AgentLoopManager,
|
||||
AgentLoopOutput,
|
||||
AgentLoopWorkerBase,
|
||||
AsyncLLMServerManager,
|
||||
BatchExecutor,
|
||||
_agent_loop_registry,
|
||||
_DummyConfig,
|
||||
get_trajectory_info,
|
||||
)
|
||||
from verl.protocol import DataProto
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
from verl.utils.rollout_trace import rollout_trace_attr
|
||||
from verl.workers.rollout.replica import TokenOutput
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
class FullyAsyncLLMServerManager(AsyncLLMServerManager):
|
||||
async def generate_for_partial(self, request_id, prompt_ids, sampling_params) -> TokenOutput:
|
||||
"""Generate tokens from prompt ids. with partial rollout function"""
|
||||
server = self._choose_server(request_id)
|
||||
output = await server.generate_for_partial.remote(
|
||||
request_id=request_id,
|
||||
prompt_ids=prompt_ids,
|
||||
sampling_params=sampling_params,
|
||||
)
|
||||
return output
|
||||
|
||||
|
||||
class FullyAsyncAgentLoopOutput(AgentLoopOutput):
|
||||
"""Agent loop output."""
|
||||
|
||||
is_cancel: bool = False
|
||||
"""Indicates whether the request was interrupted"""
|
||||
log_probs: list[float] = None
|
||||
"""Response token log probs including LLM generated token, tool response token."""
|
||||
param_version_start: int = 0
|
||||
"""Indicate start parameter version when this response is generated"""
|
||||
param_version_end: int = 0
|
||||
"""Indicate end parameter version when this response is generated, used for partial rollout"""
|
||||
|
||||
|
||||
@ray.remote
|
||||
class FullyAsyncAgentLoopWorker(AgentLoopWorkerBase):
|
||||
def __init__(
|
||||
self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None
|
||||
):
|
||||
self.server_manager = FullyAsyncLLMServerManager(config, server_handles)
|
||||
super().__init__(config, server_handles, rm_executor)
|
||||
|
||||
async def generate_sequences_no_post(
|
||||
self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]]
|
||||
) -> list[AgentLoopOutput]:
|
||||
"""Generate sequences from agent loop.
|
||||
|
||||
Args:
|
||||
batch (DataProto): Input batch.
|
||||
partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result.
|
||||
|
||||
Returns:
|
||||
list[FullyAsyncAgentLoopOutput]: List of agent loop outputs, one per sample in the batch.
|
||||
"""
|
||||
config = self.config.actor_rollout_ref.rollout
|
||||
sampling_params = dict(
|
||||
temperature=config.temperature,
|
||||
top_p=config.top_p,
|
||||
repetition_penalty=1.0,
|
||||
logprobs=config.calculate_log_probs,
|
||||
)
|
||||
|
||||
# override sampling params for validation
|
||||
if batch.meta_info.get("validate", False):
|
||||
sampling_params["top_p"] = config.val_kwargs.top_p
|
||||
sampling_params["temperature"] = config.val_kwargs.temperature
|
||||
|
||||
# by default, we assume it's a single turn agent
|
||||
if "agent_name" not in batch.non_tensor_batch:
|
||||
batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object)
|
||||
|
||||
if "index" in batch.non_tensor_batch:
|
||||
index = batch.non_tensor_batch["index"]
|
||||
else:
|
||||
index = np.arange(len(batch))
|
||||
|
||||
trajectory_info = await get_trajectory_info(
|
||||
batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False)
|
||||
)
|
||||
|
||||
if not partial_output_list:
|
||||
partial_output_list = [None] * len(batch)
|
||||
|
||||
tasks = []
|
||||
for i in range(len(batch)):
|
||||
kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()}
|
||||
kwargs["output"] = partial_output_list[i]
|
||||
tasks.append(
|
||||
asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs))
|
||||
)
|
||||
return await asyncio.gather(*tasks)
|
||||
|
||||
async def _partial_run_agent_loop(
|
||||
self,
|
||||
sampling_params: dict[str, Any],
|
||||
trajectory: dict[str, Any],
|
||||
*,
|
||||
agent_name: str,
|
||||
**kwargs,
|
||||
) -> AgentLoopOutput:
|
||||
with rollout_trace_attr(
|
||||
step=trajectory["step"],
|
||||
sample_index=trajectory["sample_index"],
|
||||
rollout_n=trajectory["rollout_n"],
|
||||
validate=trajectory["validate"],
|
||||
name="agent_loop",
|
||||
):
|
||||
assert agent_name in _agent_loop_registry, (
|
||||
f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}"
|
||||
)
|
||||
|
||||
agent_loop_config = _agent_loop_registry[agent_name]
|
||||
agent_loop = hydra.utils.instantiate(
|
||||
config=agent_loop_config,
|
||||
trainer_config=_DummyConfig(config=self.config),
|
||||
server_manager=self.server_manager,
|
||||
tokenizer=self.tokenizer,
|
||||
processor=self.processor,
|
||||
)
|
||||
return await agent_loop.run(sampling_params, **kwargs)
|
||||
|
||||
|
||||
class FullyAsyncAgentLoopManager(AgentLoopManager):
|
||||
def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):
|
||||
self.config = config
|
||||
self.worker_group = worker_group
|
||||
self.rm_executor = None
|
||||
self.rm_micro_batch_size = None
|
||||
self.agent_loop_workers_class = FullyAsyncAgentLoopWorker
|
||||
self.rollout_replica_class = FullyAsyncvLLMReplica
|
||||
|
||||
self.rm_wg = rm_wg
|
||||
self.rollout_replicas = None
|
||||
self.server_handles = None
|
||||
self.server_addresses = None
|
||||
self.agent_loop_workers = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None):
|
||||
instance = cls(config, worker_group, rm_wg)
|
||||
await instance._async_init()
|
||||
return instance
|
||||
|
||||
async def _async_init(self):
|
||||
if self.rm_wg:
|
||||
|
||||
def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]:
|
||||
new_data_list = []
|
||||
for data in data_list:
|
||||
temp_non_tensor_batch = {"__num_turns__": data.non_tensor_batch["__num_turns__"]}
|
||||
temp_data = DataProto(batch=data.batch, non_tensor_batch=temp_non_tensor_batch)
|
||||
new_data_list.append(temp_data)
|
||||
|
||||
new_batch = DataProto.concat(new_data_list)
|
||||
out_data = self.rm_wg.compute_rm_score(new_batch)
|
||||
return out_data.split(1)
|
||||
|
||||
self.rm_executor = BatchExecutor.options(
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=ray.get_runtime_context().get_node_id(),
|
||||
soft=False,
|
||||
),
|
||||
).remote(batch_fn, self.rm_wg.world_size)
|
||||
|
||||
self.rm_micro_batch_size = self.rm_wg.world_size
|
||||
|
||||
await self._initialize_llm_servers_async()
|
||||
self._init_agent_loop_workers()
|
||||
|
||||
async def _initialize_llm_servers_async(self):
|
||||
rollout_world_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size
|
||||
world_size = (
|
||||
self.worker_group.world_size
|
||||
if self.worker_group
|
||||
else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes
|
||||
)
|
||||
num_replicas = world_size // rollout_world_size
|
||||
|
||||
rollout_config = self.config.actor_rollout_ref.rollout
|
||||
model_config = self.config.actor_rollout_ref.model
|
||||
self.rollout_replicas = [
|
||||
self.rollout_replica_class(
|
||||
replica_rank=replica_rank,
|
||||
config=rollout_config,
|
||||
model_config=model_config,
|
||||
gpus_per_node=self.config.trainer.n_gpus_per_node,
|
||||
)
|
||||
for replica_rank in range(num_replicas)
|
||||
]
|
||||
|
||||
if self.worker_group:
|
||||
await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas])
|
||||
else:
|
||||
await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas])
|
||||
|
||||
self.server_handles = [server._server_handle for server in self.rollout_replicas]
|
||||
self.server_addresses = [server._server_address for server in self.rollout_replicas]
|
||||
|
||||
async def generate_single_sample_async(
|
||||
self,
|
||||
sample: DataProto,
|
||||
partial_output_list: Optional[list[AgentLoopOutput]],
|
||||
) -> list[AgentLoopOutput]:
|
||||
"""
|
||||
Asynchronously process a single sample
|
||||
|
||||
Args:
|
||||
sample: Single sample data
|
||||
partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result.
|
||||
|
||||
Returns:
|
||||
list[AgentLoopOutput]: Processing results
|
||||
"""
|
||||
worker = self._select_best_worker()
|
||||
output_future = worker.generate_sequences_no_post.remote(sample, partial_output_list)
|
||||
return await asyncio.wrap_future(output_future.future())
|
||||
|
||||
def _select_best_worker(self):
|
||||
"""Select the best worker, simple round-robin load balancing"""
|
||||
if not hasattr(self, "_worker_index"):
|
||||
self._worker_index = 0
|
||||
|
||||
worker = self.agent_loop_workers[self._worker_index]
|
||||
self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers)
|
||||
return worker
|
||||
|
||||
async def cancel(self):
|
||||
await asyncio.gather(*[replica.cancel() for replica in self.rollout_replicas])
|
||||
|
||||
async def resume(self):
|
||||
await asyncio.gather(*[replica.resume() for replica in self.rollout_replicas])
|
||||
|
||||
async def wake_up(self):
|
||||
await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas])
|
||||
|
||||
async def sleep(self):
|
||||
await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas])
|
||||
|
||||
async def reset_prefix_cache(self):
|
||||
await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas])
|
@ -0,0 +1,92 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from recipe.fully_async_policy.agent_loop.agent_loop import AgentLoopOutput, FullyAsyncAgentLoopOutput
|
||||
from verl.experimental.agent_loop import AgentLoopBase
|
||||
from verl.experimental.agent_loop.agent_loop import register
|
||||
from verl.utils.profiler import simple_timer
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
|
||||
@register("partial_single_turn_agent")
|
||||
class PartialSingleTurnAgentLoop(AgentLoopBase):
|
||||
"""Naive agent loop that only do single turn chat completion."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length
|
||||
self.response_length = self.config.actor_rollout_ref.rollout.response_length
|
||||
self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {})
|
||||
|
||||
async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput:
|
||||
output: Optional[FullyAsyncAgentLoopOutput] = kwargs.get("output", None)
|
||||
messages = list(kwargs["raw_prompt"])
|
||||
param_version = kwargs.get("param_version", 0)
|
||||
|
||||
metrics = {}
|
||||
request_id = uuid4().hex
|
||||
|
||||
param_version_start = param_version
|
||||
param_version_end = param_version
|
||||
|
||||
if not output:
|
||||
prompt_ids = await self.loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs
|
||||
),
|
||||
)
|
||||
else:
|
||||
if output.is_cancel:
|
||||
# Resume the paused sample,
|
||||
# add the result directly after prompt_ids,
|
||||
# and reset generate_sequences metric
|
||||
prompt_ids = output.prompt_ids + output.response_ids
|
||||
metrics["generate_sequences"] = output.metrics.generate_sequences
|
||||
param_version_start = output.param_version_start
|
||||
else:
|
||||
# In the same batch of samples,
|
||||
# ome are canceled and some are not.
|
||||
# The samples without partial rollout are returned directly.
|
||||
return output
|
||||
with simple_timer("generate_sequences", metrics):
|
||||
response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial(
|
||||
request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params
|
||||
)
|
||||
if not output:
|
||||
response_mask = [1] * len(response_ids)
|
||||
else:
|
||||
# Pause the sample to be resumed, add the output result to response_ids, and reset response_mask
|
||||
prompt_ids = output.prompt_ids
|
||||
log_probs = output.log_probs + log_probs
|
||||
response_ids = output.response_ids + response_ids
|
||||
response_mask = [1] * len(response_ids)
|
||||
|
||||
return FullyAsyncAgentLoopOutput(
|
||||
prompt_ids=prompt_ids,
|
||||
response_ids=response_ids[: self.response_length],
|
||||
response_mask=response_mask[: self.response_length],
|
||||
num_turns=2,
|
||||
metrics=metrics,
|
||||
is_cancel=is_cancel,
|
||||
log_probs=log_probs,
|
||||
param_version_start=param_version_start,
|
||||
param_version_end=param_version_end,
|
||||
)
|
@ -0,0 +1,55 @@
|
||||
hydra:
|
||||
searchpath:
|
||||
- file://verl/trainer/config
|
||||
|
||||
defaults:
|
||||
- ppo_trainer
|
||||
- _self_
|
||||
|
||||
async_training:
|
||||
|
||||
# Maximum samples staleness threshold
|
||||
staleness_threshold: 0.1
|
||||
|
||||
# Frequency of parameter synchronization between rollouter and trainer,
|
||||
# One step means trainer obtains a batch of required samples
|
||||
trigger_parameter_sync_step: 4
|
||||
|
||||
# The number of ppo_mini_batches that the FullyAsyncTrainer obtains once
|
||||
require_batches: 1
|
||||
|
||||
# When synchronizing parameters, whether to interrupt rollouter and perform partial rollout
|
||||
partial_rollout: True
|
||||
|
||||
# Whether to use rollout log probs for training
|
||||
use_rollout_log_probs: True
|
||||
|
||||
# Rollout config
|
||||
rollout:
|
||||
|
||||
# Number of nodes used in the rollout
|
||||
nnodes: 1
|
||||
|
||||
# Number of GPUs per node
|
||||
n_gpus_per_node: 8
|
||||
|
||||
# number of responses (i.e. num sample times). > 1 for grpo
|
||||
n: 4
|
||||
|
||||
# total rollout samples # TODO rename to total_rollout_samples
|
||||
total_rollout_steps: 100
|
||||
|
||||
# Number of epochs in training
|
||||
total_epochs: 10
|
||||
|
||||
# Test frequency, how many times a parameter update triggers a validation
|
||||
test_freq: 1
|
||||
|
||||
data:
|
||||
# Number of samples generated, currently only support 1
|
||||
gen_batch_size: 1
|
||||
|
||||
actor_rollout_ref:
|
||||
actor:
|
||||
# Whether to use rollout log probs for training
|
||||
use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True}
|
474
recipe/fully_async_policy/detach_utils.py
Normal file
474
recipe/fully_async_policy/detach_utils.py
Normal file
@ -0,0 +1,474 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tensordict import TensorDict
|
||||
|
||||
from verl import DataProto
|
||||
from verl.experimental.agent_loop.agent_loop import AgentLoopOutput
|
||||
from verl.trainer.ppo.ray_trainer import compute_response_mask
|
||||
|
||||
|
||||
def postprocess_agent_loop_outputs(inputs: list[AgentLoopOutput], tokenizer, config) -> DataProto:
|
||||
"""Static method to postprocess a list of AgentLoopOutput into DataProto
|
||||
|
||||
Args:
|
||||
inputs: List of AgentLoopOutput
|
||||
tokenizer: Tokenizer instance
|
||||
config: Configuration object
|
||||
|
||||
Returns:
|
||||
DataProto: Processed batch data
|
||||
"""
|
||||
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
|
||||
# prompts: left pad
|
||||
# responses: right pad
|
||||
# input_ids: prompt + response
|
||||
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
|
||||
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]
|
||||
|
||||
# prompts
|
||||
tokenizer.padding_side = "left"
|
||||
outputs = tokenizer.pad(
|
||||
[{"input_ids": input.prompt_ids} for input in inputs],
|
||||
padding="max_length",
|
||||
max_length=config.actor_rollout_ref.rollout.prompt_length,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"]
|
||||
|
||||
# responses
|
||||
tokenizer.padding_side = "right"
|
||||
outputs = tokenizer.pad(
|
||||
[{"input_ids": input.response_ids} for input in inputs],
|
||||
padding="max_length",
|
||||
max_length=config.actor_rollout_ref.rollout.response_length,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=True,
|
||||
)
|
||||
response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"]
|
||||
|
||||
# response_mask
|
||||
outputs = tokenizer.pad(
|
||||
[{"input_ids": input.response_mask} for input in inputs],
|
||||
padding="max_length",
|
||||
max_length=config.actor_rollout_ref.rollout.response_length,
|
||||
return_tensors="pt",
|
||||
return_attention_mask=False,
|
||||
)
|
||||
response_mask = outputs["input_ids"]
|
||||
assert response_ids.shape == response_mask.shape, (
|
||||
f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}"
|
||||
)
|
||||
response_mask = response_mask * response_attention_mask
|
||||
|
||||
input_ids = torch.cat([prompt_ids, response_ids], dim=1)
|
||||
attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1)
|
||||
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask
|
||||
|
||||
batch = TensorDict(
|
||||
{
|
||||
"prompts": prompt_ids, # [bsz, prompt_length]
|
||||
"responses": response_ids, # [bsz, response_length]
|
||||
"response_mask": response_mask, # [bsz, response_length]
|
||||
"input_ids": input_ids, # [bsz, prompt_length + response_length]
|
||||
"attention_mask": attention_mask, # [bsz, prompt_length + response_length]
|
||||
"position_ids": position_ids, # [bsz, prompt_length + response_length]
|
||||
},
|
||||
batch_size=len(input_ids),
|
||||
)
|
||||
|
||||
num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32)
|
||||
metrics = [input.metrics.model_dump() for input in inputs]
|
||||
return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics})
|
||||
|
||||
|
||||
@dataclass
|
||||
class RolloutSample:
|
||||
"""Enhanced rollout sample containing both original batch info and AgentLoopOutput"""
|
||||
|
||||
# Original batch information
|
||||
full_batch: Any
|
||||
|
||||
# AgentLoopOutput from generation
|
||||
agent_loop_output_list: list[Any] # AgentLoopOutput
|
||||
|
||||
# Metadata
|
||||
sample_id: str
|
||||
epoch: int
|
||||
|
||||
# Processing metadata
|
||||
processing_times: list[float]
|
||||
param_version: int
|
||||
param_version_start: list[int]
|
||||
param_version_end: list[int]
|
||||
rollout_status: dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidateMetrics:
|
||||
"""Metrics for validation"""
|
||||
|
||||
timing_raw: dict[str, Any]
|
||||
metrics: Optional[dict[str, Any]] = None
|
||||
global_steps: Optional[int] = None
|
||||
param_version: Optional[int] = None
|
||||
|
||||
|
||||
def prepare_single_generation_data(batch_dict, global_steps, rollout_n) -> DataProto:
|
||||
"""
|
||||
Similar to the logic of ray_trainer._prepare_generate_batch, but for a single sample.
|
||||
Separate the data used for generation from the original data.
|
||||
|
||||
Returns:
|
||||
tuple: (original_batch_dict, gen_data_for_single_sample)
|
||||
"""
|
||||
|
||||
full_batch = DataProto.from_single_dict(batch_dict)
|
||||
|
||||
batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"]
|
||||
non_tensor_batch_keys_to_pop = ["raw_prompt_ids"]
|
||||
|
||||
full_batch.pop(
|
||||
batch_keys=batch_keys_to_pop,
|
||||
non_tensor_batch_keys=non_tensor_batch_keys_to_pop,
|
||||
)
|
||||
|
||||
# Setting agent - partial_single_turn_agent, that supports partial
|
||||
full_batch.non_tensor_batch["agent_name"] = np.array(["partial_single_turn_agent"] * len(full_batch), dtype=object)
|
||||
|
||||
# Add global step count to generated data
|
||||
full_batch = full_batch.repeat(repeat_times=rollout_n, interleave=True)
|
||||
return full_batch
|
||||
|
||||
|
||||
def process_rollout_log_probs(data_proto: DataProto, rollout_log_probs: list[list[float]]) -> torch.Tensor:
|
||||
"""
|
||||
Process rollout_log_probs according to the mask in DataProto
|
||||
mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
|
||||
|
||||
Args:
|
||||
data_proto: A DataProto object containing batch information
|
||||
rollout_log_probs: A two-dimensional list, each sublist containing the log_probs of a sample
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The processed log_probs tensor, with shape: [bsz, response_length]
|
||||
"""
|
||||
|
||||
batch = data_proto.batch
|
||||
response_mask = batch["response_mask"]
|
||||
rollout_log_probs_tensor = torch.zeros(response_mask.shape, dtype=torch.float32) - 1
|
||||
|
||||
for i, log_probs_seq in enumerate(rollout_log_probs):
|
||||
# Get the effective length of the current sample (the number of positions with 1 in the mask)
|
||||
valid_length = response_mask[i].sum().item()
|
||||
|
||||
# Ensure that the length of log_probs_seq does not exceed the valid length
|
||||
actual_length = min(len(log_probs_seq), valid_length)
|
||||
|
||||
# Fill log_probs into the corresponding position
|
||||
if actual_length > 0:
|
||||
rollout_log_probs_tensor[i, :actual_length] = torch.tensor(log_probs_seq[:actual_length])
|
||||
|
||||
rollout_log_probs_tensor = rollout_log_probs_tensor.to(torch.float32)
|
||||
return rollout_log_probs_tensor
|
||||
|
||||
|
||||
def merge_rollout_sample(config, tokenizer, rs: RolloutSample):
|
||||
"""
|
||||
Supplement and refine the RolloutSample object,
|
||||
"""
|
||||
# Step 1: Create a DataProto from the AgentLoopOutput to generate the result
|
||||
gen_batch_output = postprocess_agent_loop_outputs(rs.agent_loop_output_list, tokenizer, config)
|
||||
rollout_log_probs = [x.log_probs for x in rs.agent_loop_output_list]
|
||||
rollout_log_probs = process_rollout_log_probs(gen_batch_output, rollout_log_probs)
|
||||
gen_batch_output.batch["rollout_log_probs"] = rollout_log_probs.to(torch.float32)
|
||||
|
||||
# Step 2: Add uid
|
||||
rs.full_batch.non_tensor_batch["uid"] = np.array([f"uid_{rs.sample_id}"] * len(rs.full_batch), dtype=object)
|
||||
|
||||
# Step 2: Merge batches
|
||||
# Merge the non_tensor_batch and meta_info of original_batch into final_batch
|
||||
for key, value in rs.full_batch.non_tensor_batch.items():
|
||||
gen_batch_output.non_tensor_batch[key] = value
|
||||
gen_batch_output.meta_info.update(rs.full_batch.meta_info)
|
||||
|
||||
# Step 3, set full_batch
|
||||
rs.full_batch = gen_batch_output
|
||||
rs.processing_times = []
|
||||
for agent_loop in rs.agent_loop_output_list:
|
||||
rs.processing_times.append(agent_loop.metrics.generate_sequences)
|
||||
rs.param_version_start = [agent_loop.param_version_start for agent_loop in rs.agent_loop_output_list]
|
||||
rs.param_version_end = [agent_loop.param_version_end for agent_loop in rs.agent_loop_output_list]
|
||||
# Step 4, clear agent_loop_output_list
|
||||
rs.agent_loop_output_list = []
|
||||
return rs
|
||||
|
||||
|
||||
def assemble_batch_from_rollout_samples(
|
||||
rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None
|
||||
) -> DataProto:
|
||||
"""
|
||||
Assemble gen_batch_output from RolloutSample objects
|
||||
Assembles batches from RolloutSample objects, similar to the _post_generate_batch logic in ray_trainer.
|
||||
|
||||
Args:
|
||||
rollout_samples: List of RolloutSample objects
|
||||
tokenizer: Tokenizer instance
|
||||
config: Configuration object containing trainer settings
|
||||
balance_batch: Whether to balance the batch (simplified version)
|
||||
|
||||
Returns:
|
||||
DataProto: Assembled gen_batch_output
|
||||
|
||||
Raises:
|
||||
ValueError: If rollout_samples is empty
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
if not rollout_samples:
|
||||
raise ValueError("Empty rollout_samples provided for batch assembly")
|
||||
|
||||
print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects")
|
||||
|
||||
rollout_samples_batch = []
|
||||
processing_times = []
|
||||
rollout_status = rollout_samples[0].rollout_status
|
||||
# Add a prefix to all rollout_status keys
|
||||
rollout_status = {f"fully_async/{key}": value for key, value in rollout_status.items()}
|
||||
|
||||
for rs in rollout_samples:
|
||||
rollout_samples_batch.append(rs.full_batch)
|
||||
processing_times.extend(rs.processing_times)
|
||||
final_batch = DataProto.concat(rollout_samples_batch)
|
||||
|
||||
# Calculate response_mask (if not present)
|
||||
if "response_mask" not in final_batch.batch.keys():
|
||||
final_batch.batch["response_mask"] = compute_response_mask(final_batch)
|
||||
|
||||
if balance_batch:
|
||||
balance_batch(final_batch, metrics={})
|
||||
|
||||
# Calculate the global valid token number
|
||||
if "attention_mask" in final_batch.batch:
|
||||
final_batch.meta_info["global_token_num"] = torch.sum(final_batch.batch["attention_mask"], dim=-1).tolist()
|
||||
|
||||
# Collect statistics
|
||||
param_versions = [rs.param_version for rs in rollout_samples]
|
||||
trajectorys_param_versions = [version for rs in rollout_samples for version in rs.param_version_end]
|
||||
|
||||
processing_time_stats = {
|
||||
"processing_time/avg": np.mean(processing_times),
|
||||
"processing_time/max": np.max(processing_times),
|
||||
"processing_time/min": np.min(processing_times),
|
||||
"processing_time/tp50": np.percentile(processing_times, 50),
|
||||
"processing_time/tp99": np.percentile(processing_times, 99),
|
||||
"processing_time/tp95": np.percentile(processing_times, 95),
|
||||
}
|
||||
processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()}
|
||||
|
||||
param_version_diff = [abs(a - b) for a, b in zip(rs.param_version_end, rs.param_version_start, strict=False)]
|
||||
num_diff0 = param_version_diff.count(0)
|
||||
partial_stats = {
|
||||
"fully_async/partial/total_partial_num": len(param_version_diff) - num_diff0,
|
||||
"fully_async/partial/partial_ratio": (len(param_version_diff) - num_diff0) / len(param_version_diff),
|
||||
"fully_async/partial/max_partial_span": max(param_version_diff),
|
||||
}
|
||||
# add meta_info
|
||||
final_batch.meta_info.update(
|
||||
{
|
||||
"rollout_param_versions": param_versions,
|
||||
"param_version_diversity": len(set(param_versions)) if param_versions else 0,
|
||||
"trajectory_param_versions": trajectorys_param_versions,
|
||||
**processing_time_stats,
|
||||
**rollout_status,
|
||||
**partial_stats,
|
||||
}
|
||||
)
|
||||
|
||||
print(f"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s")
|
||||
|
||||
return final_batch
|
||||
|
||||
|
||||
class MetricsAggregator:
|
||||
"""Metrics aggregator, used to combine metrics from multiple training steps"""
|
||||
|
||||
def __init__(self, total_gpus: int):
|
||||
# Store all values for each metric
|
||||
self.metric_values: dict[str, list[float]] = defaultdict(list)
|
||||
# Store the number of samples at each step for weighted averaging
|
||||
self.sample_counts: list[int] = []
|
||||
# Store the timestamp of each step for time-related calculations
|
||||
self.timestamps: list[float] = []
|
||||
# Step Count
|
||||
self.step_count = 0
|
||||
# total num gpus used
|
||||
self.total_gpus = total_gpus
|
||||
|
||||
# Metric aggregation rule configuration
|
||||
self.aggregation_rules = self._init_aggregation_rules()
|
||||
|
||||
def _init_aggregation_rules(self) -> dict[str, dict[str, list[str]]]:
|
||||
"""Initialize metrics aggregation rules"""
|
||||
return {
|
||||
# Time-Based metrics, can add metrics here
|
||||
"time_sum": ["perf/time_per_step"],
|
||||
"last": [
|
||||
"fully_async/count/total_generated_samples",
|
||||
"fully_async/count/stale_samples_processed",
|
||||
"fully_async/count/stale_trajectory_processed",
|
||||
"fully_async/count/current_param_version",
|
||||
"fully_async/count/dropped_stale_samples",
|
||||
"training/global_step", # TODO change name to: total_step
|
||||
],
|
||||
}
|
||||
|
||||
def add_step_metrics(self, metrics: dict[str, Any], sample_count: int, timestamp: float = None):
|
||||
"""Adding a single-step metrics"""
|
||||
if timestamp is None:
|
||||
timestamp = time.time()
|
||||
|
||||
self.sample_counts.append(sample_count)
|
||||
self.timestamps.append(timestamp)
|
||||
self.step_count += 1
|
||||
|
||||
# Store all metrics values
|
||||
for key, value in metrics.items():
|
||||
if isinstance(value, int | float | np.number):
|
||||
self.metric_values[key].append(float(value))
|
||||
elif isinstance(value, torch.Tensor):
|
||||
self.metric_values[key].append(float(value.item()))
|
||||
|
||||
def _get_aggregation_type(self, metric_name: str) -> str:
|
||||
"""Determine the aggregation type based on the metric name"""
|
||||
for agg_type, metric_list in self.aggregation_rules.items():
|
||||
if metric_name in metric_list:
|
||||
return agg_type
|
||||
|
||||
metric_lower = metric_name.lower()
|
||||
if any(keyword in metric_lower for keyword in ["timing_s/"]):
|
||||
return "time_sum"
|
||||
if any(keyword in metric_lower for keyword in ["mean", "avg", "average"]):
|
||||
return "avg"
|
||||
if any(keyword in metric_lower for keyword in ["max", "maximum"]):
|
||||
return "max"
|
||||
if any(keyword in metric_lower for keyword in ["min", "minimum"]):
|
||||
return "min"
|
||||
if any(keyword in metric_lower for keyword in ["sum", "total"]):
|
||||
return "sum"
|
||||
if any(keyword in metric_lower for keyword in ["weighted_avg"]):
|
||||
return "weighted_avg"
|
||||
|
||||
return "avg"
|
||||
|
||||
def _aggregate_single_metric(self, metric_name: str, values: list[float]) -> float:
|
||||
"""Aggregating a single metric"""
|
||||
if not values:
|
||||
return 0.0
|
||||
|
||||
agg_type = self._get_aggregation_type(metric_name)
|
||||
|
||||
if agg_type == "last":
|
||||
return values[-1]
|
||||
|
||||
elif agg_type == "weighted_avg":
|
||||
# Weighted average
|
||||
if len(values) != len(self.sample_counts):
|
||||
# If the lengths do not match, use a simple average
|
||||
return sum(values) / len(values)
|
||||
|
||||
total_samples = sum(self.sample_counts)
|
||||
if total_samples == 0:
|
||||
return sum(values) / len(values)
|
||||
|
||||
weighted_sum = sum(v * c for v, c in zip(values, self.sample_counts, strict=False))
|
||||
return weighted_sum / total_samples
|
||||
|
||||
elif agg_type == "sum" or agg_type == "time_sum":
|
||||
return sum(values)
|
||||
|
||||
elif agg_type == "avg":
|
||||
return sum(values) / len(values)
|
||||
|
||||
elif agg_type == "max":
|
||||
return max(values)
|
||||
|
||||
elif agg_type == "min":
|
||||
return min(values)
|
||||
|
||||
else:
|
||||
# Default average
|
||||
return sum(values) / len(values)
|
||||
|
||||
def get_aggregated_metrics(self) -> dict[str, Any]:
|
||||
"""aggregated metrics"""
|
||||
t = time.time()
|
||||
if self.step_count == 0:
|
||||
return {}
|
||||
|
||||
aggregated = {}
|
||||
|
||||
# Aggregate all metrics
|
||||
for metric_name, values in self.metric_values.items():
|
||||
aggregated[metric_name] = self._aggregate_single_metric(metric_name, values)
|
||||
|
||||
# Aggregate special metrics
|
||||
aggregated = self._special_metrics_aggergate(aggregated)
|
||||
|
||||
print(f"aggregated metrics done. cost {time.time() - t}")
|
||||
|
||||
return aggregated
|
||||
|
||||
def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> dict[str, Any]:
|
||||
"""calculate special metrics"""
|
||||
|
||||
# global_seqlen/minmax_diff
|
||||
if "global_seqlen/minmax_diff" in aggregated.keys():
|
||||
aggregated["global_seqlen/minmax_diff"] = aggregated["global_seqlen/max"] - aggregated["global_seqlen/min"]
|
||||
|
||||
# perf/throughput
|
||||
REQUIRED_PERF_KEYS = {"perf/throughput", "perf/total_num_tokens", "perf/time_per_step"}
|
||||
if REQUIRED_PERF_KEYS.issubset(aggregated):
|
||||
aggregated["perf/throughput"] = aggregated["perf/total_num_tokens"] / (
|
||||
aggregated["perf/time_per_step"] * self.total_gpus
|
||||
)
|
||||
|
||||
# trainer/idle_ratio
|
||||
if "timing_s/gen" in aggregated.keys() and "timing_s/step" in aggregated.keys():
|
||||
aggregated["trainer/idle_ratio"] = aggregated["timing_s/gen"] / aggregated["timing_s/step"]
|
||||
|
||||
return aggregated
|
||||
|
||||
def reset(self):
|
||||
"""Reset Aggregator"""
|
||||
self.metric_values.clear()
|
||||
self.sample_counts.clear()
|
||||
self.timestamps.clear()
|
||||
self.step_count = 0
|
||||
|
||||
def get_current_stats(self) -> dict[str, Any]:
|
||||
"""Get statistics about the current aggregation state (for debugging)"""
|
||||
return {
|
||||
"step_count": self.step_count,
|
||||
"metric_count": len(self.metric_values),
|
||||
"total_samples": sum(self.sample_counts),
|
||||
"metric_names": list(self.metric_values.keys()),
|
||||
}
|
136
recipe/fully_async_policy/fsdp_workers.py
Normal file
136
recipe/fully_async_policy/fsdp_workers.py
Normal file
@ -0,0 +1,136 @@
|
||||
# Copyright 2025 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
from omegaconf import DictConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
|
||||
from verl.single_controller.base.decorator import Dispatch, register
|
||||
from verl.utils.device import (
|
||||
get_device_name,
|
||||
get_torch_device,
|
||||
)
|
||||
from verl.utils.fsdp_utils import (
|
||||
fsdp_version,
|
||||
)
|
||||
from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
|
||||
|
||||
device_name = get_device_name()
|
||||
|
||||
__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"]
|
||||
|
||||
|
||||
def get_inference_model(rollout):
|
||||
"""
|
||||
get models according to different types of inference_engine
|
||||
Args:
|
||||
rollout: rollout object
|
||||
Returns:
|
||||
model: model object
|
||||
"""
|
||||
inference_engine = rollout.inference_engine
|
||||
if hasattr(inference_engine, "llm_engine"):
|
||||
inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
|
||||
elif hasattr(inference_engine, "worker"):
|
||||
inference_model = inference_engine.worker.model_runner.model
|
||||
else:
|
||||
raise AttributeError(
|
||||
f"Unsupported inference_engine type: {type(inference_engine)}. "
|
||||
f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)."
|
||||
)
|
||||
return inference_model
|
||||
|
||||
|
||||
class DetachNcclSync(AsyncActorRolloutRefWorker):
|
||||
def _get_actor_params(self):
|
||||
pass
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False)
|
||||
def sync_rollout_weights(self):
|
||||
assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine
|
||||
assert hasattr(self, "_weights_info") and self._weights_info is not None
|
||||
|
||||
params = self._get_actor_params() if self._is_actor else None
|
||||
if self._is_rollout:
|
||||
inference_model = get_inference_model(self.rollout)
|
||||
|
||||
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
|
||||
|
||||
patch_vllm_moe_model_weight_loader(inference_model)
|
||||
for key, shape, dtype in self._weights_info:
|
||||
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
|
||||
if self._is_actor:
|
||||
assert key in params
|
||||
origin_data = params[key]
|
||||
if hasattr(origin_data, "full_tensor"):
|
||||
origin_data = origin_data.full_tensor()
|
||||
if torch.distributed.get_rank() == 0:
|
||||
tensor.copy_(origin_data)
|
||||
from ray.util.collective import collective
|
||||
|
||||
collective.broadcast(tensor, src_rank=0, group_name="actor_rollout")
|
||||
if self._is_rollout:
|
||||
inference_model.load_weights([(key, tensor)])
|
||||
get_torch_device().empty_cache()
|
||||
|
||||
|
||||
class DetachActorWorker(DetachNcclSync):
|
||||
def _get_actor_params(self):
|
||||
assert self._is_actor
|
||||
params = self.actor_module_fsdp.state_dict()
|
||||
from verl.utils.model import convert_weight_keys
|
||||
|
||||
params = convert_weight_keys(
|
||||
params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp)
|
||||
)
|
||||
return params
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def get_actor_weights_info(self):
|
||||
assert self._is_actor
|
||||
if hasattr(self, "_weights_info"):
|
||||
return self._weights_info
|
||||
if fsdp_version(self.actor_module_fsdp) == 1:
|
||||
from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType
|
||||
|
||||
FSDP.set_state_dict_type(
|
||||
self.actor_module_fsdp,
|
||||
state_dict_type=StateDictType.SHARDED_STATE_DICT,
|
||||
state_dict_config=ShardedStateDictConfig(),
|
||||
)
|
||||
params = self._get_actor_params()
|
||||
ret = []
|
||||
for key, tensor in params.items():
|
||||
ret.append((key, tensor.size(), tensor.dtype))
|
||||
self._weights_info = ret
|
||||
return ret
|
||||
|
||||
|
||||
class DetachAsyncRolloutWorker(DetachNcclSync):
|
||||
def __init__(self, config: DictConfig, role: str):
|
||||
print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}")
|
||||
ActorRolloutRefWorker.__init__(self, config, role)
|
||||
|
||||
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
|
||||
def set_actor_weights_info(self, weights_info):
|
||||
assert self._is_rollout
|
||||
self._weights_info = weights_info
|
306
recipe/fully_async_policy/fully_async_main.py
Normal file
306
recipe/fully_async_policy/fully_async_main.py
Normal file
@ -0,0 +1,306 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import socket
|
||||
import threading
|
||||
from pprint import pprint
|
||||
|
||||
import hydra
|
||||
import ray
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from recipe.fully_async_policy.fully_async_rollouter import FullyAsyncRollouter
|
||||
from recipe.fully_async_policy.fully_async_trainer import FullyAsyncTrainer
|
||||
from recipe.fully_async_policy.message_queue import MessageQueue, MessageQueueClient
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
|
||||
from verl.trainer.ppo.reward import load_reward_manager
|
||||
from verl.trainer.ppo.utils import Role
|
||||
from verl.utils.fs import copy_to_local
|
||||
|
||||
|
||||
def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager:
|
||||
"""
|
||||
Create resource pool manager
|
||||
|
||||
Args:
|
||||
config: Configuration object
|
||||
roles: List of roles that need to create resource pools
|
||||
|
||||
Returns:
|
||||
ResourcePoolManager: Resource pool manager
|
||||
"""
|
||||
resource_pool_spec = {}
|
||||
mapping = {}
|
||||
|
||||
# Actor/Critic resource pool
|
||||
if any(role in roles for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]):
|
||||
assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0"
|
||||
assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0"
|
||||
|
||||
trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes
|
||||
resource_pool_spec["trainer_pool"] = trainer_pool
|
||||
|
||||
# Map training-related roles to the same resource pool
|
||||
for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]:
|
||||
if role in roles:
|
||||
mapping[role] = "trainer_pool"
|
||||
|
||||
# Rollout resource pool
|
||||
if Role.Rollout in roles:
|
||||
assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0"
|
||||
assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0"
|
||||
|
||||
rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes
|
||||
resource_pool_spec["rollout_pool"] = rollout_pool
|
||||
mapping[Role.Rollout] = "rollout_pool"
|
||||
|
||||
return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
|
||||
|
||||
|
||||
def create_role_worker_mapping(config):
|
||||
"""
|
||||
Create mapping from roles to worker classes
|
||||
|
||||
Args:
|
||||
config: Configuration object
|
||||
|
||||
Returns:
|
||||
dict: Mapping from roles to worker classes
|
||||
"""
|
||||
# Select worker class based on strategy
|
||||
if config.actor_rollout_ref.actor.strategy == "fsdp2":
|
||||
assert config.actor_rollout_ref.actor.strategy == config.critic.strategy
|
||||
from recipe.fully_async_policy.fsdp_workers import (
|
||||
CriticWorker,
|
||||
DetachActorWorker,
|
||||
DetachAsyncRolloutWorker,
|
||||
)
|
||||
from verl.single_controller.ray import RayWorkerGroup
|
||||
|
||||
ray_worker_group_cls = RayWorkerGroup
|
||||
|
||||
# TODO megatron support
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}")
|
||||
|
||||
role_worker_mapping = {
|
||||
Role.Actor: ray.remote(DetachActorWorker),
|
||||
Role.Rollout: ray.remote(DetachAsyncRolloutWorker),
|
||||
Role.Critic: ray.remote(CriticWorker),
|
||||
}
|
||||
|
||||
if config.reward_model.enable:
|
||||
if config.reward_model.strategy == "fsdp2":
|
||||
from verl.workers.fsdp_workers import RewardModelWorker
|
||||
# TODO megatron support
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported reward model strategy: {config.reward_model.strategy}")
|
||||
|
||||
role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker)
|
||||
|
||||
# Add reference policy (if KL loss or reward is required)
|
||||
if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss:
|
||||
role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker)
|
||||
|
||||
return role_worker_mapping, ray_worker_group_cls
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class FullyAsyncTaskRunner:
|
||||
"""
|
||||
Ray remote class for executing distributed PPO training tasks.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.running = False
|
||||
self.components = {}
|
||||
self.shutdown_event = threading.Event()
|
||||
|
||||
def run(self, config):
|
||||
print("[ASYNC MAIN] Starting fully async PPO training...")
|
||||
self._initialize_components(config)
|
||||
self._run_training_loop()
|
||||
|
||||
def _initialize_components(self, config) -> None:
|
||||
print(f"[ASYNC MAIN] TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}")
|
||||
pprint(OmegaConf.to_container(config, resolve=True))
|
||||
OmegaConf.resolve(config)
|
||||
|
||||
print("[ASYNC MAIN] Initializing model and tokenizer...")
|
||||
local_path = copy_to_local(
|
||||
config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False)
|
||||
)
|
||||
from verl.utils import hf_processor, hf_tokenizer
|
||||
|
||||
trust_remote_code = config.data.get("trust_remote_code", False)
|
||||
tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
# Used for multimodal LLM, could be None
|
||||
processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True)
|
||||
|
||||
self.components["tokenizer"] = tokenizer
|
||||
self.components["processor"] = processor
|
||||
self.components["config"] = config
|
||||
|
||||
print("[ASYNC MAIN] Creating worker mapping and resource pools...")
|
||||
role_worker_mapping, ray_worker_group_cls = create_role_worker_mapping(config)
|
||||
self.components["role_worker_mapping"] = role_worker_mapping
|
||||
self.components["ray_worker_group_cls"] = ray_worker_group_cls
|
||||
|
||||
print("[ASYNC MAIN] Loading reward functions...")
|
||||
reward_fn = load_reward_manager(
|
||||
config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {})
|
||||
)
|
||||
val_reward_fn = load_reward_manager(
|
||||
config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {})
|
||||
)
|
||||
self.components["reward_fn"] = reward_fn
|
||||
self.components["val_reward_fn"] = val_reward_fn
|
||||
|
||||
print("[ASYNC MAIN] Creating FullyAsyncRollouter...")
|
||||
self._create_rollouter(config)
|
||||
|
||||
print("[ASYNC MAIN] Creating FullyAsyncTrainer...")
|
||||
self._create_trainer(config)
|
||||
|
||||
# sync total_train_steps between rollouter and trainer
|
||||
total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote())
|
||||
print(f"total_train_steps {total_train_steps}")
|
||||
ray.get(self.components["trainer"].set_total_train_steps.remote(total_train_steps))
|
||||
|
||||
# max_queue_size
|
||||
max_queue_size = ray.get(self.components["rollouter"].get_max_queue_size.remote())
|
||||
print(f"[ASYNC MAIN] Creating MessageQueue... max_queue_size {max_queue_size}")
|
||||
message_queue = MessageQueue.remote(config, max_queue_size)
|
||||
message_queue_client = MessageQueueClient(message_queue)
|
||||
self.components["message_queue"] = message_queue
|
||||
self.components["message_queue_client"] = message_queue_client
|
||||
|
||||
ray.get(self.components["rollouter"].set_message_queue_client.remote(self.components["message_queue_client"]))
|
||||
ray.get(self.components["trainer"].set_message_queue_client.remote(self.components["message_queue_client"]))
|
||||
|
||||
print("[ASYNC MAIN] Setting up parameter synchronization...")
|
||||
from recipe.fully_async_policy.param_sync import ParameterSynchronizer
|
||||
|
||||
param_synchronizer = ParameterSynchronizer.remote(
|
||||
config=config,
|
||||
trainer=self.components["trainer"],
|
||||
rollouter=self.components["rollouter"],
|
||||
mq=self.components["message_queue_client"],
|
||||
)
|
||||
ray.get(self.components["trainer"].set_parameter_synchronizer.remote(param_synchronizer))
|
||||
|
||||
# load checkpoint and sync parameter before doing anything
|
||||
val_before_train = val_reward_fn is not None and config.trainer.get("val_before_train", True)
|
||||
ray.get(self.components["trainer"].load_checkpoint.remote())
|
||||
ray.get(param_synchronizer.sync_weights.remote(version=0, validate=val_before_train))
|
||||
|
||||
self.components["param_synchronizer"] = param_synchronizer
|
||||
print("[ASYNC MAIN] All components initialized successfully")
|
||||
|
||||
def _create_rollouter(self, config) -> None:
|
||||
rollouter = FullyAsyncRollouter.remote(
|
||||
config=config,
|
||||
tokenizer=self.components["tokenizer"],
|
||||
role_worker_mapping={Role.Rollout: self.components["role_worker_mapping"][Role.Rollout]},
|
||||
resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]),
|
||||
ray_worker_group_cls=self.components["ray_worker_group_cls"],
|
||||
processor=self.components["processor"],
|
||||
reward_fn=self.components["reward_fn"],
|
||||
val_reward_fn=self.components["val_reward_fn"],
|
||||
device_name=config.trainer.device,
|
||||
)
|
||||
|
||||
ray.get(rollouter.init_workers.remote())
|
||||
ray.get(rollouter.set_max_required_samples.remote())
|
||||
|
||||
self.components["rollouter"] = rollouter
|
||||
print("[ASYNC MAIN] Rollouter created and initialized successfully")
|
||||
|
||||
def _create_trainer(self, config) -> None:
|
||||
trainer_role_mapping = {
|
||||
role: worker_cls
|
||||
for role, worker_cls in self.components["role_worker_mapping"].items()
|
||||
if role != Role.Rollout
|
||||
}
|
||||
|
||||
trainer = FullyAsyncTrainer.remote(
|
||||
config=config,
|
||||
tokenizer=self.components["tokenizer"],
|
||||
role_worker_mapping=trainer_role_mapping,
|
||||
resource_pool_manager=create_resource_pool_manager(config, roles=list(trainer_role_mapping.keys())),
|
||||
ray_worker_group_cls=self.components["ray_worker_group_cls"],
|
||||
processor=self.components["processor"],
|
||||
reward_fn=self.components["reward_fn"],
|
||||
val_reward_fn=self.components["val_reward_fn"],
|
||||
device_name=config.trainer.device,
|
||||
)
|
||||
|
||||
ray.get(trainer.init_workers.remote())
|
||||
self.components["trainer"] = trainer
|
||||
print("[ASYNC MAIN] FullyAsyncTrainer created and initialized successfully")
|
||||
|
||||
def _run_training_loop(self):
|
||||
self.running = True
|
||||
|
||||
print("[ASYNC MAIN] Starting Rollouter and Trainer...")
|
||||
rollouter_future = self.components["rollouter"].fit.remote()
|
||||
trainer_future = self.components["trainer"].fit.remote()
|
||||
|
||||
futures = [rollouter_future, trainer_future]
|
||||
|
||||
try:
|
||||
while futures:
|
||||
# Use ray.wait to monitor all futures and return when any one is completed.
|
||||
done_futures, remaining_futures = ray.wait(futures, num_returns=1, timeout=None)
|
||||
|
||||
for future in done_futures:
|
||||
try:
|
||||
ray.get(future)
|
||||
print("[ASYNC MAIN] One component completed successfully")
|
||||
except Exception as e:
|
||||
print(f"[ASYNC MAIN] Component failed with error: {e}")
|
||||
for remaining_future in remaining_futures:
|
||||
ray.cancel(remaining_future)
|
||||
raise e
|
||||
|
||||
futures = remaining_futures
|
||||
|
||||
except Exception as e:
|
||||
print(f"[ASYNC MAIN] Training failed: {e}")
|
||||
for future in futures:
|
||||
ray.cancel(future)
|
||||
raise
|
||||
finally:
|
||||
self.components["message_queue_client"].clear_queue()
|
||||
print("[ASYNC MAIN] Training completed or interrupted")
|
||||
|
||||
|
||||
@hydra.main(config_path="config", config_name="fully_async_ppo_trainer", version_base=None)
|
||||
def main(config):
|
||||
from verl.trainer.main_ppo import run_ppo
|
||||
|
||||
# Ensure async training config exists
|
||||
if not hasattr(config, "async_training"):
|
||||
raise RuntimeError("must set async_training config")
|
||||
from time import time
|
||||
|
||||
start_time = time()
|
||||
run_ppo(config, task_runner_class=FullyAsyncTaskRunner)
|
||||
print(f"total time: {time() - start_time:.2f} seconds")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
646
recipe/fully_async_policy/fully_async_rollouter.py
Normal file
646
recipe/fully_async_policy/fully_async_rollouter.py
Normal file
@ -0,0 +1,646 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import time
|
||||
from pprint import pformat
|
||||
|
||||
import ray
|
||||
from ray import ObjectRef
|
||||
|
||||
from recipe.fully_async_policy.detach_utils import (
|
||||
RolloutSample,
|
||||
ValidateMetrics,
|
||||
merge_rollout_sample,
|
||||
prepare_single_generation_data,
|
||||
)
|
||||
from recipe.fully_async_policy.message_queue import MessageQueueClient
|
||||
from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
|
||||
from verl.trainer.ppo.utils import Role, WorkerType
|
||||
from verl.utils.profiler import marked_timer
|
||||
from verl.utils.tracking import ValidationGenerationsLogger
|
||||
|
||||
|
||||
@ray.remote(num_cpus=10, max_concurrency=100)
|
||||
class FullyAsyncRollouter(FullyAsyncRayPPOTrainer):
|
||||
"""
|
||||
Asynchronous sample generator, responsible for continuously generating training samples
|
||||
and putting them into MessageQueue
|
||||
Based on the mature implementation improvements of OneStepOffRayTrainer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping: dict[Role, WorkerType],
|
||||
resource_pool_manager: ResourcePoolManager,
|
||||
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
|
||||
processor=None,
|
||||
reward_fn=None,
|
||||
val_reward_fn=None,
|
||||
device_name=None,
|
||||
):
|
||||
# Store the tokenizer for text processing
|
||||
self.tokenizer = tokenizer
|
||||
self.processor = processor
|
||||
self.config = config
|
||||
self.reward_fn = reward_fn
|
||||
self.val_reward_fn = val_reward_fn
|
||||
|
||||
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
|
||||
|
||||
assert not self.hybrid_engine
|
||||
assert self.config.data.train_batch_size == 0, "train_batch_size must be zero"
|
||||
assert self.config.data.gen_batch_size == 1, "gen_batch_size must be one"
|
||||
assert self.config.async_training.staleness_threshold >= 0, "staleness_threshold must larger than 0"
|
||||
assert self.config.async_training.trigger_parameter_sync_step >= 1, (
|
||||
"trigger_parameter_sync_step must larger than 1"
|
||||
)
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
self.device_name = device_name if device_name else self.config.trainer.device
|
||||
self.validation_generations_logger = ValidationGenerationsLogger(
|
||||
project_name=self.config.trainer.project_name,
|
||||
experiment_name=self.config.trainer.experiment_name,
|
||||
)
|
||||
|
||||
self.ref_in_actor = False
|
||||
self.kl_ctrl_in_reward = False
|
||||
self.use_critic = False
|
||||
self.use_reference_policy = False
|
||||
self.use_rm = False
|
||||
|
||||
print("[FullyAsyncRollouter] Creating datasets...")
|
||||
from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler
|
||||
from verl.utils.dataset.rl_dataset import collate_fn
|
||||
|
||||
train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor)
|
||||
val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor)
|
||||
train_sampler = create_rl_sampler(config.data, train_dataset)
|
||||
|
||||
self._validate_config()
|
||||
print(f"[FullyAsyncRollouter] Rollouter _create_dataloader...\n{train_dataset}\n{val_dataset}")
|
||||
|
||||
self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler)
|
||||
|
||||
# ==================== fully async config ====================
|
||||
|
||||
self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs
|
||||
if self.config.rollout.total_rollout_steps is not None:
|
||||
self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps)
|
||||
print(f"[FullyAsyncRollouter] Total rollout steps: {self.total_rollout_steps}")
|
||||
self.total_train_steps = None
|
||||
|
||||
# Rollouter parameter configuration
|
||||
self.message_queue_client = None
|
||||
|
||||
# Worker groups: rollout_wg is same to actor_rollout_wg
|
||||
self.rollout_wg = None
|
||||
self.actor_rollout_wg = None
|
||||
self.async_rollout_manager = None
|
||||
|
||||
# Config
|
||||
self.staleness_threshold: float = config.async_training.get("staleness_threshold", 1)
|
||||
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
|
||||
self.require_batches = config.async_training.require_batches
|
||||
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
|
||||
self.max_required_samples = None
|
||||
self.max_concurrent_samples = None
|
||||
# queue size
|
||||
self.max_queue_size = None
|
||||
|
||||
# Statistics
|
||||
self.current_param_version = 0
|
||||
self.total_generated_samples = 0
|
||||
self.staleness_samples = 0
|
||||
self.dropped_stale_samples = 0
|
||||
self.processed_sample_count = 0
|
||||
self.global_steps = 0
|
||||
self.idle_start_time = None
|
||||
self.version_start_time = None
|
||||
|
||||
# Concurrency control
|
||||
# Modified by self.pause() or self._should_pause_generation()
|
||||
self.paused = False
|
||||
self.running = True
|
||||
self.monitor_loop_trigger = True
|
||||
|
||||
# Initialize async locks directly
|
||||
self.lock = asyncio.Lock()
|
||||
self.condition = asyncio.Condition(self.lock)
|
||||
|
||||
# Initialize async queues
|
||||
self.pending_queue = asyncio.Queue(maxsize=128)
|
||||
self.active_tasks = set()
|
||||
self.result_queue = asyncio.Queue()
|
||||
self.cancel_queue = asyncio.Queue()
|
||||
|
||||
async def set_message_queue_client(self, message_queue_client: MessageQueueClient):
|
||||
"""Set message queue client"""
|
||||
async with self.lock:
|
||||
self.message_queue_client = message_queue_client
|
||||
|
||||
async def set_max_required_samples(self):
|
||||
async with self.lock:
|
||||
self.max_required_samples = int(
|
||||
self.required_samples
|
||||
* (self.staleness_threshold + 1)
|
||||
* self.config.async_training.trigger_parameter_sync_step
|
||||
)
|
||||
self.total_train_steps = int(
|
||||
self.total_rollout_steps
|
||||
/ (self.required_samples * self.config.async_training.trigger_parameter_sync_step)
|
||||
)
|
||||
|
||||
self.max_concurrent_samples = len(self.async_rollout_manager.server_handles) * 16
|
||||
self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples)
|
||||
self.max_queue_size = self.max_required_samples
|
||||
|
||||
print(
|
||||
f"[FullyAsyncRollouter] required_samples : {self.required_samples} "
|
||||
f"max_required_samples: {self.max_required_samples} "
|
||||
f"max_queue_size: {self.max_queue_size} "
|
||||
f"total_train_steps: {self.total_train_steps} "
|
||||
f"total_rollout_steps: {self.total_rollout_steps} "
|
||||
f"max_concurrent_samples: {self.max_concurrent_samples} "
|
||||
)
|
||||
|
||||
def get_rollout_wg(self):
|
||||
"""Get rollout worker group"""
|
||||
return self.rollout_wg
|
||||
|
||||
def get_max_queue_size(self):
|
||||
return self.max_queue_size
|
||||
|
||||
def get_total_train_steps(self):
|
||||
return self.total_train_steps
|
||||
|
||||
async def update_param_version(self, version: int, validate: bool = False, global_steps: int = 0):
|
||||
"""Update current parameter version"""
|
||||
async with self.lock:
|
||||
old_version = self.current_param_version
|
||||
self.current_param_version = version
|
||||
# every time param change, reset staleness_samples
|
||||
self.staleness_samples = (
|
||||
len(self.active_tasks)
|
||||
+ self.result_queue.qsize()
|
||||
+ self.cancel_queue.qsize()
|
||||
+ await self.message_queue_client.get_queue_size()
|
||||
)
|
||||
timing_raw = {}
|
||||
idle_ratio = None
|
||||
if self.idle_start_time is not None and self.version_start_time is not None:
|
||||
rollout_active_time = self.idle_start_time - self.version_start_time
|
||||
rollout_version_time = time.time() - self.version_start_time
|
||||
idle_ratio = 1 - rollout_active_time / rollout_version_time
|
||||
timing_raw["rollouter/active_time"] = rollout_active_time
|
||||
timing_raw["rollouter/version_time"] = rollout_version_time
|
||||
timing_raw["rollouter/idle_ratio"] = idle_ratio
|
||||
self.idle_start_time = None
|
||||
print(
|
||||
f"[FullyAsyncRollouter][Public][update_param_version] "
|
||||
f"Parameter version updated from {old_version} to {version} "
|
||||
f",reset staleness_samples to: {self.staleness_samples}"
|
||||
f",idle_ratio: {idle_ratio}"
|
||||
)
|
||||
val_metrics = None
|
||||
if (
|
||||
self.val_reward_fn is not None
|
||||
and self.config.rollout.test_freq > 0
|
||||
and self.current_param_version % self.config.rollout.test_freq == 0
|
||||
and self.current_param_version > 0 # don't test here in the initial parameter sync
|
||||
) or (validate and self.val_reward_fn is not None):
|
||||
with marked_timer("rollouter/validate_time", timing_raw, color="green"):
|
||||
val_metrics: dict = self._validate()
|
||||
data = ValidateMetrics(
|
||||
timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version
|
||||
)
|
||||
await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data))
|
||||
|
||||
self.version_start_time = time.time()
|
||||
|
||||
def _validate_config(self):
|
||||
# Validate asynchronous training configuration
|
||||
if not hasattr(self.config, "async_training"):
|
||||
raise ValueError("[FullyAsyncRollouter] Missing async_training configuration")
|
||||
assert self.config.actor_rollout_ref.rollout.calculate_log_probs, "must rollout calculate log_probs"
|
||||
|
||||
async def init_workers(self):
|
||||
"""Initialize distributed training workers using Ray backend.
|
||||
|
||||
Creates:
|
||||
1. Ray resource pools from configuration
|
||||
2. Worker groups for each role (actor, critic, etc.)
|
||||
"""
|
||||
self._init_resource_pools()
|
||||
self._create_worker_classes()
|
||||
self._init_worker_groups()
|
||||
self._init_models()
|
||||
await self._init_async_rollout_manager()
|
||||
|
||||
def _create_actor_rollout_classes(self):
|
||||
# only create rollout
|
||||
for role in [Role.Rollout]:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(role)
|
||||
role_cls = RayClassWithInitArgs(
|
||||
cls=self.role_worker_mapping[role],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role=str(role),
|
||||
)
|
||||
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
|
||||
|
||||
def _init_models(self):
|
||||
self.rollout_wg = self.all_wg[str(Role.Rollout)]
|
||||
self.rollout_wg.init_model()
|
||||
self.actor_rollout_wg = self.rollout_wg
|
||||
|
||||
def _create_continuous_iterator(self):
|
||||
"""
|
||||
Create a continuous data iterator across epoch
|
||||
"""
|
||||
for epoch in range(self.config.rollout.total_epochs):
|
||||
iterator = iter(self.train_dataloader)
|
||||
for batch_dict in iterator:
|
||||
yield epoch, batch_dict
|
||||
|
||||
async def _init_async_rollout_manager(self):
|
||||
# create async rollout manager and request scheduler
|
||||
assert self.config.actor_rollout_ref.rollout.mode == "async"
|
||||
from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager
|
||||
|
||||
self.async_rollout_mode = True
|
||||
self.async_rollout_manager = await FullyAsyncAgentLoopManager.create(
|
||||
config=self.config,
|
||||
worker_group=self.rollout_wg,
|
||||
)
|
||||
|
||||
# Add samples to the pending_queue
|
||||
async def _feed_samples(self):
|
||||
continuous_iterator = self._create_continuous_iterator()
|
||||
|
||||
for epoch, batch_dict in continuous_iterator:
|
||||
# Similar to _prepare_generate_batch: Separate data
|
||||
full_batch = prepare_single_generation_data(
|
||||
batch_dict, self.global_steps, self.config.actor_rollout_ref.rollout.n
|
||||
)
|
||||
|
||||
sample_id = f"sample_{epoch}_{self.global_steps}"
|
||||
|
||||
rollout_sample = RolloutSample(
|
||||
full_batch=full_batch,
|
||||
agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n,
|
||||
sample_id=sample_id,
|
||||
epoch=epoch,
|
||||
param_version=0,
|
||||
param_version_start=[],
|
||||
param_version_end=[],
|
||||
processing_times=[],
|
||||
rollout_status={},
|
||||
)
|
||||
|
||||
await self.pending_queue.put(rollout_sample)
|
||||
|
||||
# Check if have reached the last step
|
||||
if self.global_steps >= self.total_rollout_steps:
|
||||
print(
|
||||
f"[FullyAsyncRollouter][Feed] "
|
||||
f"Maximum count has been reached, stop adding new samples"
|
||||
f"{self.global_steps} >= {self.total_rollout_steps}"
|
||||
)
|
||||
break
|
||||
|
||||
self.global_steps += 1
|
||||
|
||||
# End signal
|
||||
await self.pending_queue.put("DONE")
|
||||
print(f"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added")
|
||||
|
||||
async def _processor_worker(self):
|
||||
"""
|
||||
Streaming worker coroutines, a sample is submitted for processing without waiting for batches
|
||||
"""
|
||||
while True:
|
||||
if self.paused or await self._should_pause_generation():
|
||||
print(
|
||||
"[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return..."
|
||||
)
|
||||
async with self.lock:
|
||||
self.paused = True
|
||||
while self.active_tasks:
|
||||
async with self.lock:
|
||||
# After acquiring the lock, the number of active_tasks may change, need to be verified again
|
||||
if self.active_tasks:
|
||||
done_tasks, self.active_tasks = await asyncio.wait(
|
||||
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in done_tasks:
|
||||
await task
|
||||
|
||||
async with self.lock:
|
||||
while self.paused:
|
||||
self.idle_start_time = time.time()
|
||||
await self.condition.wait()
|
||||
continue
|
||||
|
||||
simple_from_cancel_queue = False
|
||||
if not self.cancel_queue.empty():
|
||||
rollout_sample = await self.cancel_queue.get()
|
||||
simple_from_cancel_queue = True
|
||||
else:
|
||||
rollout_sample = await self.pending_queue.get()
|
||||
self.staleness_samples += 1
|
||||
|
||||
if rollout_sample == "DONE":
|
||||
print(
|
||||
"[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..."
|
||||
)
|
||||
while self.active_tasks:
|
||||
async with self.lock:
|
||||
if self.active_tasks:
|
||||
done_tasks, self.active_tasks = await asyncio.wait(
|
||||
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in done_tasks:
|
||||
await task
|
||||
break
|
||||
|
||||
# Check whether the number of concurrent tasks exceeds the limit
|
||||
while len(self.active_tasks) >= self.max_concurrent_samples:
|
||||
async with self.lock:
|
||||
if self.active_tasks:
|
||||
done_tasks, self.active_tasks = await asyncio.wait(
|
||||
self.active_tasks, return_when=asyncio.FIRST_COMPLETED
|
||||
)
|
||||
for task in done_tasks:
|
||||
await task
|
||||
|
||||
# Submit single sample processing
|
||||
async with self.lock:
|
||||
# After the pause is over, the lock is acquired and it is necessary
|
||||
# to determine whether it is the pause phase, otherwise continue to wait
|
||||
while self.paused:
|
||||
await self.condition.wait()
|
||||
task = asyncio.create_task(
|
||||
self._process_single_sample_streaming(rollout_sample),
|
||||
name=rollout_sample.sample_id,
|
||||
)
|
||||
self.active_tasks.add(task)
|
||||
|
||||
if simple_from_cancel_queue:
|
||||
self.cancel_queue.task_done()
|
||||
else:
|
||||
self.pending_queue.task_done()
|
||||
|
||||
async def _process_single_sample_streaming(self, rollout_sample: RolloutSample):
|
||||
"""Process a single sample streamingly"""
|
||||
# Calling asynchronous generation methods
|
||||
rollout_sample.full_batch.non_tensor_batch["param_version"] = [self.current_param_version] * len(
|
||||
rollout_sample.full_batch
|
||||
)
|
||||
agent_loop_output_list = await self.async_rollout_manager.generate_single_sample_async(
|
||||
rollout_sample.full_batch, rollout_sample.agent_loop_output_list
|
||||
)
|
||||
rollout_sample.agent_loop_output_list = agent_loop_output_list
|
||||
|
||||
is_cancel = False
|
||||
for agent_loop in agent_loop_output_list:
|
||||
if not is_cancel and agent_loop.is_cancel:
|
||||
is_cancel = True
|
||||
|
||||
if is_cancel:
|
||||
# Put in the cancel queue and wait for the generation to resume
|
||||
await self.cancel_queue.put(rollout_sample)
|
||||
else:
|
||||
# put into the result_queue
|
||||
rollout_sample.param_version = self.current_param_version
|
||||
rollout_sample.rollout_status = await self.get_statistics()
|
||||
await self.result_queue.put(rollout_sample)
|
||||
|
||||
self.processed_sample_count += 1
|
||||
|
||||
async def _consumer_worker(self):
|
||||
"""
|
||||
The consumer coroutine is responsible for obtaining the processing results
|
||||
from the result queue and putting them into the message queue
|
||||
"""
|
||||
while True:
|
||||
rollout_sample = await self.result_queue.get()
|
||||
rollout_sample = merge_rollout_sample(self.config, self.tokenizer, rollout_sample)
|
||||
|
||||
# Put RolloutSample into the message queue
|
||||
success = await self.message_queue_client.put_sample(
|
||||
sample=ray.cloudpickle.dumps(rollout_sample),
|
||||
param_version=rollout_sample.param_version,
|
||||
)
|
||||
if success:
|
||||
self.total_generated_samples += 1
|
||||
else:
|
||||
self.dropped_stale_samples += 1
|
||||
|
||||
self.result_queue.task_done()
|
||||
|
||||
async def _streaming_generation_main(self):
|
||||
"""The main entry method for stream processing"""
|
||||
|
||||
# we start from step 1
|
||||
self.global_steps += 1
|
||||
|
||||
if self.async_rollout_manager is None:
|
||||
await self._init_async_rollout_manager()
|
||||
|
||||
# Start the streaming loop
|
||||
print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}")
|
||||
|
||||
# Start sample feed coroutine, streaming process coroutine and consumer coroutine
|
||||
self.feed_task = asyncio.create_task(self._feed_samples())
|
||||
self.processor_task = asyncio.create_task(self._processor_worker())
|
||||
self.consumer_task = asyncio.create_task(self._consumer_worker())
|
||||
|
||||
try:
|
||||
# Wait for sample feed to complete
|
||||
await self.feed_task
|
||||
print("[FullyAsyncRollouter] Sample feed completed")
|
||||
|
||||
# Wait for streaming to complete
|
||||
await self.processor_task
|
||||
print("[FullyAsyncRollouter] Streaming process completed")
|
||||
|
||||
# Waiting for the result queue to clear
|
||||
await self.result_queue.join()
|
||||
print("[FullyAsyncRollouter] Result queue cleared")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[FullyAsyncRollouter] Streaming process exception:{e}")
|
||||
|
||||
finally:
|
||||
if self.processor_task:
|
||||
self.processor_task.cancel()
|
||||
if self.consumer_task:
|
||||
self.consumer_task.cancel()
|
||||
|
||||
await asyncio.gather(self.processor_task, self.consumer_task, return_exceptions=True)
|
||||
|
||||
# Send a finish signal
|
||||
await self.message_queue_client.put_sample(
|
||||
sample=None,
|
||||
param_version=self.current_param_version,
|
||||
)
|
||||
|
||||
async with self.lock:
|
||||
self.running = False
|
||||
|
||||
async def fit(self):
|
||||
"""
|
||||
Start the async rollouter - entry point that sets up and runs async tasks
|
||||
Main async fit method that coordinates all coroutines
|
||||
"""
|
||||
|
||||
print("[FullyAsyncRollouter] Starting FullyAsyncRollouter...")
|
||||
|
||||
if self.message_queue_client is None:
|
||||
raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.")
|
||||
|
||||
# Set the running status flag
|
||||
async with self.lock:
|
||||
self.paused = False
|
||||
self.running = True
|
||||
|
||||
# Create the main asynchronous task
|
||||
generation_task = asyncio.create_task(self._streaming_generation_main())
|
||||
monitor_task = asyncio.create_task(self._async_monitor_loop())
|
||||
|
||||
try:
|
||||
# Run build and monitoring tasks concurrently
|
||||
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)
|
||||
except Exception as e:
|
||||
print(f"[FullyAsyncRollouter] Asynchronous task execution error: {e}")
|
||||
finally:
|
||||
if not generation_task.done():
|
||||
generation_task.cancel()
|
||||
if not monitor_task.done():
|
||||
monitor_task.cancel()
|
||||
|
||||
# Wait for the task to complete
|
||||
await asyncio.gather(generation_task, monitor_task, return_exceptions=True)
|
||||
|
||||
print("[FullyAsyncRollouter] Rollouter fit completed")
|
||||
|
||||
async def _async_monitor_loop(self):
|
||||
"""
|
||||
Async coroutine for monitoring:
|
||||
Function 1: Log information output
|
||||
Function 2: Trigger rollout recovery
|
||||
"""
|
||||
last_stats_time = time.time()
|
||||
stats_interval = 60.0
|
||||
check_interval = 10.0
|
||||
|
||||
while True:
|
||||
async with self.lock:
|
||||
if not self.running:
|
||||
break
|
||||
await asyncio.sleep(check_interval)
|
||||
# Print statistics periodically
|
||||
current_time = time.time()
|
||||
if current_time - last_stats_time >= stats_interval:
|
||||
stats = await self.get_statistics()
|
||||
print(f"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}")
|
||||
last_stats_time = current_time
|
||||
|
||||
# Trigger rollout recovery
|
||||
if self.monitor_loop_trigger:
|
||||
if not await self._should_pause_generation():
|
||||
async with self.lock:
|
||||
self.paused = False
|
||||
self.condition.notify_all()
|
||||
|
||||
async def _should_pause_generation(self) -> bool:
|
||||
"""Determine whether the build should be paused"""
|
||||
queue_stats = self.message_queue_client.get_statistics_sync()
|
||||
queue_size = queue_stats["queue_size"]
|
||||
|
||||
if queue_size >= self.max_queue_size:
|
||||
if not self.paused:
|
||||
print(
|
||||
f"[FullyAsyncRollouter][ShouldPause] "
|
||||
f"due to full queue: size={queue_size}, max={self.max_queue_size}"
|
||||
)
|
||||
return True
|
||||
|
||||
if self.staleness_samples >= self.max_required_samples:
|
||||
if not self.paused:
|
||||
print(
|
||||
"[FullyAsyncRollouter][ShouldPause] "
|
||||
f"due to "
|
||||
f"staleness_samples {self.staleness_samples} >= max_required_samples {self.max_required_samples} "
|
||||
)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def pause(self):
|
||||
"""pause rollout"""
|
||||
print("[FullyAsyncRollouter][Public][Pause]")
|
||||
async with self.lock:
|
||||
self.paused = True
|
||||
# Cancel all rollout tasks
|
||||
if self.config.async_training.partial_rollout:
|
||||
await self.async_rollout_manager.cancel()
|
||||
if self.active_tasks:
|
||||
await asyncio.gather(*self.active_tasks, return_exceptions=True)
|
||||
self.active_tasks.clear()
|
||||
print("[FullyAsyncRollouter][Public][Pause] All active tasks completed")
|
||||
await self.async_rollout_manager.reset_prefix_cache()
|
||||
self.monitor_loop_trigger = False
|
||||
|
||||
async def resume(self, dependency_ref: ObjectRef = None):
|
||||
if dependency_ref is not None:
|
||||
ray.get(dependency_ref)
|
||||
print("[FullyAsyncRollouter][Public][Resume]")
|
||||
async with self.lock:
|
||||
self.paused = False
|
||||
self.monitor_loop_trigger = True
|
||||
self.condition.notify_all()
|
||||
|
||||
if self.config.async_training.partial_rollout:
|
||||
await self.async_rollout_manager.resume()
|
||||
|
||||
async def get_statistics(self) -> dict:
|
||||
queue_stats = self.message_queue_client.get_statistics_sync()
|
||||
|
||||
stats = {
|
||||
# monitor stats
|
||||
"monitor/active_tasks_size": len(self.active_tasks),
|
||||
"monitor/queue/pending_queue_size": self.pending_queue.qsize(),
|
||||
"monitor/queue/cancel_queue_size": self.cancel_queue.qsize(),
|
||||
"monitor/queue/result_queue_size": self.result_queue.qsize(),
|
||||
"monitor/queue/mq_queue_size": queue_stats["queue_size"],
|
||||
# counting stats
|
||||
"count/current_param_version": self.current_param_version,
|
||||
"count/total_generated_samples": self.total_generated_samples,
|
||||
"count/staleness_samples": self.staleness_samples,
|
||||
"count/dropped_stale_samples": self.dropped_stale_samples,
|
||||
# static stats
|
||||
"static/max_required_samples": self.max_required_samples,
|
||||
"static/required_samples": self.required_samples,
|
||||
"static/staleness_threshold": self.staleness_threshold,
|
||||
"static/max_queue_size": self.max_queue_size,
|
||||
"static/max_concurrent_samples": self.max_concurrent_samples,
|
||||
}
|
||||
|
||||
return stats
|
360
recipe/fully_async_policy/fully_async_trainer.py
Normal file
360
recipe/fully_async_policy/fully_async_trainer.py
Normal file
@ -0,0 +1,360 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import time
|
||||
from datetime import datetime
|
||||
from pprint import pprint
|
||||
from typing import Any
|
||||
|
||||
import ray
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from recipe.fully_async_policy.detach_utils import (
|
||||
MetricsAggregator,
|
||||
ValidateMetrics,
|
||||
assemble_batch_from_rollout_samples,
|
||||
)
|
||||
from recipe.fully_async_policy.message_queue import MessageQueueClient
|
||||
from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer
|
||||
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.trainer.ppo.ray_trainer import ResourcePoolManager
|
||||
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
|
||||
from verl.utils.debug import marked_timer
|
||||
|
||||
|
||||
@ray.remote(num_cpus=10)
|
||||
class FullyAsyncTrainer(FullyAsyncRayPPOTrainer):
|
||||
"""
|
||||
A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training.
|
||||
Based on an improved implementation of OneStepOffRayTrainer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
tokenizer,
|
||||
role_worker_mapping: dict[Role, WorkerType],
|
||||
resource_pool_manager: ResourcePoolManager,
|
||||
ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup,
|
||||
processor=None,
|
||||
reward_fn=None,
|
||||
val_reward_fn=None,
|
||||
device_name=None,
|
||||
):
|
||||
# Store the tokenizer for text processing
|
||||
self.tokenizer = tokenizer
|
||||
self.processor = processor
|
||||
self.config = config
|
||||
self.reward_fn = reward_fn
|
||||
self.val_reward_fn = val_reward_fn
|
||||
|
||||
self.hybrid_engine = config.actor_rollout_ref.hybrid_engine
|
||||
assert not self.hybrid_engine
|
||||
|
||||
self.role_worker_mapping = role_worker_mapping
|
||||
self.resource_pool_manager = resource_pool_manager
|
||||
self.use_reference_policy = need_reference_policy(self.role_worker_mapping)
|
||||
self.use_rm = need_reward_model(self.role_worker_mapping)
|
||||
self.use_critic = need_critic(self.config)
|
||||
self.ray_worker_group_cls = ray_worker_group_cls
|
||||
self.device_name = device_name if device_name else self.config.trainer.device
|
||||
|
||||
# if ref_in_actor is True, the reference policy will be actor without lora applied
|
||||
self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0
|
||||
|
||||
# define in-reward KL control
|
||||
# kl loss control currently not suppoorted
|
||||
if self.config.algorithm.use_kl_in_reward:
|
||||
self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl)
|
||||
|
||||
# ==================== fully async config ====================
|
||||
|
||||
self.message_queue_client = None
|
||||
self.param_synchronizer = None
|
||||
|
||||
# Statistics
|
||||
# we start from step 1
|
||||
self.global_steps = 1
|
||||
self.local_trigger_step = 1
|
||||
self.processed_samples = 0
|
||||
self.stale_samples_processed = 0
|
||||
self.stale_trajectory_processed = 0
|
||||
self.current_param_version = 0
|
||||
self.total_train_steps = None
|
||||
self.progress_bar = None
|
||||
self.trigger_parameter_sync_step = config.async_training.trigger_parameter_sync_step
|
||||
|
||||
# required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples.
|
||||
self.require_batches = config.async_training.require_batches
|
||||
self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches
|
||||
total_gpus = (
|
||||
config.trainer.nnodes * config.trainer.n_gpus_per_node
|
||||
+ config.rollout.nnodes * config.rollout.n_gpus_per_node
|
||||
)
|
||||
self.metrics_aggregator = MetricsAggregator(total_gpus=total_gpus)
|
||||
|
||||
def set_message_queue_client(self, message_queue_client: MessageQueueClient):
|
||||
"""Set message queue client"""
|
||||
self.message_queue_client = message_queue_client
|
||||
|
||||
def set_parameter_synchronizer(self, param_synchronizer):
|
||||
"""Set parameter synchronizer"""
|
||||
self.param_synchronizer = param_synchronizer
|
||||
|
||||
def set_total_train_steps(self, total_train_steps):
|
||||
self.total_train_steps = total_train_steps
|
||||
self.progress_bar = tqdm(total=self.total_train_steps, initial=0, desc="Training Progress")
|
||||
|
||||
def get_actor_wg(self):
|
||||
"""Get actor worker group"""
|
||||
return self.actor_wg
|
||||
|
||||
def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]:
|
||||
"""
|
||||
Get samples from message queue and compose gen_batch_output
|
||||
Uses a loop to continuously collect samples until enough are gathered
|
||||
|
||||
Returns:
|
||||
tuple: (epoch, batch_dict, gen_batch_output)
|
||||
"""
|
||||
print(
|
||||
f"[FullyAsyncTrainer] Requesting {self.required_samples} samples from queue",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Collect samples using a simple loop calling get_sample
|
||||
consumer_start = time.time()
|
||||
queue_samples = []
|
||||
queue_len = 0
|
||||
while len(queue_samples) < self.required_samples:
|
||||
# Get a single sample and wait until there is a sample or None is received
|
||||
sample, queue_len = self.message_queue_client.get_sample_sync()
|
||||
|
||||
if sample is None:
|
||||
print(
|
||||
f"[FullyAsyncTrainer] Detected termination signal (None), stopping sample collection. "
|
||||
f"Collected {len(queue_samples)}/{self.required_samples} samples"
|
||||
)
|
||||
break
|
||||
|
||||
queue_samples.append(sample)
|
||||
|
||||
if len(queue_samples) % 64 == 0:
|
||||
print(
|
||||
f"[FullyAsyncTrainer] Collected {len(queue_samples)}/{self.required_samples} samples. "
|
||||
f"mq_len: {queue_len}"
|
||||
)
|
||||
|
||||
consumer_end = time.time()
|
||||
|
||||
if not queue_samples or len(queue_samples) < self.required_samples:
|
||||
print("[FullyAsyncTrainer] not enough samples collected after loop")
|
||||
return None, None
|
||||
total_wait_time = consumer_end - consumer_start
|
||||
|
||||
print(
|
||||
f"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, "
|
||||
f"total wait time: {total_wait_time:.2f} seconds."
|
||||
f"mq_len: {queue_len}"
|
||||
)
|
||||
|
||||
queue_samples = [ray.cloudpickle.loads(x) for x in queue_samples]
|
||||
# Assemble batch - now working directly with RolloutSample objects
|
||||
if self.config.trainer.balance_batch:
|
||||
batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, self._balance_batch)
|
||||
else:
|
||||
batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, None)
|
||||
|
||||
batch.meta_info["fully_async/total_wait_time"] = total_wait_time
|
||||
return 0, batch
|
||||
|
||||
def _create_actor_rollout_classes(self):
|
||||
# create actor
|
||||
for role in [Role.Actor]:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(role)
|
||||
role_cls = RayClassWithInitArgs(
|
||||
cls=self.role_worker_mapping[role],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role=str(role),
|
||||
)
|
||||
self.resource_pool_to_cls[resource_pool][str(role)] = role_cls
|
||||
|
||||
def _init_models(self):
|
||||
if self.use_critic:
|
||||
self.critic_wg = self.all_wg[str(Role.Critic)]
|
||||
self.critic_wg.init_model()
|
||||
|
||||
if self.use_reference_policy and not self.ref_in_actor:
|
||||
self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]
|
||||
self.ref_policy_wg.init_model()
|
||||
|
||||
if self.use_rm:
|
||||
self.rm_wg = self.all_wg[str(Role.RewardModel)]
|
||||
self.rm_wg.init_model()
|
||||
|
||||
self.actor_wg = self.all_wg[str(Role.Actor)]
|
||||
self.actor_wg.init_model()
|
||||
self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified
|
||||
|
||||
def _init_async_rollout_manager(self):
|
||||
pass
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
The training loop of PPO.
|
||||
The driver process only need to call the compute functions of the worker group through RPC
|
||||
to construct the PPO dataflow.
|
||||
The light-weight advantage computation is done on the driver process.
|
||||
"""
|
||||
print("[FullyAsyncTrainer] Starting FullyAsyncTrainer...")
|
||||
if self.message_queue_client is None:
|
||||
raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.")
|
||||
if self.param_synchronizer is None:
|
||||
raise ValueError("param_synchronizer client not set. Call set_parameter_synchronizer() first.")
|
||||
|
||||
from verl.utils.tracking import Tracking
|
||||
|
||||
self.logger = Tracking(
|
||||
project_name=self.config.trainer.project_name,
|
||||
experiment_name=self.config.trainer.experiment_name,
|
||||
default_backend=self.config.trainer.logger,
|
||||
config=OmegaConf.to_container(self.config, resolve=True),
|
||||
)
|
||||
|
||||
self.max_steps_duration = 0
|
||||
|
||||
# get validate data before training
|
||||
if self.config.trainer.val_before_train and self.reward_fn is not None:
|
||||
ray.get(self.param_synchronizer.wait_last_valid.remote())
|
||||
val_data = self.message_queue_client.get_validate_sync()
|
||||
if val_data:
|
||||
val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)
|
||||
if val_data.metrics:
|
||||
self.logger.log(data=val_data.metrics, step=val_data.param_version)
|
||||
pprint(f"[FullyAsyncTrainer] Initial validation metrics: {val_data.metrics}")
|
||||
self.logger.log(data=val_data.timing_raw, step=val_data.param_version)
|
||||
|
||||
# Use queue mode, no need for traditional dataloader iterator
|
||||
# Initialize to get the first batch of data
|
||||
while True:
|
||||
metrics = {}
|
||||
timing_raw = {}
|
||||
|
||||
with marked_timer("step", timing_raw):
|
||||
with marked_timer("gen", timing_raw, color="red"):
|
||||
epoch, batch = self._get_samples_from_queue()
|
||||
if batch is None:
|
||||
break
|
||||
self._collect_metrics_from_samples(batch, metrics)
|
||||
|
||||
batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)
|
||||
self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
|
||||
self._check_save_checkpoint(False, timing_raw)
|
||||
|
||||
self._collect_metrics(batch, 0, metrics, timing_raw)
|
||||
self.metrics_aggregator.add_step_metrics(
|
||||
metrics=metrics, sample_count=self.required_samples, timestamp=time.time()
|
||||
)
|
||||
# Trigger parameter synchronization after training step
|
||||
time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3]
|
||||
print(
|
||||
f"[FullyAsyncTrainer] global_steps: {self.global_steps} "
|
||||
f"local_trigger_step: {self.local_trigger_step} "
|
||||
f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} "
|
||||
f"{time_str}"
|
||||
)
|
||||
self._trigger_parameter_sync_after_step(global_steps=self.global_steps)
|
||||
val_data = self.message_queue_client.get_validate_sync()
|
||||
if val_data:
|
||||
val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)
|
||||
if val_data.metrics:
|
||||
self.logger.log(data=val_data.metrics, step=val_data.param_version)
|
||||
pprint(
|
||||
f"[FullyAsyncTrainer] parameter version: {val_data.param_version} \
|
||||
Validation metrics: {val_data.metrics}"
|
||||
)
|
||||
self.logger.log(data=val_data.timing_raw, step=val_data.param_version)
|
||||
self.global_steps += 1
|
||||
|
||||
# final parameter sync and validate
|
||||
if val_data is None or val_data.metrics is None:
|
||||
self._trigger_parameter_sync_after_step(validate=True, global_steps=self.global_steps - 1)
|
||||
ray.get(self.param_synchronizer.wait_last_valid.remote())
|
||||
val_data = self.message_queue_client.get_validate_sync()
|
||||
if val_data:
|
||||
val_data: ValidateMetrics = ray.cloudpickle.loads(val_data)
|
||||
if val_data.metrics:
|
||||
self.logger.log(data=val_data.metrics, step=val_data.param_version)
|
||||
pprint(f"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}")
|
||||
self.logger.log(data=val_data.timing_raw, step=val_data.param_version)
|
||||
else:
|
||||
pprint(f"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}")
|
||||
self.progress_bar.close()
|
||||
|
||||
self._check_save_checkpoint(True, timing_raw) # TODO: check checkpoint
|
||||
|
||||
def load_checkpoint(self):
|
||||
return self._load_checkpoint()
|
||||
|
||||
def _collect_metrics_from_samples(self, batch, metrics):
|
||||
"""
|
||||
Collect metrics from samples
|
||||
"""
|
||||
if hasattr(batch, "meta_info") and batch.meta_info:
|
||||
samples_param_versions = batch.meta_info["rollout_param_versions"]
|
||||
stale_count = sum(1 for v in samples_param_versions if self.current_param_version - v >= 1)
|
||||
self.stale_samples_processed += stale_count
|
||||
trajectory_param_versions = batch.meta_info["trajectory_param_versions"]
|
||||
stale_traj_count = sum(1 for v in trajectory_param_versions if self.current_param_version - v >= 1)
|
||||
self.stale_trajectory_processed += stale_traj_count
|
||||
metrics.update(
|
||||
{
|
||||
"fully_async/count/stale_samples_processed": self.stale_samples_processed,
|
||||
"fully_async/count/stale_trajectory_processed": self.stale_trajectory_processed,
|
||||
"fully_async/count/current_param_version": self.current_param_version,
|
||||
}
|
||||
)
|
||||
for key, value in batch.meta_info.items():
|
||||
if key.startswith("fully_async"):
|
||||
metrics[key] = value
|
||||
|
||||
def _trigger_parameter_sync_after_step(self, validate: bool = False, global_steps: int = None):
|
||||
"""
|
||||
Trigger parameter synchronization after training step
|
||||
This ensures rollouter always uses the latest trained parameters
|
||||
"""
|
||||
if self.local_trigger_step < self.trigger_parameter_sync_step and not validate:
|
||||
self.local_trigger_step += 1
|
||||
return
|
||||
|
||||
self.current_param_version += 1
|
||||
self.local_trigger_step = 1
|
||||
self.logger.log(
|
||||
data=self.metrics_aggregator.get_aggregated_metrics(),
|
||||
step=self.current_param_version,
|
||||
)
|
||||
self.progress_bar.update(1)
|
||||
self.metrics_aggregator.reset()
|
||||
timing_param_sync = {}
|
||||
with marked_timer("timing_s/wait_last_valid", timing_param_sync):
|
||||
ray.get(self.param_synchronizer.wait_last_valid.remote())
|
||||
with marked_timer("timing_s/param_sync", timing_param_sync):
|
||||
ray.get(
|
||||
self.param_synchronizer.sync_weights.remote(
|
||||
self.current_param_version, validate=validate, global_steps=global_steps
|
||||
)
|
||||
)
|
||||
self.logger.log(data=timing_param_sync, step=self.current_param_version)
|
265
recipe/fully_async_policy/message_queue.py
Normal file
265
recipe/fully_async_policy/message_queue.py
Normal file
@ -0,0 +1,265 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections import deque
|
||||
from typing import Any
|
||||
|
||||
import ray
|
||||
from omegaconf import DictConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ray.remote(num_cpus=2, max_concurrency=20)
|
||||
class MessageQueue:
|
||||
"""
|
||||
Simplified Ray-based asynchronous message queue for communication between Rollouter and Trainer
|
||||
"""
|
||||
|
||||
def __init__(self, config: DictConfig, max_queue_size: int = 1000):
|
||||
self.config = config
|
||||
if max_queue_size is None:
|
||||
raise ValueError(f"max_queue_size cannot be None, got: {max_queue_size}")
|
||||
self.max_queue_size = int(max_queue_size)
|
||||
self.queue = deque(maxlen=self.max_queue_size)
|
||||
self.current_param_version = 0
|
||||
|
||||
self.val_queue = deque()
|
||||
|
||||
try:
|
||||
if hasattr(config, "async_training") and config.async_training is not None:
|
||||
self.staleness_threshold = getattr(config.async_training, "staleness_threshold", 3)
|
||||
else:
|
||||
self.staleness_threshold = 3
|
||||
except (AttributeError, RecursionError):
|
||||
self.staleness_threshold = 3
|
||||
|
||||
# Asyncio for message handling
|
||||
self.running = True
|
||||
|
||||
# async safe
|
||||
self._lock = asyncio.Lock()
|
||||
self._consumer_condition = asyncio.Condition(self._lock)
|
||||
|
||||
# statistic message
|
||||
self.total_produced = 0
|
||||
self.total_consumed = 0
|
||||
self.dropped_samples = 0
|
||||
|
||||
print(
|
||||
f"[MessageQueue] initialized with max_queue_size={max_queue_size},"
|
||||
f"staleness_threshold={self.staleness_threshold}"
|
||||
)
|
||||
|
||||
async def put_sample(self, sample: Any, param_version: int) -> bool:
|
||||
"""
|
||||
Put a batch sample into the queue
|
||||
|
||||
Args:
|
||||
sample: Sample data
|
||||
param_version: Parameter version number
|
||||
|
||||
Returns:
|
||||
bool: Whether the sample was successfully put into the queue
|
||||
"""
|
||||
async with self._lock:
|
||||
# If queue is full, remove the oldest sample (rarely happens)
|
||||
is_drop = False
|
||||
if len(self.queue) >= self.max_queue_size:
|
||||
self.queue.popleft()
|
||||
self.dropped_samples += 1
|
||||
is_drop = True
|
||||
logger.warning("Queue full, dropped sample")
|
||||
self.queue.append(sample)
|
||||
self.total_produced += 1
|
||||
|
||||
# Notify waiting consumers
|
||||
self._consumer_condition.notify_all()
|
||||
|
||||
if self.total_produced % 100 == 0:
|
||||
print(f"MessageQueue stats: produced={self.total_produced}, queue_size={len(self.queue)}")
|
||||
if is_drop:
|
||||
return False
|
||||
return True
|
||||
|
||||
async def get_sample(self) -> Any | None:
|
||||
"""
|
||||
Get a single sample from the queue, wait until one is available
|
||||
|
||||
Returns:
|
||||
Any: Single sample data or None if queue is closed
|
||||
"""
|
||||
async with self._lock:
|
||||
while len(self.queue) == 0 and self.running:
|
||||
await self._consumer_condition.wait()
|
||||
|
||||
# If queue is closed and empty, return None
|
||||
if not self.running and len(self.queue) == 0:
|
||||
return None
|
||||
|
||||
# Get one sample
|
||||
data = self.queue.popleft()
|
||||
self.total_consumed += 1
|
||||
return data, len(self.queue)
|
||||
|
||||
async def update_param_version(self, version: int):
|
||||
"""Update current parameter version"""
|
||||
async with self._lock:
|
||||
old_version = self.current_param_version
|
||||
self.current_param_version = version
|
||||
print(f"Parameter version updated from {old_version} to {version}")
|
||||
|
||||
async def get_queue_size(self) -> int:
|
||||
"""Get current queue length"""
|
||||
async with self._lock:
|
||||
return len(self.queue)
|
||||
|
||||
async def get_statistics(self) -> dict[str, Any]:
|
||||
"""Get queue statistics"""
|
||||
async with self._lock:
|
||||
return {
|
||||
"queue_size": len(self.queue),
|
||||
"total_produced": self.total_produced,
|
||||
"total_consumed": self.total_consumed,
|
||||
"dropped_samples": self.dropped_samples,
|
||||
"current_param_version": self.current_param_version,
|
||||
"staleness_threshold": self.staleness_threshold,
|
||||
"max_queue_size": self.max_queue_size,
|
||||
}
|
||||
|
||||
async def clear_queue(self):
|
||||
"""Clear the queue"""
|
||||
async with self._lock:
|
||||
cleared_count = len(self.queue)
|
||||
self.queue.clear()
|
||||
logger.info(f"Cleared {cleared_count} samples from queue")
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the message queue"""
|
||||
async with self._lock:
|
||||
self.running = False
|
||||
# Notify all waiting coroutines so they can exit
|
||||
self._consumer_condition.notify_all()
|
||||
logger.info("MessageQueue shutdown")
|
||||
|
||||
async def get_memory_usage(self) -> dict:
|
||||
"""Get memory usage statistics"""
|
||||
async with self._lock:
|
||||
# Estimate memory usage of samples in queue
|
||||
import sys
|
||||
|
||||
total_size = 0
|
||||
sample_count = len(self.queue)
|
||||
|
||||
if sample_count > 0:
|
||||
# Estimate size of a single sample (simplified estimation)
|
||||
sample = list(self.queue)[0]
|
||||
try:
|
||||
sample_size = sys.getsizeof(sample)
|
||||
# Since we now store RolloutSample directly, estimate based on its components
|
||||
if hasattr(sample, "original_batch_dict") and sample.original_batch_dict:
|
||||
# Estimate batch data size
|
||||
batch_data = sample.original_batch_dict.get("batch", {})
|
||||
sample_size += len(batch_data) * 1000 # Roughly estimate 1KB per batch entry
|
||||
if hasattr(sample, "agent_loop_output"):
|
||||
# Estimate AgentLoopOutput size
|
||||
sample_size += 5000 # Roughly estimate 5KB for AgentLoopOutput
|
||||
total_size = sample_size * sample_count
|
||||
except Exception:
|
||||
total_size = sample_count * 15000 # Roughly estimate 15KB per RolloutSample
|
||||
|
||||
return {
|
||||
"queue_samples": sample_count,
|
||||
"estimated_memory_bytes": total_size,
|
||||
"estimated_memory_mb": total_size / (1024 * 1024),
|
||||
}
|
||||
|
||||
async def put_validate(self, data):
|
||||
async with self._lock:
|
||||
self.val_queue.append(data)
|
||||
|
||||
async def get_validate(self):
|
||||
async with self._lock:
|
||||
if self.val_queue:
|
||||
return self.val_queue.popleft()
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
class MessageQueueClient:
|
||||
"""Asyncio-compatible MessageQueue client for communicating with MessageQueue Actor"""
|
||||
|
||||
def __init__(self, queue_actor: Any):
|
||||
self.queue_actor = queue_actor
|
||||
|
||||
async def put_sample(self, sample: Any, param_version: int) -> bool:
|
||||
"""Put batch into queue (async)"""
|
||||
future = self.queue_actor.put_sample.remote(sample, param_version)
|
||||
return await asyncio.wrap_future(future.future())
|
||||
|
||||
async def put_validate(self, data: Any) -> bool:
|
||||
future = self.queue_actor.put_validate.remote(data)
|
||||
return await asyncio.wrap_future(future.future())
|
||||
|
||||
def get_validate_sync(self) -> Any | None:
|
||||
return ray.get(self.queue_actor.get_validate.remote())
|
||||
|
||||
async def get_sample(self) -> Any | None:
|
||||
"""Get single sample from queue, wait until one is available (async)"""
|
||||
future = self.queue_actor.get_sample.remote()
|
||||
return await asyncio.wrap_future(future.future())
|
||||
|
||||
async def get_queue_size(self) -> int:
|
||||
"""Get queue size (async)"""
|
||||
future = self.queue_actor.get_queue_size.remote()
|
||||
return await asyncio.wrap_future(future.future())
|
||||
|
||||
async def get_statistics(self) -> dict[str, Any]:
|
||||
"""Get statistics (async)"""
|
||||
future = self.queue_actor.get_statistics.remote()
|
||||
return await asyncio.wrap_future(future.future())
|
||||
|
||||
async def clear_queue(self):
|
||||
"""Clear queue (async)"""
|
||||
future = self.queue_actor.clear_queue.remote()
|
||||
await asyncio.wrap_future(future.future())
|
||||
|
||||
async def shutdown(self):
|
||||
"""Shutdown queue (async)"""
|
||||
future = self.queue_actor.shutdown.remote()
|
||||
await asyncio.wrap_future(future.future())
|
||||
|
||||
async def get_memory_usage(self) -> dict:
|
||||
"""Get memory usage statistics (async)"""
|
||||
future = self.queue_actor.get_memory_usage.remote()
|
||||
return await asyncio.wrap_future(future.future())
|
||||
|
||||
# Synchronous version of the method (deprecated)
|
||||
def put_sample_sync(self, sample: Any, param_version: int) -> bool:
|
||||
"""Put batch into queue (sync - deprecated, use put_sample instead)"""
|
||||
return ray.get(self.queue_actor.put_sample.remote(sample, param_version))
|
||||
|
||||
def get_sample_sync(self) -> Any | None:
|
||||
"""Get single sample from queue (sync - deprecated, use get_sample instead)"""
|
||||
return ray.get(self.queue_actor.get_sample.remote())
|
||||
|
||||
def get_statistics_sync(self) -> dict[str, Any]:
|
||||
"""Get statistics (sync - deprecated, use get_statistics instead)"""
|
||||
return ray.get(self.queue_actor.get_statistics.remote())
|
||||
|
||||
def update_param_version_sync(self, version: int):
|
||||
"""Update parameter version (async)"""
|
||||
return ray.get(self.queue_actor.update_param_version.remote(version))
|
105
recipe/fully_async_policy/param_sync.py
Normal file
105
recipe/fully_async_policy/param_sync.py
Normal file
@ -0,0 +1,105 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
import ray
|
||||
from ray.util.collective import collective
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ParameterSynchronizer:
|
||||
"""
|
||||
Unified parameter synchronizer, responsible for synchronizing model parameters between actor and rollout
|
||||
Based on the mature synchronization mode implementation of one_step_off_policy
|
||||
Merges the functions of the original multiple synchronizer classes
|
||||
"""
|
||||
|
||||
def __init__(self, config, trainer, rollouter, mq):
|
||||
self.config = config
|
||||
self.trainer = trainer
|
||||
self.rollouter = rollouter
|
||||
self.mq_client = mq
|
||||
self.actor_wg = ray.get(trainer.get_actor_wg.remote())
|
||||
self.rollout_wg = ray.get(rollouter.get_rollout_wg.remote())
|
||||
|
||||
# Basic attributes
|
||||
self.weights_info = None
|
||||
self.sync_group_initialized = False
|
||||
self.sync_group_name = "actor_rollout"
|
||||
self.wait_last_update = None
|
||||
self.wait_last_resume = None
|
||||
|
||||
# Statistics
|
||||
self.current_version = 0
|
||||
|
||||
self._init_weights_info()
|
||||
self._init_sync_group()
|
||||
|
||||
def get_current_param_version(self) -> int:
|
||||
"""Get current parameter version number"""
|
||||
return self.current_version
|
||||
|
||||
def get_weights_info(self):
|
||||
"""Get weights info"""
|
||||
return self.weights_info
|
||||
|
||||
def _init_weights_info(self):
|
||||
self.weights_info = self.actor_wg.get_actor_weights_info()[0]
|
||||
self.rollout_wg.set_actor_weights_info(self.weights_info)
|
||||
|
||||
def _init_sync_group(self):
|
||||
print("[ParameterSynchronizer] Initializing parameter synchronization group...")
|
||||
actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers
|
||||
collective.create_collective_group(
|
||||
actor_rollout_workers,
|
||||
len(actor_rollout_workers),
|
||||
list(range(0, len(actor_rollout_workers))),
|
||||
backend="nccl",
|
||||
group_name=self.sync_group_name,
|
||||
)
|
||||
|
||||
def sync_weights(self, version, validate=False, global_steps=0):
|
||||
"""Sync weights between trainer and rollouter, and update parameter version"""
|
||||
start_time = time.time()
|
||||
|
||||
self.current_version = version
|
||||
print(f"[ParameterSynchronizer] Starting weight synchronization (version {self.current_version})...")
|
||||
|
||||
ray.get(self.rollouter.pause.remote())
|
||||
|
||||
# Update MQ version
|
||||
self.mq_client.update_param_version_sync(version)
|
||||
|
||||
# sync weights
|
||||
self.actor_wg.sync_rollout_weights()
|
||||
ray.get(self.rollout_wg.sync_rollout_weights())
|
||||
end_time = time.time()
|
||||
print(f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds")
|
||||
|
||||
# Async Update rollout version & validation
|
||||
self.wait_last_update = self.rollouter.update_param_version.remote(version, validate, global_steps)
|
||||
self.wait_last_resume = self.rollouter.resume.remote(self.wait_last_update)
|
||||
|
||||
def wait_last_valid(self):
|
||||
print("[ParameterSynchronizer] Waiting last sync and validate...")
|
||||
start_time = time.time()
|
||||
if self.wait_last_update:
|
||||
ray.get(self.wait_last_update)
|
||||
if self.wait_last_resume:
|
||||
ray.get(self.wait_last_resume)
|
||||
print(f"[ParameterSynchronizer] Wait last validate cost: {time.time() - start_time:.2f} seconds")
|
528
recipe/fully_async_policy/ray_trainer.py
Normal file
528
recipe/fully_async_policy/ray_trainer.py
Normal file
@ -0,0 +1,528 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
PPO Trainer with Ray-based single controller.
|
||||
This trainer supports model-agonistic model initialization with huggingface
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from copy import deepcopy
|
||||
from pprint import pprint
|
||||
|
||||
import numpy as np
|
||||
import ray
|
||||
import torch
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
from verl import DataProto
|
||||
from verl.experimental.dataset.sampler import AbstractCurriculumSampler
|
||||
from verl.single_controller.ray import RayClassWithInitArgs
|
||||
from verl.single_controller.ray.base import create_colocated_worker_cls
|
||||
from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss
|
||||
from verl.trainer.ppo.metric_utils import (
|
||||
compute_data_metrics,
|
||||
compute_throughout_metrics,
|
||||
compute_timing_metrics,
|
||||
)
|
||||
from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask
|
||||
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
|
||||
from verl.trainer.ppo.utils import Role
|
||||
from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi
|
||||
from verl.utils.config import omega_conf_to_dataclass
|
||||
from verl.utils.debug import marked_timer
|
||||
from verl.utils.metric import (
|
||||
reduce_metrics,
|
||||
)
|
||||
from verl.utils.rollout_skip import RolloutSkip
|
||||
|
||||
|
||||
class FullyAsyncRayPPOTrainer(RayPPOTrainer):
|
||||
def init_workers(self):
|
||||
"""Initialize distributed training workers using Ray backend.
|
||||
|
||||
Creates:
|
||||
1. Ray resource pools from configuration
|
||||
2. Worker groups for each role (actor, critic, etc.)
|
||||
"""
|
||||
self._init_resource_pools()
|
||||
self._create_worker_classes()
|
||||
self._init_worker_groups()
|
||||
self._init_models()
|
||||
self._init_async_rollout_manager()
|
||||
|
||||
def _init_resource_pools(self):
|
||||
self.resource_pool_manager.create_resource_pool()
|
||||
|
||||
self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()}
|
||||
|
||||
def _create_worker_classes(self):
|
||||
self._create_actor_rollout_classes()
|
||||
self._create_critic_class()
|
||||
self._create_reference_policy_class()
|
||||
self._create_reward_model_class()
|
||||
|
||||
def _create_actor_rollout_classes(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def _create_critic_class(self):
|
||||
# create critic
|
||||
if self.use_critic:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
||||
critic_cfg = omega_conf_to_dataclass(self.config.critic)
|
||||
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
|
||||
|
||||
def _create_reference_policy_class(self):
|
||||
# create reference policy if needed
|
||||
if self.use_reference_policy:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy)
|
||||
ref_policy_cls = RayClassWithInitArgs(
|
||||
self.role_worker_mapping[Role.RefPolicy],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role=str(Role.RefPolicy),
|
||||
profile_option=self.config.trainer.npu_profile.options,
|
||||
)
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
|
||||
|
||||
def _create_reward_model_class(self):
|
||||
# create a reward model if reward_fn is None
|
||||
if self.use_rm:
|
||||
# we create a RM here
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
|
||||
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
|
||||
|
||||
def _init_worker_groups(self):
|
||||
# initialize WorkerGroup
|
||||
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
|
||||
# you should not use `create_colocated_worker_cls`.
|
||||
# Instead, directly pass different resource pool to different worker groups.
|
||||
# See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information.
|
||||
all_wg = {}
|
||||
wg_kwargs = {} # Setting up kwargs for RayWorkerGroup
|
||||
if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None:
|
||||
wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout
|
||||
if OmegaConf.select(self.config.global_profiler, "steps") is not None:
|
||||
wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps")
|
||||
# Only require nsight worker options when tool is nsys
|
||||
if OmegaConf.select(self.config.global_profiler, "tool") == "nsys":
|
||||
assert (
|
||||
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
|
||||
is not None
|
||||
), "worker_nsight_options must be set when using nsys with profile_steps"
|
||||
wg_kwargs["worker_nsight_options"] = OmegaConf.to_container(
|
||||
OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options")
|
||||
)
|
||||
wg_kwargs["device_name"] = self.device_name
|
||||
|
||||
for resource_pool, class_dict in self.resource_pool_to_cls.items():
|
||||
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict)
|
||||
wg_dict = self.ray_worker_group_cls(
|
||||
resource_pool=resource_pool,
|
||||
ray_cls_with_init=worker_dict_cls,
|
||||
**wg_kwargs,
|
||||
)
|
||||
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
|
||||
all_wg.update(spawn_wg)
|
||||
self.all_wg = all_wg
|
||||
|
||||
def _init_models(self):
|
||||
if self.use_critic:
|
||||
self.critic_wg = self.all_wg[str(Role.Critic)]
|
||||
self.critic_wg.init_model()
|
||||
|
||||
if self.use_reference_policy and not self.ref_in_actor:
|
||||
self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)]
|
||||
self.ref_policy_wg.init_model()
|
||||
|
||||
if self.use_rm:
|
||||
self.rm_wg = self.all_wg[str(Role.RewardModel)]
|
||||
self.rm_wg.init_model()
|
||||
|
||||
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
|
||||
self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)]
|
||||
self.actor_rollout_wg.init_model()
|
||||
|
||||
def _init_async_rollout_manager(self):
|
||||
pass
|
||||
|
||||
def fit(self):
|
||||
"""
|
||||
The training loop of PPO.
|
||||
The driver process only need to call the compute functions of the worker group through RPC
|
||||
to construct the PPO dataflow.
|
||||
The light-weight advantage computation is done on the driver process.
|
||||
"""
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
from verl.utils.tracking import Tracking
|
||||
|
||||
logger = Tracking(
|
||||
project_name=self.config.trainer.project_name,
|
||||
experiment_name=self.config.trainer.experiment_name,
|
||||
default_backend=self.config.trainer.logger,
|
||||
config=OmegaConf.to_container(self.config, resolve=True),
|
||||
)
|
||||
|
||||
self.global_steps = 0
|
||||
|
||||
# load checkpoint before doing anything
|
||||
self._load_checkpoint()
|
||||
|
||||
# perform validation before training
|
||||
# currently, we only support validation using the reward_function.
|
||||
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
|
||||
val_metrics = self._validate()
|
||||
assert val_metrics, f"{val_metrics=}"
|
||||
pprint(f"Initial validation metrics: {val_metrics}")
|
||||
logger.log(data=val_metrics, step=self.global_steps)
|
||||
if self.config.trainer.get("val_only", False):
|
||||
return
|
||||
|
||||
if self.config.actor_rollout_ref.rollout.get("skip_rollout", False):
|
||||
rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg)
|
||||
rollout_skip.wrap_generate_sequences()
|
||||
|
||||
# add tqdm
|
||||
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
|
||||
|
||||
# we start from step 1
|
||||
self.global_steps += 1
|
||||
last_val_metrics = None
|
||||
self.max_steps_duration = 0
|
||||
|
||||
prev_step_profile = False
|
||||
curr_step_profile = (
|
||||
self.global_steps in self.config.global_profiler.steps
|
||||
if self.config.global_profiler.steps is not None
|
||||
else False
|
||||
)
|
||||
next_step_profile = False
|
||||
|
||||
for epoch in range(self.config.trainer.total_epochs):
|
||||
for batch_dict in self.train_dataloader:
|
||||
metrics = {}
|
||||
timing_raw = {}
|
||||
|
||||
with marked_timer("start_profile", timing_raw):
|
||||
self._start_profiling(
|
||||
not prev_step_profile and curr_step_profile
|
||||
if self.config.global_profiler.profile_continuous_steps
|
||||
else curr_step_profile
|
||||
)
|
||||
|
||||
batch, gen_batch = self._prepare_generate_batch(batch_dict)
|
||||
|
||||
is_last_step = self.global_steps >= self.total_training_steps
|
||||
|
||||
with marked_timer("step", timing_raw):
|
||||
# generate a batch
|
||||
with marked_timer("gen", timing_raw, color="red"):
|
||||
if not self.async_rollout_mode:
|
||||
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
|
||||
else:
|
||||
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch)
|
||||
timing_raw.update(gen_batch_output.meta_info["timing"])
|
||||
gen_batch_output.meta_info.pop("timing", None)
|
||||
|
||||
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
|
||||
if self.reward_fn is None:
|
||||
raise ValueError("A reward_fn is required for REMAX advantage estimation.")
|
||||
|
||||
with marked_timer("gen_max", timing_raw, color="purple"):
|
||||
gen_baseline_batch = deepcopy(gen_batch)
|
||||
gen_baseline_batch.meta_info["do_sample"] = False
|
||||
if not self.async_rollout_mode:
|
||||
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
|
||||
else:
|
||||
gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch)
|
||||
batch = batch.union(gen_baseline_output)
|
||||
reward_baseline_tensor = self.reward_fn(batch)
|
||||
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
|
||||
|
||||
batch.pop(batch_keys=list(gen_baseline_output.batch.keys()))
|
||||
|
||||
batch.batch["reward_baselines"] = reward_baseline_tensor
|
||||
|
||||
del gen_baseline_batch, gen_baseline_output
|
||||
|
||||
batch = self._post_generate_batch(batch, gen_batch_output, metrics)
|
||||
batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw)
|
||||
self._log_rollout(batch, reward_extra_infos_dict, timing_raw)
|
||||
|
||||
last_val_metrics = self._validate_metrics(is_last_step, last_val_metrics, metrics, timing_raw)
|
||||
self._check_save_checkpoint(is_last_step, timing_raw)
|
||||
|
||||
with marked_timer("stop_profile", timing_raw):
|
||||
next_step_profile = (
|
||||
self.global_steps + 1 in self.config.global_profiler.steps
|
||||
if self.config.global_profiler.steps is not None
|
||||
else False
|
||||
)
|
||||
self._stop_profiling(
|
||||
curr_step_profile and not next_step_profile
|
||||
if self.config.global_profiler.profile_continuous_steps
|
||||
else curr_step_profile
|
||||
)
|
||||
prev_step_profile = curr_step_profile
|
||||
curr_step_profile = next_step_profile
|
||||
|
||||
self._collect_metrics(batch, epoch, metrics, timing_raw)
|
||||
self._post_batch_processing(batch)
|
||||
|
||||
# TODO: make a canonical logger that supports various backend
|
||||
logger.log(data=metrics, step=self.global_steps)
|
||||
|
||||
progress_bar.update(1)
|
||||
self.global_steps += 1
|
||||
|
||||
if (
|
||||
hasattr(self.config.actor_rollout_ref.actor, "profiler")
|
||||
and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory"
|
||||
):
|
||||
self.actor_rollout_wg.dump_memory_snapshot(
|
||||
tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}"
|
||||
)
|
||||
|
||||
if is_last_step:
|
||||
pprint(f"Final validation metrics: {last_val_metrics}")
|
||||
progress_bar.close()
|
||||
return
|
||||
|
||||
def _prepare_generate_batch(self, batch_dict):
|
||||
batch: DataProto = DataProto.from_single_dict(batch_dict)
|
||||
|
||||
# add uid to batch
|
||||
batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object)
|
||||
|
||||
gen_batch = self._get_gen_batch(batch)
|
||||
|
||||
# pass global_steps to trace
|
||||
gen_batch.meta_info["global_steps"] = self.global_steps
|
||||
gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
return batch, gen_batch
|
||||
|
||||
def _post_generate_batch(self, batch, gen_batch_output, metrics):
|
||||
# repeat to align with repeated responses in rollout
|
||||
batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
|
||||
batch = batch.union(gen_batch_output)
|
||||
|
||||
if "response_mask" not in batch.batch.keys():
|
||||
batch.batch["response_mask"] = compute_response_mask(batch)
|
||||
# Balance the number of valid tokens across DP ranks.
|
||||
# NOTE: This usually changes the order of data in the `batch`,
|
||||
# which won't affect the advantage calculation (since it's based on uid),
|
||||
# but might affect the loss calculation (due to the change of mini-batching).
|
||||
# TODO: Decouple the DP balancing and mini-batching.
|
||||
if self.config.trainer.balance_batch:
|
||||
self._balance_batch(batch, metrics=metrics)
|
||||
|
||||
# compute global_valid tokens
|
||||
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
|
||||
|
||||
return batch
|
||||
|
||||
def _process_batch_common(self, batch, metrics, timing_raw):
|
||||
with marked_timer("reward", timing_raw, color="yellow"):
|
||||
# compute reward model score
|
||||
if self.use_rm:
|
||||
reward_tensor = self.rm_wg.compute_rm_score(batch)
|
||||
batch = batch.union(reward_tensor)
|
||||
|
||||
if self.config.reward_model.launch_reward_fn_async:
|
||||
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn)
|
||||
else:
|
||||
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn)
|
||||
|
||||
# recompute old_log_probs
|
||||
with marked_timer("old_log_prob", timing_raw, color="blue"):
|
||||
async_training = self.config.get("async_training", None)
|
||||
if async_training and async_training.use_rollout_log_probs:
|
||||
batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"]
|
||||
batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature
|
||||
|
||||
else:
|
||||
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
|
||||
entropys = old_log_prob.batch["entropys"]
|
||||
response_masks = batch.batch["response_mask"]
|
||||
loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode
|
||||
entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode)
|
||||
old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()}
|
||||
metrics.update(old_log_prob_metrics)
|
||||
old_log_prob.batch.pop("entropys")
|
||||
batch = batch.union(old_log_prob)
|
||||
|
||||
if "rollout_log_probs" in batch.batch.keys():
|
||||
# TODO: we may want to add diff of probs too.
|
||||
from verl.utils.debug.metrics import calculate_debug_metrics
|
||||
|
||||
metrics.update(calculate_debug_metrics(batch))
|
||||
|
||||
if self.use_reference_policy:
|
||||
# compute reference log_prob
|
||||
with marked_timer("ref", timing_raw, color="olive"):
|
||||
if not self.ref_in_actor:
|
||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||
else:
|
||||
ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch)
|
||||
batch = batch.union(ref_log_prob)
|
||||
|
||||
# compute values
|
||||
if self.use_critic:
|
||||
with marked_timer("values", timing_raw, color="cyan"):
|
||||
values = self.critic_wg.compute_values(batch)
|
||||
batch = batch.union(values)
|
||||
|
||||
with marked_timer("adv", timing_raw, color="brown"):
|
||||
# we combine with rule-based rm
|
||||
reward_extra_infos_dict: dict[str, list]
|
||||
if self.config.reward_model.launch_reward_fn_async:
|
||||
reward_tensor, reward_extra_infos_dict = ray.get(future_reward)
|
||||
batch.batch["token_level_scores"] = reward_tensor
|
||||
|
||||
if reward_extra_infos_dict:
|
||||
batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()})
|
||||
|
||||
# compute rewards. apply_kl_penalty if available
|
||||
if self.config.algorithm.use_kl_in_reward:
|
||||
batch, kl_metrics = apply_kl_penalty(
|
||||
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
|
||||
)
|
||||
metrics.update(kl_metrics)
|
||||
else:
|
||||
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
|
||||
|
||||
# compute advantages, executed on the driver process
|
||||
|
||||
norm_adv_by_std_in_grpo = self.config.algorithm.get(
|
||||
"norm_adv_by_std_in_grpo", True
|
||||
) # GRPO adv normalization factor
|
||||
|
||||
batch = compute_advantage(
|
||||
batch,
|
||||
adv_estimator=self.config.algorithm.adv_estimator,
|
||||
gamma=self.config.algorithm.gamma,
|
||||
lam=self.config.algorithm.lam,
|
||||
num_repeat=self.config.actor_rollout_ref.rollout.n,
|
||||
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
|
||||
config=self.config.algorithm,
|
||||
)
|
||||
|
||||
# update critic
|
||||
if self.use_critic:
|
||||
with marked_timer("update_critic", timing_raw, color="pink"):
|
||||
critic_output = self.critic_wg.update_critic(batch)
|
||||
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
|
||||
metrics.update(critic_output_metrics)
|
||||
|
||||
# implement critic warmup
|
||||
if self.config.trainer.critic_warmup <= self.global_steps:
|
||||
# update actor
|
||||
with marked_timer("update_actor", timing_raw, color="red"):
|
||||
batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable
|
||||
actor_output = self.actor_rollout_wg.update_actor(batch)
|
||||
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
|
||||
metrics.update(actor_output_metrics)
|
||||
return batch, reward_extra_infos_dict
|
||||
|
||||
def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw):
|
||||
# Log rollout generations if enabled
|
||||
rollout_data_dir = self.config.trainer.get("rollout_data_dir", None)
|
||||
if rollout_data_dir:
|
||||
with marked_timer("dump_rollout_generations", timing_raw, color="green"):
|
||||
inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True)
|
||||
outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True)
|
||||
scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist()
|
||||
sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch]
|
||||
|
||||
if "request_id" in batch.non_tensor_batch:
|
||||
reward_extra_infos_dict.setdefault(
|
||||
"request_id",
|
||||
batch.non_tensor_batch["request_id"].tolist(),
|
||||
)
|
||||
|
||||
self._dump_generations(
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
gts=sample_gts,
|
||||
scores=scores,
|
||||
reward_extra_infos_dict=reward_extra_infos_dict,
|
||||
dump_path=rollout_data_dir,
|
||||
)
|
||||
|
||||
def _validate_metrics(self, is_last_step, last_val_metrics, metrics, timing_raw):
|
||||
if (
|
||||
self.val_reward_fn is not None
|
||||
and self.config.trainer.test_freq > 0
|
||||
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
|
||||
):
|
||||
with marked_timer("testing", timing_raw, color="green"):
|
||||
val_metrics: dict = self._validate()
|
||||
if is_last_step:
|
||||
last_val_metrics = val_metrics
|
||||
metrics.update(val_metrics)
|
||||
return last_val_metrics
|
||||
|
||||
def _check_save_checkpoint(self, is_last_step, timing_raw):
|
||||
# Check if the ESI (Elastic Server Instance)/training plan is close to expiration.
|
||||
esi_close_to_expiration = should_save_ckpt_esi(
|
||||
max_steps_duration=self.max_steps_duration,
|
||||
redundant_time=self.config.trainer.esi_redundant_time,
|
||||
)
|
||||
# Check if the conditions for saving a checkpoint are met.
|
||||
# The conditions include a mandatory condition (1) and
|
||||
# one of the following optional conditions (2/3/4):
|
||||
# 1. The save frequency is set to a positive value.
|
||||
# 2. It's the last training step.
|
||||
# 3. The current step number is a multiple of the save frequency.
|
||||
# 4. The ESI(Elastic Server Instance)/training plan is close to expiration.
|
||||
if self.config.trainer.save_freq > 0 and (
|
||||
is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration
|
||||
):
|
||||
if esi_close_to_expiration:
|
||||
print("Force saving checkpoint: ESI instance expiration approaching.")
|
||||
with marked_timer("save_checkpoint", timing_raw, color="green"):
|
||||
self._save_checkpoint()
|
||||
|
||||
def _collect_metrics(self, batch, epoch, metrics, timing_raw):
|
||||
steps_duration = timing_raw["step"]
|
||||
self.max_steps_duration = max(self.max_steps_duration, steps_duration)
|
||||
|
||||
# training metrics
|
||||
metrics.update(
|
||||
{
|
||||
"training/global_step": self.global_steps,
|
||||
"training/epoch": epoch,
|
||||
}
|
||||
)
|
||||
# collect metrics
|
||||
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
|
||||
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
|
||||
# TODO: implement actual tflpo and theoretical tflpo
|
||||
n_gpus = self.resource_pool_manager.get_n_gpus()
|
||||
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
|
||||
|
||||
def _post_batch_processing(self, batch: DataProto):
|
||||
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
|
||||
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):
|
||||
self.train_dataloader.sampler.update(batch=batch)
|
||||
|
||||
# this is experimental and may be changed/removed in the future
|
||||
# in favor of a general-purpose data buffer pool
|
||||
if hasattr(self.train_dataset, "on_batch_end"):
|
||||
# The dataset may be changed after each training batch
|
||||
self.train_dataset.on_batch_end(batch=batch)
|
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh
Normal file
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh
Normal file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_16-16'
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 28))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=4
|
||||
sp_size=4
|
||||
fsdp_size=8
|
||||
|
||||
# Fully async specific parameters
|
||||
NNODES_ROLLOUT=${NNODES_ROLLOUT:-2}
|
||||
NNODES_TRAIN=${NNODES_TRAIN:-2}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*400)))
|
||||
test_freq=20
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
require_batches=4
|
||||
partial_rollout=True
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.total_epochs=10 \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.require_batches="${require_batches}" \
|
||||
async_training.partial_rollout="${partial_rollout}" \
|
||||
async_training.use_rollout_log_probs=True
|
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
Normal file
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh
Normal file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_32-32'
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 28))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=4
|
||||
sp_size=4
|
||||
fsdp_size=8
|
||||
|
||||
# Fully async specific parameters
|
||||
NNODES_ROLLOUT=${NNODES_ROLLOUT:-4}
|
||||
NNODES_TRAIN=${NNODES_TRAIN:-4}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*400)))
|
||||
test_freq=20
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
require_batches=4
|
||||
partial_rollout=True
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.total_epochs=10 \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.require_batches="${require_batches}" \
|
||||
async_training.partial_rollout="${partial_rollout}" \
|
||||
async_training.use_rollout_log_probs=True
|
164
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
Normal file
164
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh
Normal file
@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-12'
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=1
|
||||
sp_size=1
|
||||
fsdp_size=2
|
||||
|
||||
# Fully async specific parameters
|
||||
NNODES=${NNODES:-2}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
n_gpus_rollout=2
|
||||
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*100)))
|
||||
test_freq=10
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
require_batches=4
|
||||
partial_rollout=True
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.test_freq="${test_freq}" \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${n_gpus_training}" \
|
||||
rollout.nnodes="${NNODES}" \
|
||||
rollout.n_gpus_per_node="${n_gpus_rollout}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.total_epochs=10 \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.require_batches="${require_batches}" \
|
||||
async_training.partial_rollout="${partial_rollout}" \
|
||||
async_training.use_rollout_log_probs=True
|
164
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
Normal file
164
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh
Normal file
@ -0,0 +1,164 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-4'
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=1
|
||||
sp_size=1
|
||||
fsdp_size=2
|
||||
|
||||
# Fully async specific parameters
|
||||
NNODES=${NNODES:-1}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
n_gpus_rollout=4
|
||||
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*100)))
|
||||
test_freq=10
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
require_batches=4
|
||||
partial_rollout=True
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=False \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.nnodes="${NNODES}" \
|
||||
trainer.n_gpus_per_node="${n_gpus_training}" \
|
||||
rollout.nnodes="${NNODES}" \
|
||||
rollout.n_gpus_per_node="${n_gpus_rollout}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.total_epochs=10 \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.require_batches="${require_batches}" \
|
||||
async_training.partial_rollout="${partial_rollout}" \
|
||||
async_training.use_rollout_log_probs=True
|
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
Normal file
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh
Normal file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_64-64'
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 28))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=4
|
||||
sp_size=4
|
||||
fsdp_size=8
|
||||
|
||||
# Fully async specific parameters
|
||||
NNODES_ROLLOUT=${NNODES_ROLLOUT:-8}
|
||||
NNODES_TRAIN=${NNODES_TRAIN:-8}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*400)))
|
||||
test_freq=20
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
require_batches=4
|
||||
partial_rollout=True
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.total_epochs=10 \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.require_batches="${require_batches}" \
|
||||
async_training.partial_rollout="${partial_rollout}" \
|
||||
async_training.use_rollout_log_probs=True
|
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
Normal file
162
recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh
Normal file
@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
project_name='DAPO'
|
||||
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-8-8'
|
||||
|
||||
# Ray
|
||||
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
|
||||
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
|
||||
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
|
||||
# Paths
|
||||
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
|
||||
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
|
||||
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
|
||||
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
|
||||
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
|
||||
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=$((1024 * 2))
|
||||
max_response_length=$((1024 * 8))
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=$((1024 * 4))
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Algorithm
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
|
||||
val_top_p=0.7
|
||||
|
||||
# Performance Related Parameter
|
||||
use_dynamic_bsz=True
|
||||
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
|
||||
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
gen_tp=1
|
||||
sp_size=1
|
||||
fsdp_size=2
|
||||
|
||||
# Fully async specific parameters
|
||||
NNODES_ROLLOUT=${NNODES_ROLLOUT:-1}
|
||||
NNODES_TRAIN=${NNODES_TRAIN:-1}
|
||||
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=32
|
||||
total_rollout_steps=$(((512*100)))
|
||||
test_freq=10
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
require_batches=4
|
||||
partial_rollout=True
|
||||
|
||||
python -m recipe.fully_async_policy.fully_async_main \
|
||||
data.train_files="${TRAIN_FILE}" \
|
||||
data.val_files="${TEST_FILE}" \
|
||||
data.prompt_key=prompt \
|
||||
data.truncation='left' \
|
||||
data.max_prompt_length=${max_prompt_length} \
|
||||
data.max_response_length=${max_response_length} \
|
||||
data.train_batch_size=${train_prompt_bsz} \
|
||||
data.gen_batch_size=${gen_prompt_bsz} \
|
||||
data.return_raw_chat=${return_raw_chat} \
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
|
||||
algorithm.adv_estimator=${adv_estimator} \
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward} \
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef} \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.hybrid_engine=False \
|
||||
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
|
||||
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
|
||||
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}" \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.entropy_coeff=0 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True \
|
||||
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
|
||||
actor_rollout_ref.rollout.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.top_p=${top_p} \
|
||||
actor_rollout_ref.rollout.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
|
||||
actor_rollout_ref.rollout.name=${rollout_name} \
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode} \
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True \
|
||||
reward_model.reward_manager=dapo \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
|
||||
trainer.logger=['console','tensorboard'] \
|
||||
trainer.project_name="${project_name}" \
|
||||
trainer.experiment_name="${exp_name}" \
|
||||
trainer.val_before_train=True \
|
||||
trainer.save_freq=-1 \
|
||||
trainer.default_local_dir="${CKPTS_DIR}" \
|
||||
trainer.resume_mode=auto \
|
||||
trainer.nnodes="${NNODES_TRAIN}" \
|
||||
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.nnodes="${NNODES_ROLLOUT}" \
|
||||
rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \
|
||||
rollout.total_rollout_steps="${total_rollout_steps}" \
|
||||
rollout.total_epochs=10 \
|
||||
rollout.test_freq="${test_freq}" \
|
||||
async_training.staleness_threshold="${staleness_threshold}" \
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \
|
||||
async_training.require_batches="${require_batches}" \
|
||||
async_training.partial_rollout="${partial_rollout}" \
|
||||
async_training.use_rollout_log_probs=True
|
4
recipe/fully_async_policy/shell/runtime_env.yaml
Normal file
4
recipe/fully_async_policy/shell/runtime_env.yaml
Normal file
@ -0,0 +1,4 @@
|
||||
env_vars:
|
||||
VLLM_USE_V1: "1"
|
||||
NCCL_DEBUG: "INFO"
|
||||
HYDRA_FULL_ERROR: "1"
|
176
recipe/fully_async_policy/unittest/simple_streaming_demo.py
Normal file
176
recipe/fully_async_policy/unittest/simple_streaming_demo.py
Normal file
@ -0,0 +1,176 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
|
||||
|
||||
class SimpleStreamingSystem:
|
||||
"""Simplified streaming system demonstration"""
|
||||
|
||||
def __init__(self, max_concurrent_tasks: int = 4):
|
||||
self.max_concurrent_tasks = max_concurrent_tasks
|
||||
self.data_queue = asyncio.Queue()
|
||||
self.result_queue = asyncio.Queue()
|
||||
self.consumer_count = 0
|
||||
|
||||
# Data stream coroutine
|
||||
async def data_stream(self):
|
||||
# Add initial data
|
||||
# Prepare test data
|
||||
test_data = [{"id": f"task_{i}", "content": f"data_{i}"} for i in range(8)]
|
||||
await self.add_data_stream(test_data)
|
||||
|
||||
# Simulate subsequent data stream
|
||||
await asyncio.sleep(3)
|
||||
print("\nAdding second batch of data...")
|
||||
extra_data = [{"id": f"extra_{i}", "content": f"extra_data_{i}"} for i in range(5)]
|
||||
await self.add_data_stream(extra_data)
|
||||
|
||||
# Send termination signal
|
||||
await asyncio.sleep(1)
|
||||
await self.data_queue.put("DONE")
|
||||
print("Sending termination signal")
|
||||
|
||||
async def add_data_stream(self, data_list: list[dict]):
|
||||
"""Simulate data stream"""
|
||||
print("Starting to add data stream...")
|
||||
|
||||
for i, data_item in enumerate(data_list):
|
||||
await self.data_queue.put(data_item)
|
||||
print(f"Data {data_item['id']} added to pending queue")
|
||||
|
||||
# Simulate interval between data streams
|
||||
if i < len(data_list) - 1: # Don't wait after the last item
|
||||
await asyncio.sleep(0.8)
|
||||
|
||||
print("Initial data stream added successfully")
|
||||
|
||||
async def _process_data_async(self, data_item: dict):
|
||||
"""Asynchronously process a single data item"""
|
||||
data_id = data_item["id"]
|
||||
content = data_item["content"]
|
||||
|
||||
# Simulate different processing times (1-3 seconds)
|
||||
processing_time = random.uniform(1, 3)
|
||||
|
||||
print(f" Starting to process {data_id}, estimated time {processing_time:.1f}s")
|
||||
|
||||
# Asynchronously wait for processing completion
|
||||
await asyncio.sleep(processing_time)
|
||||
|
||||
result = {
|
||||
"id": data_id,
|
||||
"processed_content": f"Processed {content}",
|
||||
"processing_time": round(processing_time, 2),
|
||||
"completed_at": time.time(),
|
||||
}
|
||||
|
||||
# Immediately put into result queue
|
||||
await self.result_queue.put(result)
|
||||
print(f" {data_id} processing completed! (took {processing_time:.1f}s) -> Added to result queue")
|
||||
|
||||
async def _submit_worker(self):
|
||||
"""Stream submission worker coroutine"""
|
||||
active_tasks = set()
|
||||
|
||||
print("Stream submitter started...")
|
||||
|
||||
while True:
|
||||
# Get data to process
|
||||
data_item = await self.data_queue.get()
|
||||
|
||||
if data_item == "DONE":
|
||||
print("Received termination signal, waiting for remaining tasks to complete...")
|
||||
if active_tasks:
|
||||
await asyncio.gather(*active_tasks, return_exceptions=True)
|
||||
break
|
||||
|
||||
# Check concurrent limit
|
||||
while len(active_tasks) >= self.max_concurrent_tasks:
|
||||
print(f"Reached maximum concurrency {self.max_concurrent_tasks}, waiting for tasks to complete...")
|
||||
done_tasks, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
# Clean up completed tasks
|
||||
for task in done_tasks:
|
||||
try:
|
||||
await task
|
||||
print(f"Task completed {task}")
|
||||
except Exception as e:
|
||||
print(f"Task execution failed: {e}")
|
||||
|
||||
# Immediately submit new task
|
||||
task = asyncio.create_task(self._process_data_async(data_item), name=f"active {data_item}")
|
||||
active_tasks.add(task)
|
||||
|
||||
print(f"Submitted task {data_item['id']}, current concurrency: {len(active_tasks)}")
|
||||
|
||||
async def _consumer_worker(self):
|
||||
"""Result consumer coroutine"""
|
||||
print("Consumer started...")
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Get processing result from result queue
|
||||
result = await asyncio.wait_for(self.result_queue.get(), timeout=2.0)
|
||||
|
||||
self.consumer_count += 1
|
||||
|
||||
print(
|
||||
f"Consumed #{self.consumer_count}: {result['id']} "
|
||||
f"(processing time {result['processing_time']}s) - {result['processed_content']}"
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(" Consumer waiting...")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
async def run_demo(self):
|
||||
"""Run demonstration"""
|
||||
print("=" * 60)
|
||||
print(f"Maximum concurrency: {self.max_concurrent_tasks}")
|
||||
print("=" * 60)
|
||||
|
||||
# Start core coroutines
|
||||
stream_task = asyncio.create_task(self.data_stream())
|
||||
submit_task = asyncio.create_task(self._submit_worker())
|
||||
consumer_task = asyncio.create_task(self._consumer_worker())
|
||||
|
||||
try:
|
||||
# Wait for data stream to complete
|
||||
await stream_task
|
||||
print("Data stream completed")
|
||||
|
||||
# Wait for processing to complete
|
||||
await submit_task
|
||||
print("All tasks processed")
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
submit_task.cancel()
|
||||
consumer_task.cancel()
|
||||
await asyncio.gather(submit_task, consumer_task, return_exceptions=True)
|
||||
|
||||
print(f"\nFinal statistics: Consumed {self.consumer_count} results")
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function"""
|
||||
system = SimpleStreamingSystem(max_concurrent_tasks=3)
|
||||
await system.run_demo()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
13
recipe/fully_async_policy/vllm_rollout/__init__.py
Normal file
13
recipe/fully_async_policy/vllm_rollout/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
154
recipe/fully_async_policy/vllm_rollout/vllm_async_server.py
Normal file
154
recipe/fully_async_policy/vllm_rollout/vllm_async_server.py
Normal file
@ -0,0 +1,154 @@
|
||||
# Copyright 2025 Meituan Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Optional, Sequence
|
||||
|
||||
import ray
|
||||
from ray.actor import ActorHandle
|
||||
from vllm import SamplingParams
|
||||
from vllm.inputs import TokensPrompt
|
||||
from vllm.outputs import RequestOutput
|
||||
|
||||
from verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig
|
||||
from verl.workers.rollout.replica import RolloutMode
|
||||
from verl.workers.rollout.vllm_rollout.vllm_async_server import (
|
||||
_qwen2_5_vl_dedup_image_tokens,
|
||||
vLLMHttpServerBase,
|
||||
vLLMReplica,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class vLLMHttpServerForPartial(vLLMHttpServerBase):
|
||||
def __init__(
|
||||
self,
|
||||
config: RolloutConfig | RewardModelConfig,
|
||||
model_config: HFModelConfig,
|
||||
rollout_mode: RolloutMode,
|
||||
workers: list[ActorHandle],
|
||||
replica_rank: int,
|
||||
node_rank: int,
|
||||
gpus_per_node: int,
|
||||
nnodes: int,
|
||||
):
|
||||
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
|
||||
|
||||
# for cancel LLMServer
|
||||
self.paused = False
|
||||
self.lock = asyncio.Lock()
|
||||
self.cancel_event: dict[str, asyncio.Event] = {}
|
||||
self.req_output: dict[str, Optional[RequestOutput]] = {}
|
||||
|
||||
async def _generate_step(
|
||||
self,
|
||||
prompt_ids: list[int],
|
||||
sampling_params: dict[str, Any],
|
||||
request_id: str,
|
||||
image_data: Optional[list[Any]] = None,
|
||||
):
|
||||
max_tokens = self.config.max_model_len - len(prompt_ids)
|
||||
sampling_params["logprobs"] = 1
|
||||
sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0))
|
||||
sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params)
|
||||
prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor)
|
||||
prompt = TokensPrompt(
|
||||
prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None
|
||||
)
|
||||
generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id)
|
||||
|
||||
# Get final response
|
||||
self.req_output[request_id]: Optional[RequestOutput] = None
|
||||
async for output in generator:
|
||||
self.req_output[request_id] = output
|
||||
assert self.req_output[request_id] is not None
|
||||
|
||||
async def generate_for_partial(
|
||||
self,
|
||||
prompt_ids: list[int],
|
||||
sampling_params: dict[str, Any],
|
||||
request_id: str,
|
||||
image_data: Optional[list[Any]] = None,
|
||||
) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]:
|
||||
async with self.lock:
|
||||
if self.paused:
|
||||
# After cancel, all tasks will return directly and wait for the next submission
|
||||
return [], [], True
|
||||
self.cancel_event[request_id] = asyncio.Event()
|
||||
cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait())
|
||||
generation_handle = asyncio.create_task(
|
||||
self._generate_step(prompt_ids, sampling_params, request_id, image_data)
|
||||
)
|
||||
|
||||
done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
for task in done:
|
||||
await task
|
||||
|
||||
for task in pend:
|
||||
task.cancel()
|
||||
|
||||
async with self.lock:
|
||||
token_ids = self.req_output[request_id].outputs[0].token_ids
|
||||
log_probs: list[float] = []
|
||||
for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs):
|
||||
# In sampling_params, logprobs is set to 1, which should return 1,
|
||||
# but in practice there are multiple. Take the log_prob corresponding to token_id
|
||||
token_id = self.req_output[request_id].outputs[0].token_ids[i]
|
||||
log_probs.append(x[token_id].logprob)
|
||||
is_cancel = generation_handle not in done
|
||||
self.cancel_event.pop(request_id, None)
|
||||
self.req_output.pop(request_id, None)
|
||||
return token_ids, log_probs, is_cancel
|
||||
|
||||
async def cancel(self):
|
||||
async with self.lock:
|
||||
self.paused = True
|
||||
for request_id in self.cancel_event:
|
||||
self.cancel_event[request_id].set()
|
||||
|
||||
async def resume(self):
|
||||
async with self.lock:
|
||||
self.paused = False
|
||||
|
||||
async def reset_prefix_cache(self):
|
||||
async with self.lock:
|
||||
await self.engine.reset_prefix_cache()
|
||||
|
||||
|
||||
class FullyAsyncvLLMReplica(vLLMReplica):
|
||||
def __init__(
|
||||
self,
|
||||
replica_rank: int,
|
||||
config: RolloutConfig | RewardModelConfig,
|
||||
model_config: HFModelConfig,
|
||||
gpus_per_node: int = 8,
|
||||
):
|
||||
super().__init__(replica_rank, config, model_config, gpus_per_node)
|
||||
self.server_class = vLLMHttpServerForPartial
|
||||
|
||||
async def cancel(self):
|
||||
"""Cancel each rollout server."""
|
||||
await asyncio.gather(*[server.cancel.remote() for server in self.servers])
|
||||
|
||||
async def resume(self):
|
||||
"""Resume each rollout server."""
|
||||
await asyncio.gather(*[server.resume.remote() for server in self.servers])
|
||||
|
||||
async def reset_prefix_cache(self):
|
||||
"""reset kv cache in each rollout server."""
|
||||
await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers])
|
196
tests/special_e2e/run_fully_async_policy.sh
Normal file
196
tests/special_e2e/run_fully_async_policy.sh
Normal file
@ -0,0 +1,196 @@
|
||||
#!/usr/bin/env bash
|
||||
set -xeuo pipefail
|
||||
|
||||
# Test script for fully_async_policy E2E regression testing
|
||||
# This script runs fully async PPO training with both FSDP2 and Megatron backends
|
||||
# to ensure the asynchronous training mechanism works correctly
|
||||
|
||||
NUM_GPUS=${NUM_GPUS:-8}
|
||||
ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp2"} # fsdp2 or megatron
|
||||
|
||||
# Download model if not exists
|
||||
MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct}
|
||||
MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}}
|
||||
huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}"
|
||||
|
||||
|
||||
rollout_mode="async"
|
||||
rollout_name="vllm" # sglang or vllm
|
||||
if [ "$rollout_mode" = "async" ]; then
|
||||
export VLLM_USE_V1=1
|
||||
return_raw_chat="True"
|
||||
fi
|
||||
|
||||
# Algorithm parameters
|
||||
adv_estimator=grpo
|
||||
|
||||
use_kl_in_reward=False
|
||||
kl_coef=0.0
|
||||
use_kl_loss=False
|
||||
kl_loss_coef=0.0
|
||||
|
||||
clip_ratio_low=0.2
|
||||
clip_ratio_high=0.28
|
||||
|
||||
# Response length parameters
|
||||
max_prompt_length=1024
|
||||
max_response_length=2048
|
||||
enable_overlong_buffer=True
|
||||
overlong_buffer_len=128
|
||||
overlong_penalty_factor=1.0
|
||||
|
||||
# Training parameters
|
||||
loss_agg_mode="token-mean"
|
||||
|
||||
# Temperature parameters
|
||||
temperature=1.0
|
||||
top_p=1.0
|
||||
top_k=-1
|
||||
val_top_p=0.7
|
||||
|
||||
# Fully async specific parameters
|
||||
n_gpus_rollout=4
|
||||
n_gpus_training=4
|
||||
|
||||
train_prompt_bsz=0
|
||||
gen_prompt_bsz=1
|
||||
n_resp_per_prompt=16
|
||||
train_prompt_mini_bsz=16
|
||||
total_rollout_steps=$(((128)))
|
||||
test_freq=-1
|
||||
staleness_threshold=0.1
|
||||
trigger_parameter_sync_step=4
|
||||
partial_rollout=True
|
||||
|
||||
exp_name="$(basename "${MODEL_ID,,}")-fully-async-policy-${ACTOR_STRATEGY}-minimal"
|
||||
|
||||
echo "Running fully_async_policy with ${ACTOR_STRATEGY} strategy"
|
||||
echo "Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}"
|
||||
|
||||
# Common parameters for both FSDP2 and Megatron
|
||||
common_params=(
|
||||
data.train_files="${HOME}/data/gsm8k/train.parquet"
|
||||
data.val_files="${HOME}/data/gsm8k/test.parquet"
|
||||
data.prompt_key=prompt
|
||||
data.truncation='left'
|
||||
data.max_prompt_length=${max_prompt_length}
|
||||
data.max_response_length=${max_response_length}
|
||||
data.train_batch_size=${train_prompt_bsz}
|
||||
data.gen_batch_size=${gen_prompt_bsz}
|
||||
data.return_raw_chat=${return_raw_chat}
|
||||
actor_rollout_ref.rollout.n=${n_resp_per_prompt}
|
||||
actor_rollout_ref.rollout.calculate_log_probs=True
|
||||
algorithm.adv_estimator=${adv_estimator}
|
||||
algorithm.use_kl_in_reward=${use_kl_in_reward}
|
||||
algorithm.kl_ctrl.kl_coef=${kl_coef}
|
||||
actor_rollout_ref.hybrid_engine=False
|
||||
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss}
|
||||
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef}
|
||||
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low}
|
||||
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high}
|
||||
actor_rollout_ref.actor.clip_ratio_c=10.0
|
||||
actor_rollout_ref.model.path="${MODEL_PATH}"
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True
|
||||
actor_rollout_ref.actor.optim.lr=1e-6
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps=-1
|
||||
actor_rollout_ref.actor.optim.weight_decay=0.1
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz}
|
||||
actor_rollout_ref.actor.entropy_coeff=0
|
||||
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode}
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.80
|
||||
actor_rollout_ref.rollout.temperature=${temperature}
|
||||
actor_rollout_ref.rollout.top_p=${top_p}
|
||||
actor_rollout_ref.rollout.top_k=${top_k}
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature}
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p}
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k}
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=True
|
||||
actor_rollout_ref.rollout.val_kwargs.n=1
|
||||
actor_rollout_ref.rollout.enable_chunked_prefill=True
|
||||
actor_rollout_ref.rollout.name=${rollout_name}
|
||||
actor_rollout_ref.rollout.mode=${rollout_mode}
|
||||
reward_model.reward_manager=dapo
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer}
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len}
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor}
|
||||
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False
|
||||
+reward_model.reward_kwargs.max_resp_len=${max_response_length}
|
||||
trainer.logger=['console']
|
||||
trainer.project_name='verl-test-fully-async'
|
||||
trainer.experiment_name="${exp_name}"
|
||||
trainer.val_before_train=True
|
||||
trainer.save_freq=-1
|
||||
trainer.resume_mode=disable
|
||||
trainer.nnodes=1
|
||||
trainer.n_gpus_per_node=${n_gpus_training}
|
||||
rollout.nnodes=1
|
||||
rollout.n_gpus_per_node=${n_gpus_rollout}
|
||||
rollout.total_rollout_steps=${total_rollout_steps}
|
||||
rollout.total_epochs=2
|
||||
rollout.test_freq=${test_freq}
|
||||
# Fully async specific configurations
|
||||
async_training.staleness_threshold=${staleness_threshold}
|
||||
async_training.partial_rollout="${partial_rollout}"
|
||||
async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}"
|
||||
)
|
||||
|
||||
if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then
|
||||
echo "Running fully async training with FSDP2 strategy..."
|
||||
# FSDP2 specific parameters
|
||||
gen_tp=1
|
||||
sp_size=1
|
||||
fsdp_size=1
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
|
||||
python3 -m recipe.fully_async_policy.fully_async_main \
|
||||
"${common_params[@]}" \
|
||||
actor_rollout_ref.actor.strategy=fsdp2 \
|
||||
critic.strategy=fsdp2 \
|
||||
actor_rollout_ref.actor.grad_clip=1.0 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.use_dynamic_bsz=True \
|
||||
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
|
||||
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
|
||||
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
|
||||
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@
|
||||
|
||||
elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then
|
||||
echo "Running fully async training with Megatron strategy..."
|
||||
# Megatron specific parameters
|
||||
gen_tp=2
|
||||
train_tp=1
|
||||
train_pp=2
|
||||
ref_offload=True
|
||||
actor_offload=False
|
||||
|
||||
python3 -m recipe.fully_async_policy.fully_async_main \
|
||||
--config-path=config \
|
||||
--config-name='fully_async_ppo_megatron_trainer.yaml' \
|
||||
"${common_params[@]}" \
|
||||
actor_rollout_ref.actor.strategy=megatron \
|
||||
critic.strategy=megatron \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \
|
||||
actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \
|
||||
actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \
|
||||
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
|
||||
actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \
|
||||
actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \
|
||||
actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@
|
||||
else
|
||||
echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "Fully async policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy"
|
||||
|
@ -48,6 +48,7 @@ CUDA_KEYWORD_CHECK_WHITELIST = [
|
||||
NCCL_KEYWORD_CHECK_WHITELIST = [
|
||||
"verl/utils/device.py",
|
||||
"verl/third_party/sglang/parallel_state.py", # appear in default backend
|
||||
"verl/recipe/fully_async_policy/param_sync.py", # fully_async_policy in default backend
|
||||
]
|
||||
|
||||
SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST
|
||||
|
@ -24,6 +24,7 @@ license_head_sglang = "Copyright 2023-2024 SGLang Team"
|
||||
license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates"
|
||||
license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates"
|
||||
license_head_facebook = "Copyright (c) 2016- Facebook, Inc"
|
||||
license_head_meituan = "Copyright 2025 Meituan Ltd. and/or its affiliates"
|
||||
license_headers = [
|
||||
license_head_bytedance,
|
||||
license_head_bytedance_25,
|
||||
@ -33,6 +34,7 @@ license_headers = [
|
||||
license_head_modelbest,
|
||||
license_head_amazon,
|
||||
license_head_facebook,
|
||||
license_head_meituan,
|
||||
]
|
||||
|
||||
|
||||
|
@ -12,10 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .agent_loop import AgentLoopBase, AgentLoopManager, AsyncLLMServerManager
|
||||
from .agent_loop import AgentLoopBase, AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager
|
||||
from .single_turn_agent_loop import SingleTurnAgentLoop
|
||||
from .tool_agent_loop import ToolAgentLoop
|
||||
|
||||
_ = [SingleTurnAgentLoop, ToolAgentLoop]
|
||||
|
||||
__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager"]
|
||||
__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager", "AgentLoopWorker"]
|
||||
|
@ -370,8 +370,7 @@ class RewardManagerWorker:
|
||||
return self.reward_manager(data, return_dict)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AgentLoopWorker:
|
||||
class AgentLoopWorkerBase:
|
||||
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
|
||||
|
||||
def __init__(
|
||||
@ -384,7 +383,11 @@ class AgentLoopWorker:
|
||||
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
|
||||
"""
|
||||
self.config = config
|
||||
self.server_manager = AsyncLLMServerManager(config, server_handles)
|
||||
|
||||
# for recipe to change
|
||||
if not hasattr(self, "server_manager"):
|
||||
self.server_manager = AsyncLLMServerManager(config, server_handles)
|
||||
|
||||
self.rm_executor = rm_executor
|
||||
|
||||
model_path = config.actor_rollout_ref.model.path
|
||||
@ -720,6 +723,22 @@ class AgentLoopWorker:
|
||||
)
|
||||
|
||||
|
||||
@ray.remote
|
||||
class AgentLoopWorker(AgentLoopWorkerBase):
|
||||
"""Agent loop worker takes a batch of messages and run each message in an agent loop."""
|
||||
|
||||
def __init__(
|
||||
self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None
|
||||
):
|
||||
"""Initialize agent loop manager.
|
||||
|
||||
Args:
|
||||
config (DictConfig): YAML config.
|
||||
server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles.
|
||||
"""
|
||||
super().__init__(config, server_handles, rm_executor)
|
||||
|
||||
|
||||
async def get_trajectory_info(step, index, validate):
|
||||
"""Get trajectory info.
|
||||
|
||||
@ -778,6 +797,12 @@ class AgentLoopManager:
|
||||
|
||||
self.rm_micro_batch_size = rm_wg.world_size
|
||||
|
||||
# for recipe to change
|
||||
if not hasattr(self, "rollout_replica_class"):
|
||||
self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
|
||||
if not hasattr(self, "agent_loop_workers_class"):
|
||||
self.agent_loop_workers_class = AgentLoopWorker
|
||||
|
||||
self._initialize_llm_servers()
|
||||
self._init_agent_loop_workers()
|
||||
|
||||
@ -798,11 +823,10 @@ class AgentLoopManager:
|
||||
)
|
||||
num_replicas = world_size // rollout_world_size
|
||||
|
||||
rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name)
|
||||
rollout_config = self.config.actor_rollout_ref.rollout
|
||||
model_config = self.config.actor_rollout_ref.model
|
||||
self.rollout_replicas = [
|
||||
rollout_replica_class(
|
||||
self.rollout_replica_class(
|
||||
replica_rank=replica_rank,
|
||||
config=rollout_config,
|
||||
model_config=model_config,
|
||||
@ -826,7 +850,7 @@ class AgentLoopManager:
|
||||
# Round-robin scheduling over the all nodes
|
||||
node_id = node_ids[i % len(node_ids)]
|
||||
self.agent_loop_workers.append(
|
||||
AgentLoopWorker.options(
|
||||
self.agent_loop_workers_class.options(
|
||||
name=f"agent_loop_worker_{i}",
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id, soft=True
|
||||
|
@ -39,4 +39,4 @@ entropy_from_logits_with_chunking: False
|
||||
entropy_checkpointing: False
|
||||
|
||||
# Whether to remove padding tokens in inputs during training
|
||||
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
|
||||
use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false}
|
@ -43,13 +43,14 @@ def main(config):
|
||||
|
||||
|
||||
# Define a function to run the PPO-like training process
|
||||
def run_ppo(config) -> None:
|
||||
def run_ppo(config, task_runner_class=None) -> None:
|
||||
"""Initialize Ray cluster and run distributed PPO training process.
|
||||
|
||||
Args:
|
||||
config: Training configuration object containing all necessary parameters
|
||||
for distributed PPO training including Ray initialization settings,
|
||||
model paths, and training hyperparameters.
|
||||
task_runner_class: For recipe to change TaskRunner.
|
||||
"""
|
||||
# Check if Ray is not initialized
|
||||
if not ray.is_initialized():
|
||||
@ -65,6 +66,9 @@ def run_ppo(config) -> None:
|
||||
print(f"ray init kwargs: {ray_init_kwargs}")
|
||||
ray.init(**OmegaConf.to_container(ray_init_kwargs))
|
||||
|
||||
if task_runner_class is None:
|
||||
task_runner_class = TaskRunner
|
||||
|
||||
# Create a remote instance of the TaskRunner class, and
|
||||
# Execute the `run` method of the TaskRunner instance remotely and wait for it to complete
|
||||
if (
|
||||
@ -79,9 +83,9 @@ def run_ppo(config) -> None:
|
||||
nsight_options = OmegaConf.to_container(
|
||||
config.global_profiler.global_tool_config.nsys.controller_nsight_options
|
||||
)
|
||||
runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote()
|
||||
runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote()
|
||||
else:
|
||||
runner = TaskRunner.remote()
|
||||
runner = task_runner_class.remote()
|
||||
ray.get(runner.run.remote(config))
|
||||
|
||||
# [Optional] get the path of the timeline trace file from the configuration, default to None
|
||||
|
@ -602,11 +602,9 @@ class RayPPOTrainer:
|
||||
sample_scores.extend(scores)
|
||||
|
||||
reward_extra_infos_dict["reward"].extend(scores)
|
||||
print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}")
|
||||
if "reward_extra_info" in result:
|
||||
for key, lst in result["reward_extra_info"].items():
|
||||
reward_extra_infos_dict[key].extend(lst)
|
||||
print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}")
|
||||
|
||||
# collect num_turns of each prompt
|
||||
if "__num_turns__" in test_batch.non_tensor_batch:
|
||||
@ -676,9 +674,9 @@ class RayPPOTrainer:
|
||||
actor_rollout_cls = RayClassWithInitArgs(
|
||||
cls=self.role_worker_mapping[Role.ActorRollout],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role="actor_rollout",
|
||||
role=str(Role.ActorRollout),
|
||||
)
|
||||
self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@ -687,7 +685,7 @@ class RayPPOTrainer:
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic)
|
||||
critic_cfg = omega_conf_to_dataclass(self.config.critic)
|
||||
critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg)
|
||||
self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls
|
||||
|
||||
# create reference policy if needed
|
||||
if self.use_reference_policy:
|
||||
@ -695,16 +693,16 @@ class RayPPOTrainer:
|
||||
ref_policy_cls = RayClassWithInitArgs(
|
||||
self.role_worker_mapping[Role.RefPolicy],
|
||||
config=self.config.actor_rollout_ref,
|
||||
role="ref",
|
||||
role=str(Role.RefPolicy),
|
||||
)
|
||||
self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls
|
||||
|
||||
# create a reward model if reward_fn is None
|
||||
if self.use_rm:
|
||||
# we create a RM here
|
||||
resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel)
|
||||
rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model)
|
||||
self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls
|
||||
self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls
|
||||
|
||||
# initialize WorkerGroup
|
||||
# NOTE: if you want to use a different resource pool for each role, which can support different parallel size,
|
||||
@ -739,20 +737,20 @@ class RayPPOTrainer:
|
||||
all_wg.update(spawn_wg)
|
||||
|
||||
if self.use_critic:
|
||||
self.critic_wg = all_wg["critic"]
|
||||
self.critic_wg = all_wg[str(Role.Critic)]
|
||||
self.critic_wg.init_model()
|
||||
|
||||
if self.use_reference_policy and not self.ref_in_actor:
|
||||
self.ref_policy_wg = all_wg["ref"]
|
||||
self.ref_policy_wg = all_wg[str(Role.RefPolicy)]
|
||||
self.ref_policy_wg.init_model()
|
||||
|
||||
self.rm_wg = None
|
||||
if self.use_rm:
|
||||
self.rm_wg = all_wg["rm"]
|
||||
self.rm_wg = all_wg[str(Role.RewardModel)]
|
||||
self.rm_wg.init_model()
|
||||
|
||||
# we should create rollout at the end so that vllm can have a better estimation of kv cache memory
|
||||
self.actor_rollout_wg = all_wg["actor_rollout"]
|
||||
self.actor_rollout_wg = all_wg[str(Role.ActorRollout)]
|
||||
self.actor_rollout_wg.init_model()
|
||||
|
||||
# create async rollout manager and request scheduler
|
||||
@ -800,11 +798,13 @@ class RayPPOTrainer:
|
||||
)
|
||||
|
||||
if self.use_critic:
|
||||
critic_local_path = os.path.join(local_global_step_folder, "critic")
|
||||
critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic))
|
||||
critic_remote_path = (
|
||||
None
|
||||
if self.config.trainer.default_hdfs_dir is None
|
||||
else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic")
|
||||
else os.path.join(
|
||||
self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", str(Role.Critic)
|
||||
)
|
||||
)
|
||||
self.critic_wg.save_checkpoint(
|
||||
critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep
|
||||
@ -860,7 +860,7 @@ class RayPPOTrainer:
|
||||
print(f"Resuming from {global_step_folder}")
|
||||
|
||||
actor_path = os.path.join(global_step_folder, "actor")
|
||||
critic_path = os.path.join(global_step_folder, "critic")
|
||||
critic_path = os.path.join(global_step_folder, str(Role.Critic))
|
||||
# load actor
|
||||
self.actor_rollout_wg.load_checkpoint(
|
||||
actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load
|
||||
@ -1129,7 +1129,7 @@ class RayPPOTrainer:
|
||||
|
||||
if self.use_reference_policy:
|
||||
# compute reference log_prob
|
||||
with marked_timer("ref", timing_raw, color="olive"):
|
||||
with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"):
|
||||
if not self.ref_in_actor:
|
||||
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
|
||||
else:
|
||||
|
@ -36,6 +36,37 @@ class Role(Enum):
|
||||
RewardModel = 5
|
||||
ActorRolloutRef = 6
|
||||
|
||||
def __str__(self):
|
||||
return self._get_role_string()
|
||||
|
||||
def _get_role_string(self):
|
||||
role_mapping = {
|
||||
Role.Actor: "actor",
|
||||
Role.Rollout: "rollout",
|
||||
Role.ActorRollout: "actor_rollout",
|
||||
Role.Critic: "critic",
|
||||
Role.RefPolicy: "ref",
|
||||
Role.RewardModel: "rm",
|
||||
Role.ActorRolloutRef: "actor_rollout_ref",
|
||||
}
|
||||
return role_mapping.get(self, self.name.lower())
|
||||
|
||||
@classmethod
|
||||
def from_string(cls, name: str):
|
||||
string_mapping = {
|
||||
"actor": cls.Actor,
|
||||
"rollout": cls.Rollout,
|
||||
"actor_rollout": cls.ActorRollout,
|
||||
"critic": cls.Critic,
|
||||
"ref": cls.RefPolicy,
|
||||
"rm": cls.RewardModel,
|
||||
"actor_rollout_ref": cls.ActorRolloutRef,
|
||||
}
|
||||
role = string_mapping.get(name.lower())
|
||||
if role is None:
|
||||
raise ValueError(f"No Role found for string: {name}")
|
||||
return role
|
||||
|
||||
|
||||
def need_reference_policy(
|
||||
role_worker_mapping: dict[Role, WorkerType],
|
||||
|
@ -78,7 +78,7 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
|
||||
self.compute_entropy_from_logits = (
|
||||
torch.compile(entropy_from_logits, dynamic=True)
|
||||
if self.config.get("use_torch_compile", True) # use torch compile by default
|
||||
if self.config.get("use_torch_compile", True) # use torch compile by default
|
||||
else entropy_from_logits
|
||||
)
|
||||
self.device_name = get_device_name()
|
||||
@ -427,10 +427,14 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
model_inputs, temperature=temperature, calculate_entropy=calculate_entropy
|
||||
)
|
||||
|
||||
if on_policy:
|
||||
old_log_prob = log_prob.detach()
|
||||
else:
|
||||
# for fully_async_policy recipe
|
||||
if hasattr(self.config, "use_rollout_log_probs") and self.config.use_rollout_log_probs:
|
||||
old_log_prob = model_inputs["old_log_probs"]
|
||||
else:
|
||||
if on_policy:
|
||||
old_log_prob = log_prob.detach()
|
||||
else:
|
||||
old_log_prob = model_inputs["old_log_probs"]
|
||||
|
||||
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
|
||||
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
|
||||
|
@ -231,6 +231,7 @@ class FSDPActorConfig(ActorConfig):
|
||||
fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig)
|
||||
use_remove_padding: bool = False
|
||||
profiler: ProfilerConfig = field(default_factory=ProfilerConfig)
|
||||
use_rollout_log_probs: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate FSDP actor configuration parameters."""
|
||||
|
@ -110,8 +110,7 @@ class ExternalZeroMQDistributedExecutor(Executor):
|
||||
return
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class vLLMHttpServer:
|
||||
class vLLMHttpServerBase:
|
||||
"""vLLM http server in single node, this is equivalent to launch server with command line:
|
||||
```
|
||||
vllm serve --tensor-parallel-size=8 ...
|
||||
@ -399,10 +398,42 @@ class vLLMHttpServer:
|
||||
await self.engine.wait_for_requests_to_drain()
|
||||
|
||||
|
||||
@ray.remote(num_cpus=1)
|
||||
class vLLMHttpServer(vLLMHttpServerBase):
|
||||
"""vLLM http server in single node, this is equivalent to launch server with command line:
|
||||
```
|
||||
vllm serve --tensor-parallel-size=8 ...
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: RolloutConfig | RewardModelConfig,
|
||||
model_config: HFModelConfig,
|
||||
rollout_mode: RolloutMode,
|
||||
workers: list[ActorHandle],
|
||||
replica_rank: int,
|
||||
node_rank: int,
|
||||
gpus_per_node: int,
|
||||
nnodes: int,
|
||||
):
|
||||
super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes)
|
||||
|
||||
|
||||
_rollout_worker_actor_cls = ray.remote(vLLMAsyncRollout)
|
||||
|
||||
|
||||
class vLLMReplica(RolloutReplica):
|
||||
def __init__(
|
||||
self,
|
||||
replica_rank: int,
|
||||
config: RolloutConfig | RewardModelConfig,
|
||||
model_config: HFModelConfig,
|
||||
gpus_per_node: int = 8,
|
||||
):
|
||||
super().__init__(replica_rank, config, model_config, gpus_per_node)
|
||||
self.server_class = vLLMHttpServer
|
||||
|
||||
def get_ray_class_with_init_args(self) -> RayClassWithInitArgs:
|
||||
"""Get rollout worker actor class for colocated and standalone mode."""
|
||||
worker_dict_cls = RayClassWithInitArgs(
|
||||
@ -437,7 +468,7 @@ class vLLMReplica(RolloutReplica):
|
||||
for node_rank in range(nnodes):
|
||||
workers = self.workers[node_rank * gpus_per_node : (node_rank + 1) * gpus_per_node]
|
||||
node_id = worker_node_ids[node_rank * gpus_per_node]
|
||||
server = vLLMHttpServer.options(
|
||||
server = self.server_class.options(
|
||||
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
|
||||
node_id=node_id,
|
||||
soft=False,
|
||||
|
Reference in New Issue
Block a user