From b25bb7d4f35628fc0a335aae40a0f64a85475087 Mon Sep 17 00:00:00 2001 From: arron Date: Fri, 17 Oct 2025 22:29:18 +0800 Subject: [PATCH] [trainer, recipe] feat: fully async training recipe (#2981) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### What does this PR do? To implement a purely asynchronous training workflow, we further split the training process into a Trainer and a Rollouter based on the existing one-step-off policy code, with samples transmitted via a message queue. We will continue to integrate partial rollout to mitigate the impact of long-tail training. > Add **concise** overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review. https://github.com/volcengine/verl/pull/2231 https://github.com/volcengine/verl/pull/2200 ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: ... - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).) --------- Co-authored-by: meituan-search Co-authored-by: wangshulin02 Co-authored-by: arron Co-authored-by: wangshulin02 <953550366@qq.com> Co-authored-by: hadoop-ai-search Co-authored-by: sl-1314 <82856253+sl-1314@users.noreply.github.com> Co-authored-by: arron Co-authored-by: arron --- .github/workflows/e2e_fully_async_policy.yml | 149 ++++ docs/advance/fully_async.md | 428 ++++++++++++ docs/index.rst | 1 + recipe/fully_async_policy/README.md | 428 ++++++++++++ recipe/fully_async_policy/README_zh.md | 373 ++++++++++ .../fully_async_policy/agent_loop/__init__.py | 19 + .../agent_loop/agent_loop.py | 275 ++++++++ .../partial_single_turn_agent_loop.py | 92 +++ .../config/fully_async_ppo_trainer.yaml | 55 ++ recipe/fully_async_policy/detach_utils.py | 474 +++++++++++++ recipe/fully_async_policy/fsdp_workers.py | 136 ++++ recipe/fully_async_policy/fully_async_main.py | 306 +++++++++ .../fully_async_rollouter.py | 646 ++++++++++++++++++ .../fully_async_policy/fully_async_trainer.py | 360 ++++++++++ recipe/fully_async_policy/message_queue.py | 265 +++++++ recipe/fully_async_policy/param_sync.py | 105 +++ recipe/fully_async_policy/ray_trainer.py | 528 ++++++++++++++ .../shell/dapo_7b_math_fsdp2_16-16.sh | 162 +++++ .../shell/dapo_7b_math_fsdp2_32_32.sh | 162 +++++ .../shell/dapo_7b_math_fsdp2_4_12.sh | 164 +++++ .../shell/dapo_7b_math_fsdp2_4_4.sh | 164 +++++ .../shell/dapo_7b_math_fsdp2_64_64.sh | 162 +++++ .../shell/dapo_7b_math_fsdp2_8_8.sh | 162 +++++ .../fully_async_policy/shell/runtime_env.yaml | 4 + .../unittest/simple_streaming_demo.py | 176 +++++ .../vllm_rollout/__init__.py | 13 + .../vllm_rollout/vllm_async_server.py | 154 +++++ tests/special_e2e/run_fully_async_policy.sh | 196 ++++++ .../special_sanity/check_device_api_usage.py | 1 + tests/special_sanity/check_license.py | 2 + verl/experimental/agent_loop/__init__.py | 4 +- verl/experimental/agent_loop/agent_loop.py | 36 +- verl/trainer/config/actor/dp_actor.yaml | 2 +- verl/trainer/main_ppo.py | 10 +- verl/trainer/ppo/ray_trainer.py | 32 +- verl/trainer/ppo/utils.py | 31 + verl/workers/actor/dp_actor.py | 12 +- verl/workers/config/actor.py | 1 + .../rollout/vllm_rollout/vllm_async_server.py | 37 +- 39 files changed, 6292 insertions(+), 35 deletions(-) create mode 100644 .github/workflows/e2e_fully_async_policy.yml create mode 100644 docs/advance/fully_async.md create mode 100644 recipe/fully_async_policy/README.md create mode 100644 recipe/fully_async_policy/README_zh.md create mode 100644 recipe/fully_async_policy/agent_loop/__init__.py create mode 100644 recipe/fully_async_policy/agent_loop/agent_loop.py create mode 100644 recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py create mode 100644 recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml create mode 100644 recipe/fully_async_policy/detach_utils.py create mode 100644 recipe/fully_async_policy/fsdp_workers.py create mode 100644 recipe/fully_async_policy/fully_async_main.py create mode 100644 recipe/fully_async_policy/fully_async_rollouter.py create mode 100644 recipe/fully_async_policy/fully_async_trainer.py create mode 100644 recipe/fully_async_policy/message_queue.py create mode 100644 recipe/fully_async_policy/param_sync.py create mode 100644 recipe/fully_async_policy/ray_trainer.py create mode 100644 recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh create mode 100644 recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh create mode 100644 recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh create mode 100644 recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh create mode 100644 recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh create mode 100644 recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh create mode 100644 recipe/fully_async_policy/shell/runtime_env.yaml create mode 100644 recipe/fully_async_policy/unittest/simple_streaming_demo.py create mode 100644 recipe/fully_async_policy/vllm_rollout/__init__.py create mode 100644 recipe/fully_async_policy/vllm_rollout/vllm_async_server.py create mode 100644 tests/special_e2e/run_fully_async_policy.sh diff --git a/.github/workflows/e2e_fully_async_policy.yml b/.github/workflows/e2e_fully_async_policy.yml new file mode 100644 index 000000000..e2cf0d1c0 --- /dev/null +++ b/.github/workflows/e2e_fully_async_policy.yml @@ -0,0 +1,149 @@ +# # Tests layout + +# Each folder under tests/ corresponds to a test category for a sub-namespace in verl. For instance: +# - `tests/trainer` for testing functionality related to `verl/trainer` +# - `tests/models` for testing functionality related to `verl/models` +# - ... + +# There are a few folders with `special_` prefix, created for special purposes: +# - `special_distributed`: unit tests that must run with multiple GPUs +# - `special_e2e`: end-to-end tests with training/generation scripts +# - `special_npu`: tests for NPUs +# - `special_sanity`: a suite of quick sanity tests +# - `special_standalone`: a set of test that are designed to run in dedicated environments + +# Accelerators for tests +# - By default tests are run with GPU available, except for the ones under `special_npu`, and any test script whose name ends with `on_cpu.py`. +# - For test scripts with `on_cpu.py` name suffix would be tested on CPU resources in linux environment. + +# # Workflow layout + +# All CI tests are configured by yaml files in `.github/workflows/`. Here's an overview of all test configs: +# 1. A list of always triggered CPU sanity tests: `check-pr-title.yml`, `secrets_scan.yml`, `check-pr-title,yml`, `pre-commit.yml`, `doc.yml` +# 2. Some heavy multi-GPU unit tests, such as `model.yml`, `vllm.yml`, `sgl.yml` +# 3. End-to-end tests: `e2e_*.yml` +# 4. Unit tests +# - `cpu_unit_tests.yml`, run pytest on all scripts with file name pattern `tests/**/test_*_on_cpu.py` +# - `gpu_unit_tests.yml`, run pytest on all scripts with file without the `on_cpu.py` suffix. +# - Since cpu/gpu unit tests by default runs all tests under `tests`, please make sure tests are manually excluded in them when +# - new workflow yaml is added to `.github/workflows` +# - new tests are added to workflow mentioned in 2. + + +name: e2e_fully_async_policy + +on: + # Trigger the workflow on push or pull request, + # but only for the main branch + # For push, for now only anti-patterns are specified so it is more conservative + # and achieves higher coverage. + push: + branches: + - main + - v0.* + paths: + - "**/*.py" + - "!**/*.md" + - "!**/*.sh" + # Other entrypoints + - "!examples/*trainer*" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + - "!recipe/**" + - "recipe/fully_async_policy" + pull_request: + branches: + - main + - v0.* + paths: + - "**/*.py" + - "!**/*.md" + - "!**/*.sh" + # Other entrypoints + - "!examples/**" + - "!tests/**" + - "!verl/trainer/main_*.py" + - "!verl/trainer/fsdp_sft_trainer.py" + # Other recipes + - "!recipe/**" + # Home + - "recipe/fully_async_policy" + # Entrypoints + - ".github/workflows/e2e_fully_async_policy.yml" + - "examples/data_preprocess/gsm8k.py" + - "tests/special_e2e/run_fully_async_policy.sh" + +# Cancel jobs on the same ref if a new one is triggered +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + +# Declare permissions just read content. +permissions: + contents: read + +env: + IMAGE: "verl-ci-cn-beijing.cr.volces.com/verlai/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.13.0-te2.2" + DYNAMIC_RUNNER_ENDPOINT: "https://sd10g3clalm04ug7alq90.apigateway-cn-beijing.volceapi.com/runner" + TRANSFORMERS_VERSION: "4.56.2" + +jobs: + setup: + if: github.repository_owner == 'volcengine' + runs-on: ubuntu-latest + outputs: + runner-label: ${{ steps.create-runner.outputs.runner-label }} + mlp-task-id: ${{ steps.create-runner.outputs.mlp-task-id }} + steps: + - uses: actions/checkout@v4 + - id: create-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "create" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-image: "${{ env.IMAGE }}" + + # Test FSDP2 strategy + e2e_fully_async_policy_fsdp2: + needs: setup + runs-on: [ "${{ needs.setup.outputs.runner-label || 'L20x8' }}" ] + timeout-minutes: 10 # Increase timeout for async training + env: + HTTP_PROXY: ${{ secrets.PROXY_HTTP }} + HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }} + NO_PROXY: "localhost,127.0.0.1,hf-mirror.com" + HF_ENDPOINT: "https://hf-mirror.com" + HF_HUB_ENABLE_HF_TRANSFER: "0" # This is more stable + ACTOR_STRATEGY: "fsdp2" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + with: + fetch-depth: 0 + - name: Install the current repository + run: | + pip3 install --no-deps -e .[test,gpu] + pip3 install transformers==$TRANSFORMERS_VERSION + - name: Prepare GSM8K dataset + run: | + python3 examples/data_preprocess/gsm8k.py --local_dataset_path ${HOME}/models/hf_data/gsm8k + - name: Running the E2E test with fully_async_policy algorithm (FSDP2) + run: | + ray stop --force + bash tests/special_e2e/run_fully_async_policy.sh + + cleanup: + runs-on: ubuntu-latest + needs: + [ + setup, + e2e_fully_async_policy_fsdp2 + ] + if: always() + steps: + - id: destroy-runner + uses: volcengine/vemlp-github-runner@v1 + with: + mode: "destroy" + faas-url: "${{ env.DYNAMIC_RUNNER_ENDPOINT }}" + mlp-task-id: "${{ needs.setup.outputs.mlp-task-id }}" \ No newline at end of file diff --git a/docs/advance/fully_async.md b/docs/advance/fully_async.md new file mode 100644 index 000000000..a3ad5e5cf --- /dev/null +++ b/docs/advance/fully_async.md @@ -0,0 +1,428 @@ +# Recipe: Fully Async Policy Async Trainer + +**Author:** `https://github.com/meituan-search` + +Last updated: 10/17/2025. + +This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter, +supporting asynchronous sample generation and training. +Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs, +without significantly affecting the results. + +## Introduction + +### Background + +The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more +flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training +efficiency caused by long-tail problems. +The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by +designing a separated architecture and performing asynchronous training between rollout and train for one round. +However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot +completely eliminate the impact of long-tail on training efficiency. +In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have +been implemented based on the separated architecture and have achieved gains. +We借鉴 their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial +rollout training. +By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy +can significantly improve training efficiency. + +> Magistral https://arxiv.org/abs/2506.10910 +> +> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language +> Reasoning https://arxiv.org/abs/2505.24298 +> +> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream +> Generation https://arxiv.org/abs/2504.15930 +> +> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663 +> + +### Core Contributions + +* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to + specify the resources they occupy separately. +* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples. +* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to + multiple steps, making the asynchronous solution more flexible. +* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter + and Trainer. +* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single + sample as the minimum transmission unit. +* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it + supports training with samples generated by old parameters. +* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter + synchronization, by adding `sleep() and resume()` logic, it + saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for + ongoing tasks to finish during parameter synchronization. + +Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop. + +## Design + +The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four +parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer. + +![fully_async_policy_structure]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true) + +1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the + production speed controlled by freshness. +2. MessageQueue is used to temporarily store samples generated by Rollouter. +3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size` + samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers + a parameter synchronization with Rollouter. +4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability. + +The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for +rollout cannot solve the idleness caused by long-tail samples. +After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources +are used), +but the overlap in their time consumption reduces the end-to-end time consumption. + +![fully_async_policy_revenue]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true) + +## Usage + +### Parameter Description + +| super params | implication | +|-----------------------------------------------|------------------------------------------------------------------------------------------------| +| `trainer.nnodes` | Number of nodes for Trainer | +| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer | +| `rollout.nnodes` | Number of nodes for Rollouter | +| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter | +| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) | +| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) | +| `rollout.total_rollout_steps` | Total number of rollout samples | +| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation | +| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus | +| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once | +| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization | +| `async_training.staleness_threshold` | Freshness control | +| `async_training.partial_rollout` | Whether to perform partial_rollout | +| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout | + +**Further Explanation:** + +* `rollout.total_rollout_steps` + + Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step: + `rollout.total_rollout_steps = data.train_batch_size * step`. + +* `async_training.trigger_parameter_sync_step` + + In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches + `require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter. + Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process + `trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples. + To fairly compare speed with colocate, trigger_parameter_sync_step should be set to + `data.train_batch_size / (require_batches * ppo_mini_batch_size)`. + +* `async_training.staleness_threshold` + + In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used. + + * staleness_threshold=0, indicates synchronous training. + Rollouter will generate a fixed number of samples between two parameter updates, the sample count is: + $$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$ + * staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous + calls. + Rollouter will generate at most the following number of samples between two parameter updates: + $$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$ + + num_staleness_sample represents the number of stale samples generated in excess during the last rollout. + + Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower, + trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples. + When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy. + To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1. + +* `async_training.partial_rollout` + + partial_rollout only actually takes effect when staleness_threshold>0. + +* `async_training.use_rollout_log_probs` + + In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to + the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling, + old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm + correctness. In the fully + async strategy, we default to old_log_prob being calculated by rollout rather than by trainer. + + * `async_training.require_batches` + + In streaming training, require_batches should be set to 1, indicating that training is performed after producing + enough ppo_mini_batch_size samples. + In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can + cause training instability and longer response lengths. + Here, we additionally provide require_batches for streaming distribution and control the number of samples + participating in training at once. + +### Supported Modes + +1. on policy pipeline: + 1. **trigger_parameter_sync_step=1, staleness_threshold=0** + 2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for + training, and after training completes, Trainer and Rollouter perform a parameter synchronization; + 3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill + idle resources, causing some resource waste. + 4. As shown in figure a; + +2. stream off policy pipeline: + 1. **trigger_parameter_sync_step>1, staleness_threshold=0** + 2. Synchronous streaming training will be performed. Rollouter produces + `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local + training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training + trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization; + 3. Compared to a, since more samples are generated at once, resource idleness will be lower. + 4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples, + train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter + update, rollout waits for training to complete. + 5. As shown in figure b; + +3. async stream pipeline with stale samples: + 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False** + 2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number + of samples generated may be less than this value depending on rollout speed). + 3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples + before parameter synchronization for immediate use by Trainer after synchronization. + When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete + and not add new tasks; + 4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the + first batch rollout to finish, but will have the time to wait for active tasks to finish. + 5. As shown in figure c; + +4. async stream pipeline with partial rollout: + 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True** + 2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will + interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be + generated after synchronization. This reduces the time to wait for active tasks to finish. + 3. As shown in figure d; + +![fully_async_policy_mode]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true) + +### Key Metrics + +| metrics | implication | +|------------------------------------------------|--------------------------------------------------------------------------------------------------------| +| `trainer/idle_ratio` | Trainer idle rate | +| `rollouter/idle_ratio` | Rollouter idle rate | +| `fully_async/count/stale_samples_processed` | Total number of old samples used in training | +| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) | +| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step | +| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step | +| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step | + +### Parameter Tuning Recommendations + +* Resource Allocation and Adjustment: + * Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource + allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire + training process, + avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource + allocation can be adjusted based on the idle time of rollout and train during actual training, + which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and + trainer/idle_ratio is low, + Trainer resources should be increased and Rollouter resources should be reduced, and vice versa. + +* Key Parameters: + * staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It + is recommended to set it to less than 1. + * require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and + the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample + processing; + * trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent + parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in + low resource utilization. + The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy. + * rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small. + +* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at + different levels, suitable for tasks in different scenarios. + * For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed + requirements, the on policy pipeline mode (Mode 1) can be tried. + * For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy + pipeline mode can be tried. That is, by + setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization + mechanism (staleness_threshold=0) (Mode 2). + * For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and + staleness, setting staleness_threshold> + 0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4). + +### Quick Start + +```shell +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=10 +staleness_threshold=0 +trigger_parameter_sync_step=16 +partial_rollout=False + + +python -m recipe.fully_async_policy.fully_async_main \ + train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.partial_rollout="${partial_rollout}" +``` + +## Experiments + +### Asynchronous Training on 7B Model + +We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources. +Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards, +64 cards, and 128 cards without significantly affecting experimental results. + +* Machine: H20 +* Model: Qwen2.5-Math-7B +* Rollout length: max_response_length FSDP2: 28K tokens; +* Algorithm: DAPO +* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet +* Engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colocate sync: + * step: 400 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*400 + * require_batches: 4 + * trigger_parameter_sync_step: 4 + * staleness_threshold: 0.3 + * partial_rollout: True + +| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:| +| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 | +| fully_async_policy | 16:16 | | | \ | | | | | | max:
last: | +| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 | +| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 | +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.65x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg + +### 128-card 7B Asynchronous Mode Experiment + +We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async. +We can see that the benefit brought by streaming is approximately 0.6x, and after combining staleness and +partial_rollout, the benefit reaches 2.35x. + +| mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| `async stream pipeline with stale samples`
(+staleness_threshold=0.5) | | | | | | | | | | +| `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128-card Stale Ablation Experiment + +Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training +efficiency. +We found that the larger the staleness, the more obvious the final gains. +We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps +increase, the response length changes significantly, causing training instability. +Further analysis and optimization are needed for this issue. + +| staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 | +| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 | +| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128-card 7B require_batches Ablation Experiment + +In multiple tests, we found that the number of samples issued each time in streaming affects the response length during +training, which in turn affects training time. We verified the impact on results by modifying +`async_training.require_batches`. + +| require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 | +|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 | +| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 | +| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg + +### 30B Model Mode Experiment + +TODO: The 30B experiment is still in progress. + +* Machine: H20 +* Model: Qwen2.5-32B~~~~ +* Rollout length: max_response_length FSDP2: 20K tokens; +* Algorithm: DAPO +* Engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colocate sync: + * step:200 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*200 + * trigger_parameter_sync_step: 512/32 = 16 + * staleness_threshold: 0 + * partial_rollout: False + +| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | +|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------| +| colocate sync | 128 | | | | | | | | +| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | | +| fully_async_policy | 64:64 | async stream pipeline with stale samples | | | | | | | +| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | | + + +## Future Plans + +* GRPO experiments +* Megatron adaptation +* SGLang integration +* Transfer queue integration +* Asynchronous parameter synchronization +* AReaL asynchronous algorithm implementation +* TPPO algorithm implementation +* Multi-turn and Tool support \ No newline at end of file diff --git a/docs/index.rst b/docs/index.rst index 68e37545d..e8467dc96 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -124,6 +124,7 @@ verl is fast with: advance/rollout_is_migration.md advance/one_step_off advance/agent_loop + advance/fully_async .. toctree:: :maxdepth: 1 diff --git a/recipe/fully_async_policy/README.md b/recipe/fully_async_policy/README.md new file mode 100644 index 000000000..a3ad5e5cf --- /dev/null +++ b/recipe/fully_async_policy/README.md @@ -0,0 +1,428 @@ +# Recipe: Fully Async Policy Async Trainer + +**Author:** `https://github.com/meituan-search` + +Last updated: 10/17/2025. + +This document introduces a fully asynchronous PPO training system that completely decouples the Trainer and Rollouter, +supporting asynchronous sample generation and training. +Under this system, we achieved a 2.35x-2.67x performance improvement when training the Qwen2.5-7B model with 128 GPUs, +without significantly affecting the results. + +## Introduction + +### Background + +The separated rollout and train architecture, compared to the colocate architecture, can allocate resources more +flexibly and design more flexible training logic, thereby addressing issues such as low GPU utilization and training +efficiency caused by long-tail problems. +The one_step_off_policy alleviates the problem of long rollout times and achieves some gains in training efficiency by +designing a separated architecture and performing asynchronous training between rollout and train for one round. +However, it forcibly uses data from one round of asynchronous training, which is not flexible enough and cannot +completely eliminate the impact of long-tail on training efficiency. +In other frameworks such as AReaL, Magistral, StreamRL, and AsyncFlow, asynchronous training and streaming training have +been implemented based on the separated architecture and have achieved gains. +We借鉴 their methods and implemented them in VERL. The fully_async_policy supports asynchronous, streaming, and partial +rollout training. +By reasonably setting parameters such as resource allocation and parameter synchronization frequency, fully_async_policy +can significantly improve training efficiency. + +> Magistral https://arxiv.org/abs/2506.10910 +> +> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language +> Reasoning https://arxiv.org/abs/2505.24298 +> +> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream +> Generation https://arxiv.org/abs/2504.15930 +> +> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663 +> + +### Core Contributions + +* **Resource Isolation**: Unlike using hybrid_engine, Rollouter and Trainer use separate computing resources and need to + specify the resources they occupy separately. +* **Parallel Generation and Training**: While the Trainer is training, the Rollouter is generating new samples. +* **Multi-step Asynchronous**: Compared to one step off policy, it supports asynchronous settings from 0.x steps to + multiple steps, making the asynchronous solution more flexible. +* **NCCL Parameter Synchronization**: Uses NCCL communication primitives for parameter communication between Rollouter + and Trainer. +* **Stream Inference and Training**: Rollouter generates data sample by sample, and data transmission uses a single + sample as the minimum transmission unit. +* **Asynchronous Training and Freshness Control**: By setting the parameter async_training.staleness_threshold, it + supports training with samples generated by old parameters. +* **PartialRollout**: The Rollouter's inference process supports partial rollout logic. During parameter + synchronization, by adding `sleep() and resume()` logic, it + saves samples from ongoing rollouts and continues using them in the next rollout, reducing the time spent waiting for + ongoing tasks to finish during parameter synchronization. + +Currently, the supported usage mode is fsdp+vllm. vllm must use the server mode based on AgentLoop. + +## Design + +The overall architecture of fully_async_policy is shown in the figure below. fully_async_policy mainly consists of four +parts: Rollouter, MessageQueue, Trainer, and ParameterSynchronizer. + +![fully_async_policy_structure]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true) + +1. Rollouter generates sequences sample by sample and puts the generated samples into the MessageQueue, with the + production speed controlled by freshness. +2. MessageQueue is used to temporarily store samples generated by Rollouter. +3. Trainer fetches samples from MessageQueue sample by sample. After fetching `require_batches*ppo_mini_batch_size` + samples, it will perform training. After training for async_training.trigger_parameter_sync_step rounds, it triggers + a parameter synchronization with Rollouter. +4. ParameterSynchronizer implements the NCCL synchronous parameter synchronization capability. + +The source of benefits compared to the base scheme lies in the fact that in the colocate case, using more resources for +rollout cannot solve the idleness caused by long-tail samples. +After we perform resource isolation, the time for rollout and train may be longer than before (because fewer resources +are used), +but the overlap in their time consumption reduces the end-to-end time consumption. + +![fully_async_policy_revenue]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true) + +## Usage + +### Parameter Description + +| super params | implication | +|-----------------------------------------------|------------------------------------------------------------------------------------------------| +| `trainer.nnodes` | Number of nodes for Trainer | +| `trainer.n_gpus_per_node` | Number of GPUs per node for Trainer | +| `rollout.nnodes` | Number of nodes for Rollouter | +| `rollout.n_gpus_per_node` | Number of GPUs per node for Rollouter | +| `data.train_batch_size` | In the fully async strategy, this value is not effective (default is 0) | +| `data.gen_batch_size` | In the fully async strategy, uses streaming sample production logic (default is 1) | +| `rollout.total_rollout_steps` | Total number of rollout samples | +| `rollout.test_freq` | How many times Rollouter updates parameters before performing a validation | +| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus | +| `async_training.require_batches` | Number of ppo_mini_batch_size that FullyAsyncTrainer fetches at once | +| `async_training.trigger_parameter_sync_step` | Indicates how many local updates FullyAsyncTrainer performs before a parameter synchronization | +| `async_training.staleness_threshold` | Freshness control | +| `async_training.partial_rollout` | Whether to perform partial_rollout | +| `async_training.use_rollout_log_probs` | Use log_probs generated by rollout | + +**Further Explanation:** + +* `rollout.total_rollout_steps` + + Compared to colocate, the quantity can be aligned by multiplying train_batch_size and step: + `rollout.total_rollout_steps = data.train_batch_size * step`. + +* `async_training.trigger_parameter_sync_step` + + In the fully async strategy, it indicates how many local updates the Trainer performs (i.e., how many times it fetches + `require_batches * ppo_mini_batch_size` samples) before a parameter synchronization with Rollouter. + Between every two parameter synchronizations between Rollouter and Trainer, the Trainer will process + `trigger_parameter_sync_step* require_batches*ppo_mini_batch_size` samples. + To fairly compare speed with colocate, trigger_parameter_sync_step should be set to + `data.train_batch_size / (require_batches * ppo_mini_batch_size)`. + +* `async_training.staleness_threshold` + + In the fully async strategy, it indicates the maximum proportion of stale samples allowed to be used. + + * staleness_threshold=0, indicates synchronous training. + Rollouter will generate a fixed number of samples between two parameter updates, the sample count is: + $$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$ + * staleness_threshold>0, indicates asynchronous training, can be set to a decimal for more flexible asynchronous + calls. + Rollouter will generate at most the following number of samples between two parameter updates: + $$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$ + + num_staleness_sample represents the number of stale samples generated in excess during the last rollout. + + Since it's a streaming system, rollout continues to generate and trainer continues to consume. If rollouter is slower, + trainer will trigger parameter synchronization earlier, and rollouter will not actually produce rollout_num samples. + When rollout is fast enough, setting staleness_threshold to 1 is basically equivalent to one_step_off policy. + To avoid too many expired samples affecting training accuracy, it is recommended to set this value to less than 1. + +* `async_training.partial_rollout` + + partial_rollout only actually takes effect when staleness_threshold>0. + +* `async_training.use_rollout_log_probs` + + In reinforcement learning algorithms, log_probs have implicit correlations with parameter versions and tokens. Due to + the settings of algorithms like PPO/GRPO/DAPO, when calculating importance sampling, + old_log_prob must use the log_probs corresponding to the rollout parameters and tokens to ensure algorithm + correctness. In the fully + async strategy, we default to old_log_prob being calculated by rollout rather than by trainer. + + * `async_training.require_batches` + + In streaming training, require_batches should be set to 1, indicating that training is performed after producing + enough ppo_mini_batch_size samples. + In actual testing, we found that if fewer samples are issued at once, due to the order of data distribution, it can + cause training instability and longer response lengths. + Here, we additionally provide require_batches for streaming distribution and control the number of samples + participating in training at once. + +### Supported Modes + +1. on policy pipeline: + 1. **trigger_parameter_sync_step=1, staleness_threshold=0** + 2. Rollouter produces `require_batches*ppo_mini_batch_size` samples at once, Trainer fetches these samples for + training, and after training completes, Trainer and Rollouter perform a parameter synchronization; + 3. During the rollout phase, if there are long-tail samples but few rollout samples, shorter samples cannot fill + idle resources, causing some resource waste. + 4. As shown in figure a; + +2. stream off policy pipeline: + 1. **trigger_parameter_sync_step>1, staleness_threshold=0** + 2. Synchronous streaming training will be performed. Rollouter produces + `require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` samples at once, Trainer performs a local + training every time it fetches `require_batches*ppo_mini_batch_size` samples, and after training + trigger_parameter_sync_step times, Trainer and Rollouter perform a parameter synchronization; + 3. Compared to a, since more samples are generated at once, resource idleness will be lower. + 4. In one step training, there will be two periods of resource idleness: when fetching the first batch of samples, + train waits for `require_batches*ppo_mini_batch_size` samples to be produced, and during the last parameter + update, rollout waits for training to complete. + 5. As shown in figure b; + +3. async stream pipeline with stale samples: + 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=False** + 2. After each parameter update, Rollouter will plan to produce at most rollout_num samples (in practice, the number + of samples generated may be less than this value depending on rollout speed). + 3. If the rollout process is relatively fast, Rollouter will generate some additional samples num_stale_samples + before parameter synchronization for immediate use by Trainer after synchronization. + When triggering parameter synchronization, if Rollouter has ongoing tasks, it will wait for the tasks to complete + and not add new tasks; + 4. Compared to b, except for the first step training, subsequent training will not have the time to wait for the + first batch rollout to finish, but will have the time to wait for active tasks to finish. + 5. As shown in figure c; + +4. async stream pipeline with partial rollout: + 1. **trigger_parameter_sync_step>=1, staleness_threshold>0, partial_rollout=True** + 2. Compared to c, when triggering parameter synchronization, if Rollouter has samples being produced, it will + interrupt the rollout process and perform parameter synchronization. The interrupted samples will continue to be + generated after synchronization. This reduces the time to wait for active tasks to finish. + 3. As shown in figure d; + +![fully_async_policy_mode]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true) + +### Key Metrics + +| metrics | implication | +|------------------------------------------------|--------------------------------------------------------------------------------------------------------| +| `trainer/idle_ratio` | Trainer idle rate | +| `rollouter/idle_ratio` | Rollouter idle rate | +| `fully_async/count/stale_samples_processed` | Total number of old samples used in training | +| `fully_async/count/stale_trajectory_processed` | Total number of old trajectories used in training (one sample produces rollout.n trajectories) | +| `fully_async/partial/total_partial_num` | Number of partial samples processed by Trainer between two trigger_parameter_sync_step | +| `fully_async/partial/partial_ratio` | Ratio of partial samples processed by Trainer between two trigger_parameter_sync_step | +| `fully_async/partial/max_partial_span` | Maximum parameter span of partial samples processed by Trainer between two trigger_parameter_sync_step | + +### Parameter Tuning Recommendations + +* Resource Allocation and Adjustment: + * Reasonable resource allocation is the prerequisite for achieving good training efficiency. The ideal resource + allocation should make the rollout time and train time close, thereby minimizing pipeline bubbles in the entire + training process, + avoiding resource idleness, and ensuring Trainer does not use old samples. In real training scenarios, resource + allocation can be adjusted based on the idle time of rollout and train during actual training, + which can be obtained from rollouter/idle_ratio and trainer/idle_ratio. If rollouter/idle_ratio is high and + trainer/idle_ratio is low, + Trainer resources should be increased and Rollouter resources should be reduced, and vice versa. + +* Key Parameters: + * staleness_threshold: Setting it too high will cause more old samples to be used, affecting model performance. It + is recommended to set it to less than 1. + * require_batches: The closer to 1, the closer to a pure streaming process, the smaller the training bubbles, and + the faster the acceleration effect that can be achieved in terms of speed, but it will affect the order of sample + processing; + * trigger_parameter_sync_step: The smaller the setting, the closer to on policy, but it will cause frequent + parameter synchronization. Long-tail samples waste resources that cannot be filled by short samples, resulting in + low resource utilization. + The larger the setting, the higher the computational efficiency, but the accuracy will be affected by off policy. + * rollout.test_freq: It will occupy Rollouter resources and is not recommended to be set too small. + +* Mode Selection: By adjusting different parameters, the Fully Async architecture supports optimization acceleration at + different levels, suitable for tasks in different scenarios. + * For small-scale tasks that need to ensure training stability and on-policy nature, and have low speed + requirements, the on policy pipeline mode (Mode 1) can be tried. + * For scenarios that need to improve training throughput but are sensitive to staleness, the stream off policy + pipeline mode can be tried. That is, by + setting trigger_parameter_sync_step>1 to improve training efficiency, but still maintaining the synchronization + mechanism (staleness_threshold=0) (Mode 2). + * For large-scale tasks with high training speed requirements and can tolerate a certain degree of off-policy and + staleness, setting staleness_threshold> + 0 and partial_rollout=True can improve training efficiency, using the async stream pipeline mode (Mode 3 or 4). + +### Quick Start + +```shell +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=10 +staleness_threshold=0 +trigger_parameter_sync_step=16 +partial_rollout=False + + +python -m recipe.fully_async_policy.fully_async_main \ + train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.partial_rollout="${partial_rollout}" +``` + +## Experiments + +### Asynchronous Training on 7B Model + +We used Qwen2.5-Math-7B to verify the benefits of the fully async strategy under long candidates and multiple resources. +Using the `async stream pipeline with stale samples` strategy, we achieved about 2x performance improvement on 32 cards, +64 cards, and 128 cards without significantly affecting experimental results. + +* Machine: H20 +* Model: Qwen2.5-Math-7B +* Rollout length: max_response_length FSDP2: 28K tokens; +* Algorithm: DAPO +* Dataset: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet +* Engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colocate sync: + * step: 400 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*400 + * require_batches: 4 + * trigger_parameter_sync_step: 4 + * staleness_threshold: 0.3 + * partial_rollout: True + +| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:| +| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 | +| fully_async_policy | 16:16 | | | \ | | | | | | max:
last: | +| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 | +| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 | +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.65x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg + +### 128-card 7B Asynchronous Mode Experiment + +We used Qwen2.5-Math-7B to verify the effects of various modes supported by fully async. +We can see that the benefit brought by streaming is approximately 0.6x, and after combining staleness and +partial_rollout, the benefit reaches 2.35x. + +| mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| `async stream pipeline with stale samples`
(+staleness_threshold=0.5) | | | | | | | | | | +| `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128-card Stale Ablation Experiment + +Under the `async stream pipeline with partial rollout` mode, we verified the impact of staleness settings on training +efficiency. +We found that the larger the staleness, the more obvious the final gains. +We also noticed that the times for staleness values of 0.3 and 0.5 are quite close, because as the training steps +increase, the response length changes significantly, causing training instability. +Further analysis and optimization are needed for this issue. + +| staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 | +| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 | +| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128-card 7B require_batches Ablation Experiment + +In multiple tests, we found that the number of samples issued each time in streaming affects the response length during +training, which in turn affects training time. We verified the impact on results by modifying +`async_training.require_batches`. + +| require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 | +|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 | +| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 | +| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg + +### 30B Model Mode Experiment + +TODO: The 30B experiment is still in progress. + +* Machine: H20 +* Model: Qwen2.5-32B~~~~ +* Rollout length: max_response_length FSDP2: 20K tokens; +* Algorithm: DAPO +* Engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colocate sync: + * step:200 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*200 + * trigger_parameter_sync_step: 512/32 = 16 + * staleness_threshold: 0 + * partial_rollout: False + +| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | +|--------------------|---------------------|--------------------------------------------|------|--------------------|--------------|--------------|------------|------------------| +| colocate sync | 128 | | | | | | | | +| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | | +| fully_async_policy | 64:64 | async stream pipeline with stale samples | | | | | | | +| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | | + + +## Future Plans + +* GRPO experiments +* Megatron adaptation +* SGLang integration +* Transfer queue integration +* Asynchronous parameter synchronization +* AReaL asynchronous algorithm implementation +* TPPO algorithm implementation +* Multi-turn and Tool support \ No newline at end of file diff --git a/recipe/fully_async_policy/README_zh.md b/recipe/fully_async_policy/README_zh.md new file mode 100644 index 000000000..fbbed992d --- /dev/null +++ b/recipe/fully_async_policy/README_zh.md @@ -0,0 +1,373 @@ +# Recipe: Fully Async Policy Async Trainer + +**Author:** `https://github.com/meituan-search` + +Last updated: 10/17/2025. + +本文档介绍了完全异步PPO训练系统,该系统实现了 Trainer 和 Rollouter 的完全解耦,支持异步样本生成和训练。 +在该系统下,我们使用128卡训练qwen2.5-7B模型取得了2.35x-2.67x的性能提升,同时效果没有显著受到影响。 + +## Introduction + +### Background + +rollout和train分离架构相较于colocate的架构能够更加灵活地分配资源,设计更加灵活的训练逻辑,从而处理长尾等问题带来的GPU利用率低,训练效率低的问题。 +one_step_off_policy通过分离架构的设计并进行rollout和train一轮异步的训练方法,缓解了rollout时间过长的问题,并在训练效率上取得了一些收益, +但其强制使用一轮异步的数据,存在不够灵活等问题,而且并不能完全去除长尾对训练效率带来的的影响;在其他框架如areal、Magistral、streamrl、asyncflow上, +已经基于分离架构实现了异步训练、流式训练,并取得了收益;我们借鉴其方法,在verl上进行了实现。fully_async_policy支持异步、流式、partial +rollout的训练, 通过合理设置资源分配情况、参数同步频率等参数,fully_async_policy能够显著提高训练效率。 + +> Magistral https://arxiv.org/abs/2506.10910 +> +> AReaL: A Large-Scale Asynchronous Reinforcement Learning System for Language +> Reasoning https://arxiv.org/abs/2505.24298 +> +> StreamRL: Scalable, Heterogeneous, and Elastic RL for LLMs with Disaggregated Stream +> Generation https://arxiv.org/abs/2504.15930 +> +> AsyncFlow: An Asynchronous Streaming RL Framework for Efficient LLM Post-Training https://arxiv.org/abs/2507.01663 +> + +### 核心贡献 + +* **资源隔离**:与使用hybrid_engine不同,Rollouter和Trainer使用分离的计算资源,需要分别指定所占用的资源。 +* **生成与训练并行**:Trainer在训练的同时,Rollouter在生成新的样本。 +* **多步异步**: 相比 one step off policy 支持0.x步到多步的异步设定,异步方案更加灵活。 +* **nccl参数同步**:使用nccl通信原语进行Rollouter与Trainer参数的通信。 +* **Stream推理与训练**:Rollouter逐样本生成数据,同时数据传输以单个sample为最小传输单位。 +* **异步训练与新鲜度控制**:通过设置参数async_training.staleness_threshold,支持使用旧参数生成的样本进行训练。 +* **PartialRollout**: Rollouter推理过程支持partial rollout逻辑,通过参数同步时,添加`sleep()`和`resume()` + 逻辑,保存进行中的rollout的样本,并在下一次rollout中继续使用,减少参数同步等待进行中的任务结束时间。 + +目前支持使用模式为 fsdp+vllm。vllm必须使用基于AgentLoop的server模式。 + +## 设计 + +fully_async_policy的整体架构如下图所示,fully_async_policy主要由Rollouter、MessageQueue、Trainer、ParameterSynchronizer四部分组成。 + +![fully_async_policy_structure]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_structure.svg?raw=true) + +1. Rollouter逐样本生成序列,并将生成的sample放入MessageQueue中,生产的速度受新鲜度控制。 +2. MessageQueue用于暂存Rollouter生成的sample。 +3. Trainer逐样本从MessageQueue中获取,获取到`require_batches*ppo_mini_batch_size` + 数量的样本后,就会进行训练,训练async_training.trigger_parameter_sync_step轮后,触发与Rollouter的一次参数同步。 +4. ParameterSynchronizer 实现了Nccl的同步参数同步能力。 + +当前方案对比base的收益来源,在于colocate情况下,rollout使用更多的资源无法解决长尾样本带来的空闲, +当我们进行资源隔离后,rollout的时间和train的时间都可能相较于之前更长(因为使用的资源变少了), +但是相互之间的耗时overlap,端到端的耗时反而有所缩减。 + +![fully_async_policy_revenue]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_revenue.svg?raw=true) + +## 使用方式 + +### 参数说明 + +| super params | implication | +|-----------------------------------------------|-----------------------------------------------------------------| +| `trainer.nnodes` | Trainer的node数量 | +| `trainer.n_gpus_per_node` | Trainer每个node上gpu的数量 | +| `rollout.nnodes` | Rollouter的node数量 | +| `rollout.n_gpus_per_node` | Rollouter每个node上gpu的数量 | +| `data.train_batch_size` | 在fully async策略中,该值不生效(默认设置为0) | +| `data.gen_batch_size` | 在fully async策略中,使用流式的样本生产逻辑(默认设置为1) | +| `rollout.total_rollout_steps` | 总的rollout的sample数量 | +| `rollout.test_freq` | Rollouter每更新多少次参数,进行一次validation | +| `actor_rollout_ref.actor.ppo_mini_batch_size` | The ppo_mini_batch_size is a global num across all workers/gpus | +| `async_training.require_batches` | FullyAsyncTrainer一次性获取的ppo_mini_batch_size的数量 | +| `async_training.trigger_parameter_sync_step` | 表示FullyAsyncTrainer进行多少次本地更新后,进行一次参数同步 | +| `async_training.staleness_threshold` | 新鲜度控制 | +| `async_training.partial_rollout` | 是否进行partial_rollout | +| `async_training.use_rollout_log_probs` | 使用rollout产生的log_probs | + +**进一步的解释:** + +* `rollout.total_rollout_steps` + + 与 colocate 相比,数量可以通过 train_batch_size 与 step 相乘对齐: + `rollout.total_rollout_steps = data.train_batch_size * step`。 + +* `async_training.trigger_parameter_sync_step` + + 在fully async策略中,表示Trainer进行多少次本地更新后(也就是获取多少次`require_batches * ppo_mini_batch_size`数量样本), + 与Rollouter之间进行一次参数同步。 + 每两次Rollouter和Trainer参数同步之间,Trainer将会处理`trigger_parameter_sync_step* require_batches\ + ppo_mini_batch_size`份sample。 + 如果为了与colocate在公平的情况下对比速度,trigger_parameter_sync_step应该设置为 `data.train_batch_size / ( + require_batches * ppo_mini_batch_size)`。 + +* `async_training.staleness_threshold` + + 在fully async策略中,表示最大允许使用的staleness样本的比例。 + + * staleness_threshold=0,表示同步训练。 + Rollouter两次参数更新之间将会生成固定数量的样本,样本数为: + $$rollout\_num = (trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size)$$ + * staleness_threshold>0,表示异步训练, 可以设置为小数,支持更灵活的异步调用。 + Rollouter两次参数更新之间将会最多生成的样本数为: + $$rollout\_num = (1+staleness\_threshold)*(trigger\_parameter\_sync\_step*require\_batches*ppo\_mini\_batch\_size) - num\_staleness\_sample $$ + + num_staleness_sample 表示上一次rollout多生成的陈旧样本数。 + + 由于是流式系统,rollout持续生成,trainer持续消费。如果rollouter较慢,trainer会更早触发参数同步,rollouter并不会实际生产rollout_num个样本。 + 当rollout 足够快时,staleness_threshold设置为1,基本上等价于one_step_off policy。 + 为了避免过期样本太多影响训练精度,建议该值设置小于1。 + +* `async_training.partial_rollout` + + partial_rollout只会在staleness_threshold>0时才实际上起作用。 + +* `async_training.use_rollout_log_probs` + + 在强化学习算法中,log_probs与参数版本,token都存在隐性的相关性。由于PPO/GRPO/DAPO等算法的设定,我们在计算重要性采样时, + 即 old_log_prob必须使用rollout参数及token所对应log_probs,才能保证算法的正确性。在fully + async策略中,我们默认old_log_prob是有rollout所计算的,而不是由trainer所计算。 + + * `async_training.require_batches` + + 在流式训练中,require_batches 应该设置为1,表示生产够ppo_mini_batch_size样本后,就进行训练。 + 在实际测试中,我们发现,如果单次下发的样本较少,由于数据分发的顺序,会导致训练不稳定,response 长度变长。 + 在这里,我们额外提供 require_batches 进行流式分发,单次参与训练的样本数量控制。 + +### 模式支持 + +1. on policy pipeline: + 1. **trigger_parameter_sync_step=1,staleness_threshold=0** + 2. Rollouter一次生产`require_batches*ppo_mini_batch_size` + 的samples,Trainer获取这些samples后进行训练,训练完后Trainer和Rollouter之间进行一次参数同步; + 3. 在rollout阶段,如果存在长尾的样本,但是rollout样本数较少时,较短的样本无法填充到空闲的资源中,会造成一定的资源浪费。 + 4. 如图a所示; + +2. stream off policy pipeline: + 1. **trigger_parameter_sync_step>1,staleness_threshold=0** + 2. 将会进行同步的流式训练,Rollouter一次生产`require_batches*ppo_mini_batch_size*trigger_parameter_sync_step` + 的samples,Trainer每获取`require_batches*ppo_mini_batch_size` + 就进行一次本地训练,训练trigger_parameter_sync_step次后,Trainer和Rollouter之间进行一次参数同步; + 3. 相较于a,由于一次生成的样本更多,资源的空闲会更低。 + 4. 在一次step训练中,会存在两次资源闲置的时间,分别是在第一次获取样本时,train等待`require_batches*ppo_mini_batch_size` + 个样本生产,以及最后一次参数更新时,rollout等待训练完成。 + 5. 如图b所示; + +3. async stream pipeline with staleness samples: + 1. **trigger_parameter_sync_step>=1,staleness_threshold>0,partial_rollout=Flase** + 2. Rollouter在每次参数更新后将计划最多生产rollout_num个样本(实际根据rollout速度,生成的样本可能会少与这个值)。 + 3. 如果rollout过程比较快,Rollouter将会在参数同步前额外生成一部分样本num_stale_samples,用于参数同步后立即给Trainer使用。 + 触发参数同步时,如果Rollouter有正在生产的任务,将会等待任务完成,同时不会添加新的任务; + 4. 相较于b,除第一次step训练外,后续的训练都不会有wait first batch rollout finish的时间,但是会有wait active task + finish的时间。 + 5. 如图c所示; + +4. async stream pipeline with partial rollout: + 1. **trigger_parameter_sync_step>=1,staleness_threshold>0,partial_rollout=True** + 2. 相较于c,触发参数同步时,Rollouter如果有正在生产的sample,会打断rollout过程并进行参数同步,被中断的sample会在参数同步后继续生成。减少了wait + active task finish的时间。 + 3. 如图d所示; + +![fully_async_policy_mode]( +https://github.com/ArronHZG/verl-community/blob/recipe/async_policy/docs/fully_async_policy_mode.svg?raw=true) + +### 关键指标 + +| metrics | implication | +|------------------------------------------------|-----------------------------------------------------------| +| `trainer/idle_ratio` | Trainer闲置率 | +| `rollouter/idle_ratio` | Rollouter闲置率 | +| `fully_async/count/stale_samples_processed` | 训练使用的旧sample总数 | +| `fully_async/count/stale_trajectory_processed` | 训练使用的旧trajectory总数(一个sample会生产rollout.n条trajectory) | +| `fully_async/partial/total_partial_num` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本数 | +| `fully_async/partial/partial_ratio` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的比例 | +| `fully_async/partial/max_partial_span` | 两次trigger_parameter_sync_step之间Trainer处理的partial样本的最大参数跨度 | + +### 调参建议 + +* 资源分配与调整: + * 合理的资源分配是获得好的训练效率的前提。理想的资源分配情况应该是使得Rollout的时间和Train的时间接近,从而使得整个训练过程流水气泡最小, + 避免资源闲置,同时Trainer不会使用旧样本。在真实训练场景下,可以根据实际训练过程中rollout和train的空闲时间调整资源分配, + 可从rollouter/idle_ratio和trainer/idle_ratio获得,如果rollouter/idle_ratio较高trainer/idle_ratio较低, + 应该增多Trainer的资源减少Rollouter的资源,反之亦然。 + +* 关键参数: + * staleness_threshold: 设置太大会导致较多的旧样本使用,影响模型效果,建议设置小于1。 + * require_batches:越接近1,越接近纯流式过程,训练过程中bubble越小,能够在速度上获得更快的加速效果,但会对样本的处理顺序产生影响; + * trigger_parameter_sync_step: 设置的越小越接近on policy,但会导致频繁的参数同步,长尾样本浪费的资源无法被短样本填充,资源利用率低。 + 设置的越大有更高的计算效率,但是精度上会受到off policy的影响。 + * rollout.test_freq: 会占用Rollouter资源,不建议设置太小。 + +* 模式选择:通过调整不同的参数,Fully Async架构支持不同程度上的优化加速,适用于不同场景的任务。 + * 对于小规模任务,需要保证训练的稳定性和 on-policy 性,对速度要求不高的场景,可以尝试使用on policy pipeline的模式(模式1)。 + * 对于需要提高训练吞吐量,但对 staleness 敏感的场景,可以尝试使用 stream off policy pipeline 的模式。即通过 + 设置trigger_parameter_sync_step>1 ,提高 训练效率,但仍保持同步机制 (staleness_threshold=0 )(模式2)。 + * 对于大规模任务,对训练速度有较高要求,且可以容忍一定 off-policy 程度、staleness的场景,可以设置staleness_threshold> + 0、partial_rollout=True提高训练效率,使用 async stream pipeline 模式(模式 3 或 4)。 + +### 快速开始 + +```shell +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=10 +staleness_threshold=0 +trigger_parameter_sync_step=16 +partial_rollout=False + + +python -m recipe.fully_async_policy.fully_async_main \ + train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.hybrid_engine=False \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.partial_rollout="${partial_rollout}" +``` + +## 实验 + +### 在7B模型上进行异步训练 + +我们使用 Qwen2.5-Math-7B 验证 fully async 策略在长候选下,多种资源下的收益情况。 +使用`async stream pipeline with staleness samples` 策略,我们在32卡,64卡,128卡都取得2x左右的性能提升,同时没有显著影响实验效果。 + +* 机器:H20 +* 模型:Qwen2.5-Math-7B +* rollout长度:max_response_length FSDP2: 28K tokens; +* 算法:DAPO +* 数据集: TRAIN_FILE: dapo-math-17k.parquet TEST_FILE: aime-2024.parquet +* engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colocate sync: + * step: 400 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*400 + * require_batches: 4 + * trigger_parameter_sync_step: 4 + * staleness_threshold: 0.3 + * partial_rollout: True + +| training mode | resource allocation | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:--------------------:|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------------:| +| colocate sync | 32 | 790.10 | 357.41 | 107.71 | 313.81 | 13h 44m | 1d 3h 43m | 2d 9h 22m | 3d 17h 5m | max: 0.3313
last: 0.2448 | +| fully_async_policy | 16:16 | | | \ | | | | | | max:
last: | +| colocate sync | 64 | 365.28 | 150.72 | 70.26 | 133.41 | 10h 22m | 20h 45m | 1d 7h 6m | 1d 17h 32m | max: 0.3365
last: 0.2333 | +| fully_async_policy | 32:32 | 189.26 | 28.46 | \ | 156.98 | 4h 57m
(2.09x) | 10h 14m
(2.03x) | 16h 58m
(1.83x) | 21h 40m
(1.92x) | max: 0.3677
last: 0.3406 | +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| fully_async_policy | 64:64 | 150.63 | 33.14 | \ | 113.16 | 3h 13m
(2.67x) | 6h 46m
(2.65x) | 10h 53m
(2.67x) | 17h 22m
(2.35x) | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-colocate_async?nw=nwuserhouzg + +### 128卡 7B 异步模式实验 + +我们使用 Qwen2.5-Math-7B 验证 fully async 所支持的各个模式的效果。 +我们可以看到 stream 带来的收益大约0.6x,叠加 staleness 和 partial_rollout 后,收益为2.35x。 + +| mode | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:-------------------------------------------------------------------------------------------------------:|:---------------------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| colocate sync | 128 | 356.30 | 177.85 | 53.92 | 113.81 | 8h 36m | 17h 56m | 1d 5h 6m | 1d 16h 48m | max: 0.3573
last: 0.2958 | +| `stream off policy pipeline`
(+fully async: trigger_parameter_sync_step= 4,
require_batches= 4) | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| `async stream pipeline with staleness samples`
(+staleness_threshold=0.5) | | | | | | | | | | +| `async stream pipeline with partial rollout`
(+partial_rollout=True) | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-stream_stale_partial?nw=nwuserhouzg + +### 128卡 stale 消融实验 + +在 `async stream pipeline with partial rollout` 模式下,我们验证 staleness 的设置对于训练效率的影响。 +我们可以发现,staleness 越大,最终取得的收益越明显。 +同时我们也注意到 staleness 取 0.3 和 0.5 的时间比较接近,原因是随着训练步数的增量,response 长度变化较大,训练出现了不稳定的问题。 +后续还需要针对该问题进行进一步的分析和优化。 + +| staleness_threshold | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | total time
400 step | acc/mean@1 | +|:---------------------:|:--------:|:--------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 0 | 231.34 | 128.47 | \ | 98.77 | 4h 25m | 9h 41m | 15h 2m | 1d 1h 53m | max: 0.2844
last: 0.2604 | +| 0.1 | 171.30 | 58.17 | \ | 109.12 | 3h 53m | 8h 37m | 14h 25m | 19h 59m | max: 0.3542
last: 0.2979 | +| 0.3 | 146.11 | 38.88 | \ | 103.22 | 3h 18m | 6h 49m | 11h 40m | 17h 20m | max: 0.3469
last: 0.2865 | +| 0.5 | 150.63 | 33.14 | \ | 113.16 | 3h 13m | 6h 46m | 10h 53m | 17h 22m | max: 0.3521
last: 0.3094 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_stale?nw=nwuserhouzg + +### 128卡 7B require_batches 消融实验 + +在多次测试下,我们发现流式每次下发样本的数量会影响训练的response长度,进而影响训练时长,我们通过修改 +`async_training.require_batches` 验证对与结果的影响。 + +| require_batches | step | gen | old_log_prob | update_actor | total time
100 step | total time
200 step | total time
300 step | acc/mean@1 | +|:-----------------:|:--------:|:-------:|:--------------:|:--------------:|:------------------------:|:------------------------:|:------------------------:|:-----------------------------:| +| 1 | 203.47 | 30.88 | \ | 181.08 | 3h 31m | 8h 29m | 17h 36m | max: 0.349
last: 0.326 | +| 2 | 158.72 | 26.32 | \ | 128.08 | 3h 35m | 7h 38m | 13h 57m | max: 0.351
last: 0.3406 | +| 4 | 124.64 | 25.62 | \ | 95.06 | 3h 13m | 6h 46m | 10h 53m | max: 0.3521
last: 0.3521 | + +> source data: https://wandb.ai/hou-zg-meituan/fully-async-policy-ablation_require_batches?nw=nwuserhouzg + +### 30B模型模式实验 + +TODO: 30B 的实验,还在完善中。 + +* 机器: H20 +* 模型:Qwen2.5-32B +* rollout长度:max_response_length FSDP2: 20K tokens; +* 算法:DAPO +* engine: vllm+FSDP2 +* rollout.n: 16 +* ppo_mini_batch_size: 32 +* test_freq: 20 + +* colacate sync: + * step:200 + * train_batch_size: 512 + +* fully_async_policy + * total_rollout_steps: 512*200 + * trigger_parameter_sync_step: 512/32 = 16 + * staleness_threshold: 0 + * partial_rollout: False + +| training mode | Resource allocation | mode | step | generate_sequences | old_log_prob | update_actor | total time | acc/best@32/mean | +|--------------------|---------------------|----------------------------------------------|------|--------------------|--------------|--------------|------------|------------------| +| colocate sync | 128 | | | | | | | | +| fully_async_policy | 64:64 | stream off policy pipeline | | | | | | | +| fully_async_policy | 64:64 | async stream pipeline with staleness samples | | | | | | | +| fully_async_policy | 64:64 | async stream pipeline with partial rollout | | | | | | | + + +## 后续计划 + +* GRPO实验 +* megatron 适配 +* sglang 集成 +* transfer queue 集成 +* 异步参数同步 +* Areal异步算法实现 +* TPPO算法实现 +* 多轮及Tool的支持 \ No newline at end of file diff --git a/recipe/fully_async_policy/agent_loop/__init__.py b/recipe/fully_async_policy/agent_loop/__init__.py new file mode 100644 index 000000000..e30d78f1a --- /dev/null +++ b/recipe/fully_async_policy/agent_loop/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .agent_loop import FullyAsyncAgentLoopManager +from .partial_single_turn_agent_loop import PartialSingleTurnAgentLoop + +_ = [PartialSingleTurnAgentLoop] +__all__ = [FullyAsyncAgentLoopManager] diff --git a/recipe/fully_async_policy/agent_loop/agent_loop.py b/recipe/fully_async_policy/agent_loop/agent_loop.py new file mode 100644 index 000000000..55489d705 --- /dev/null +++ b/recipe/fully_async_policy/agent_loop/agent_loop.py @@ -0,0 +1,275 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +import os +from typing import Any, Optional + +import hydra +import numpy as np +import ray +import torch +from omegaconf import DictConfig + +from recipe.fully_async_policy.vllm_rollout.vllm_async_server import FullyAsyncvLLMReplica +from verl.experimental.agent_loop.agent_loop import ( + AgentLoopManager, + AgentLoopOutput, + AgentLoopWorkerBase, + AsyncLLMServerManager, + BatchExecutor, + _agent_loop_registry, + _DummyConfig, + get_trajectory_info, +) +from verl.protocol import DataProto +from verl.single_controller.ray import RayWorkerGroup +from verl.utils.rollout_trace import rollout_trace_attr +from verl.workers.rollout.replica import TokenOutput + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +class FullyAsyncLLMServerManager(AsyncLLMServerManager): + async def generate_for_partial(self, request_id, prompt_ids, sampling_params) -> TokenOutput: + """Generate tokens from prompt ids. with partial rollout function""" + server = self._choose_server(request_id) + output = await server.generate_for_partial.remote( + request_id=request_id, + prompt_ids=prompt_ids, + sampling_params=sampling_params, + ) + return output + + +class FullyAsyncAgentLoopOutput(AgentLoopOutput): + """Agent loop output.""" + + is_cancel: bool = False + """Indicates whether the request was interrupted""" + log_probs: list[float] = None + """Response token log probs including LLM generated token, tool response token.""" + param_version_start: int = 0 + """Indicate start parameter version when this response is generated""" + param_version_end: int = 0 + """Indicate end parameter version when this response is generated, used for partial rollout""" + + +@ray.remote +class FullyAsyncAgentLoopWorker(AgentLoopWorkerBase): + def __init__( + self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None + ): + self.server_manager = FullyAsyncLLMServerManager(config, server_handles) + super().__init__(config, server_handles, rm_executor) + + async def generate_sequences_no_post( + self, batch: DataProto, partial_output_list: Optional[list[AgentLoopOutput]] + ) -> list[AgentLoopOutput]: + """Generate sequences from agent loop. + + Args: + batch (DataProto): Input batch. + partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result. + + Returns: + list[FullyAsyncAgentLoopOutput]: List of agent loop outputs, one per sample in the batch. + """ + config = self.config.actor_rollout_ref.rollout + sampling_params = dict( + temperature=config.temperature, + top_p=config.top_p, + repetition_penalty=1.0, + logprobs=config.calculate_log_probs, + ) + + # override sampling params for validation + if batch.meta_info.get("validate", False): + sampling_params["top_p"] = config.val_kwargs.top_p + sampling_params["temperature"] = config.val_kwargs.temperature + + # by default, we assume it's a single turn agent + if "agent_name" not in batch.non_tensor_batch: + batch.non_tensor_batch["agent_name"] = np.array(["single_turn_agent"] * len(batch), dtype=object) + + if "index" in batch.non_tensor_batch: + index = batch.non_tensor_batch["index"] + else: + index = np.arange(len(batch)) + + trajectory_info = await get_trajectory_info( + batch.meta_info.get("global_steps", -1), index, batch.meta_info.get("validate", False) + ) + + if not partial_output_list: + partial_output_list = [None] * len(batch) + + tasks = [] + for i in range(len(batch)): + kwargs = {k: v[i] for k, v in batch.non_tensor_batch.items()} + kwargs["output"] = partial_output_list[i] + tasks.append( + asyncio.create_task(self._partial_run_agent_loop(sampling_params, trajectory_info[i], **kwargs)) + ) + return await asyncio.gather(*tasks) + + async def _partial_run_agent_loop( + self, + sampling_params: dict[str, Any], + trajectory: dict[str, Any], + *, + agent_name: str, + **kwargs, + ) -> AgentLoopOutput: + with rollout_trace_attr( + step=trajectory["step"], + sample_index=trajectory["sample_index"], + rollout_n=trajectory["rollout_n"], + validate=trajectory["validate"], + name="agent_loop", + ): + assert agent_name in _agent_loop_registry, ( + f"Agent loop {agent_name} not registered, registered agent loops: {_agent_loop_registry.keys()}" + ) + + agent_loop_config = _agent_loop_registry[agent_name] + agent_loop = hydra.utils.instantiate( + config=agent_loop_config, + trainer_config=_DummyConfig(config=self.config), + server_manager=self.server_manager, + tokenizer=self.tokenizer, + processor=self.processor, + ) + return await agent_loop.run(sampling_params, **kwargs) + + +class FullyAsyncAgentLoopManager(AgentLoopManager): + def __init__(self, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None): + self.config = config + self.worker_group = worker_group + self.rm_executor = None + self.rm_micro_batch_size = None + self.agent_loop_workers_class = FullyAsyncAgentLoopWorker + self.rollout_replica_class = FullyAsyncvLLMReplica + + self.rm_wg = rm_wg + self.rollout_replicas = None + self.server_handles = None + self.server_addresses = None + self.agent_loop_workers = None + + @classmethod + async def create(cls, config: DictConfig, worker_group: RayWorkerGroup = None, rm_wg: RayWorkerGroup = None): + instance = cls(config, worker_group, rm_wg) + await instance._async_init() + return instance + + async def _async_init(self): + if self.rm_wg: + + def batch_fn(data_list: list[DataProto]) -> list[torch.Tensor]: + new_data_list = [] + for data in data_list: + temp_non_tensor_batch = {"__num_turns__": data.non_tensor_batch["__num_turns__"]} + temp_data = DataProto(batch=data.batch, non_tensor_batch=temp_non_tensor_batch) + new_data_list.append(temp_data) + + new_batch = DataProto.concat(new_data_list) + out_data = self.rm_wg.compute_rm_score(new_batch) + return out_data.split(1) + + self.rm_executor = BatchExecutor.options( + scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( + node_id=ray.get_runtime_context().get_node_id(), + soft=False, + ), + ).remote(batch_fn, self.rm_wg.world_size) + + self.rm_micro_batch_size = self.rm_wg.world_size + + await self._initialize_llm_servers_async() + self._init_agent_loop_workers() + + async def _initialize_llm_servers_async(self): + rollout_world_size = self.config.actor_rollout_ref.rollout.tensor_model_parallel_size + world_size = ( + self.worker_group.world_size + if self.worker_group + else self.config.trainer.n_gpus_per_node * self.config.trainer.nnodes + ) + num_replicas = world_size // rollout_world_size + + rollout_config = self.config.actor_rollout_ref.rollout + model_config = self.config.actor_rollout_ref.model + self.rollout_replicas = [ + self.rollout_replica_class( + replica_rank=replica_rank, + config=rollout_config, + model_config=model_config, + gpus_per_node=self.config.trainer.n_gpus_per_node, + ) + for replica_rank in range(num_replicas) + ] + + if self.worker_group: + await asyncio.gather(*[server.init_hybrid(self.worker_group) for server in self.rollout_replicas]) + else: + await asyncio.gather(*[server.init_standalone() for server in self.rollout_replicas]) + + self.server_handles = [server._server_handle for server in self.rollout_replicas] + self.server_addresses = [server._server_address for server in self.rollout_replicas] + + async def generate_single_sample_async( + self, + sample: DataProto, + partial_output_list: Optional[list[AgentLoopOutput]], + ) -> list[AgentLoopOutput]: + """ + Asynchronously process a single sample + + Args: + sample: Single sample data + partial_output_list: Optional[List[AgentLoopOutput]]: already rollout result. + + Returns: + list[AgentLoopOutput]: Processing results + """ + worker = self._select_best_worker() + output_future = worker.generate_sequences_no_post.remote(sample, partial_output_list) + return await asyncio.wrap_future(output_future.future()) + + def _select_best_worker(self): + """Select the best worker, simple round-robin load balancing""" + if not hasattr(self, "_worker_index"): + self._worker_index = 0 + + worker = self.agent_loop_workers[self._worker_index] + self._worker_index = (self._worker_index + 1) % len(self.agent_loop_workers) + return worker + + async def cancel(self): + await asyncio.gather(*[replica.cancel() for replica in self.rollout_replicas]) + + async def resume(self): + await asyncio.gather(*[replica.resume() for replica in self.rollout_replicas]) + + async def wake_up(self): + await asyncio.gather(*[replica.wake_up() for replica in self.rollout_replicas]) + + async def sleep(self): + await asyncio.gather(*[replica.sleep() for replica in self.rollout_replicas]) + + async def reset_prefix_cache(self): + await asyncio.gather(*[replica.reset_prefix_cache() for replica in self.rollout_replicas]) diff --git a/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py b/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py new file mode 100644 index 000000000..246085454 --- /dev/null +++ b/recipe/fully_async_policy/agent_loop/partial_single_turn_agent_loop.py @@ -0,0 +1,92 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os +from typing import Any, Optional +from uuid import uuid4 + +from recipe.fully_async_policy.agent_loop.agent_loop import AgentLoopOutput, FullyAsyncAgentLoopOutput +from verl.experimental.agent_loop import AgentLoopBase +from verl.experimental.agent_loop.agent_loop import register +from verl.utils.profiler import simple_timer + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + + +@register("partial_single_turn_agent") +class PartialSingleTurnAgentLoop(AgentLoopBase): + """Naive agent loop that only do single turn chat completion.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.prompt_length = self.config.actor_rollout_ref.rollout.prompt_length + self.response_length = self.config.actor_rollout_ref.rollout.response_length + self.apply_chat_template_kwargs = self.config.data.get("apply_chat_template_kwargs", {}) + + async def run(self, sampling_params: dict[str, Any], **kwargs) -> AgentLoopOutput: + output: Optional[FullyAsyncAgentLoopOutput] = kwargs.get("output", None) + messages = list(kwargs["raw_prompt"]) + param_version = kwargs.get("param_version", 0) + + metrics = {} + request_id = uuid4().hex + + param_version_start = param_version + param_version_end = param_version + + if not output: + prompt_ids = await self.loop.run_in_executor( + None, + lambda: self.tokenizer.apply_chat_template( + messages, add_generation_prompt=True, tokenize=True, **self.apply_chat_template_kwargs + ), + ) + else: + if output.is_cancel: + # Resume the paused sample, + # add the result directly after prompt_ids, + # and reset generate_sequences metric + prompt_ids = output.prompt_ids + output.response_ids + metrics["generate_sequences"] = output.metrics.generate_sequences + param_version_start = output.param_version_start + else: + # In the same batch of samples, + # ome are canceled and some are not. + # The samples without partial rollout are returned directly. + return output + with simple_timer("generate_sequences", metrics): + response_ids, log_probs, is_cancel = await self.server_manager.generate_for_partial( + request_id=request_id, prompt_ids=prompt_ids, sampling_params=sampling_params + ) + if not output: + response_mask = [1] * len(response_ids) + else: + # Pause the sample to be resumed, add the output result to response_ids, and reset response_mask + prompt_ids = output.prompt_ids + log_probs = output.log_probs + log_probs + response_ids = output.response_ids + response_ids + response_mask = [1] * len(response_ids) + + return FullyAsyncAgentLoopOutput( + prompt_ids=prompt_ids, + response_ids=response_ids[: self.response_length], + response_mask=response_mask[: self.response_length], + num_turns=2, + metrics=metrics, + is_cancel=is_cancel, + log_probs=log_probs, + param_version_start=param_version_start, + param_version_end=param_version_end, + ) diff --git a/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml b/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml new file mode 100644 index 000000000..4a8b8fc32 --- /dev/null +++ b/recipe/fully_async_policy/config/fully_async_ppo_trainer.yaml @@ -0,0 +1,55 @@ +hydra: + searchpath: + - file://verl/trainer/config + +defaults: + - ppo_trainer + - _self_ + +async_training: + + # Maximum samples staleness threshold + staleness_threshold: 0.1 + + # Frequency of parameter synchronization between rollouter and trainer, + # One step means trainer obtains a batch of required samples + trigger_parameter_sync_step: 4 + + # The number of ppo_mini_batches that the FullyAsyncTrainer obtains once + require_batches: 1 + + # When synchronizing parameters, whether to interrupt rollouter and perform partial rollout + partial_rollout: True + + # Whether to use rollout log probs for training + use_rollout_log_probs: True + +# Rollout config +rollout: + + # Number of nodes used in the rollout + nnodes: 1 + + # Number of GPUs per node + n_gpus_per_node: 8 + + # number of responses (i.e. num sample times). > 1 for grpo + n: 4 + + # total rollout samples # TODO rename to total_rollout_samples + total_rollout_steps: 100 + + # Number of epochs in training + total_epochs: 10 + + # Test frequency, how many times a parameter update triggers a validation + test_freq: 1 + +data: + # Number of samples generated, currently only support 1 + gen_batch_size: 1 + +actor_rollout_ref: + actor: + # Whether to use rollout log probs for training + use_rollout_log_probs: ${oc.select:async_training.use_rollout_log_probs, True} diff --git a/recipe/fully_async_policy/detach_utils.py b/recipe/fully_async_policy/detach_utils.py new file mode 100644 index 000000000..545211198 --- /dev/null +++ b/recipe/fully_async_policy/detach_utils.py @@ -0,0 +1,474 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import Any, Optional + +import numpy as np +import torch +from tensordict import TensorDict + +from verl import DataProto +from verl.experimental.agent_loop.agent_loop import AgentLoopOutput +from verl.trainer.ppo.ray_trainer import compute_response_mask + + +def postprocess_agent_loop_outputs(inputs: list[AgentLoopOutput], tokenizer, config) -> DataProto: + """Static method to postprocess a list of AgentLoopOutput into DataProto + + Args: + inputs: List of AgentLoopOutput + tokenizer: Tokenizer instance + config: Configuration object + + Returns: + DataProto: Processed batch data + """ + # NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py + # prompts: left pad + # responses: right pad + # input_ids: prompt + response + # attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + # position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11] + + # prompts + tokenizer.padding_side = "left" + outputs = tokenizer.pad( + [{"input_ids": input.prompt_ids} for input in inputs], + padding="max_length", + max_length=config.actor_rollout_ref.rollout.prompt_length, + return_tensors="pt", + return_attention_mask=True, + ) + prompt_ids, prompt_attention_mask = outputs["input_ids"], outputs["attention_mask"] + + # responses + tokenizer.padding_side = "right" + outputs = tokenizer.pad( + [{"input_ids": input.response_ids} for input in inputs], + padding="max_length", + max_length=config.actor_rollout_ref.rollout.response_length, + return_tensors="pt", + return_attention_mask=True, + ) + response_ids, response_attention_mask = outputs["input_ids"], outputs["attention_mask"] + + # response_mask + outputs = tokenizer.pad( + [{"input_ids": input.response_mask} for input in inputs], + padding="max_length", + max_length=config.actor_rollout_ref.rollout.response_length, + return_tensors="pt", + return_attention_mask=False, + ) + response_mask = outputs["input_ids"] + assert response_ids.shape == response_mask.shape, ( + f"mismatch in response_ids and response_mask shape: {response_ids.shape} vs {response_mask.shape}" + ) + response_mask = response_mask * response_attention_mask + + input_ids = torch.cat([prompt_ids, response_ids], dim=1) + attention_mask = torch.cat([prompt_attention_mask, response_attention_mask], dim=1) + position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask + + batch = TensorDict( + { + "prompts": prompt_ids, # [bsz, prompt_length] + "responses": response_ids, # [bsz, response_length] + "response_mask": response_mask, # [bsz, response_length] + "input_ids": input_ids, # [bsz, prompt_length + response_length] + "attention_mask": attention_mask, # [bsz, prompt_length + response_length] + "position_ids": position_ids, # [bsz, prompt_length + response_length] + }, + batch_size=len(input_ids), + ) + + num_turns = np.array([input.num_turns for input in inputs], dtype=np.int32) + metrics = [input.metrics.model_dump() for input in inputs] + return DataProto(batch=batch, non_tensor_batch={"__num_turns__": num_turns}, meta_info={"metrics": metrics}) + + +@dataclass +class RolloutSample: + """Enhanced rollout sample containing both original batch info and AgentLoopOutput""" + + # Original batch information + full_batch: Any + + # AgentLoopOutput from generation + agent_loop_output_list: list[Any] # AgentLoopOutput + + # Metadata + sample_id: str + epoch: int + + # Processing metadata + processing_times: list[float] + param_version: int + param_version_start: list[int] + param_version_end: list[int] + rollout_status: dict[str, Any] + + +@dataclass +class ValidateMetrics: + """Metrics for validation""" + + timing_raw: dict[str, Any] + metrics: Optional[dict[str, Any]] = None + global_steps: Optional[int] = None + param_version: Optional[int] = None + + +def prepare_single_generation_data(batch_dict, global_steps, rollout_n) -> DataProto: + """ + Similar to the logic of ray_trainer._prepare_generate_batch, but for a single sample. + Separate the data used for generation from the original data. + + Returns: + tuple: (original_batch_dict, gen_data_for_single_sample) + """ + + full_batch = DataProto.from_single_dict(batch_dict) + + batch_keys_to_pop = ["input_ids", "attention_mask", "position_ids"] + non_tensor_batch_keys_to_pop = ["raw_prompt_ids"] + + full_batch.pop( + batch_keys=batch_keys_to_pop, + non_tensor_batch_keys=non_tensor_batch_keys_to_pop, + ) + + # Setting agent - partial_single_turn_agent, that supports partial + full_batch.non_tensor_batch["agent_name"] = np.array(["partial_single_turn_agent"] * len(full_batch), dtype=object) + + # Add global step count to generated data + full_batch = full_batch.repeat(repeat_times=rollout_n, interleave=True) + return full_batch + + +def process_rollout_log_probs(data_proto: DataProto, rollout_log_probs: list[list[float]]) -> torch.Tensor: + """ + Process rollout_log_probs according to the mask in DataProto + mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0] + + Args: + data_proto: A DataProto object containing batch information + rollout_log_probs: A two-dimensional list, each sublist containing the log_probs of a sample + + Returns: + torch.Tensor: The processed log_probs tensor, with shape: [bsz, response_length] + """ + + batch = data_proto.batch + response_mask = batch["response_mask"] + rollout_log_probs_tensor = torch.zeros(response_mask.shape, dtype=torch.float32) - 1 + + for i, log_probs_seq in enumerate(rollout_log_probs): + # Get the effective length of the current sample (the number of positions with 1 in the mask) + valid_length = response_mask[i].sum().item() + + # Ensure that the length of log_probs_seq does not exceed the valid length + actual_length = min(len(log_probs_seq), valid_length) + + # Fill log_probs into the corresponding position + if actual_length > 0: + rollout_log_probs_tensor[i, :actual_length] = torch.tensor(log_probs_seq[:actual_length]) + + rollout_log_probs_tensor = rollout_log_probs_tensor.to(torch.float32) + return rollout_log_probs_tensor + + +def merge_rollout_sample(config, tokenizer, rs: RolloutSample): + """ + Supplement and refine the RolloutSample object, + """ + # Step 1: Create a DataProto from the AgentLoopOutput to generate the result + gen_batch_output = postprocess_agent_loop_outputs(rs.agent_loop_output_list, tokenizer, config) + rollout_log_probs = [x.log_probs for x in rs.agent_loop_output_list] + rollout_log_probs = process_rollout_log_probs(gen_batch_output, rollout_log_probs) + gen_batch_output.batch["rollout_log_probs"] = rollout_log_probs.to(torch.float32) + + # Step 2: Add uid + rs.full_batch.non_tensor_batch["uid"] = np.array([f"uid_{rs.sample_id}"] * len(rs.full_batch), dtype=object) + + # Step 2: Merge batches + # Merge the non_tensor_batch and meta_info of original_batch into final_batch + for key, value in rs.full_batch.non_tensor_batch.items(): + gen_batch_output.non_tensor_batch[key] = value + gen_batch_output.meta_info.update(rs.full_batch.meta_info) + + # Step 3, set full_batch + rs.full_batch = gen_batch_output + rs.processing_times = [] + for agent_loop in rs.agent_loop_output_list: + rs.processing_times.append(agent_loop.metrics.generate_sequences) + rs.param_version_start = [agent_loop.param_version_start for agent_loop in rs.agent_loop_output_list] + rs.param_version_end = [agent_loop.param_version_end for agent_loop in rs.agent_loop_output_list] + # Step 4, clear agent_loop_output_list + rs.agent_loop_output_list = [] + return rs + + +def assemble_batch_from_rollout_samples( + rollout_samples: list[RolloutSample], tokenizer, config, balance_batch=None +) -> DataProto: + """ + Assemble gen_batch_output from RolloutSample objects + Assembles batches from RolloutSample objects, similar to the _post_generate_batch logic in ray_trainer. + + Args: + rollout_samples: List of RolloutSample objects + tokenizer: Tokenizer instance + config: Configuration object containing trainer settings + balance_batch: Whether to balance the batch (simplified version) + + Returns: + DataProto: Assembled gen_batch_output + + Raises: + ValueError: If rollout_samples is empty + """ + start_time = time.time() + + if not rollout_samples: + raise ValueError("Empty rollout_samples provided for batch assembly") + + print(f"[BatchUtils] Assembling batch from {len(rollout_samples)} RolloutSample objects") + + rollout_samples_batch = [] + processing_times = [] + rollout_status = rollout_samples[0].rollout_status + # Add a prefix to all rollout_status keys + rollout_status = {f"fully_async/{key}": value for key, value in rollout_status.items()} + + for rs in rollout_samples: + rollout_samples_batch.append(rs.full_batch) + processing_times.extend(rs.processing_times) + final_batch = DataProto.concat(rollout_samples_batch) + + # Calculate response_mask (if not present) + if "response_mask" not in final_batch.batch.keys(): + final_batch.batch["response_mask"] = compute_response_mask(final_batch) + + if balance_batch: + balance_batch(final_batch, metrics={}) + + # Calculate the global valid token number + if "attention_mask" in final_batch.batch: + final_batch.meta_info["global_token_num"] = torch.sum(final_batch.batch["attention_mask"], dim=-1).tolist() + + # Collect statistics + param_versions = [rs.param_version for rs in rollout_samples] + trajectorys_param_versions = [version for rs in rollout_samples for version in rs.param_version_end] + + processing_time_stats = { + "processing_time/avg": np.mean(processing_times), + "processing_time/max": np.max(processing_times), + "processing_time/min": np.min(processing_times), + "processing_time/tp50": np.percentile(processing_times, 50), + "processing_time/tp99": np.percentile(processing_times, 99), + "processing_time/tp95": np.percentile(processing_times, 95), + } + processing_time_stats = {f"fully_async/{key}": value for key, value in processing_time_stats.items()} + + param_version_diff = [abs(a - b) for a, b in zip(rs.param_version_end, rs.param_version_start, strict=False)] + num_diff0 = param_version_diff.count(0) + partial_stats = { + "fully_async/partial/total_partial_num": len(param_version_diff) - num_diff0, + "fully_async/partial/partial_ratio": (len(param_version_diff) - num_diff0) / len(param_version_diff), + "fully_async/partial/max_partial_span": max(param_version_diff), + } + # add meta_info + final_batch.meta_info.update( + { + "rollout_param_versions": param_versions, + "param_version_diversity": len(set(param_versions)) if param_versions else 0, + "trajectory_param_versions": trajectorys_param_versions, + **processing_time_stats, + **rollout_status, + **partial_stats, + } + ) + + print(f"[BatchUtils] Batch assembly completed in {time.time() - start_time:.2f}s") + + return final_batch + + +class MetricsAggregator: + """Metrics aggregator, used to combine metrics from multiple training steps""" + + def __init__(self, total_gpus: int): + # Store all values ​​for each metric + self.metric_values: dict[str, list[float]] = defaultdict(list) + # Store the number of samples at each step for weighted averaging + self.sample_counts: list[int] = [] + # Store the timestamp of each step for time-related calculations + self.timestamps: list[float] = [] + # Step Count + self.step_count = 0 + # total num gpus used + self.total_gpus = total_gpus + + # Metric aggregation rule configuration + self.aggregation_rules = self._init_aggregation_rules() + + def _init_aggregation_rules(self) -> dict[str, dict[str, list[str]]]: + """Initialize metrics aggregation rules""" + return { + # Time-Based metrics, can add metrics here + "time_sum": ["perf/time_per_step"], + "last": [ + "fully_async/count/total_generated_samples", + "fully_async/count/stale_samples_processed", + "fully_async/count/stale_trajectory_processed", + "fully_async/count/current_param_version", + "fully_async/count/dropped_stale_samples", + "training/global_step", # TODO change name to: total_step + ], + } + + def add_step_metrics(self, metrics: dict[str, Any], sample_count: int, timestamp: float = None): + """Adding a single-step metrics""" + if timestamp is None: + timestamp = time.time() + + self.sample_counts.append(sample_count) + self.timestamps.append(timestamp) + self.step_count += 1 + + # Store all metrics values + for key, value in metrics.items(): + if isinstance(value, int | float | np.number): + self.metric_values[key].append(float(value)) + elif isinstance(value, torch.Tensor): + self.metric_values[key].append(float(value.item())) + + def _get_aggregation_type(self, metric_name: str) -> str: + """Determine the aggregation type based on the metric name""" + for agg_type, metric_list in self.aggregation_rules.items(): + if metric_name in metric_list: + return agg_type + + metric_lower = metric_name.lower() + if any(keyword in metric_lower for keyword in ["timing_s/"]): + return "time_sum" + if any(keyword in metric_lower for keyword in ["mean", "avg", "average"]): + return "avg" + if any(keyword in metric_lower for keyword in ["max", "maximum"]): + return "max" + if any(keyword in metric_lower for keyword in ["min", "minimum"]): + return "min" + if any(keyword in metric_lower for keyword in ["sum", "total"]): + return "sum" + if any(keyword in metric_lower for keyword in ["weighted_avg"]): + return "weighted_avg" + + return "avg" + + def _aggregate_single_metric(self, metric_name: str, values: list[float]) -> float: + """Aggregating a single metric""" + if not values: + return 0.0 + + agg_type = self._get_aggregation_type(metric_name) + + if agg_type == "last": + return values[-1] + + elif agg_type == "weighted_avg": + # Weighted average + if len(values) != len(self.sample_counts): + # If the lengths do not match, use a simple average + return sum(values) / len(values) + + total_samples = sum(self.sample_counts) + if total_samples == 0: + return sum(values) / len(values) + + weighted_sum = sum(v * c for v, c in zip(values, self.sample_counts, strict=False)) + return weighted_sum / total_samples + + elif agg_type == "sum" or agg_type == "time_sum": + return sum(values) + + elif agg_type == "avg": + return sum(values) / len(values) + + elif agg_type == "max": + return max(values) + + elif agg_type == "min": + return min(values) + + else: + # Default average + return sum(values) / len(values) + + def get_aggregated_metrics(self) -> dict[str, Any]: + """aggregated metrics""" + t = time.time() + if self.step_count == 0: + return {} + + aggregated = {} + + # Aggregate all metrics + for metric_name, values in self.metric_values.items(): + aggregated[metric_name] = self._aggregate_single_metric(metric_name, values) + + # Aggregate special metrics + aggregated = self._special_metrics_aggergate(aggregated) + + print(f"aggregated metrics done. cost {time.time() - t}") + + return aggregated + + def _special_metrics_aggergate(self, aggregated: dict[str, Any]) -> dict[str, Any]: + """calculate special metrics""" + + # global_seqlen/minmax_diff + if "global_seqlen/minmax_diff" in aggregated.keys(): + aggregated["global_seqlen/minmax_diff"] = aggregated["global_seqlen/max"] - aggregated["global_seqlen/min"] + + # perf/throughput + REQUIRED_PERF_KEYS = {"perf/throughput", "perf/total_num_tokens", "perf/time_per_step"} + if REQUIRED_PERF_KEYS.issubset(aggregated): + aggregated["perf/throughput"] = aggregated["perf/total_num_tokens"] / ( + aggregated["perf/time_per_step"] * self.total_gpus + ) + + # trainer/idle_ratio + if "timing_s/gen" in aggregated.keys() and "timing_s/step" in aggregated.keys(): + aggregated["trainer/idle_ratio"] = aggregated["timing_s/gen"] / aggregated["timing_s/step"] + + return aggregated + + def reset(self): + """Reset Aggregator""" + self.metric_values.clear() + self.sample_counts.clear() + self.timestamps.clear() + self.step_count = 0 + + def get_current_stats(self) -> dict[str, Any]: + """Get statistics about the current aggregation state (for debugging)""" + return { + "step_count": self.step_count, + "metric_count": len(self.metric_values), + "total_samples": sum(self.sample_counts), + "metric_names": list(self.metric_values.keys()), + } diff --git a/recipe/fully_async_policy/fsdp_workers.py b/recipe/fully_async_policy/fsdp_workers.py new file mode 100644 index 000000000..ad6b0db8b --- /dev/null +++ b/recipe/fully_async_policy/fsdp_workers.py @@ -0,0 +1,136 @@ +# Copyright 2025 Bytedance Ltd. and/or its affiliates +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os + +import torch +import torch.distributed +from omegaconf import DictConfig +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + +from verl.single_controller.base.decorator import Dispatch, register +from verl.utils.device import ( + get_device_name, + get_torch_device, +) +from verl.utils.fsdp_utils import ( + fsdp_version, +) +from verl.workers.fsdp_workers import ActorRolloutRefWorker, AsyncActorRolloutRefWorker, CriticWorker + +logger = logging.getLogger(__file__) +logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN")) + +device_name = get_device_name() + +__all__ = ["DetachActorWorker", "DetachAsyncRolloutWorker", "CriticWorker"] + + +def get_inference_model(rollout): + """ + get models according to different types of inference_engine + Args: + rollout: rollout object + Returns: + model: model object + """ + inference_engine = rollout.inference_engine + if hasattr(inference_engine, "llm_engine"): + inference_model = inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model + elif hasattr(inference_engine, "worker"): + inference_model = inference_engine.worker.model_runner.model + else: + raise AttributeError( + f"Unsupported inference_engine type: {type(inference_engine)}. " + f"Expected LLM (with llm_engine attribute) or WorkerWrapperBase (with worker attribute)." + ) + return inference_model + + +class DetachNcclSync(AsyncActorRolloutRefWorker): + def _get_actor_params(self): + pass + + @register(dispatch_mode=Dispatch.ONE_TO_ALL, blocking=False) + def sync_rollout_weights(self): + assert (self._is_actor or self._is_rollout) and not self.config.hybrid_engine + assert hasattr(self, "_weights_info") and self._weights_info is not None + + params = self._get_actor_params() if self._is_actor else None + if self._is_rollout: + inference_model = get_inference_model(self.rollout) + + from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader + + patch_vllm_moe_model_weight_loader(inference_model) + for key, shape, dtype in self._weights_info: + tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device()) + if self._is_actor: + assert key in params + origin_data = params[key] + if hasattr(origin_data, "full_tensor"): + origin_data = origin_data.full_tensor() + if torch.distributed.get_rank() == 0: + tensor.copy_(origin_data) + from ray.util.collective import collective + + collective.broadcast(tensor, src_rank=0, group_name="actor_rollout") + if self._is_rollout: + inference_model.load_weights([(key, tensor)]) + get_torch_device().empty_cache() + + +class DetachActorWorker(DetachNcclSync): + def _get_actor_params(self): + assert self._is_actor + params = self.actor_module_fsdp.state_dict() + from verl.utils.model import convert_weight_keys + + params = convert_weight_keys( + params, getattr(self.actor_module_fsdp, "_fsdp_wrapped_module", self.actor_module_fsdp) + ) + return params + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def get_actor_weights_info(self): + assert self._is_actor + if hasattr(self, "_weights_info"): + return self._weights_info + if fsdp_version(self.actor_module_fsdp) == 1: + from torch.distributed.fsdp.api import ShardedStateDictConfig, StateDictType + + FSDP.set_state_dict_type( + self.actor_module_fsdp, + state_dict_type=StateDictType.SHARDED_STATE_DICT, + state_dict_config=ShardedStateDictConfig(), + ) + params = self._get_actor_params() + ret = [] + for key, tensor in params.items(): + ret.append((key, tensor.size(), tensor.dtype)) + self._weights_info = ret + return ret + + +class DetachAsyncRolloutWorker(DetachNcclSync): + def __init__(self, config: DictConfig, role: str): + print(f"[DetachAsyncRolloutWorker] {DetachAsyncRolloutWorker.__mro__}") + ActorRolloutRefWorker.__init__(self, config, role) + + @register(dispatch_mode=Dispatch.ONE_TO_ALL) + def set_actor_weights_info(self, weights_info): + assert self._is_rollout + self._weights_info = weights_info diff --git a/recipe/fully_async_policy/fully_async_main.py b/recipe/fully_async_policy/fully_async_main.py new file mode 100644 index 000000000..4dafd4484 --- /dev/null +++ b/recipe/fully_async_policy/fully_async_main.py @@ -0,0 +1,306 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import socket +import threading +from pprint import pprint + +import hydra +import ray +from omegaconf import OmegaConf + +from recipe.fully_async_policy.fully_async_rollouter import FullyAsyncRollouter +from recipe.fully_async_policy.fully_async_trainer import FullyAsyncTrainer +from recipe.fully_async_policy.message_queue import MessageQueue, MessageQueueClient +from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.trainer.ppo.reward import load_reward_manager +from verl.trainer.ppo.utils import Role +from verl.utils.fs import copy_to_local + + +def create_resource_pool_manager(config, roles: list) -> ResourcePoolManager: + """ + Create resource pool manager + + Args: + config: Configuration object + roles: List of roles that need to create resource pools + + Returns: + ResourcePoolManager: Resource pool manager + """ + resource_pool_spec = {} + mapping = {} + + # Actor/Critic resource pool + if any(role in roles for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]): + assert config.trainer.n_gpus_per_node > 0, "config.trainer.n_gpus_per_node must be greater than 0" + assert config.trainer.nnodes > 0, "config.trainer.nnodes must be greater than 0" + + trainer_pool = [config.trainer.n_gpus_per_node] * config.trainer.nnodes + resource_pool_spec["trainer_pool"] = trainer_pool + + # Map training-related roles to the same resource pool + for role in [Role.Actor, Role.Critic, Role.RefPolicy, Role.RewardModel]: + if role in roles: + mapping[role] = "trainer_pool" + + # Rollout resource pool + if Role.Rollout in roles: + assert config.rollout.n_gpus_per_node > 0, "config.rollout.n_gpus_per_node must be greater than 0" + assert config.rollout.nnodes > 0, "config.rollout.nnodes must be greater than 0" + + rollout_pool = [config.rollout.n_gpus_per_node] * config.rollout.nnodes + resource_pool_spec["rollout_pool"] = rollout_pool + mapping[Role.Rollout] = "rollout_pool" + + return ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping) + + +def create_role_worker_mapping(config): + """ + Create mapping from roles to worker classes + + Args: + config: Configuration object + + Returns: + dict: Mapping from roles to worker classes + """ + # Select worker class based on strategy + if config.actor_rollout_ref.actor.strategy == "fsdp2": + assert config.actor_rollout_ref.actor.strategy == config.critic.strategy + from recipe.fully_async_policy.fsdp_workers import ( + CriticWorker, + DetachActorWorker, + DetachAsyncRolloutWorker, + ) + from verl.single_controller.ray import RayWorkerGroup + + ray_worker_group_cls = RayWorkerGroup + + # TODO megatron support + else: + raise NotImplementedError(f"Unsupported strategy: {config.actor_rollout_ref.actor.strategy}") + + role_worker_mapping = { + Role.Actor: ray.remote(DetachActorWorker), + Role.Rollout: ray.remote(DetachAsyncRolloutWorker), + Role.Critic: ray.remote(CriticWorker), + } + + if config.reward_model.enable: + if config.reward_model.strategy == "fsdp2": + from verl.workers.fsdp_workers import RewardModelWorker + # TODO megatron support + else: + raise NotImplementedError(f"Unsupported reward model strategy: {config.reward_model.strategy}") + + role_worker_mapping[Role.RewardModel] = ray.remote(RewardModelWorker) + + # Add reference policy (if KL loss or reward is required) + if config.algorithm.use_kl_in_reward or config.actor_rollout_ref.actor.use_kl_loss: + role_worker_mapping[Role.RefPolicy] = ray.remote(DetachActorWorker) + + return role_worker_mapping, ray_worker_group_cls + + +@ray.remote(num_cpus=1) +class FullyAsyncTaskRunner: + """ + Ray remote class for executing distributed PPO training tasks. + """ + + def __init__(self): + self.running = False + self.components = {} + self.shutdown_event = threading.Event() + + def run(self, config): + print("[ASYNC MAIN] Starting fully async PPO training...") + self._initialize_components(config) + self._run_training_loop() + + def _initialize_components(self, config) -> None: + print(f"[ASYNC MAIN] TaskRunner hostname: {socket.gethostname()}, PID: {os.getpid()}") + pprint(OmegaConf.to_container(config, resolve=True)) + OmegaConf.resolve(config) + + print("[ASYNC MAIN] Initializing model and tokenizer...") + local_path = copy_to_local( + config.actor_rollout_ref.model.path, use_shm=config.actor_rollout_ref.model.get("use_shm", False) + ) + from verl.utils import hf_processor, hf_tokenizer + + trust_remote_code = config.data.get("trust_remote_code", False) + tokenizer = hf_tokenizer(local_path, trust_remote_code=trust_remote_code) + + # Used for multimodal LLM, could be None + processor = hf_processor(local_path, trust_remote_code=trust_remote_code, use_fast=True) + + self.components["tokenizer"] = tokenizer + self.components["processor"] = processor + self.components["config"] = config + + print("[ASYNC MAIN] Creating worker mapping and resource pools...") + role_worker_mapping, ray_worker_group_cls = create_role_worker_mapping(config) + self.components["role_worker_mapping"] = role_worker_mapping + self.components["ray_worker_group_cls"] = ray_worker_group_cls + + print("[ASYNC MAIN] Loading reward functions...") + reward_fn = load_reward_manager( + config, tokenizer, num_examine=0, **config.reward_model.get("reward_kwargs", {}) + ) + val_reward_fn = load_reward_manager( + config, tokenizer, num_examine=1, **config.reward_model.get("reward_kwargs", {}) + ) + self.components["reward_fn"] = reward_fn + self.components["val_reward_fn"] = val_reward_fn + + print("[ASYNC MAIN] Creating FullyAsyncRollouter...") + self._create_rollouter(config) + + print("[ASYNC MAIN] Creating FullyAsyncTrainer...") + self._create_trainer(config) + + # sync total_train_steps between rollouter and trainer + total_train_steps = ray.get(self.components["rollouter"].get_total_train_steps.remote()) + print(f"total_train_steps {total_train_steps}") + ray.get(self.components["trainer"].set_total_train_steps.remote(total_train_steps)) + + # max_queue_size + max_queue_size = ray.get(self.components["rollouter"].get_max_queue_size.remote()) + print(f"[ASYNC MAIN] Creating MessageQueue... max_queue_size {max_queue_size}") + message_queue = MessageQueue.remote(config, max_queue_size) + message_queue_client = MessageQueueClient(message_queue) + self.components["message_queue"] = message_queue + self.components["message_queue_client"] = message_queue_client + + ray.get(self.components["rollouter"].set_message_queue_client.remote(self.components["message_queue_client"])) + ray.get(self.components["trainer"].set_message_queue_client.remote(self.components["message_queue_client"])) + + print("[ASYNC MAIN] Setting up parameter synchronization...") + from recipe.fully_async_policy.param_sync import ParameterSynchronizer + + param_synchronizer = ParameterSynchronizer.remote( + config=config, + trainer=self.components["trainer"], + rollouter=self.components["rollouter"], + mq=self.components["message_queue_client"], + ) + ray.get(self.components["trainer"].set_parameter_synchronizer.remote(param_synchronizer)) + + # load checkpoint and sync parameter before doing anything + val_before_train = val_reward_fn is not None and config.trainer.get("val_before_train", True) + ray.get(self.components["trainer"].load_checkpoint.remote()) + ray.get(param_synchronizer.sync_weights.remote(version=0, validate=val_before_train)) + + self.components["param_synchronizer"] = param_synchronizer + print("[ASYNC MAIN] All components initialized successfully") + + def _create_rollouter(self, config) -> None: + rollouter = FullyAsyncRollouter.remote( + config=config, + tokenizer=self.components["tokenizer"], + role_worker_mapping={Role.Rollout: self.components["role_worker_mapping"][Role.Rollout]}, + resource_pool_manager=create_resource_pool_manager(config, roles=[Role.Rollout]), + ray_worker_group_cls=self.components["ray_worker_group_cls"], + processor=self.components["processor"], + reward_fn=self.components["reward_fn"], + val_reward_fn=self.components["val_reward_fn"], + device_name=config.trainer.device, + ) + + ray.get(rollouter.init_workers.remote()) + ray.get(rollouter.set_max_required_samples.remote()) + + self.components["rollouter"] = rollouter + print("[ASYNC MAIN] Rollouter created and initialized successfully") + + def _create_trainer(self, config) -> None: + trainer_role_mapping = { + role: worker_cls + for role, worker_cls in self.components["role_worker_mapping"].items() + if role != Role.Rollout + } + + trainer = FullyAsyncTrainer.remote( + config=config, + tokenizer=self.components["tokenizer"], + role_worker_mapping=trainer_role_mapping, + resource_pool_manager=create_resource_pool_manager(config, roles=list(trainer_role_mapping.keys())), + ray_worker_group_cls=self.components["ray_worker_group_cls"], + processor=self.components["processor"], + reward_fn=self.components["reward_fn"], + val_reward_fn=self.components["val_reward_fn"], + device_name=config.trainer.device, + ) + + ray.get(trainer.init_workers.remote()) + self.components["trainer"] = trainer + print("[ASYNC MAIN] FullyAsyncTrainer created and initialized successfully") + + def _run_training_loop(self): + self.running = True + + print("[ASYNC MAIN] Starting Rollouter and Trainer...") + rollouter_future = self.components["rollouter"].fit.remote() + trainer_future = self.components["trainer"].fit.remote() + + futures = [rollouter_future, trainer_future] + + try: + while futures: + # Use ray.wait to monitor all futures and return when any one is completed. + done_futures, remaining_futures = ray.wait(futures, num_returns=1, timeout=None) + + for future in done_futures: + try: + ray.get(future) + print("[ASYNC MAIN] One component completed successfully") + except Exception as e: + print(f"[ASYNC MAIN] Component failed with error: {e}") + for remaining_future in remaining_futures: + ray.cancel(remaining_future) + raise e + + futures = remaining_futures + + except Exception as e: + print(f"[ASYNC MAIN] Training failed: {e}") + for future in futures: + ray.cancel(future) + raise + finally: + self.components["message_queue_client"].clear_queue() + print("[ASYNC MAIN] Training completed or interrupted") + + +@hydra.main(config_path="config", config_name="fully_async_ppo_trainer", version_base=None) +def main(config): + from verl.trainer.main_ppo import run_ppo + + # Ensure async training config exists + if not hasattr(config, "async_training"): + raise RuntimeError("must set async_training config") + from time import time + + start_time = time() + run_ppo(config, task_runner_class=FullyAsyncTaskRunner) + print(f"total time: {time() - start_time:.2f} seconds") + + +if __name__ == "__main__": + main() diff --git a/recipe/fully_async_policy/fully_async_rollouter.py b/recipe/fully_async_policy/fully_async_rollouter.py new file mode 100644 index 000000000..503c6ae6d --- /dev/null +++ b/recipe/fully_async_policy/fully_async_rollouter.py @@ -0,0 +1,646 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import time +from pprint import pformat + +import ray +from ray import ObjectRef + +from recipe.fully_async_policy.detach_utils import ( + RolloutSample, + ValidateMetrics, + merge_rollout_sample, + prepare_single_generation_data, +) +from recipe.fully_async_policy.message_queue import MessageQueueClient +from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.trainer.ppo.utils import Role, WorkerType +from verl.utils.profiler import marked_timer +from verl.utils.tracking import ValidationGenerationsLogger + + +@ray.remote(num_cpus=10, max_concurrency=100) +class FullyAsyncRollouter(FullyAsyncRayPPOTrainer): + """ + Asynchronous sample generator, responsible for continuously generating training samples + and putting them into MessageQueue + Based on the mature implementation improvements of OneStepOffRayTrainer + """ + + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + device_name=None, + ): + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + + assert not self.hybrid_engine + assert self.config.data.train_batch_size == 0, "train_batch_size must be zero" + assert self.config.data.gen_batch_size == 1, "gen_batch_size must be one" + assert self.config.async_training.staleness_threshold >= 0, "staleness_threshold must larger than 0" + assert self.config.async_training.trigger_parameter_sync_step >= 1, ( + "trigger_parameter_sync_step must larger than 1" + ) + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + self.validation_generations_logger = ValidationGenerationsLogger( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + ) + + self.ref_in_actor = False + self.kl_ctrl_in_reward = False + self.use_critic = False + self.use_reference_policy = False + self.use_rm = False + + print("[FullyAsyncRollouter] Creating datasets...") + from verl.trainer.main_ppo import create_rl_dataset, create_rl_sampler + from verl.utils.dataset.rl_dataset import collate_fn + + train_dataset = create_rl_dataset(config.data.train_files, config.data, tokenizer, processor) + val_dataset = create_rl_dataset(config.data.val_files, config.data, tokenizer, processor) + train_sampler = create_rl_sampler(config.data, train_dataset) + + self._validate_config() + print(f"[FullyAsyncRollouter] Rollouter _create_dataloader...\n{train_dataset}\n{val_dataset}") + + self._create_dataloader(train_dataset, val_dataset, collate_fn, train_sampler) + + # ==================== fully async config ==================== + + self.total_rollout_steps = len(self.train_dataloader) * self.config.trainer.total_epochs + if self.config.rollout.total_rollout_steps is not None: + self.total_rollout_steps = min(self.config.rollout.total_rollout_steps, self.total_rollout_steps) + print(f"[FullyAsyncRollouter] Total rollout steps: {self.total_rollout_steps}") + self.total_train_steps = None + + # Rollouter parameter configuration + self.message_queue_client = None + + # Worker groups: rollout_wg is same to actor_rollout_wg + self.rollout_wg = None + self.actor_rollout_wg = None + self.async_rollout_manager = None + + # Config + self.staleness_threshold: float = config.async_training.get("staleness_threshold", 1) + # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples. + self.require_batches = config.async_training.require_batches + self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches + self.max_required_samples = None + self.max_concurrent_samples = None + # queue size + self.max_queue_size = None + + # Statistics + self.current_param_version = 0 + self.total_generated_samples = 0 + self.staleness_samples = 0 + self.dropped_stale_samples = 0 + self.processed_sample_count = 0 + self.global_steps = 0 + self.idle_start_time = None + self.version_start_time = None + + # Concurrency control + # Modified by self.pause() or self._should_pause_generation() + self.paused = False + self.running = True + self.monitor_loop_trigger = True + + # Initialize async locks directly + self.lock = asyncio.Lock() + self.condition = asyncio.Condition(self.lock) + + # Initialize async queues + self.pending_queue = asyncio.Queue(maxsize=128) + self.active_tasks = set() + self.result_queue = asyncio.Queue() + self.cancel_queue = asyncio.Queue() + + async def set_message_queue_client(self, message_queue_client: MessageQueueClient): + """Set message queue client""" + async with self.lock: + self.message_queue_client = message_queue_client + + async def set_max_required_samples(self): + async with self.lock: + self.max_required_samples = int( + self.required_samples + * (self.staleness_threshold + 1) + * self.config.async_training.trigger_parameter_sync_step + ) + self.total_train_steps = int( + self.total_rollout_steps + / (self.required_samples * self.config.async_training.trigger_parameter_sync_step) + ) + + self.max_concurrent_samples = len(self.async_rollout_manager.server_handles) * 16 + self.max_concurrent_samples = min(self.max_concurrent_samples, self.max_required_samples) + self.max_queue_size = self.max_required_samples + + print( + f"[FullyAsyncRollouter] required_samples : {self.required_samples} " + f"max_required_samples: {self.max_required_samples} " + f"max_queue_size: {self.max_queue_size} " + f"total_train_steps: {self.total_train_steps} " + f"total_rollout_steps: {self.total_rollout_steps} " + f"max_concurrent_samples: {self.max_concurrent_samples} " + ) + + def get_rollout_wg(self): + """Get rollout worker group""" + return self.rollout_wg + + def get_max_queue_size(self): + return self.max_queue_size + + def get_total_train_steps(self): + return self.total_train_steps + + async def update_param_version(self, version: int, validate: bool = False, global_steps: int = 0): + """Update current parameter version""" + async with self.lock: + old_version = self.current_param_version + self.current_param_version = version + # every time param change, reset staleness_samples + self.staleness_samples = ( + len(self.active_tasks) + + self.result_queue.qsize() + + self.cancel_queue.qsize() + + await self.message_queue_client.get_queue_size() + ) + timing_raw = {} + idle_ratio = None + if self.idle_start_time is not None and self.version_start_time is not None: + rollout_active_time = self.idle_start_time - self.version_start_time + rollout_version_time = time.time() - self.version_start_time + idle_ratio = 1 - rollout_active_time / rollout_version_time + timing_raw["rollouter/active_time"] = rollout_active_time + timing_raw["rollouter/version_time"] = rollout_version_time + timing_raw["rollouter/idle_ratio"] = idle_ratio + self.idle_start_time = None + print( + f"[FullyAsyncRollouter][Public][update_param_version] " + f"Parameter version updated from {old_version} to {version} " + f",reset staleness_samples to: {self.staleness_samples}" + f",idle_ratio: {idle_ratio}" + ) + val_metrics = None + if ( + self.val_reward_fn is not None + and self.config.rollout.test_freq > 0 + and self.current_param_version % self.config.rollout.test_freq == 0 + and self.current_param_version > 0 # don't test here in the initial parameter sync + ) or (validate and self.val_reward_fn is not None): + with marked_timer("rollouter/validate_time", timing_raw, color="green"): + val_metrics: dict = self._validate() + data = ValidateMetrics( + timing_raw=timing_raw, metrics=val_metrics, global_steps=global_steps, param_version=version + ) + await self.message_queue_client.put_validate(ray.cloudpickle.dumps(data)) + + self.version_start_time = time.time() + + def _validate_config(self): + # Validate asynchronous training configuration + if not hasattr(self.config, "async_training"): + raise ValueError("[FullyAsyncRollouter] Missing async_training configuration") + assert self.config.actor_rollout_ref.rollout.calculate_log_probs, "must rollout calculate log_probs" + + async def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self._init_resource_pools() + self._create_worker_classes() + self._init_worker_groups() + self._init_models() + await self._init_async_rollout_manager() + + def _create_actor_rollout_classes(self): + # only create rollout + for role in [Role.Rollout]: + resource_pool = self.resource_pool_manager.get_resource_pool(role) + role_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[role], + config=self.config.actor_rollout_ref, + role=str(role), + ) + self.resource_pool_to_cls[resource_pool][str(role)] = role_cls + + def _init_models(self): + self.rollout_wg = self.all_wg[str(Role.Rollout)] + self.rollout_wg.init_model() + self.actor_rollout_wg = self.rollout_wg + + def _create_continuous_iterator(self): + """ + Create a continuous data iterator across epoch + """ + for epoch in range(self.config.rollout.total_epochs): + iterator = iter(self.train_dataloader) + for batch_dict in iterator: + yield epoch, batch_dict + + async def _init_async_rollout_manager(self): + # create async rollout manager and request scheduler + assert self.config.actor_rollout_ref.rollout.mode == "async" + from recipe.fully_async_policy.agent_loop import FullyAsyncAgentLoopManager + + self.async_rollout_mode = True + self.async_rollout_manager = await FullyAsyncAgentLoopManager.create( + config=self.config, + worker_group=self.rollout_wg, + ) + + # Add samples to the pending_queue + async def _feed_samples(self): + continuous_iterator = self._create_continuous_iterator() + + for epoch, batch_dict in continuous_iterator: + # Similar to _prepare_generate_batch: Separate data + full_batch = prepare_single_generation_data( + batch_dict, self.global_steps, self.config.actor_rollout_ref.rollout.n + ) + + sample_id = f"sample_{epoch}_{self.global_steps}" + + rollout_sample = RolloutSample( + full_batch=full_batch, + agent_loop_output_list=[None] * self.config.actor_rollout_ref.rollout.n, + sample_id=sample_id, + epoch=epoch, + param_version=0, + param_version_start=[], + param_version_end=[], + processing_times=[], + rollout_status={}, + ) + + await self.pending_queue.put(rollout_sample) + + # Check if have reached the last step + if self.global_steps >= self.total_rollout_steps: + print( + f"[FullyAsyncRollouter][Feed] " + f"Maximum count has been reached, stop adding new samples" + f"{self.global_steps} >= {self.total_rollout_steps}" + ) + break + + self.global_steps += 1 + + # End signal + await self.pending_queue.put("DONE") + print(f"[FullyAsyncRollouter][Feed] Sample addition is complete, {self.global_steps} samples have been added") + + async def _processor_worker(self): + """ + Streaming worker coroutines, a sample is submitted for processing without waiting for batches + """ + while True: + if self.paused or await self._should_pause_generation(): + print( + "[FullyAsyncRollouter][Processor] Received pause signal, waiting for remaining tasks to return..." + ) + async with self.lock: + self.paused = True + while self.active_tasks: + async with self.lock: + # After acquiring the lock, the number of active_tasks may change, need to be verified again + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + + async with self.lock: + while self.paused: + self.idle_start_time = time.time() + await self.condition.wait() + continue + + simple_from_cancel_queue = False + if not self.cancel_queue.empty(): + rollout_sample = await self.cancel_queue.get() + simple_from_cancel_queue = True + else: + rollout_sample = await self.pending_queue.get() + self.staleness_samples += 1 + + if rollout_sample == "DONE": + print( + "[FullyAsyncRollouter][Processor] Received end signal, waiting for remaining tasks to complete..." + ) + while self.active_tasks: + async with self.lock: + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + break + + # Check whether the number of concurrent tasks exceeds the limit + while len(self.active_tasks) >= self.max_concurrent_samples: + async with self.lock: + if self.active_tasks: + done_tasks, self.active_tasks = await asyncio.wait( + self.active_tasks, return_when=asyncio.FIRST_COMPLETED + ) + for task in done_tasks: + await task + + # Submit single sample processing + async with self.lock: + # After the pause is over, the lock is acquired and it is necessary + # to determine whether it is the pause phase, otherwise continue to wait + while self.paused: + await self.condition.wait() + task = asyncio.create_task( + self._process_single_sample_streaming(rollout_sample), + name=rollout_sample.sample_id, + ) + self.active_tasks.add(task) + + if simple_from_cancel_queue: + self.cancel_queue.task_done() + else: + self.pending_queue.task_done() + + async def _process_single_sample_streaming(self, rollout_sample: RolloutSample): + """Process a single sample streamingly""" + # Calling asynchronous generation methods + rollout_sample.full_batch.non_tensor_batch["param_version"] = [self.current_param_version] * len( + rollout_sample.full_batch + ) + agent_loop_output_list = await self.async_rollout_manager.generate_single_sample_async( + rollout_sample.full_batch, rollout_sample.agent_loop_output_list + ) + rollout_sample.agent_loop_output_list = agent_loop_output_list + + is_cancel = False + for agent_loop in agent_loop_output_list: + if not is_cancel and agent_loop.is_cancel: + is_cancel = True + + if is_cancel: + # Put in the cancel queue and wait for the generation to resume + await self.cancel_queue.put(rollout_sample) + else: + # put into the result_queue + rollout_sample.param_version = self.current_param_version + rollout_sample.rollout_status = await self.get_statistics() + await self.result_queue.put(rollout_sample) + + self.processed_sample_count += 1 + + async def _consumer_worker(self): + """ + The consumer coroutine is responsible for obtaining the processing results + from the result queue and putting them into the message queue + """ + while True: + rollout_sample = await self.result_queue.get() + rollout_sample = merge_rollout_sample(self.config, self.tokenizer, rollout_sample) + + # Put RolloutSample into the message queue + success = await self.message_queue_client.put_sample( + sample=ray.cloudpickle.dumps(rollout_sample), + param_version=rollout_sample.param_version, + ) + if success: + self.total_generated_samples += 1 + else: + self.dropped_stale_samples += 1 + + self.result_queue.task_done() + + async def _streaming_generation_main(self): + """The main entry method for stream processing""" + + # we start from step 1 + self.global_steps += 1 + + if self.async_rollout_manager is None: + await self._init_async_rollout_manager() + + # Start the streaming loop + print(f"[FullyAsyncRollouter] Start streaming mode, maximum concurrent samples: {self.max_concurrent_samples}") + + # Start sample feed coroutine, streaming process coroutine and consumer coroutine + self.feed_task = asyncio.create_task(self._feed_samples()) + self.processor_task = asyncio.create_task(self._processor_worker()) + self.consumer_task = asyncio.create_task(self._consumer_worker()) + + try: + # Wait for sample feed to complete + await self.feed_task + print("[FullyAsyncRollouter] Sample feed completed") + + # Wait for streaming to complete + await self.processor_task + print("[FullyAsyncRollouter] Streaming process completed") + + # Waiting for the result queue to clear + await self.result_queue.join() + print("[FullyAsyncRollouter] Result queue cleared") + + except Exception as e: + print(f"[FullyAsyncRollouter] Streaming process exception:{e}") + + finally: + if self.processor_task: + self.processor_task.cancel() + if self.consumer_task: + self.consumer_task.cancel() + + await asyncio.gather(self.processor_task, self.consumer_task, return_exceptions=True) + + # Send a finish signal + await self.message_queue_client.put_sample( + sample=None, + param_version=self.current_param_version, + ) + + async with self.lock: + self.running = False + + async def fit(self): + """ + Start the async rollouter - entry point that sets up and runs async tasks + Main async fit method that coordinates all coroutines + """ + + print("[FullyAsyncRollouter] Starting FullyAsyncRollouter...") + + if self.message_queue_client is None: + raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.") + + # Set the running status flag + async with self.lock: + self.paused = False + self.running = True + + # Create the main asynchronous task + generation_task = asyncio.create_task(self._streaming_generation_main()) + monitor_task = asyncio.create_task(self._async_monitor_loop()) + + try: + # Run build and monitoring tasks concurrently + await asyncio.gather(generation_task, monitor_task, return_exceptions=True) + except Exception as e: + print(f"[FullyAsyncRollouter] Asynchronous task execution error: {e}") + finally: + if not generation_task.done(): + generation_task.cancel() + if not monitor_task.done(): + monitor_task.cancel() + + # Wait for the task to complete + await asyncio.gather(generation_task, monitor_task, return_exceptions=True) + + print("[FullyAsyncRollouter] Rollouter fit completed") + + async def _async_monitor_loop(self): + """ + Async coroutine for monitoring: + Function 1: Log information output + Function 2: Trigger rollout recovery + """ + last_stats_time = time.time() + stats_interval = 60.0 + check_interval = 10.0 + + while True: + async with self.lock: + if not self.running: + break + await asyncio.sleep(check_interval) + # Print statistics periodically + current_time = time.time() + if current_time - last_stats_time >= stats_interval: + stats = await self.get_statistics() + print(f"[FullyAsyncRollouter][MonitorLoop][Statistics] {pformat(stats)}") + last_stats_time = current_time + + # Trigger rollout recovery + if self.monitor_loop_trigger: + if not await self._should_pause_generation(): + async with self.lock: + self.paused = False + self.condition.notify_all() + + async def _should_pause_generation(self) -> bool: + """Determine whether the build should be paused""" + queue_stats = self.message_queue_client.get_statistics_sync() + queue_size = queue_stats["queue_size"] + + if queue_size >= self.max_queue_size: + if not self.paused: + print( + f"[FullyAsyncRollouter][ShouldPause] " + f"due to full queue: size={queue_size}, max={self.max_queue_size}" + ) + return True + + if self.staleness_samples >= self.max_required_samples: + if not self.paused: + print( + "[FullyAsyncRollouter][ShouldPause] " + f"due to " + f"staleness_samples {self.staleness_samples} >= max_required_samples {self.max_required_samples} " + ) + return True + + return False + + async def pause(self): + """pause rollout""" + print("[FullyAsyncRollouter][Public][Pause]") + async with self.lock: + self.paused = True + # Cancel all rollout tasks + if self.config.async_training.partial_rollout: + await self.async_rollout_manager.cancel() + if self.active_tasks: + await asyncio.gather(*self.active_tasks, return_exceptions=True) + self.active_tasks.clear() + print("[FullyAsyncRollouter][Public][Pause] All active tasks completed") + await self.async_rollout_manager.reset_prefix_cache() + self.monitor_loop_trigger = False + + async def resume(self, dependency_ref: ObjectRef = None): + if dependency_ref is not None: + ray.get(dependency_ref) + print("[FullyAsyncRollouter][Public][Resume]") + async with self.lock: + self.paused = False + self.monitor_loop_trigger = True + self.condition.notify_all() + + if self.config.async_training.partial_rollout: + await self.async_rollout_manager.resume() + + async def get_statistics(self) -> dict: + queue_stats = self.message_queue_client.get_statistics_sync() + + stats = { + # monitor stats + "monitor/active_tasks_size": len(self.active_tasks), + "monitor/queue/pending_queue_size": self.pending_queue.qsize(), + "monitor/queue/cancel_queue_size": self.cancel_queue.qsize(), + "monitor/queue/result_queue_size": self.result_queue.qsize(), + "monitor/queue/mq_queue_size": queue_stats["queue_size"], + # counting stats + "count/current_param_version": self.current_param_version, + "count/total_generated_samples": self.total_generated_samples, + "count/staleness_samples": self.staleness_samples, + "count/dropped_stale_samples": self.dropped_stale_samples, + # static stats + "static/max_required_samples": self.max_required_samples, + "static/required_samples": self.required_samples, + "static/staleness_threshold": self.staleness_threshold, + "static/max_queue_size": self.max_queue_size, + "static/max_concurrent_samples": self.max_concurrent_samples, + } + + return stats diff --git a/recipe/fully_async_policy/fully_async_trainer.py b/recipe/fully_async_policy/fully_async_trainer.py new file mode 100644 index 000000000..6693eac74 --- /dev/null +++ b/recipe/fully_async_policy/fully_async_trainer.py @@ -0,0 +1,360 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from datetime import datetime +from pprint import pprint +from typing import Any + +import ray +from omegaconf import OmegaConf +from tqdm import tqdm + +from recipe.fully_async_policy.detach_utils import ( + MetricsAggregator, + ValidateMetrics, + assemble_batch_from_rollout_samples, +) +from recipe.fully_async_policy.message_queue import MessageQueueClient +from recipe.fully_async_policy.ray_trainer import FullyAsyncRayPPOTrainer +from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup +from verl.trainer.ppo import core_algos +from verl.trainer.ppo.ray_trainer import ResourcePoolManager +from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model +from verl.utils.debug import marked_timer + + +@ray.remote(num_cpus=10) +class FullyAsyncTrainer(FullyAsyncRayPPOTrainer): + """ + A fully asynchronous PPO trainer that obtains samples from a MessageQueue for training. + Based on an improved implementation of OneStepOffRayTrainer + """ + + def __init__( + self, + config, + tokenizer, + role_worker_mapping: dict[Role, WorkerType], + resource_pool_manager: ResourcePoolManager, + ray_worker_group_cls: RayWorkerGroup = RayWorkerGroup, + processor=None, + reward_fn=None, + val_reward_fn=None, + device_name=None, + ): + # Store the tokenizer for text processing + self.tokenizer = tokenizer + self.processor = processor + self.config = config + self.reward_fn = reward_fn + self.val_reward_fn = val_reward_fn + + self.hybrid_engine = config.actor_rollout_ref.hybrid_engine + assert not self.hybrid_engine + + self.role_worker_mapping = role_worker_mapping + self.resource_pool_manager = resource_pool_manager + self.use_reference_policy = need_reference_policy(self.role_worker_mapping) + self.use_rm = need_reward_model(self.role_worker_mapping) + self.use_critic = need_critic(self.config) + self.ray_worker_group_cls = ray_worker_group_cls + self.device_name = device_name if device_name else self.config.trainer.device + + # if ref_in_actor is True, the reference policy will be actor without lora applied + self.ref_in_actor = config.actor_rollout_ref.model.get("lora_rank", 0) > 0 + + # define in-reward KL control + # kl loss control currently not suppoorted + if self.config.algorithm.use_kl_in_reward: + self.kl_ctrl_in_reward = core_algos.get_kl_controller(self.config.algorithm.kl_ctrl) + + # ==================== fully async config ==================== + + self.message_queue_client = None + self.param_synchronizer = None + + # Statistics + # we start from step 1 + self.global_steps = 1 + self.local_trigger_step = 1 + self.processed_samples = 0 + self.stale_samples_processed = 0 + self.stale_trajectory_processed = 0 + self.current_param_version = 0 + self.total_train_steps = None + self.progress_bar = None + self.trigger_parameter_sync_step = config.async_training.trigger_parameter_sync_step + + # required_samples use ppo_mini_batch_size*require_batches as the minimum number of samples. + self.require_batches = config.async_training.require_batches + self.required_samples = config.actor_rollout_ref.actor.ppo_mini_batch_size * self.require_batches + total_gpus = ( + config.trainer.nnodes * config.trainer.n_gpus_per_node + + config.rollout.nnodes * config.rollout.n_gpus_per_node + ) + self.metrics_aggregator = MetricsAggregator(total_gpus=total_gpus) + + def set_message_queue_client(self, message_queue_client: MessageQueueClient): + """Set message queue client""" + self.message_queue_client = message_queue_client + + def set_parameter_synchronizer(self, param_synchronizer): + """Set parameter synchronizer""" + self.param_synchronizer = param_synchronizer + + def set_total_train_steps(self, total_train_steps): + self.total_train_steps = total_train_steps + self.progress_bar = tqdm(total=self.total_train_steps, initial=0, desc="Training Progress") + + def get_actor_wg(self): + """Get actor worker group""" + return self.actor_wg + + def _get_samples_from_queue(self) -> tuple[None, None] | tuple[int, Any]: + """ + Get samples from message queue and compose gen_batch_output + Uses a loop to continuously collect samples until enough are gathered + + Returns: + tuple: (epoch, batch_dict, gen_batch_output) + """ + print( + f"[FullyAsyncTrainer] Requesting {self.required_samples} samples from queue", + flush=True, + ) + + # Collect samples using a simple loop calling get_sample + consumer_start = time.time() + queue_samples = [] + queue_len = 0 + while len(queue_samples) < self.required_samples: + # Get a single sample and wait until there is a sample or None is received + sample, queue_len = self.message_queue_client.get_sample_sync() + + if sample is None: + print( + f"[FullyAsyncTrainer] Detected termination signal (None), stopping sample collection. " + f"Collected {len(queue_samples)}/{self.required_samples} samples" + ) + break + + queue_samples.append(sample) + + if len(queue_samples) % 64 == 0: + print( + f"[FullyAsyncTrainer] Collected {len(queue_samples)}/{self.required_samples} samples. " + f"mq_len: {queue_len}" + ) + + consumer_end = time.time() + + if not queue_samples or len(queue_samples) < self.required_samples: + print("[FullyAsyncTrainer] not enough samples collected after loop") + return None, None + total_wait_time = consumer_end - consumer_start + + print( + f"[FullyAsyncTrainer] Loop collection completed: {len(queue_samples)}/{self.required_samples} samples, " + f"total wait time: {total_wait_time:.2f} seconds." + f"mq_len: {queue_len}" + ) + + queue_samples = [ray.cloudpickle.loads(x) for x in queue_samples] + # Assemble batch - now working directly with RolloutSample objects + if self.config.trainer.balance_batch: + batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, self._balance_batch) + else: + batch = assemble_batch_from_rollout_samples(queue_samples, self.tokenizer, self.config, None) + + batch.meta_info["fully_async/total_wait_time"] = total_wait_time + return 0, batch + + def _create_actor_rollout_classes(self): + # create actor + for role in [Role.Actor]: + resource_pool = self.resource_pool_manager.get_resource_pool(role) + role_cls = RayClassWithInitArgs( + cls=self.role_worker_mapping[role], + config=self.config.actor_rollout_ref, + role=str(role), + ) + self.resource_pool_to_cls[resource_pool][str(role)] = role_cls + + def _init_models(self): + if self.use_critic: + self.critic_wg = self.all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = self.all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + self.actor_wg = self.all_wg[str(Role.Actor)] + self.actor_wg.init_model() + self.actor_rollout_wg = self.actor_wg # to be compatible with the functions that not be modified + + def _init_async_rollout_manager(self): + pass + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + print("[FullyAsyncTrainer] Starting FullyAsyncTrainer...") + if self.message_queue_client is None: + raise ValueError("MessageQueue client not set. Call set_message_queue_client() first.") + if self.param_synchronizer is None: + raise ValueError("param_synchronizer client not set. Call set_parameter_synchronizer() first.") + + from verl.utils.tracking import Tracking + + self.logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.max_steps_duration = 0 + + # get validate data before training + if self.config.trainer.val_before_train and self.reward_fn is not None: + ray.get(self.param_synchronizer.wait_last_valid.remote()) + val_data = self.message_queue_client.get_validate_sync() + if val_data: + val_data: ValidateMetrics = ray.cloudpickle.loads(val_data) + if val_data.metrics: + self.logger.log(data=val_data.metrics, step=val_data.param_version) + pprint(f"[FullyAsyncTrainer] Initial validation metrics: {val_data.metrics}") + self.logger.log(data=val_data.timing_raw, step=val_data.param_version) + + # Use queue mode, no need for traditional dataloader iterator + # Initialize to get the first batch of data + while True: + metrics = {} + timing_raw = {} + + with marked_timer("step", timing_raw): + with marked_timer("gen", timing_raw, color="red"): + epoch, batch = self._get_samples_from_queue() + if batch is None: + break + self._collect_metrics_from_samples(batch, metrics) + + batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw) + self._log_rollout(batch, reward_extra_infos_dict, timing_raw) + self._check_save_checkpoint(False, timing_raw) + + self._collect_metrics(batch, 0, metrics, timing_raw) + self.metrics_aggregator.add_step_metrics( + metrics=metrics, sample_count=self.required_samples, timestamp=time.time() + ) + # Trigger parameter synchronization after training step + time_str = datetime.now().strftime("%H:%M:%S.%f")[:-3] + print( + f"[FullyAsyncTrainer] global_steps: {self.global_steps} " + f"local_trigger_step: {self.local_trigger_step} " + f"trigger_parameter_sync_step: {self.trigger_parameter_sync_step} " + f"{time_str}" + ) + self._trigger_parameter_sync_after_step(global_steps=self.global_steps) + val_data = self.message_queue_client.get_validate_sync() + if val_data: + val_data: ValidateMetrics = ray.cloudpickle.loads(val_data) + if val_data.metrics: + self.logger.log(data=val_data.metrics, step=val_data.param_version) + pprint( + f"[FullyAsyncTrainer] parameter version: {val_data.param_version} \ + Validation metrics: {val_data.metrics}" + ) + self.logger.log(data=val_data.timing_raw, step=val_data.param_version) + self.global_steps += 1 + + # final parameter sync and validate + if val_data is None or val_data.metrics is None: + self._trigger_parameter_sync_after_step(validate=True, global_steps=self.global_steps - 1) + ray.get(self.param_synchronizer.wait_last_valid.remote()) + val_data = self.message_queue_client.get_validate_sync() + if val_data: + val_data: ValidateMetrics = ray.cloudpickle.loads(val_data) + if val_data.metrics: + self.logger.log(data=val_data.metrics, step=val_data.param_version) + pprint(f"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}") + self.logger.log(data=val_data.timing_raw, step=val_data.param_version) + else: + pprint(f"[FullyAsyncTrainer] Final validation metrics: {val_data.metrics}") + self.progress_bar.close() + + self._check_save_checkpoint(True, timing_raw) # TODO: check checkpoint + + def load_checkpoint(self): + return self._load_checkpoint() + + def _collect_metrics_from_samples(self, batch, metrics): + """ + Collect metrics from samples + """ + if hasattr(batch, "meta_info") and batch.meta_info: + samples_param_versions = batch.meta_info["rollout_param_versions"] + stale_count = sum(1 for v in samples_param_versions if self.current_param_version - v >= 1) + self.stale_samples_processed += stale_count + trajectory_param_versions = batch.meta_info["trajectory_param_versions"] + stale_traj_count = sum(1 for v in trajectory_param_versions if self.current_param_version - v >= 1) + self.stale_trajectory_processed += stale_traj_count + metrics.update( + { + "fully_async/count/stale_samples_processed": self.stale_samples_processed, + "fully_async/count/stale_trajectory_processed": self.stale_trajectory_processed, + "fully_async/count/current_param_version": self.current_param_version, + } + ) + for key, value in batch.meta_info.items(): + if key.startswith("fully_async"): + metrics[key] = value + + def _trigger_parameter_sync_after_step(self, validate: bool = False, global_steps: int = None): + """ + Trigger parameter synchronization after training step + This ensures rollouter always uses the latest trained parameters + """ + if self.local_trigger_step < self.trigger_parameter_sync_step and not validate: + self.local_trigger_step += 1 + return + + self.current_param_version += 1 + self.local_trigger_step = 1 + self.logger.log( + data=self.metrics_aggregator.get_aggregated_metrics(), + step=self.current_param_version, + ) + self.progress_bar.update(1) + self.metrics_aggregator.reset() + timing_param_sync = {} + with marked_timer("timing_s/wait_last_valid", timing_param_sync): + ray.get(self.param_synchronizer.wait_last_valid.remote()) + with marked_timer("timing_s/param_sync", timing_param_sync): + ray.get( + self.param_synchronizer.sync_weights.remote( + self.current_param_version, validate=validate, global_steps=global_steps + ) + ) + self.logger.log(data=timing_param_sync, step=self.current_param_version) diff --git a/recipe/fully_async_policy/message_queue.py b/recipe/fully_async_policy/message_queue.py new file mode 100644 index 000000000..85860c6f2 --- /dev/null +++ b/recipe/fully_async_policy/message_queue.py @@ -0,0 +1,265 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import logging +from collections import deque +from typing import Any + +import ray +from omegaconf import DictConfig + +logger = logging.getLogger(__name__) + + +@ray.remote(num_cpus=2, max_concurrency=20) +class MessageQueue: + """ + Simplified Ray-based asynchronous message queue for communication between Rollouter and Trainer + """ + + def __init__(self, config: DictConfig, max_queue_size: int = 1000): + self.config = config + if max_queue_size is None: + raise ValueError(f"max_queue_size cannot be None, got: {max_queue_size}") + self.max_queue_size = int(max_queue_size) + self.queue = deque(maxlen=self.max_queue_size) + self.current_param_version = 0 + + self.val_queue = deque() + + try: + if hasattr(config, "async_training") and config.async_training is not None: + self.staleness_threshold = getattr(config.async_training, "staleness_threshold", 3) + else: + self.staleness_threshold = 3 + except (AttributeError, RecursionError): + self.staleness_threshold = 3 + + # Asyncio for message handling + self.running = True + + # async safe + self._lock = asyncio.Lock() + self._consumer_condition = asyncio.Condition(self._lock) + + # statistic message + self.total_produced = 0 + self.total_consumed = 0 + self.dropped_samples = 0 + + print( + f"[MessageQueue] initialized with max_queue_size={max_queue_size}," + f"staleness_threshold={self.staleness_threshold}" + ) + + async def put_sample(self, sample: Any, param_version: int) -> bool: + """ + Put a batch sample into the queue + + Args: + sample: Sample data + param_version: Parameter version number + + Returns: + bool: Whether the sample was successfully put into the queue + """ + async with self._lock: + # If queue is full, remove the oldest sample (rarely happens) + is_drop = False + if len(self.queue) >= self.max_queue_size: + self.queue.popleft() + self.dropped_samples += 1 + is_drop = True + logger.warning("Queue full, dropped sample") + self.queue.append(sample) + self.total_produced += 1 + + # Notify waiting consumers + self._consumer_condition.notify_all() + + if self.total_produced % 100 == 0: + print(f"MessageQueue stats: produced={self.total_produced}, queue_size={len(self.queue)}") + if is_drop: + return False + return True + + async def get_sample(self) -> Any | None: + """ + Get a single sample from the queue, wait until one is available + + Returns: + Any: Single sample data or None if queue is closed + """ + async with self._lock: + while len(self.queue) == 0 and self.running: + await self._consumer_condition.wait() + + # If queue is closed and empty, return None + if not self.running and len(self.queue) == 0: + return None + + # Get one sample + data = self.queue.popleft() + self.total_consumed += 1 + return data, len(self.queue) + + async def update_param_version(self, version: int): + """Update current parameter version""" + async with self._lock: + old_version = self.current_param_version + self.current_param_version = version + print(f"Parameter version updated from {old_version} to {version}") + + async def get_queue_size(self) -> int: + """Get current queue length""" + async with self._lock: + return len(self.queue) + + async def get_statistics(self) -> dict[str, Any]: + """Get queue statistics""" + async with self._lock: + return { + "queue_size": len(self.queue), + "total_produced": self.total_produced, + "total_consumed": self.total_consumed, + "dropped_samples": self.dropped_samples, + "current_param_version": self.current_param_version, + "staleness_threshold": self.staleness_threshold, + "max_queue_size": self.max_queue_size, + } + + async def clear_queue(self): + """Clear the queue""" + async with self._lock: + cleared_count = len(self.queue) + self.queue.clear() + logger.info(f"Cleared {cleared_count} samples from queue") + + async def shutdown(self): + """Shutdown the message queue""" + async with self._lock: + self.running = False + # Notify all waiting coroutines so they can exit + self._consumer_condition.notify_all() + logger.info("MessageQueue shutdown") + + async def get_memory_usage(self) -> dict: + """Get memory usage statistics""" + async with self._lock: + # Estimate memory usage of samples in queue + import sys + + total_size = 0 + sample_count = len(self.queue) + + if sample_count > 0: + # Estimate size of a single sample (simplified estimation) + sample = list(self.queue)[0] + try: + sample_size = sys.getsizeof(sample) + # Since we now store RolloutSample directly, estimate based on its components + if hasattr(sample, "original_batch_dict") and sample.original_batch_dict: + # Estimate batch data size + batch_data = sample.original_batch_dict.get("batch", {}) + sample_size += len(batch_data) * 1000 # Roughly estimate 1KB per batch entry + if hasattr(sample, "agent_loop_output"): + # Estimate AgentLoopOutput size + sample_size += 5000 # Roughly estimate 5KB for AgentLoopOutput + total_size = sample_size * sample_count + except Exception: + total_size = sample_count * 15000 # Roughly estimate 15KB per RolloutSample + + return { + "queue_samples": sample_count, + "estimated_memory_bytes": total_size, + "estimated_memory_mb": total_size / (1024 * 1024), + } + + async def put_validate(self, data): + async with self._lock: + self.val_queue.append(data) + + async def get_validate(self): + async with self._lock: + if self.val_queue: + return self.val_queue.popleft() + else: + return None + + +class MessageQueueClient: + """Asyncio-compatible MessageQueue client for communicating with MessageQueue Actor""" + + def __init__(self, queue_actor: Any): + self.queue_actor = queue_actor + + async def put_sample(self, sample: Any, param_version: int) -> bool: + """Put batch into queue (async)""" + future = self.queue_actor.put_sample.remote(sample, param_version) + return await asyncio.wrap_future(future.future()) + + async def put_validate(self, data: Any) -> bool: + future = self.queue_actor.put_validate.remote(data) + return await asyncio.wrap_future(future.future()) + + def get_validate_sync(self) -> Any | None: + return ray.get(self.queue_actor.get_validate.remote()) + + async def get_sample(self) -> Any | None: + """Get single sample from queue, wait until one is available (async)""" + future = self.queue_actor.get_sample.remote() + return await asyncio.wrap_future(future.future()) + + async def get_queue_size(self) -> int: + """Get queue size (async)""" + future = self.queue_actor.get_queue_size.remote() + return await asyncio.wrap_future(future.future()) + + async def get_statistics(self) -> dict[str, Any]: + """Get statistics (async)""" + future = self.queue_actor.get_statistics.remote() + return await asyncio.wrap_future(future.future()) + + async def clear_queue(self): + """Clear queue (async)""" + future = self.queue_actor.clear_queue.remote() + await asyncio.wrap_future(future.future()) + + async def shutdown(self): + """Shutdown queue (async)""" + future = self.queue_actor.shutdown.remote() + await asyncio.wrap_future(future.future()) + + async def get_memory_usage(self) -> dict: + """Get memory usage statistics (async)""" + future = self.queue_actor.get_memory_usage.remote() + return await asyncio.wrap_future(future.future()) + + # Synchronous version of the method (deprecated) + def put_sample_sync(self, sample: Any, param_version: int) -> bool: + """Put batch into queue (sync - deprecated, use put_sample instead)""" + return ray.get(self.queue_actor.put_sample.remote(sample, param_version)) + + def get_sample_sync(self) -> Any | None: + """Get single sample from queue (sync - deprecated, use get_sample instead)""" + return ray.get(self.queue_actor.get_sample.remote()) + + def get_statistics_sync(self) -> dict[str, Any]: + """Get statistics (sync - deprecated, use get_statistics instead)""" + return ray.get(self.queue_actor.get_statistics.remote()) + + def update_param_version_sync(self, version: int): + """Update parameter version (async)""" + return ray.get(self.queue_actor.update_param_version.remote(version)) diff --git a/recipe/fully_async_policy/param_sync.py b/recipe/fully_async_policy/param_sync.py new file mode 100644 index 000000000..d6c67ceb4 --- /dev/null +++ b/recipe/fully_async_policy/param_sync.py @@ -0,0 +1,105 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import time + +import ray +from ray.util.collective import collective + +logger = logging.getLogger(__name__) + + +@ray.remote +class ParameterSynchronizer: + """ + Unified parameter synchronizer, responsible for synchronizing model parameters between actor and rollout + Based on the mature synchronization mode implementation of one_step_off_policy + Merges the functions of the original multiple synchronizer classes + """ + + def __init__(self, config, trainer, rollouter, mq): + self.config = config + self.trainer = trainer + self.rollouter = rollouter + self.mq_client = mq + self.actor_wg = ray.get(trainer.get_actor_wg.remote()) + self.rollout_wg = ray.get(rollouter.get_rollout_wg.remote()) + + # Basic attributes + self.weights_info = None + self.sync_group_initialized = False + self.sync_group_name = "actor_rollout" + self.wait_last_update = None + self.wait_last_resume = None + + # Statistics + self.current_version = 0 + + self._init_weights_info() + self._init_sync_group() + + def get_current_param_version(self) -> int: + """Get current parameter version number""" + return self.current_version + + def get_weights_info(self): + """Get weights info""" + return self.weights_info + + def _init_weights_info(self): + self.weights_info = self.actor_wg.get_actor_weights_info()[0] + self.rollout_wg.set_actor_weights_info(self.weights_info) + + def _init_sync_group(self): + print("[ParameterSynchronizer] Initializing parameter synchronization group...") + actor_rollout_workers = self.actor_wg.workers + self.rollout_wg.workers + collective.create_collective_group( + actor_rollout_workers, + len(actor_rollout_workers), + list(range(0, len(actor_rollout_workers))), + backend="nccl", + group_name=self.sync_group_name, + ) + + def sync_weights(self, version, validate=False, global_steps=0): + """Sync weights between trainer and rollouter, and update parameter version""" + start_time = time.time() + + self.current_version = version + print(f"[ParameterSynchronizer] Starting weight synchronization (version {self.current_version})...") + + ray.get(self.rollouter.pause.remote()) + + # Update MQ version + self.mq_client.update_param_version_sync(version) + + # sync weights + self.actor_wg.sync_rollout_weights() + ray.get(self.rollout_wg.sync_rollout_weights()) + end_time = time.time() + print(f"[ParameterSynchronizer] sync_weights success. cost {end_time - start_time:.2f} seconds") + + # Async Update rollout version & validation + self.wait_last_update = self.rollouter.update_param_version.remote(version, validate, global_steps) + self.wait_last_resume = self.rollouter.resume.remote(self.wait_last_update) + + def wait_last_valid(self): + print("[ParameterSynchronizer] Waiting last sync and validate...") + start_time = time.time() + if self.wait_last_update: + ray.get(self.wait_last_update) + if self.wait_last_resume: + ray.get(self.wait_last_resume) + print(f"[ParameterSynchronizer] Wait last validate cost: {time.time() - start_time:.2f} seconds") diff --git a/recipe/fully_async_policy/ray_trainer.py b/recipe/fully_async_policy/ray_trainer.py new file mode 100644 index 000000000..b82d9fe0a --- /dev/null +++ b/recipe/fully_async_policy/ray_trainer.py @@ -0,0 +1,528 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# Copyright 2023-2024 SGLang Team +# Copyright 2025 ModelBest Inc. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +PPO Trainer with Ray-based single controller. +This trainer supports model-agonistic model initialization with huggingface +""" + +import uuid +from copy import deepcopy +from pprint import pprint + +import numpy as np +import ray +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from verl import DataProto +from verl.experimental.dataset.sampler import AbstractCurriculumSampler +from verl.single_controller.ray import RayClassWithInitArgs +from verl.single_controller.ray.base import create_colocated_worker_cls +from verl.trainer.ppo.core_algos import AdvantageEstimator, agg_loss +from verl.trainer.ppo.metric_utils import ( + compute_data_metrics, + compute_throughout_metrics, + compute_timing_metrics, +) +from verl.trainer.ppo.ray_trainer import RayPPOTrainer, apply_kl_penalty, compute_advantage, compute_response_mask +from verl.trainer.ppo.reward import compute_reward, compute_reward_async +from verl.trainer.ppo.utils import Role +from verl.utils.checkpoint.checkpoint_manager import should_save_ckpt_esi +from verl.utils.config import omega_conf_to_dataclass +from verl.utils.debug import marked_timer +from verl.utils.metric import ( + reduce_metrics, +) +from verl.utils.rollout_skip import RolloutSkip + + +class FullyAsyncRayPPOTrainer(RayPPOTrainer): + def init_workers(self): + """Initialize distributed training workers using Ray backend. + + Creates: + 1. Ray resource pools from configuration + 2. Worker groups for each role (actor, critic, etc.) + """ + self._init_resource_pools() + self._create_worker_classes() + self._init_worker_groups() + self._init_models() + self._init_async_rollout_manager() + + def _init_resource_pools(self): + self.resource_pool_manager.create_resource_pool() + + self.resource_pool_to_cls = {pool: {} for pool in self.resource_pool_manager.resource_pool_dict.values()} + + def _create_worker_classes(self): + self._create_actor_rollout_classes() + self._create_critic_class() + self._create_reference_policy_class() + self._create_reward_model_class() + + def _create_actor_rollout_classes(self): + raise NotImplementedError + + def _create_critic_class(self): + # create critic + if self.use_critic: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) + critic_cfg = omega_conf_to_dataclass(self.config.critic) + critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls + + def _create_reference_policy_class(self): + # create reference policy if needed + if self.use_reference_policy: + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RefPolicy) + ref_policy_cls = RayClassWithInitArgs( + self.role_worker_mapping[Role.RefPolicy], + config=self.config.actor_rollout_ref, + role=str(Role.RefPolicy), + profile_option=self.config.trainer.npu_profile.options, + ) + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls + + def _create_reward_model_class(self): + # create a reward model if reward_fn is None + if self.use_rm: + # we create a RM here + resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) + rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls + + def _init_worker_groups(self): + # initialize WorkerGroup + # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, + # you should not use `create_colocated_worker_cls`. + # Instead, directly pass different resource pool to different worker groups. + # See https://github.com/volcengine/verl/blob/master/examples/ray/tutorial.ipynb for more information. + all_wg = {} + wg_kwargs = {} # Setting up kwargs for RayWorkerGroup + if OmegaConf.select(self.config.trainer, "ray_wait_register_center_timeout") is not None: + wg_kwargs["ray_wait_register_center_timeout"] = self.config.trainer.ray_wait_register_center_timeout + if OmegaConf.select(self.config.global_profiler, "steps") is not None: + wg_kwargs["profile_steps"] = OmegaConf.select(self.config.global_profiler, "steps") + # Only require nsight worker options when tool is nsys + if OmegaConf.select(self.config.global_profiler, "tool") == "nsys": + assert ( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + is not None + ), "worker_nsight_options must be set when using nsys with profile_steps" + wg_kwargs["worker_nsight_options"] = OmegaConf.to_container( + OmegaConf.select(self.config.global_profiler.global_tool_config.nsys, "worker_nsight_options") + ) + wg_kwargs["device_name"] = self.device_name + + for resource_pool, class_dict in self.resource_pool_to_cls.items(): + worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict) + wg_dict = self.ray_worker_group_cls( + resource_pool=resource_pool, + ray_cls_with_init=worker_dict_cls, + **wg_kwargs, + ) + spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys()) + all_wg.update(spawn_wg) + self.all_wg = all_wg + + def _init_models(self): + if self.use_critic: + self.critic_wg = self.all_wg[str(Role.Critic)] + self.critic_wg.init_model() + + if self.use_reference_policy and not self.ref_in_actor: + self.ref_policy_wg = self.all_wg[str(Role.RefPolicy)] + self.ref_policy_wg.init_model() + + if self.use_rm: + self.rm_wg = self.all_wg[str(Role.RewardModel)] + self.rm_wg.init_model() + + # we should create rollout at the end so that vllm can have a better estimation of kv cache memory + self.actor_rollout_wg = self.all_wg[str(Role.ActorRollout)] + self.actor_rollout_wg.init_model() + + def _init_async_rollout_manager(self): + pass + + def fit(self): + """ + The training loop of PPO. + The driver process only need to call the compute functions of the worker group through RPC + to construct the PPO dataflow. + The light-weight advantage computation is done on the driver process. + """ + from omegaconf import OmegaConf + + from verl.utils.tracking import Tracking + + logger = Tracking( + project_name=self.config.trainer.project_name, + experiment_name=self.config.trainer.experiment_name, + default_backend=self.config.trainer.logger, + config=OmegaConf.to_container(self.config, resolve=True), + ) + + self.global_steps = 0 + + # load checkpoint before doing anything + self._load_checkpoint() + + # perform validation before training + # currently, we only support validation using the reward_function. + if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True): + val_metrics = self._validate() + assert val_metrics, f"{val_metrics=}" + pprint(f"Initial validation metrics: {val_metrics}") + logger.log(data=val_metrics, step=self.global_steps) + if self.config.trainer.get("val_only", False): + return + + if self.config.actor_rollout_ref.rollout.get("skip_rollout", False): + rollout_skip = RolloutSkip(self.config, self.actor_rollout_wg) + rollout_skip.wrap_generate_sequences() + + # add tqdm + progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress") + + # we start from step 1 + self.global_steps += 1 + last_val_metrics = None + self.max_steps_duration = 0 + + prev_step_profile = False + curr_step_profile = ( + self.global_steps in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + next_step_profile = False + + for epoch in range(self.config.trainer.total_epochs): + for batch_dict in self.train_dataloader: + metrics = {} + timing_raw = {} + + with marked_timer("start_profile", timing_raw): + self._start_profiling( + not prev_step_profile and curr_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + + batch, gen_batch = self._prepare_generate_batch(batch_dict) + + is_last_step = self.global_steps >= self.total_training_steps + + with marked_timer("step", timing_raw): + # generate a batch + with marked_timer("gen", timing_raw, color="red"): + if not self.async_rollout_mode: + gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) + else: + gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch) + timing_raw.update(gen_batch_output.meta_info["timing"]) + gen_batch_output.meta_info.pop("timing", None) + + if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX: + if self.reward_fn is None: + raise ValueError("A reward_fn is required for REMAX advantage estimation.") + + with marked_timer("gen_max", timing_raw, color="purple"): + gen_baseline_batch = deepcopy(gen_batch) + gen_baseline_batch.meta_info["do_sample"] = False + if not self.async_rollout_mode: + gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch) + else: + gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_batch) + batch = batch.union(gen_baseline_output) + reward_baseline_tensor = self.reward_fn(batch) + reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) + + batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) + + batch.batch["reward_baselines"] = reward_baseline_tensor + + del gen_baseline_batch, gen_baseline_output + + batch = self._post_generate_batch(batch, gen_batch_output, metrics) + batch, reward_extra_infos_dict = self._process_batch_common(batch, metrics, timing_raw) + self._log_rollout(batch, reward_extra_infos_dict, timing_raw) + + last_val_metrics = self._validate_metrics(is_last_step, last_val_metrics, metrics, timing_raw) + self._check_save_checkpoint(is_last_step, timing_raw) + + with marked_timer("stop_profile", timing_raw): + next_step_profile = ( + self.global_steps + 1 in self.config.global_profiler.steps + if self.config.global_profiler.steps is not None + else False + ) + self._stop_profiling( + curr_step_profile and not next_step_profile + if self.config.global_profiler.profile_continuous_steps + else curr_step_profile + ) + prev_step_profile = curr_step_profile + curr_step_profile = next_step_profile + + self._collect_metrics(batch, epoch, metrics, timing_raw) + self._post_batch_processing(batch) + + # TODO: make a canonical logger that supports various backend + logger.log(data=metrics, step=self.global_steps) + + progress_bar.update(1) + self.global_steps += 1 + + if ( + hasattr(self.config.actor_rollout_ref.actor, "profiler") + and self.config.actor_rollout_ref.actor.profiler.tool == "torch_memory" + ): + self.actor_rollout_wg.dump_memory_snapshot( + tag=f"post_update_step{self.global_steps}", sub_dir=f"step{self.global_steps}" + ) + + if is_last_step: + pprint(f"Final validation metrics: {last_val_metrics}") + progress_bar.close() + return + + def _prepare_generate_batch(self, batch_dict): + batch: DataProto = DataProto.from_single_dict(batch_dict) + + # add uid to batch + batch.non_tensor_batch["uid"] = np.array([str(uuid.uuid4()) for _ in range(len(batch.batch))], dtype=object) + + gen_batch = self._get_gen_batch(batch) + + # pass global_steps to trace + gen_batch.meta_info["global_steps"] = self.global_steps + gen_batch = gen_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + return batch, gen_batch + + def _post_generate_batch(self, batch, gen_batch_output, metrics): + # repeat to align with repeated responses in rollout + batch = batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True) + batch = batch.union(gen_batch_output) + + if "response_mask" not in batch.batch.keys(): + batch.batch["response_mask"] = compute_response_mask(batch) + # Balance the number of valid tokens across DP ranks. + # NOTE: This usually changes the order of data in the `batch`, + # which won't affect the advantage calculation (since it's based on uid), + # but might affect the loss calculation (due to the change of mini-batching). + # TODO: Decouple the DP balancing and mini-batching. + if self.config.trainer.balance_batch: + self._balance_batch(batch, metrics=metrics) + + # compute global_valid tokens + batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist() + + return batch + + def _process_batch_common(self, batch, metrics, timing_raw): + with marked_timer("reward", timing_raw, color="yellow"): + # compute reward model score + if self.use_rm: + reward_tensor = self.rm_wg.compute_rm_score(batch) + batch = batch.union(reward_tensor) + + if self.config.reward_model.launch_reward_fn_async: + future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) + else: + reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) + + # recompute old_log_probs + with marked_timer("old_log_prob", timing_raw, color="blue"): + async_training = self.config.get("async_training", None) + if async_training and async_training.use_rollout_log_probs: + batch.batch["old_log_probs"] = batch.batch["rollout_log_probs"] + batch.meta_info["temperature"] = self.config.actor_rollout_ref.rollout.temperature + + else: + old_log_prob = self.actor_rollout_wg.compute_log_prob(batch) + entropys = old_log_prob.batch["entropys"] + response_masks = batch.batch["response_mask"] + loss_agg_mode = self.config.actor_rollout_ref.actor.loss_agg_mode + entropy_agg = agg_loss(loss_mat=entropys, loss_mask=response_masks, loss_agg_mode=loss_agg_mode) + old_log_prob_metrics = {"actor/entropy": entropy_agg.detach().item()} + metrics.update(old_log_prob_metrics) + old_log_prob.batch.pop("entropys") + batch = batch.union(old_log_prob) + + if "rollout_log_probs" in batch.batch.keys(): + # TODO: we may want to add diff of probs too. + from verl.utils.debug.metrics import calculate_debug_metrics + + metrics.update(calculate_debug_metrics(batch)) + + if self.use_reference_policy: + # compute reference log_prob + with marked_timer("ref", timing_raw, color="olive"): + if not self.ref_in_actor: + ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) + else: + ref_log_prob = self.actor_rollout_wg.compute_ref_log_prob(batch) + batch = batch.union(ref_log_prob) + + # compute values + if self.use_critic: + with marked_timer("values", timing_raw, color="cyan"): + values = self.critic_wg.compute_values(batch) + batch = batch.union(values) + + with marked_timer("adv", timing_raw, color="brown"): + # we combine with rule-based rm + reward_extra_infos_dict: dict[str, list] + if self.config.reward_model.launch_reward_fn_async: + reward_tensor, reward_extra_infos_dict = ray.get(future_reward) + batch.batch["token_level_scores"] = reward_tensor + + if reward_extra_infos_dict: + batch.non_tensor_batch.update({k: np.array(v) for k, v in reward_extra_infos_dict.items()}) + + # compute rewards. apply_kl_penalty if available + if self.config.algorithm.use_kl_in_reward: + batch, kl_metrics = apply_kl_penalty( + batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty + ) + metrics.update(kl_metrics) + else: + batch.batch["token_level_rewards"] = batch.batch["token_level_scores"] + + # compute advantages, executed on the driver process + + norm_adv_by_std_in_grpo = self.config.algorithm.get( + "norm_adv_by_std_in_grpo", True + ) # GRPO adv normalization factor + + batch = compute_advantage( + batch, + adv_estimator=self.config.algorithm.adv_estimator, + gamma=self.config.algorithm.gamma, + lam=self.config.algorithm.lam, + num_repeat=self.config.actor_rollout_ref.rollout.n, + norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, + config=self.config.algorithm, + ) + + # update critic + if self.use_critic: + with marked_timer("update_critic", timing_raw, color="pink"): + critic_output = self.critic_wg.update_critic(batch) + critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"]) + metrics.update(critic_output_metrics) + + # implement critic warmup + if self.config.trainer.critic_warmup <= self.global_steps: + # update actor + with marked_timer("update_actor", timing_raw, color="red"): + batch.meta_info["multi_turn"] = self.config.actor_rollout_ref.rollout.multi_turn.enable + actor_output = self.actor_rollout_wg.update_actor(batch) + actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) + metrics.update(actor_output_metrics) + return batch, reward_extra_infos_dict + + def _log_rollout(self, batch, reward_extra_infos_dict, timing_raw): + # Log rollout generations if enabled + rollout_data_dir = self.config.trainer.get("rollout_data_dir", None) + if rollout_data_dir: + with marked_timer("dump_rollout_generations", timing_raw, color="green"): + inputs = self.tokenizer.batch_decode(batch.batch["prompts"], skip_special_tokens=True) + outputs = self.tokenizer.batch_decode(batch.batch["responses"], skip_special_tokens=True) + scores = batch.batch["token_level_scores"].sum(-1).cpu().tolist() + sample_gts = [item.non_tensor_batch.get("reward_model", {}).get("ground_truth", None) for item in batch] + + if "request_id" in batch.non_tensor_batch: + reward_extra_infos_dict.setdefault( + "request_id", + batch.non_tensor_batch["request_id"].tolist(), + ) + + self._dump_generations( + inputs=inputs, + outputs=outputs, + gts=sample_gts, + scores=scores, + reward_extra_infos_dict=reward_extra_infos_dict, + dump_path=rollout_data_dir, + ) + + def _validate_metrics(self, is_last_step, last_val_metrics, metrics, timing_raw): + if ( + self.val_reward_fn is not None + and self.config.trainer.test_freq > 0 + and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0) + ): + with marked_timer("testing", timing_raw, color="green"): + val_metrics: dict = self._validate() + if is_last_step: + last_val_metrics = val_metrics + metrics.update(val_metrics) + return last_val_metrics + + def _check_save_checkpoint(self, is_last_step, timing_raw): + # Check if the ESI (Elastic Server Instance)/training plan is close to expiration. + esi_close_to_expiration = should_save_ckpt_esi( + max_steps_duration=self.max_steps_duration, + redundant_time=self.config.trainer.esi_redundant_time, + ) + # Check if the conditions for saving a checkpoint are met. + # The conditions include a mandatory condition (1) and + # one of the following optional conditions (2/3/4): + # 1. The save frequency is set to a positive value. + # 2. It's the last training step. + # 3. The current step number is a multiple of the save frequency. + # 4. The ESI(Elastic Server Instance)/training plan is close to expiration. + if self.config.trainer.save_freq > 0 and ( + is_last_step or self.global_steps % self.config.trainer.save_freq == 0 or esi_close_to_expiration + ): + if esi_close_to_expiration: + print("Force saving checkpoint: ESI instance expiration approaching.") + with marked_timer("save_checkpoint", timing_raw, color="green"): + self._save_checkpoint() + + def _collect_metrics(self, batch, epoch, metrics, timing_raw): + steps_duration = timing_raw["step"] + self.max_steps_duration = max(self.max_steps_duration, steps_duration) + + # training metrics + metrics.update( + { + "training/global_step": self.global_steps, + "training/epoch": epoch, + } + ) + # collect metrics + metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) + # TODO: implement actual tflpo and theoretical tflpo + n_gpus = self.resource_pool_manager.get_n_gpus() + metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus)) + + def _post_batch_processing(self, batch: DataProto): + # this is experimental and may be changed/removed in the future in favor of a general-purpose one + if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler): + self.train_dataloader.sampler.update(batch=batch) + + # this is experimental and may be changed/removed in the future + # in favor of a general-purpose data buffer pool + if hasattr(self.train_dataset, "on_batch_end"): + # The dataset may be changed after each training batch + self.train_dataset.on_batch_end(batch=batch) diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh new file mode 100644 index 000000000..82072c3a0 --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_16-16.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_16-16' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=4 +sp_size=4 +fsdp_size=8 + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-2} +NNODES_TRAIN=${NNODES_TRAIN:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=20 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh new file mode 100644 index 000000000..ded0b0d42 --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_32_32.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_32-32' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=4 +sp_size=4 +fsdp_size=8 + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-4} +NNODES_TRAIN=${NNODES_TRAIN:-4} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=20 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh new file mode 100644 index 000000000..18888fd16 --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_12.sh @@ -0,0 +1,164 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-12' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=1 +sp_size=1 +fsdp_size=2 + +# Fully async specific parameters +NNODES=${NNODES:-2} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=2 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*100))) +test_freq=10 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.test_freq="${test_freq}" \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ No newline at end of file diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh new file mode 100644 index 000000000..bd56bdd42 --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_4_4.sh @@ -0,0 +1,164 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-4-4' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=1 +sp_size=1 +fsdp_size=2 + +# Fully async specific parameters +NNODES=${NNODES:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +n_gpus_rollout=4 +n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout)) + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*100))) +test_freq=10 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=False \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES}" \ + trainer.n_gpus_per_node="${n_gpus_training}" \ + rollout.nnodes="${NNODES}" \ + rollout.n_gpus_per_node="${n_gpus_rollout}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ No newline at end of file diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh new file mode 100644 index 000000000..c03e880ee --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_64_64.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='dapo_qwen2-7B-math_28k_fsdp2_fully-async_64-64' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 28)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=4 +sp_size=4 +fsdp_size=8 + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-8} +NNODES_TRAIN=${NNODES_TRAIN:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*400))) +test_freq=20 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True diff --git a/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh new file mode 100644 index 000000000..ab9c98b1f --- /dev/null +++ b/recipe/fully_async_policy/shell/dapo_7b_math_fsdp2_8_8.sh @@ -0,0 +1,162 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +project_name='DAPO' +exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-fully-async-8-8' + +# Ray +# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"} +# WORKING_DIR=${WORKING_DIR:-"${PWD}"} +# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"} +# Paths +RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"} +# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface +MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"} +CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"} +TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"} +TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"} + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=$((1024 * 2)) +max_response_length=$((1024 * 8)) +enable_overlong_buffer=True +overlong_buffer_len=$((1024 * 4)) +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Algorithm +temperature=1.0 +top_p=1.0 +top_k=-1 # 0 for HF rollout, -1 for vLLM rollout +val_top_p=0.7 + +# Performance Related Parameter +use_dynamic_bsz=True +actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2)) +infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3)) +ref_offload=True +actor_offload=False +gen_tp=1 +sp_size=1 +fsdp_size=2 + +# Fully async specific parameters +NNODES_ROLLOUT=${NNODES_ROLLOUT:-1} +NNODES_TRAIN=${NNODES_TRAIN:-1} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=32 +total_rollout_steps=$(((512*100))) +test_freq=10 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +require_batches=4 +partial_rollout=True + +python -m recipe.fully_async_policy.fully_async_main \ + data.train_files="${TRAIN_FILE}" \ + data.val_files="${TEST_FILE}" \ + data.prompt_key=prompt \ + data.truncation='left' \ + data.max_prompt_length=${max_prompt_length} \ + data.max_response_length=${max_response_length} \ + data.train_batch_size=${train_prompt_bsz} \ + data.gen_batch_size=${gen_prompt_bsz} \ + data.return_raw_chat=${return_raw_chat} \ + actor_rollout_ref.rollout.n=${n_resp_per_prompt} \ + algorithm.adv_estimator=${adv_estimator} \ + algorithm.use_kl_in_reward=${use_kl_in_reward} \ + algorithm.kl_ctrl.kl_coef=${kl_coef} \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \ + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \ + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \ + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \ + actor_rollout_ref.actor.clip_ratio_c=10.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.hybrid_engine=False \ + +actor_rollout_ref.model.override_config.max_position_embeddings=32768 \ + actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \ + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \ + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \ + actor_rollout_ref.model.path="${MODEL_PATH}" \ + actor_rollout_ref.actor.optim.lr=1e-6 \ + actor_rollout_ref.actor.optim.lr_warmup_steps=10 \ + actor_rollout_ref.actor.optim.weight_decay=0.1 \ + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.entropy_coeff=0 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.rollout.enable_chunked_prefill=True \ + actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \ + actor_rollout_ref.rollout.temperature=${temperature} \ + actor_rollout_ref.rollout.top_p=${top_p} \ + actor_rollout_ref.rollout.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \ + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \ + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \ + actor_rollout_ref.rollout.val_kwargs.do_sample=True \ + actor_rollout_ref.rollout.val_kwargs.n=1 \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \ + actor_rollout_ref.rollout.name=${rollout_name} \ + actor_rollout_ref.rollout.mode=${rollout_mode} \ + actor_rollout_ref.rollout.calculate_log_probs=True \ + reward_model.reward_manager=dapo \ + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \ + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False \ + +reward_model.reward_kwargs.max_resp_len=${max_response_length} \ + trainer.logger=['console','tensorboard'] \ + trainer.project_name="${project_name}" \ + trainer.experiment_name="${exp_name}" \ + trainer.val_before_train=True \ + trainer.save_freq=-1 \ + trainer.default_local_dir="${CKPTS_DIR}" \ + trainer.resume_mode=auto \ + trainer.nnodes="${NNODES_TRAIN}" \ + trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.nnodes="${NNODES_ROLLOUT}" \ + rollout.n_gpus_per_node="${NGPUS_PER_NODE}" \ + rollout.total_rollout_steps="${total_rollout_steps}" \ + rollout.total_epochs=10 \ + rollout.test_freq="${test_freq}" \ + async_training.staleness_threshold="${staleness_threshold}" \ + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" \ + async_training.require_batches="${require_batches}" \ + async_training.partial_rollout="${partial_rollout}" \ + async_training.use_rollout_log_probs=True \ No newline at end of file diff --git a/recipe/fully_async_policy/shell/runtime_env.yaml b/recipe/fully_async_policy/shell/runtime_env.yaml new file mode 100644 index 000000000..88467b8c2 --- /dev/null +++ b/recipe/fully_async_policy/shell/runtime_env.yaml @@ -0,0 +1,4 @@ +env_vars: + VLLM_USE_V1: "1" + NCCL_DEBUG: "INFO" + HYDRA_FULL_ERROR: "1" \ No newline at end of file diff --git a/recipe/fully_async_policy/unittest/simple_streaming_demo.py b/recipe/fully_async_policy/unittest/simple_streaming_demo.py new file mode 100644 index 000000000..209c2aae3 --- /dev/null +++ b/recipe/fully_async_policy/unittest/simple_streaming_demo.py @@ -0,0 +1,176 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import random +import time + + +class SimpleStreamingSystem: + """Simplified streaming system demonstration""" + + def __init__(self, max_concurrent_tasks: int = 4): + self.max_concurrent_tasks = max_concurrent_tasks + self.data_queue = asyncio.Queue() + self.result_queue = asyncio.Queue() + self.consumer_count = 0 + + # Data stream coroutine + async def data_stream(self): + # Add initial data + # Prepare test data + test_data = [{"id": f"task_{i}", "content": f"data_{i}"} for i in range(8)] + await self.add_data_stream(test_data) + + # Simulate subsequent data stream + await asyncio.sleep(3) + print("\nAdding second batch of data...") + extra_data = [{"id": f"extra_{i}", "content": f"extra_data_{i}"} for i in range(5)] + await self.add_data_stream(extra_data) + + # Send termination signal + await asyncio.sleep(1) + await self.data_queue.put("DONE") + print("Sending termination signal") + + async def add_data_stream(self, data_list: list[dict]): + """Simulate data stream""" + print("Starting to add data stream...") + + for i, data_item in enumerate(data_list): + await self.data_queue.put(data_item) + print(f"Data {data_item['id']} added to pending queue") + + # Simulate interval between data streams + if i < len(data_list) - 1: # Don't wait after the last item + await asyncio.sleep(0.8) + + print("Initial data stream added successfully") + + async def _process_data_async(self, data_item: dict): + """Asynchronously process a single data item""" + data_id = data_item["id"] + content = data_item["content"] + + # Simulate different processing times (1-3 seconds) + processing_time = random.uniform(1, 3) + + print(f" Starting to process {data_id}, estimated time {processing_time:.1f}s") + + # Asynchronously wait for processing completion + await asyncio.sleep(processing_time) + + result = { + "id": data_id, + "processed_content": f"Processed {content}", + "processing_time": round(processing_time, 2), + "completed_at": time.time(), + } + + # Immediately put into result queue + await self.result_queue.put(result) + print(f" {data_id} processing completed! (took {processing_time:.1f}s) -> Added to result queue") + + async def _submit_worker(self): + """Stream submission worker coroutine""" + active_tasks = set() + + print("Stream submitter started...") + + while True: + # Get data to process + data_item = await self.data_queue.get() + + if data_item == "DONE": + print("Received termination signal, waiting for remaining tasks to complete...") + if active_tasks: + await asyncio.gather(*active_tasks, return_exceptions=True) + break + + # Check concurrent limit + while len(active_tasks) >= self.max_concurrent_tasks: + print(f"Reached maximum concurrency {self.max_concurrent_tasks}, waiting for tasks to complete...") + done_tasks, active_tasks = await asyncio.wait(active_tasks, return_when=asyncio.FIRST_COMPLETED) + + # Clean up completed tasks + for task in done_tasks: + try: + await task + print(f"Task completed {task}") + except Exception as e: + print(f"Task execution failed: {e}") + + # Immediately submit new task + task = asyncio.create_task(self._process_data_async(data_item), name=f"active {data_item}") + active_tasks.add(task) + + print(f"Submitted task {data_item['id']}, current concurrency: {len(active_tasks)}") + + async def _consumer_worker(self): + """Result consumer coroutine""" + print("Consumer started...") + + while True: + try: + # Get processing result from result queue + result = await asyncio.wait_for(self.result_queue.get(), timeout=2.0) + + self.consumer_count += 1 + + print( + f"Consumed #{self.consumer_count}: {result['id']} " + f"(processing time {result['processing_time']}s) - {result['processed_content']}" + ) + + except asyncio.TimeoutError: + print(" Consumer waiting...") + await asyncio.sleep(0.5) + + async def run_demo(self): + """Run demonstration""" + print("=" * 60) + print(f"Maximum concurrency: {self.max_concurrent_tasks}") + print("=" * 60) + + # Start core coroutines + stream_task = asyncio.create_task(self.data_stream()) + submit_task = asyncio.create_task(self._submit_worker()) + consumer_task = asyncio.create_task(self._consumer_worker()) + + try: + # Wait for data stream to complete + await stream_task + print("Data stream completed") + + # Wait for processing to complete + await submit_task + print("All tasks processed") + + finally: + # Cleanup + submit_task.cancel() + consumer_task.cancel() + await asyncio.gather(submit_task, consumer_task, return_exceptions=True) + + print(f"\nFinal statistics: Consumed {self.consumer_count} results") + + +async def main(): + """Main function""" + system = SimpleStreamingSystem(max_concurrent_tasks=3) + await system.run_demo() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/recipe/fully_async_policy/vllm_rollout/__init__.py b/recipe/fully_async_policy/vllm_rollout/__init__.py new file mode 100644 index 000000000..9cd3ed5b8 --- /dev/null +++ b/recipe/fully_async_policy/vllm_rollout/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py b/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py new file mode 100644 index 000000000..93381e1bf --- /dev/null +++ b/recipe/fully_async_policy/vllm_rollout/vllm_async_server.py @@ -0,0 +1,154 @@ +# Copyright 2025 Meituan Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import asyncio +import logging +from typing import Any, Optional, Sequence + +import ray +from ray.actor import ActorHandle +from vllm import SamplingParams +from vllm.inputs import TokensPrompt +from vllm.outputs import RequestOutput + +from verl.workers.config import HFModelConfig, RewardModelConfig, RolloutConfig +from verl.workers.rollout.replica import RolloutMode +from verl.workers.rollout.vllm_rollout.vllm_async_server import ( + _qwen2_5_vl_dedup_image_tokens, + vLLMHttpServerBase, + vLLMReplica, +) + +logger = logging.getLogger(__file__) +logger.setLevel(logging.INFO) + + +@ray.remote(num_cpus=1) +class vLLMHttpServerForPartial(vLLMHttpServerBase): + def __init__( + self, + config: RolloutConfig | RewardModelConfig, + model_config: HFModelConfig, + rollout_mode: RolloutMode, + workers: list[ActorHandle], + replica_rank: int, + node_rank: int, + gpus_per_node: int, + nnodes: int, + ): + super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes) + + # for cancel LLMServer + self.paused = False + self.lock = asyncio.Lock() + self.cancel_event: dict[str, asyncio.Event] = {} + self.req_output: dict[str, Optional[RequestOutput]] = {} + + async def _generate_step( + self, + prompt_ids: list[int], + sampling_params: dict[str, Any], + request_id: str, + image_data: Optional[list[Any]] = None, + ): + max_tokens = self.config.max_model_len - len(prompt_ids) + sampling_params["logprobs"] = 1 + sampling_params.setdefault("repetition_penalty", self.config.get("repetition_penalty", 1.0)) + sampling_params = SamplingParams(max_tokens=max_tokens, **sampling_params) + prompt_ids = _qwen2_5_vl_dedup_image_tokens(prompt_ids, self.model_config.processor) + prompt = TokensPrompt( + prompt_token_ids=prompt_ids, multi_modal_data={"image": image_data} if image_data else None + ) + generator = self.engine.generate(prompt=prompt, sampling_params=sampling_params, request_id=request_id) + + # Get final response + self.req_output[request_id]: Optional[RequestOutput] = None + async for output in generator: + self.req_output[request_id] = output + assert self.req_output[request_id] is not None + + async def generate_for_partial( + self, + prompt_ids: list[int], + sampling_params: dict[str, Any], + request_id: str, + image_data: Optional[list[Any]] = None, + ) -> tuple[list[Any], list[Any], bool] | tuple[Sequence[int], list[float], Any]: + async with self.lock: + if self.paused: + # After cancel, all tasks will return directly and wait for the next submission + return [], [], True + self.cancel_event[request_id] = asyncio.Event() + cancel_handle = asyncio.create_task(self.cancel_event[request_id].wait()) + generation_handle = asyncio.create_task( + self._generate_step(prompt_ids, sampling_params, request_id, image_data) + ) + + done, pend = await asyncio.wait([generation_handle, cancel_handle], return_when=asyncio.FIRST_COMPLETED) + + for task in done: + await task + + for task in pend: + task.cancel() + + async with self.lock: + token_ids = self.req_output[request_id].outputs[0].token_ids + log_probs: list[float] = [] + for i, x in enumerate(self.req_output[request_id].outputs[0].logprobs): + # In sampling_params, logprobs is set to 1, which should return 1, + # but in practice there are multiple. Take the log_prob corresponding to token_id + token_id = self.req_output[request_id].outputs[0].token_ids[i] + log_probs.append(x[token_id].logprob) + is_cancel = generation_handle not in done + self.cancel_event.pop(request_id, None) + self.req_output.pop(request_id, None) + return token_ids, log_probs, is_cancel + + async def cancel(self): + async with self.lock: + self.paused = True + for request_id in self.cancel_event: + self.cancel_event[request_id].set() + + async def resume(self): + async with self.lock: + self.paused = False + + async def reset_prefix_cache(self): + async with self.lock: + await self.engine.reset_prefix_cache() + + +class FullyAsyncvLLMReplica(vLLMReplica): + def __init__( + self, + replica_rank: int, + config: RolloutConfig | RewardModelConfig, + model_config: HFModelConfig, + gpus_per_node: int = 8, + ): + super().__init__(replica_rank, config, model_config, gpus_per_node) + self.server_class = vLLMHttpServerForPartial + + async def cancel(self): + """Cancel each rollout server.""" + await asyncio.gather(*[server.cancel.remote() for server in self.servers]) + + async def resume(self): + """Resume each rollout server.""" + await asyncio.gather(*[server.resume.remote() for server in self.servers]) + + async def reset_prefix_cache(self): + """reset kv cache in each rollout server.""" + await asyncio.gather(*[server.reset_prefix_cache.remote() for server in self.servers]) diff --git a/tests/special_e2e/run_fully_async_policy.sh b/tests/special_e2e/run_fully_async_policy.sh new file mode 100644 index 000000000..a2f99f0d6 --- /dev/null +++ b/tests/special_e2e/run_fully_async_policy.sh @@ -0,0 +1,196 @@ +#!/usr/bin/env bash +set -xeuo pipefail + +# Test script for fully_async_policy E2E regression testing +# This script runs fully async PPO training with both FSDP2 and Megatron backends +# to ensure the asynchronous training mechanism works correctly + +NUM_GPUS=${NUM_GPUS:-8} +ACTOR_STRATEGY=${ACTOR_STRATEGY:-"fsdp2"} # fsdp2 or megatron + +# Download model if not exists +MODEL_ID=${MODEL_ID:-Qwen/Qwen2.5-0.5B-Instruct} +MODEL_PATH=${MODEL_PATH:-${HOME}/models/${MODEL_ID}} +huggingface-cli download "${MODEL_ID}" --local-dir "${MODEL_PATH}" + + +rollout_mode="async" +rollout_name="vllm" # sglang or vllm +if [ "$rollout_mode" = "async" ]; then + export VLLM_USE_V1=1 + return_raw_chat="True" +fi + +# Algorithm parameters +adv_estimator=grpo + +use_kl_in_reward=False +kl_coef=0.0 +use_kl_loss=False +kl_loss_coef=0.0 + +clip_ratio_low=0.2 +clip_ratio_high=0.28 + +# Response length parameters +max_prompt_length=1024 +max_response_length=2048 +enable_overlong_buffer=True +overlong_buffer_len=128 +overlong_penalty_factor=1.0 + +# Training parameters +loss_agg_mode="token-mean" + +# Temperature parameters +temperature=1.0 +top_p=1.0 +top_k=-1 +val_top_p=0.7 + +# Fully async specific parameters +n_gpus_rollout=4 +n_gpus_training=4 + +train_prompt_bsz=0 +gen_prompt_bsz=1 +n_resp_per_prompt=16 +train_prompt_mini_bsz=16 +total_rollout_steps=$(((128))) +test_freq=-1 +staleness_threshold=0.1 +trigger_parameter_sync_step=4 +partial_rollout=True + +exp_name="$(basename "${MODEL_ID,,}")-fully-async-policy-${ACTOR_STRATEGY}-minimal" + +echo "Running fully_async_policy with ${ACTOR_STRATEGY} strategy" +echo "Total GPUs: ${NUM_GPUS}, Rollout GPUs: ${n_gpus_rollout}, Training GPUs: ${n_gpus_training}" + +# Common parameters for both FSDP2 and Megatron +common_params=( + data.train_files="${HOME}/data/gsm8k/train.parquet" + data.val_files="${HOME}/data/gsm8k/test.parquet" + data.prompt_key=prompt + data.truncation='left' + data.max_prompt_length=${max_prompt_length} + data.max_response_length=${max_response_length} + data.train_batch_size=${train_prompt_bsz} + data.gen_batch_size=${gen_prompt_bsz} + data.return_raw_chat=${return_raw_chat} + actor_rollout_ref.rollout.n=${n_resp_per_prompt} + actor_rollout_ref.rollout.calculate_log_probs=True + algorithm.adv_estimator=${adv_estimator} + algorithm.use_kl_in_reward=${use_kl_in_reward} + algorithm.kl_ctrl.kl_coef=${kl_coef} + actor_rollout_ref.hybrid_engine=False + actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} + actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} + actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} + actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} + actor_rollout_ref.actor.clip_ratio_c=10.0 + actor_rollout_ref.model.path="${MODEL_PATH}" + actor_rollout_ref.model.enable_gradient_checkpointing=True + actor_rollout_ref.actor.optim.lr=1e-6 + actor_rollout_ref.actor.optim.lr_warmup_steps=-1 + actor_rollout_ref.actor.optim.weight_decay=0.1 + actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} + actor_rollout_ref.rollout.gpu_memory_utilization=0.80 + actor_rollout_ref.rollout.temperature=${temperature} + actor_rollout_ref.rollout.top_p=${top_p} + actor_rollout_ref.rollout.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} + actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} + actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.n=1 + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.name=${rollout_name} + actor_rollout_ref.rollout.mode=${rollout_mode} + reward_model.reward_manager=dapo + +reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} + +reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} + +reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} + +reward_model.reward_kwargs.overlong_buffer_cfg.log=False + +reward_model.reward_kwargs.max_resp_len=${max_response_length} + trainer.logger=['console'] + trainer.project_name='verl-test-fully-async' + trainer.experiment_name="${exp_name}" + trainer.val_before_train=True + trainer.save_freq=-1 + trainer.resume_mode=disable + trainer.nnodes=1 + trainer.n_gpus_per_node=${n_gpus_training} + rollout.nnodes=1 + rollout.n_gpus_per_node=${n_gpus_rollout} + rollout.total_rollout_steps=${total_rollout_steps} + rollout.total_epochs=2 + rollout.test_freq=${test_freq} + # Fully async specific configurations + async_training.staleness_threshold=${staleness_threshold} + async_training.partial_rollout="${partial_rollout}" + async_training.trigger_parameter_sync_step="${trigger_parameter_sync_step}" +) + +if [ "${ACTOR_STRATEGY}" == "fsdp2" ]; then + echo "Running fully async training with FSDP2 strategy..." + # FSDP2 specific parameters + gen_tp=1 + sp_size=1 + fsdp_size=1 + ref_offload=True + actor_offload=False + + python3 -m recipe.fully_async_policy.fully_async_main \ + "${common_params[@]}" \ + actor_rollout_ref.actor.strategy=fsdp2 \ + critic.strategy=fsdp2 \ + actor_rollout_ref.actor.grad_clip=1.0 \ + actor_rollout_ref.model.use_remove_padding=True \ + actor_rollout_ref.actor.use_dynamic_bsz=True \ + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \ + actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \ + actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \ + actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \ + actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} $@ + +elif [ "${ACTOR_STRATEGY}" == "megatron" ]; then + echo "Running fully async training with Megatron strategy..." + # Megatron specific parameters + gen_tp=2 + train_tp=1 + train_pp=2 + ref_offload=True + actor_offload=False + + python3 -m recipe.fully_async_policy.fully_async_main \ + --config-path=config \ + --config-name='fully_async_ppo_megatron_trainer.yaml' \ + "${common_params[@]}" \ + actor_rollout_ref.actor.strategy=megatron \ + critic.strategy=megatron \ + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=2 \ + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=4 \ + actor_rollout_ref.actor.megatron.param_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.optimizer_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.grad_offload=${actor_offload} \ + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \ + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${train_pp} \ + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${train_tp} \ + actor_rollout_ref.ref.megatron.param_offload=${ref_offload} $@ +else + echo "Error: Unknown strategy ${ACTOR_STRATEGY}. Please use 'fsdp2' or 'megatron'" + exit 1 +fi + +echo "Fully async policy E2E test completed successfully with ${ACTOR_STRATEGY} strategy" + diff --git a/tests/special_sanity/check_device_api_usage.py b/tests/special_sanity/check_device_api_usage.py index dae5ac4b4..8d3cfda27 100644 --- a/tests/special_sanity/check_device_api_usage.py +++ b/tests/special_sanity/check_device_api_usage.py @@ -48,6 +48,7 @@ CUDA_KEYWORD_CHECK_WHITELIST = [ NCCL_KEYWORD_CHECK_WHITELIST = [ "verl/utils/device.py", "verl/third_party/sglang/parallel_state.py", # appear in default backend + "verl/recipe/fully_async_policy/param_sync.py", # fully_async_policy in default backend ] SEARCH_WHITELIST = CUDA_KEYWORD_CHECK_WHITELIST + NCCL_KEYWORD_CHECK_WHITELIST diff --git a/tests/special_sanity/check_license.py b/tests/special_sanity/check_license.py index a4ade0244..7e099ccb3 100644 --- a/tests/special_sanity/check_license.py +++ b/tests/special_sanity/check_license.py @@ -24,6 +24,7 @@ license_head_sglang = "Copyright 2023-2024 SGLang Team" license_head_modelbest = "Copyright 2025 ModelBest Inc. and/or its affiliates" license_head_amazon = "Copyright 2025 Amazon.com Inc and/or its affiliates" license_head_facebook = "Copyright (c) 2016- Facebook, Inc" +license_head_meituan = "Copyright 2025 Meituan Ltd. and/or its affiliates" license_headers = [ license_head_bytedance, license_head_bytedance_25, @@ -33,6 +34,7 @@ license_headers = [ license_head_modelbest, license_head_amazon, license_head_facebook, + license_head_meituan, ] diff --git a/verl/experimental/agent_loop/__init__.py b/verl/experimental/agent_loop/__init__.py index 88d61ee41..d43683df3 100644 --- a/verl/experimental/agent_loop/__init__.py +++ b/verl/experimental/agent_loop/__init__.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .agent_loop import AgentLoopBase, AgentLoopManager, AsyncLLMServerManager +from .agent_loop import AgentLoopBase, AgentLoopManager, AgentLoopWorker, AsyncLLMServerManager from .single_turn_agent_loop import SingleTurnAgentLoop from .tool_agent_loop import ToolAgentLoop _ = [SingleTurnAgentLoop, ToolAgentLoop] -__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager"] +__all__ = ["AgentLoopBase", "AgentLoopManager", "AsyncLLMServerManager", "AgentLoopWorker"] diff --git a/verl/experimental/agent_loop/agent_loop.py b/verl/experimental/agent_loop/agent_loop.py index 5a699da0e..fe54edd51 100644 --- a/verl/experimental/agent_loop/agent_loop.py +++ b/verl/experimental/agent_loop/agent_loop.py @@ -370,8 +370,7 @@ class RewardManagerWorker: return self.reward_manager(data, return_dict) -@ray.remote -class AgentLoopWorker: +class AgentLoopWorkerBase: """Agent loop worker takes a batch of messages and run each message in an agent loop.""" def __init__( @@ -384,7 +383,11 @@ class AgentLoopWorker: server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. """ self.config = config - self.server_manager = AsyncLLMServerManager(config, server_handles) + + # for recipe to change + if not hasattr(self, "server_manager"): + self.server_manager = AsyncLLMServerManager(config, server_handles) + self.rm_executor = rm_executor model_path = config.actor_rollout_ref.model.path @@ -720,6 +723,22 @@ class AgentLoopWorker: ) +@ray.remote +class AgentLoopWorker(AgentLoopWorkerBase): + """Agent loop worker takes a batch of messages and run each message in an agent loop.""" + + def __init__( + self, config: DictConfig, server_handles: list[ray.actor.ActorHandle], rm_executor: BatchExecutor = None + ): + """Initialize agent loop manager. + + Args: + config (DictConfig): YAML config. + server_handles (List[ray.actor.ActorHandle]): OpenAI compatible LLM server actor handles. + """ + super().__init__(config, server_handles, rm_executor) + + async def get_trajectory_info(step, index, validate): """Get trajectory info. @@ -778,6 +797,12 @@ class AgentLoopManager: self.rm_micro_batch_size = rm_wg.world_size + # for recipe to change + if not hasattr(self, "rollout_replica_class"): + self.rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name) + if not hasattr(self, "agent_loop_workers_class"): + self.agent_loop_workers_class = AgentLoopWorker + self._initialize_llm_servers() self._init_agent_loop_workers() @@ -798,11 +823,10 @@ class AgentLoopManager: ) num_replicas = world_size // rollout_world_size - rollout_replica_class = get_rollout_replica_class(self.config.actor_rollout_ref.rollout.name) rollout_config = self.config.actor_rollout_ref.rollout model_config = self.config.actor_rollout_ref.model self.rollout_replicas = [ - rollout_replica_class( + self.rollout_replica_class( replica_rank=replica_rank, config=rollout_config, model_config=model_config, @@ -826,7 +850,7 @@ class AgentLoopManager: # Round-robin scheduling over the all nodes node_id = node_ids[i % len(node_ids)] self.agent_loop_workers.append( - AgentLoopWorker.options( + self.agent_loop_workers_class.options( name=f"agent_loop_worker_{i}", scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id, soft=True diff --git a/verl/trainer/config/actor/dp_actor.yaml b/verl/trainer/config/actor/dp_actor.yaml index ab27304f7..a2ff54d48 100644 --- a/verl/trainer/config/actor/dp_actor.yaml +++ b/verl/trainer/config/actor/dp_actor.yaml @@ -39,4 +39,4 @@ entropy_from_logits_with_chunking: False entropy_checkpointing: False # Whether to remove padding tokens in inputs during training -use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} +use_remove_padding: ${oc.select:actor_rollout_ref.model.use_remove_padding,false} \ No newline at end of file diff --git a/verl/trainer/main_ppo.py b/verl/trainer/main_ppo.py index 1320cdaa1..b9a357dcf 100644 --- a/verl/trainer/main_ppo.py +++ b/verl/trainer/main_ppo.py @@ -43,13 +43,14 @@ def main(config): # Define a function to run the PPO-like training process -def run_ppo(config) -> None: +def run_ppo(config, task_runner_class=None) -> None: """Initialize Ray cluster and run distributed PPO training process. Args: config: Training configuration object containing all necessary parameters for distributed PPO training including Ray initialization settings, model paths, and training hyperparameters. + task_runner_class: For recipe to change TaskRunner. """ # Check if Ray is not initialized if not ray.is_initialized(): @@ -65,6 +66,9 @@ def run_ppo(config) -> None: print(f"ray init kwargs: {ray_init_kwargs}") ray.init(**OmegaConf.to_container(ray_init_kwargs)) + if task_runner_class is None: + task_runner_class = TaskRunner + # Create a remote instance of the TaskRunner class, and # Execute the `run` method of the TaskRunner instance remotely and wait for it to complete if ( @@ -79,9 +83,9 @@ def run_ppo(config) -> None: nsight_options = OmegaConf.to_container( config.global_profiler.global_tool_config.nsys.controller_nsight_options ) - runner = TaskRunner.options(runtime_env={"nsight": nsight_options}).remote() + runner = task_runner_class.options(runtime_env={"nsight": nsight_options}).remote() else: - runner = TaskRunner.remote() + runner = task_runner_class.remote() ray.get(runner.run.remote(config)) # [Optional] get the path of the timeline trace file from the configuration, default to None diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index cff021b25..5f33f6e3a 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -602,11 +602,9 @@ class RayPPOTrainer: sample_scores.extend(scores) reward_extra_infos_dict["reward"].extend(scores) - print(f"len reward_extra_infos_dict['reward']: {len(reward_extra_infos_dict['reward'])}") if "reward_extra_info" in result: for key, lst in result["reward_extra_info"].items(): reward_extra_infos_dict[key].extend(lst) - print(f"len reward_extra_infos_dict['{key}']: {len(reward_extra_infos_dict[key])}") # collect num_turns of each prompt if "__num_turns__" in test_batch.non_tensor_batch: @@ -676,9 +674,9 @@ class RayPPOTrainer: actor_rollout_cls = RayClassWithInitArgs( cls=self.role_worker_mapping[Role.ActorRollout], config=self.config.actor_rollout_ref, - role="actor_rollout", + role=str(Role.ActorRollout), ) - self.resource_pool_to_cls[resource_pool]["actor_rollout"] = actor_rollout_cls + self.resource_pool_to_cls[resource_pool][str(Role.ActorRollout)] = actor_rollout_cls else: raise NotImplementedError @@ -687,7 +685,7 @@ class RayPPOTrainer: resource_pool = self.resource_pool_manager.get_resource_pool(Role.Critic) critic_cfg = omega_conf_to_dataclass(self.config.critic) critic_cls = RayClassWithInitArgs(cls=self.role_worker_mapping[Role.Critic], config=critic_cfg) - self.resource_pool_to_cls[resource_pool]["critic"] = critic_cls + self.resource_pool_to_cls[resource_pool][str(Role.Critic)] = critic_cls # create reference policy if needed if self.use_reference_policy: @@ -695,16 +693,16 @@ class RayPPOTrainer: ref_policy_cls = RayClassWithInitArgs( self.role_worker_mapping[Role.RefPolicy], config=self.config.actor_rollout_ref, - role="ref", + role=str(Role.RefPolicy), ) - self.resource_pool_to_cls[resource_pool]["ref"] = ref_policy_cls + self.resource_pool_to_cls[resource_pool][str(Role.RefPolicy)] = ref_policy_cls # create a reward model if reward_fn is None if self.use_rm: # we create a RM here resource_pool = self.resource_pool_manager.get_resource_pool(Role.RewardModel) rm_cls = RayClassWithInitArgs(self.role_worker_mapping[Role.RewardModel], config=self.config.reward_model) - self.resource_pool_to_cls[resource_pool]["rm"] = rm_cls + self.resource_pool_to_cls[resource_pool][str(Role.RewardModel)] = rm_cls # initialize WorkerGroup # NOTE: if you want to use a different resource pool for each role, which can support different parallel size, @@ -739,20 +737,20 @@ class RayPPOTrainer: all_wg.update(spawn_wg) if self.use_critic: - self.critic_wg = all_wg["critic"] + self.critic_wg = all_wg[str(Role.Critic)] self.critic_wg.init_model() if self.use_reference_policy and not self.ref_in_actor: - self.ref_policy_wg = all_wg["ref"] + self.ref_policy_wg = all_wg[str(Role.RefPolicy)] self.ref_policy_wg.init_model() self.rm_wg = None if self.use_rm: - self.rm_wg = all_wg["rm"] + self.rm_wg = all_wg[str(Role.RewardModel)] self.rm_wg.init_model() # we should create rollout at the end so that vllm can have a better estimation of kv cache memory - self.actor_rollout_wg = all_wg["actor_rollout"] + self.actor_rollout_wg = all_wg[str(Role.ActorRollout)] self.actor_rollout_wg.init_model() # create async rollout manager and request scheduler @@ -800,11 +798,13 @@ class RayPPOTrainer: ) if self.use_critic: - critic_local_path = os.path.join(local_global_step_folder, "critic") + critic_local_path = os.path.join(local_global_step_folder, str(Role.Critic)) critic_remote_path = ( None if self.config.trainer.default_hdfs_dir is None - else os.path.join(self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", "critic") + else os.path.join( + self.config.trainer.default_hdfs_dir, f"global_step_{self.global_steps}", str(Role.Critic) + ) ) self.critic_wg.save_checkpoint( critic_local_path, critic_remote_path, self.global_steps, max_ckpt_to_keep=max_critic_ckpt_to_keep @@ -860,7 +860,7 @@ class RayPPOTrainer: print(f"Resuming from {global_step_folder}") actor_path = os.path.join(global_step_folder, "actor") - critic_path = os.path.join(global_step_folder, "critic") + critic_path = os.path.join(global_step_folder, str(Role.Critic)) # load actor self.actor_rollout_wg.load_checkpoint( actor_path, del_local_after_load=self.config.trainer.del_local_ckpt_after_load @@ -1129,7 +1129,7 @@ class RayPPOTrainer: if self.use_reference_policy: # compute reference log_prob - with marked_timer("ref", timing_raw, color="olive"): + with marked_timer(str(Role.RefPolicy), timing_raw, color="olive"): if not self.ref_in_actor: ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) else: diff --git a/verl/trainer/ppo/utils.py b/verl/trainer/ppo/utils.py index 22d00a450..31e886fd6 100644 --- a/verl/trainer/ppo/utils.py +++ b/verl/trainer/ppo/utils.py @@ -36,6 +36,37 @@ class Role(Enum): RewardModel = 5 ActorRolloutRef = 6 + def __str__(self): + return self._get_role_string() + + def _get_role_string(self): + role_mapping = { + Role.Actor: "actor", + Role.Rollout: "rollout", + Role.ActorRollout: "actor_rollout", + Role.Critic: "critic", + Role.RefPolicy: "ref", + Role.RewardModel: "rm", + Role.ActorRolloutRef: "actor_rollout_ref", + } + return role_mapping.get(self, self.name.lower()) + + @classmethod + def from_string(cls, name: str): + string_mapping = { + "actor": cls.Actor, + "rollout": cls.Rollout, + "actor_rollout": cls.ActorRollout, + "critic": cls.Critic, + "ref": cls.RefPolicy, + "rm": cls.RewardModel, + "actor_rollout_ref": cls.ActorRolloutRef, + } + role = string_mapping.get(name.lower()) + if role is None: + raise ValueError(f"No Role found for string: {name}") + return role + def need_reference_policy( role_worker_mapping: dict[Role, WorkerType], diff --git a/verl/workers/actor/dp_actor.py b/verl/workers/actor/dp_actor.py index f78233107..7dd531ad2 100644 --- a/verl/workers/actor/dp_actor.py +++ b/verl/workers/actor/dp_actor.py @@ -78,7 +78,7 @@ class DataParallelPPOActor(BasePPOActor): self.compute_entropy_from_logits = ( torch.compile(entropy_from_logits, dynamic=True) - if self.config.get("use_torch_compile", True) # use torch compile by default + if self.config.get("use_torch_compile", True) # use torch compile by default else entropy_from_logits ) self.device_name = get_device_name() @@ -427,10 +427,14 @@ class DataParallelPPOActor(BasePPOActor): model_inputs, temperature=temperature, calculate_entropy=calculate_entropy ) - if on_policy: - old_log_prob = log_prob.detach() - else: + # for fully_async_policy recipe + if hasattr(self.config, "use_rollout_log_probs") and self.config.use_rollout_log_probs: old_log_prob = model_inputs["old_log_probs"] + else: + if on_policy: + old_log_prob = log_prob.detach() + else: + old_log_prob = model_inputs["old_log_probs"] loss_mode = self.config.policy_loss.get("loss_mode", "vanilla") # vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla diff --git a/verl/workers/config/actor.py b/verl/workers/config/actor.py index 60ba0308a..fe5b3e119 100644 --- a/verl/workers/config/actor.py +++ b/verl/workers/config/actor.py @@ -231,6 +231,7 @@ class FSDPActorConfig(ActorConfig): fsdp_config: FSDPEngineConfig = field(default_factory=FSDPEngineConfig) use_remove_padding: bool = False profiler: ProfilerConfig = field(default_factory=ProfilerConfig) + use_rollout_log_probs: bool = False def __post_init__(self): """Validate FSDP actor configuration parameters.""" diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 3cb6c8f80..aceec6c2e 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -110,8 +110,7 @@ class ExternalZeroMQDistributedExecutor(Executor): return -@ray.remote(num_cpus=1) -class vLLMHttpServer: +class vLLMHttpServerBase: """vLLM http server in single node, this is equivalent to launch server with command line: ``` vllm serve --tensor-parallel-size=8 ... @@ -399,10 +398,42 @@ class vLLMHttpServer: await self.engine.wait_for_requests_to_drain() +@ray.remote(num_cpus=1) +class vLLMHttpServer(vLLMHttpServerBase): + """vLLM http server in single node, this is equivalent to launch server with command line: + ``` + vllm serve --tensor-parallel-size=8 ... + ``` + """ + + def __init__( + self, + config: RolloutConfig | RewardModelConfig, + model_config: HFModelConfig, + rollout_mode: RolloutMode, + workers: list[ActorHandle], + replica_rank: int, + node_rank: int, + gpus_per_node: int, + nnodes: int, + ): + super().__init__(config, model_config, rollout_mode, workers, replica_rank, node_rank, gpus_per_node, nnodes) + + _rollout_worker_actor_cls = ray.remote(vLLMAsyncRollout) class vLLMReplica(RolloutReplica): + def __init__( + self, + replica_rank: int, + config: RolloutConfig | RewardModelConfig, + model_config: HFModelConfig, + gpus_per_node: int = 8, + ): + super().__init__(replica_rank, config, model_config, gpus_per_node) + self.server_class = vLLMHttpServer + def get_ray_class_with_init_args(self) -> RayClassWithInitArgs: """Get rollout worker actor class for colocated and standalone mode.""" worker_dict_cls = RayClassWithInitArgs( @@ -437,7 +468,7 @@ class vLLMReplica(RolloutReplica): for node_rank in range(nnodes): workers = self.workers[node_rank * gpus_per_node : (node_rank + 1) * gpus_per_node] node_id = worker_node_ids[node_rank * gpus_per_node] - server = vLLMHttpServer.options( + server = self.server_class.options( scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy( node_id=node_id, soft=False,