Compare commits

...

17 Commits

Author SHA1 Message Date
ddd86f527a [misc] chore: bump version to v0.6.0 (#3773)
### What does this PR do?

bump version to v0.6.0
2025-10-15 13:19:38 +08:00
22d082f9a4 [recipe] feat: add open math reasoning (#3767)
### What does this PR do?

- Add open math reasoning recipe using sft trainer with model engine
- Support setting none to val dataset in sft trainer
- Fix main_eval
- Using aiohttp for main_generation_server to avoid hang in AsyncOpenAI

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-15 12:11:41 +08:00
8ec9bf64a1 [ci] fix: fix test_engine ci (#3771)
### What does this PR do?

- fix test_engine ci for latest transformers

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-15 12:11:17 +08:00
231d725f69 Revert "[trainer] feat: set interleave to False in dapo trainer" (#3770)
Reverts volcengine/verl#3760
2025-10-15 11:41:33 +08:00
d69164e1cb [misc] feat: bump version to 0.6.0.dev (#3768)
### What does this PR do?

- Bump version

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-15 10:47:13 +08:00
2181d5b33a [recipe] fix: update readme for gmpo-trainer (#3764)
### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

Co-authored-by: 刘悦 <liuyue127@xiaohongshu.com>
2025-10-15 10:24:24 +08:00
33eb86f54f [megatron] feat: support qwen3vl (#3763)
### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

support training qwen3vl with megatron

1. add an image with vllm0.11 and nemo's dedicated megatron that support
gpt-oss with optimized fused kernels.
2. add a script of training qwen3vl-30b with megatron
3. necessary changes to support qwen3vl megatron. (just register forward
functions, the modeling is through mbridge)


### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.
<img width="372" height="314" alt="image"
src="https://github.com/user-attachments/assets/f1126e46-51a9-4e00-958f-5d034b8f94bd"
/>

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-15 10:19:22 +08:00
67f9a21b8e [trainer] feat: set interleave to False in dapo trainer (#3760)
### What does this PR do?

Set interleave to False. This way, during inference, if rollout.n is set
to a large value, it can prevent multiple identical samples from being
run on the same instance, which would otherwise lead to excessive
inference overhead.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-14 21:13:57 +08:00
d2c51dc186 Add Meta-Bandit-LLM, a long-horizon multiturn interative awesome use case of verl (#3756)
[Meta-Bandit-LLM](https://github.com/sanxing-chen/meta-bandit-llm/)
utilizes verl to train on-policy LLM agent with up to 50-turn
interations, with support of async vLLM and LoRA.

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-10-14 12:01:13 +08:00
16c2a21064 Add ARES and Revisual-R1 two awesome multimodal reasoning work using verl. (#3755)
…verl to project list

### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-14 10:51:32 +08:00
3abcc09d44 [sglang, recipe] feat: add SGLang as rollout engine for one-step-off-policy (#3531)
### What does this PR do?

This PR extends the one-step-off-policy recipe by adding SGLang as an
alternative rollout engine to vLLM, allowing flexible backend selection
and improving training efficiency.

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here:
https://github.com/volcengine/verl/pull/3460
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

To validate this solution, we adopted the existing experimental
configuration from the recipe one-step-off-policy.

The evaluation demonstrates that the proposed SGLang rollout engine
integration achieves effective acceleration in one-step-off-policy
asynchronous training, providing users with enhanced rollout engine
options for diverse deployment scenarios.

**Experimental Results**

- **Machine Configuration**: 2 nodes with 16 H20 GPUs each
    - Generation: 4 GPUs
    - Training: 12 GPUs
- **Model**: Qwen2.5-Math-7B
- **Max Response Length**: 8,192 tokens
- **Algorithm**: DAPO
- **Rollout Engine**: vLLM, SGLang

| training mode | engine | step | gen | wait_prev_gen |
generate_sequences | old_log_prob | update_actor | total time |
acc/best@32/mean | acc/maj@32/mean |

|------------------------|----------------|------|-----|---------------|--------------------|--------------|--------------|---------------|------------------|-----------------|
| colocate sync | SGLang+FSDP2 | 452 | 131 | - | 125 | 54 | 199 | 12h25m
| 0.6560 | 0.4471 |
| one-step-overlap async | SGLang+FSDP2 | 406 | - | 12 | 305 | 58 | 245
| 11h12m (+11%) | 0.6303 | 0.4443 |

* colocate sync: step ≈ gen + old_log_prob + update_actor
* one-step-overlap async: step ≈ max(wait_prev_gen + generate_sequences,
old_log_prob + update_actor)

<img width="1218" height="777" alt="image"
src="https://github.com/user-attachments/assets/58734164-2534-492f-bf00-1e80faae0fe7"
/>

### API and Usage Example

**Configuration Example**
```bash
# Using SGLang engine
python3 -m recipe.one_step_off_policy.main_ppo \
    actor_rollout_ref.rollout.name=sglang \
    # ... other configuration parameters

# Using vLLM engine
python3 -m recipe.one_step_off_policy.main_ppo \
    actor_rollout_ref.rollout.name=vllm \
    # ... other configuration parameters
```

**Script Usage**
```bash
# Using SGLang engine
bash dapo_7b_math_fsdp2_sglang_4_12.sh
bash dapo_7b_math_fsdp2_sglang_colocate.sh

# Using vLLM engine
bash dapo_7b_math_fsdp2_4_12.sh
bash dapo_7b_math_fsdp2_colocate.sh
```

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: wuxibin <wuxibin@bytedance.com>
2025-10-14 10:48:29 +08:00
5d378b5f95 [rollout] refactor: rename "clip" mode back to "mask" mode (#3750)
# Rollout Importance Sampling Framework

related to https://github.com/volcengine/verl/pull/3694

## Summary

This PR introduces a comprehensive **Rollout Importance Sampling (IS)**
framework to correct distribution mismatch between data-collecting
(rollout) and training policies, a critical factor for ensuring stable
and efficient model training in RL fine-tuning.

This work is motivated by the analysis in our blog post, [When Speed
Kills Stability: Demystifying RL Collapse from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda).
If you find this implementation useful in your research, please consider
citing:

```bibtex
@misc{liu-li-2025,
  title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch},
  url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda},
  author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen},
  year = {2025},
  month = {September},
}
```

---

## Problem Statement

When using different policies for rollout generation (e.g., vLLM with
BFloat16) and training (e.g., FSDP with FP32), distribution mismatch
occurs, leading to:
- Biased gradient estimates
- Training instability and collapse
- Reduced sample efficiency
- Poor convergence properties

This framework addresses these issues through principled importance
sampling correction.

---

## Key Features & Improvements

### 1. **Flexible Aggregation Levels**
Three methods for calculating IS weights:
- **`token`**: Per-token importance ratios
- **`sequence`**: Product of per-token ratios
- **`geometric`**: Geometric mean of ratios

### 2. **Advanced Bounding Modes**
Two strategies to control weight variance:
- **`truncate`** (TIS): Caps weights at upper threshold only, preserving
gradients
- **`mask`** (MIS): Zeros out weights outside bounds, more aggressive
filtering

### 3. **Comprehensive Diagnostics**
Detailed metrics to monitor distribution mismatch and training health:

**Rollout IS Metrics** (automatically prefixed with `mismatch/`):
- Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean`
- Distribution statistics: `rollout_is_p25`, `rollout_is_p50`,
`rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`,
`rollout_is_min`, `rollout_is_std`
- Diagnostics: `rollout_is_veto_fraction`,
`rollout_is_catastrophic_token_fraction`, `rollout_is_masked_fraction`
(mask mode)
- Sequence-level statistics (for sequence/geometric modes):
`rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`,
`rollout_is_seq_min`, etc.

**Mismatch Metrics** (computed efficiently within IS weight
computation):
- KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3
estimator for stability)
- Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`,
`mismatch_ppl_ratio`
- Log perplexity statistics: `mismatch_log_ppl_diff`,
`mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`,
`mismatch_log_ppl_diff_min`

### 4. **Outlier Mitigation**
- **Veto mechanism**: Automatically discards samples with catastrophic
importance weights (per-token ratios below threshold)
- Prevents gradient corruption from extreme outliers
- Configurable threshold (default: 1e-4)

### 5. **Numerical Stability**
- All core computations in **log-space** to prevent underflow/overflow
- Carefully designed clamping and bounding to maintain numerical
precision
- Safe handling of edge cases (zero probabilities, extreme ratios)

### 6. **Memory Efficiency**
- Optimized computation to minimize CUDA memory usage
- Efficient metric aggregation without large intermediate tensors
- Suitable for large-scale distributed training

### 7. **Metrics-Only Mode**
- Compute and monitor mismatch metrics **without** applying IS weights
- Useful for:
  - Understanding distribution mismatch before intervention
  - Deciding whether IS correction is needed
  - A/B testing IS impact
- Controlled by `algorithm.rollout_is` flag (independent of weight
computation)

### 8. **Universal PPO Support**
- Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov,
KL-Cov, geo_mean
- Consistent interface across different policy loss functions
- Automatic weight application when enabled

---

## API and Configuration Changes

### Migration from Legacy TIS

####  **Before (REMOVED)**
```yaml
# Old TIS configuration - NO LONGER SUPPORTED
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0  # Removed from actor config
```

The legacy implementation:
- Only supported token-level truncation
- No metrics tracking
- Lacked numerical stability
- Limited configurability

####  **After (New Framework)**

Configuration moved to `algorithm` section for better organization:

```yaml
algorithm:
  # Main on/off switch: null = disabled, float = enabled
  rollout_is_threshold: 2.0

  # Control weight application (independent of metrics computation)
  rollout_is: true  # true = apply weights, false = metrics only

  # Optional: lower threshold (defaults to 1/upper if null)
  rollout_is_threshold_lower: null

  # Aggregation level: "token", "sequence", or "geometric"
  rollout_is_level: token

  # Bounding mode: "truncate" or "mask"
  rollout_is_mode: truncate

  # Veto threshold for catastrophic outliers (null = disabled)
  rollout_is_veto_threshold: 1e-4

# REQUIRED: Enable log probability calculation
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

### Configuration Examples

**1. Token-level truncation (recommended starting point)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate
```

**2. Sequence-level masking (more aggressive)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: sequence
  rollout_is_mode: mask
```

**3. Metrics-only mode (monitoring without correction)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: false  # Compute metrics but don't apply weights
  rollout_is_level: token
  rollout_is_mode: truncate
```

**Example script:** `bash
examples/rollout_importance_sampling/run_with_rollout_is.sh`

---

## Code Changes Overview

### New Files (4 files, 1,442 lines)

1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines)
   - Core implementation of IS weight computation
   - Three aggregation levels: token, sequence, geometric
   - Two bounding modes: truncate, mask
   - Veto mechanism for outlier detection
   - Comprehensive metrics computation (IS + mismatch)
   - All computations in log-space for numerical stability
   - Memory-efficient design

2. **`docs/advance/rollout_is_migration.md`** (642 lines)
   - Comprehensive migration guide from legacy TIS
   - Detailed explanation of all configuration options
   - Recommended threshold ranges for each aggregation level
   - Troubleshooting guide and best practices
   - Metrics interpretation guide

3. **`examples/rollout_importance_sampling/README.md`** (242 lines)
   - Quick start guide with working examples
   - Configuration templates for common scenarios
   - Threshold tuning guidelines
   - Metrics monitoring instructions

4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99
lines)
   - Complete working example script
   - Demonstrates token-level and sequence-level configurations
   - Ready to run with minimal modifications

### Modified Core Files (9 files)

1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed)
   - Removed legacy TIS logic (`tis_imp_ratio_cap`)
   - Added `rollout_is_weights` parameter to all policy loss functions
   - Unified IS weight application interface across all PPO variants:
     - `compute_policy_loss_vanilla`
     - `compute_policy_loss_gspo`
     - `compute_policy_loss_gpg`
     - `compute_policy_loss_clip_cov`
     - `compute_policy_loss_kl_cov`
     - `compute_policy_loss_geo_mean`
   - Special handling for `geo_mean` (sequence-level aggregation)

2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added)
   - New method: `compute_rollout_importance_weights_and_add_to_batch()`
   - Centralized IS computation (once per batch, on driver)
- Conditional weight distribution to workers based on
`algorithm.rollout_is`
   - Metrics collection and aggregation
   - Integration with existing training loop

3. **`verl/trainer/config/algorithm.py`** (+18 lines)
   - Added 6 new Rollout IS parameters:
     - `rollout_is_threshold` (main on/off switch)
     - `rollout_is` (weight application control)
     - `rollout_is_threshold_lower`
     - `rollout_is_level`
     - `rollout_is_mode`
     - `rollout_is_veto_threshold`
   - Comprehensive docstrings explaining each parameter

4. **`verl/workers/config/actor.py`** (-1 line)
   - Removed deprecated `tis_imp_ratio_cap` parameter

5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

7. **Configuration Files** (4 files updated)
   - `verl/trainer/config/ppo_trainer.yaml`
   - `verl/trainer/config/ppo_megatron_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml`
- Added default Rollout IS configuration section with explanatory
comments

### Testing (2 files, 530 lines)

1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines)
   - Unit tests for `mismatch_helper.py`
   - Coverage for all aggregation levels (token, sequence, geometric)
   - Coverage for all bounding modes (truncate, mask)
   - Veto mechanism tests
   - Edge case handling (zeros, extremes, empty sequences)
   - Numerical stability verification
   - Metrics correctness validation

2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines)
   - Integration tests with PPO training loop
   - End-to-end workflow validation
   - Batch processing tests
   - Configuration validation
   - Metrics collection verification
   - Compatibility with distributed training

### Updated Recipes (2 files)

1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines)
   - Updated imports to use new framework

2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed)
   - Migrated from legacy TIS to new Rollout IS configuration
   - Updated documentation and comments

### Documentation Updates (2 files)

1. **`docs/examples/config.rst`** (~22 lines changed)
   - Updated configuration examples
   - Added Rollout IS section

2. **`docs/index.rst`** (+1 line)
   - Added link to Rollout IS migration guide

---

## Implementation Highlights

### Centralized Architecture

The new design follows a clean separation of concerns:

```
ray_trainer.py (driver)
    └─> compute_rollout_importance_weights_and_add_to_batch()
         └─> mismatch_helper.compute_rollout_importance_weights()
              ├─> Computes IS weights (token/sequence/geometric)
              ├─> Applies bounding (truncate/mask)
              ├─> Veto mechanism for outliers
              ├─> Computes IS metrics
              └─> Computes mismatch metrics (KL, PPL)
    └─> Conditionally adds weights to batch (if rollout_is=True)
    └─> Distributes batch to workers

actor workers (dp_actor, megatron_actor)
    └─> Receive batch with rollout_is_weights (if enabled)
    └─> Pass weights to policy loss function

core_algos.py
    └─> All policy loss functions accept rollout_is_weights
    └─> Apply weights if provided: pg_losses *= rollout_is_weights
```

### Key Design Decisions

1. **Centralized Computation**: IS weights computed once on driver, not
per worker
   - Reduces redundant computation
   - Ensures consistency across workers
   - Simplifies debugging and metrics collection

2. **Configuration in Algorithm**: Moved from actor config to algorithm
config
- Better conceptual organization (algorithm-level concern, not
worker-level)
   - Easier to manage and validate
   - Consistent with other algorithm parameters

3. **Two-Level Control**:
   - `rollout_is_threshold`: Enables/disables entire system (null = off)
- `rollout_is`: Controls weight application (true = apply, false =
metrics only)
   - Allows flexible monitoring and gradual rollout

4. **Metrics Consolidation**: Mismatch metrics computed within IS weight
computation
   - Eliminates duplicate computation
   - Reduces memory overhead
   - Maintains metric accuracy

5. **Universal PPO Support**: Single interface for all PPO variants
   - Minimal code changes required
   - Consistent behavior across algorithms
   - Easy to add new variants

---

## Migration Guide

### For Users of Legacy TIS

**Step 1: Update your configuration file**

```yaml
# OLD (remove this)
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0

# NEW (add this)
algorithm:
  rollout_is_threshold: 2.0  # Use same value as old tis_imp_ratio_cap
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

# REQUIRED (add if not present)
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

**Step 2: Monitor metrics**

The first time you run with the new configuration, check these metrics:
- `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size
- `mismatch/rollout_is_veto_fraction`: Should be < 5%
- `mismatch/rollout_is_mean`: Should be close to 1.0

**Step 3: Tune if needed**

If effective sample size is too low:
- Increase `rollout_is_threshold`
- Try `rollout_is_mode: mask` with appropriate lower bound
- Consider `rollout_is_level: sequence` for more aggressive correction

For detailed guidance, see `docs/advance/rollout_is_migration.md`.

### For New Users

Start with recommended defaults:

```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

Run the example script to see it in action:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```

---

## Testing

### Unit Tests
- **289 lines** of comprehensive unit tests in `test_rollout_is.py`
- Covers all aggregation levels, bounding modes, and edge cases
- Validates numerical stability and correctness
- Fast execution (~1-2 seconds)

### Integration Tests
- **241 lines** of integration tests in `test_rollout_is_integration.py`
- End-to-end workflow with PPO training loop
- Distributed training compatibility
- Metrics collection validation
- Moderate execution time (~10-20 seconds)

### Running Tests
```bash
# Run all Rollout IS tests
pytest tests/trainer/ppo/test_rollout_is.py -v
pytest tests/trainer/ppo/test_rollout_is_integration.py -v

# Run specific test
pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v
```

---

## Metrics Reference

### Rollout IS Metrics (all prefixed with `mismatch/`)

| Metric | Description | Ideal Range |
|--------|-------------|-------------|
| `rollout_is_eff_sample_size` | Effective number of samples after IS |
> 80% of batch |
| `rollout_is_mean` | Mean IS weight | ~1.0 |
| `rollout_is_std` | Standard deviation of IS weights | Low variance |
| `rollout_is_p25` | 25th percentile | ~0.8-1.0 |
| `rollout_is_p50` | Median IS weight | ~1.0 |
| `rollout_is_p75` | 75th percentile | ~1.0-1.2 |
| `rollout_is_p95` | 95th percentile | < threshold |
| `rollout_is_p99` | 99th percentile | < threshold |
| `rollout_is_max` | Maximum weight | ≤ threshold |
| `rollout_is_min` | Minimum weight | ≥ lower threshold (mask mode) |
| `rollout_is_veto_fraction` | % sequences vetoed | < 5% |
| `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | <
1% |
| `rollout_is_masked_fraction` | % tokens masked (mask mode) | Variable
|

### Mismatch Metrics (all prefixed with `mismatch/`)

| Metric | Description | What It Means |
|--------|-------------|---------------|
| `mismatch_kl` | Forward KL divergence | Distribution difference
(rollout vs training) |
| `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small
divergences |
| `mismatch_training_ppl` | Training policy perplexity | Prediction
difficulty of training policy |
| `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction
difficulty of rollout policy |
| `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative
prediction difficulty |
| `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level
PPL mismatch |
| `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude
of mismatch |
| `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case
mismatch |
| `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case
mismatch |
| `mismatch_training_log_ppl` | Log of training PPL | Log-scale training
perplexity |
| `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout
perplexity |

---

## Performance Impact

### Memory
- Minimal overhead: ~1-2% increase in peak memory usage
- Efficient log-space computation
- No large intermediate tensors

### Computation
- Negligible impact on training speed: < 1% overhead
- Centralized computation on driver (no per-worker redundancy)
- Optimized tensor operations

### Training Stability
- Significant improvement in stability when distribution mismatch exists
- Faster convergence in many scenarios
- Reduced risk of training collapse

---

## Breaking Changes

> [!IMPORTANT]
> This PR contains **BREAKING CHANGES** to the configuration API.

### Removed
- `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported

### Migration Required
All users of the legacy TIS implementation must update their
configuration files. See the migration guide above or
`docs/advance/rollout_is_migration.md` for detailed instructions.

### Backward Compatibility
- No backward compatibility with legacy TIS
- Configuration files with `tis_imp_ratio_cap` will raise validation
errors
- Affected recipes have been updated in this PR

---

## Pre-Submission Checklist

- [x] Search for similar PRs:
[https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling)
- [x] Format PR title as `[{modules}] {type}: {description}` (checked by
CI)
- **Suggested title:** `[BREAKING][rollout, trainer, algo] feat:
implement comprehensive Rollout Importance Sampling framework`
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md)
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting)
- [x] Add/update
[documentation](https://github.com/volcengine/verl/tree/main/docs) (3
new docs, 2 updated)
- [x] Add unit and integration tests (530 lines of tests)
- [x] Once PR is ready for CI, send message in `ci-request` channel

---

## References

- **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse
from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
- **Migration guide:** `docs/advance/rollout_is_migration.md`
- **Examples:** `examples/rollout_importance_sampling/`
- **Tests:** `tests/trainer/ppo/test_rollout_is*.py`
2025-10-13 11:06:36 -07:00
21271aabb9 [BREAKING][rollout, trainer, algo] feat: comprehensive rollout importance sampling implementation (#3694)
# Rollout Importance Sampling Framework

## Summary

This PR introduces a comprehensive **Rollout Importance Sampling (IS)**
framework to correct distribution mismatch between data-collecting
(rollout) and training policies, a critical factor for ensuring stable
and efficient model training in RL fine-tuning.

This work is motivated by the analysis in our blog post, [When Speed
Kills Stability: Demystifying RL Collapse from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda).
If you find this implementation useful in your research, please consider
citing:

```bibtex
@misc{liu-li-2025,
  title = {When Speed Kills Stability: Demystifying RL Collapse from the Inference-Training Mismatch},
  url = {https://yingru.notion.site/When-Speed-Kills-Stability-Demystifying-RL-Collapse-from-the-Inference-Training-Mismatch-271211a558b7808d8b12d403fd15edda},
  author = {Jiacai Liu and Yingru Li and Yuqian Fu and Jiawei Wang and Qian Liu and Yu Shen},
  year = {2025},
  month = {September},
}
```

---

## Problem Statement

When using different policies for rollout generation (e.g., vLLM with
BFloat16) and training (e.g., FSDP with FP32), distribution mismatch
occurs, leading to:
- Biased gradient estimates
- Training instability and collapse
- Reduced sample efficiency
- Poor convergence properties

This framework addresses these issues through principled importance
sampling correction.

---

## Key Features & Improvements

### 1. **Flexible Aggregation Levels**
Three methods for calculating IS weights:
- **`token`**: Per-token importance ratios
- **`sequence`**: Product of per-token ratios
- **`geometric`**: Geometric mean of ratios

### 2. **Advanced Bounding Modes**
Two strategies to control weight variance:
- **`truncate`** (TIS): Caps weights at upper threshold only, preserving
gradients
- **`clip`** (CIS): Zeros out weights outside bounds, more aggressive
filtering

### 3. **Comprehensive Diagnostics**
Detailed metrics to monitor distribution mismatch and training health:

**Rollout IS Metrics** (automatically prefixed with `mismatch/`):
- Health indicators: `rollout_is_eff_sample_size`, `rollout_is_mean`
- Distribution statistics: `rollout_is_p25`, `rollout_is_p50`,
`rollout_is_p75`, `rollout_is_p95`, `rollout_is_p99`, `rollout_is_max`,
`rollout_is_min`, `rollout_is_std`
- Diagnostics: `rollout_is_veto_fraction`,
`rollout_is_catastrophic_token_fraction`, `rollout_is_clipped_fraction`
(clip mode)
- Sequence-level statistics (for sequence/geometric modes):
`rollout_is_seq_mean`, `rollout_is_seq_std`, `rollout_is_seq_max`,
`rollout_is_seq_min`, etc.

**Mismatch Metrics** (computed efficiently within IS weight
computation):
- KL Divergence: `mismatch_kl` (forward KL), `mismatch_k3_kl` (K3
estimator for stability)
- Perplexity: `mismatch_training_ppl`, `mismatch_rollout_ppl`,
`mismatch_ppl_ratio`
- Log perplexity statistics: `mismatch_log_ppl_diff`,
`mismatch_log_ppl_abs_diff`, `mismatch_log_ppl_diff_max`,
`mismatch_log_ppl_diff_min`

### 4. **Outlier Mitigation**
- **Veto mechanism**: Automatically discards samples with catastrophic
importance weights (per-token ratios below threshold)
- Prevents gradient corruption from extreme outliers
- Configurable threshold (default: 1e-4)

### 5. **Numerical Stability**
- All core computations in **log-space** to prevent underflow/overflow
- Carefully designed clipping and bounding to maintain numerical
precision
- Safe handling of edge cases (zero probabilities, extreme ratios)

### 6. **Memory Efficiency**
- Optimized computation to minimize CUDA memory usage
- Efficient metric aggregation without large intermediate tensors
- Suitable for large-scale distributed training

### 7. **Metrics-Only Mode**
- Compute and monitor mismatch metrics **without** applying IS weights
- Useful for:
  - Understanding distribution mismatch before intervention
  - Deciding whether IS correction is needed
  - A/B testing IS impact
- Controlled by `algorithm.rollout_is` flag (independent of weight
computation)

### 8. **Universal PPO Support**
- Integrated with **all PPO variants**: vanilla, GSPO, GPG, Clip-Cov,
KL-Cov, geo_mean
- Consistent interface across different policy loss functions
- Automatic weight application when enabled

---

## API and Configuration Changes

### Migration from Legacy TIS

####  **Before (REMOVED)**
```yaml
# Old TIS configuration - NO LONGER SUPPORTED
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0  # Removed from actor config
```

The legacy implementation:
- Only supported token-level truncation
- No metrics tracking
- Lacked numerical stability
- Limited configurability

####  **After (New Framework)**

Configuration moved to `algorithm` section for better organization:

```yaml
algorithm:
  # Main on/off switch: null = disabled, float = enabled
  rollout_is_threshold: 2.0

  # Control weight application (independent of metrics computation)
  rollout_is: true  # true = apply weights, false = metrics only

  # Optional: lower threshold (defaults to 1/upper if null)
  rollout_is_threshold_lower: null

  # Aggregation level: "token", "sequence", or "geometric"
  rollout_is_level: token

  # Bounding mode: "truncate" or "clip"
  rollout_is_mode: truncate

  # Veto threshold for catastrophic outliers (null = disabled)
  rollout_is_veto_threshold: 1e-4

# REQUIRED: Enable log probability calculation
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

### Configuration Examples

**1. Token-level truncation (recommended starting point)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate
```

**2. Sequence-level clipping (more aggressive)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: sequence
  rollout_is_mode: clip
```

**3. Metrics-only mode (monitoring without correction)**
```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: false  # Compute metrics but don't apply weights
  rollout_is_level: token
  rollout_is_mode: truncate
```

**Example script:** `bash
examples/rollout_importance_sampling/run_with_rollout_is.sh`

---

## Code Changes Overview

### New Files (4 files, 1,442 lines)

1. **`verl/trainer/ppo/mismatch_helper.py`** (459 lines)
   - Core implementation of IS weight computation
   - Three aggregation levels: token, sequence, geometric
   - Two bounding modes: truncate, clip
   - Veto mechanism for outlier detection
   - Comprehensive metrics computation (IS + mismatch)
   - All computations in log-space for numerical stability
   - Memory-efficient design

2. **`docs/advance/rollout_is_migration.md`** (642 lines)
   - Comprehensive migration guide from legacy TIS
   - Detailed explanation of all configuration options
   - Recommended threshold ranges for each aggregation level
   - Troubleshooting guide and best practices
   - Metrics interpretation guide

3. **`examples/rollout_importance_sampling/README.md`** (242 lines)
   - Quick start guide with working examples
   - Configuration templates for common scenarios
   - Threshold tuning guidelines
   - Metrics monitoring instructions

4. **`examples/rollout_importance_sampling/run_with_rollout_is.sh`** (99
lines)
   - Complete working example script
   - Demonstrates token-level and sequence-level configurations
   - Ready to run with minimal modifications

### Modified Core Files (9 files)

1. **`verl/trainer/ppo/core_algos.py`** (~50 lines changed)
   - Removed legacy TIS logic (`tis_imp_ratio_cap`)
   - Added `rollout_is_weights` parameter to all policy loss functions
   - Unified IS weight application interface across all PPO variants:
     - `compute_policy_loss_vanilla`
     - `compute_policy_loss_gspo`
     - `compute_policy_loss_gpg`
     - `compute_policy_loss_clip_cov`
     - `compute_policy_loss_kl_cov`
     - `compute_policy_loss_geo_mean`
   - Special handling for `geo_mean` (sequence-level aggregation)

2. **`verl/trainer/ppo/ray_trainer.py`** (~52 lines added)
   - New method: `compute_rollout_importance_weights_and_add_to_batch()`
   - Centralized IS computation (once per batch, on driver)
- Conditional weight distribution to workers based on
`algorithm.rollout_is`
   - Metrics collection and aggregation
   - Integration with existing training loop

3. **`verl/trainer/config/algorithm.py`** (+18 lines)
   - Added 6 new Rollout IS parameters:
     - `rollout_is_threshold` (main on/off switch)
     - `rollout_is` (weight application control)
     - `rollout_is_threshold_lower`
     - `rollout_is_level`
     - `rollout_is_mode`
     - `rollout_is_veto_threshold`
   - Comprehensive docstrings explaining each parameter

4. **`verl/workers/config/actor.py`** (-1 line)
   - Removed deprecated `tis_imp_ratio_cap` parameter

5. **`verl/workers/actor/dp_actor.py`** (~26 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

6. **`verl/workers/actor/megatron_actor.py`** (~15 lines changed)
   - Updated to use new `rollout_is_weights` parameter
   - Removed legacy TIS logic

7. **Configuration Files** (4 files updated)
   - `verl/trainer/config/ppo_trainer.yaml`
   - `verl/trainer/config/ppo_megatron_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_trainer.yaml`
   - `verl/trainer/config/_generated_ppo_megatron_trainer.yaml`
- Added default Rollout IS configuration section with explanatory
comments

### Testing (2 files, 530 lines)

1. **`tests/trainer/ppo/test_rollout_is.py`** (289 lines)
   - Unit tests for `mismatch_helper.py`
   - Coverage for all aggregation levels (token, sequence, geometric)
   - Coverage for all bounding modes (truncate, clip)
   - Veto mechanism tests
   - Edge case handling (zeros, extremes, empty sequences)
   - Numerical stability verification
   - Metrics correctness validation

2. **`tests/trainer/ppo/test_rollout_is_integration.py`** (241 lines)
   - Integration tests with PPO training loop
   - End-to-end workflow validation
   - Batch processing tests
   - Configuration validation
   - Metrics collection verification
   - Compatibility with distributed training

### Updated Recipes (2 files)

1. **`recipe/dapo/dapo_ray_trainer.py`** (+5 lines)
   - Updated imports to use new framework

2. **`recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`** (~42 lines changed)
   - Migrated from legacy TIS to new Rollout IS configuration
   - Updated documentation and comments

### Documentation Updates (2 files)

1. **`docs/examples/config.rst`** (~22 lines changed)
   - Updated configuration examples
   - Added Rollout IS section

2. **`docs/index.rst`** (+1 line)
   - Added link to Rollout IS migration guide

---

## Implementation Highlights

### Centralized Architecture

The new design follows a clean separation of concerns:

```
ray_trainer.py (driver)
    └─> compute_rollout_importance_weights_and_add_to_batch()
         └─> mismatch_helper.compute_rollout_importance_weights()
              ├─> Computes IS weights (token/sequence/geometric)
              ├─> Applies bounding (truncate/clip)
              ├─> Veto mechanism for outliers
              ├─> Computes IS metrics
              └─> Computes mismatch metrics (KL, PPL)
    └─> Conditionally adds weights to batch (if rollout_is=True)
    └─> Distributes batch to workers

actor workers (dp_actor, megatron_actor)
    └─> Receive batch with rollout_is_weights (if enabled)
    └─> Pass weights to policy loss function

core_algos.py
    └─> All policy loss functions accept rollout_is_weights
    └─> Apply weights if provided: pg_losses *= rollout_is_weights
```

### Key Design Decisions

1. **Centralized Computation**: IS weights computed once on driver, not
per worker
   - Reduces redundant computation
   - Ensures consistency across workers
   - Simplifies debugging and metrics collection

2. **Configuration in Algorithm**: Moved from actor config to algorithm
config
- Better conceptual organization (algorithm-level concern, not
worker-level)
   - Easier to manage and validate
   - Consistent with other algorithm parameters

3. **Two-Level Control**:
   - `rollout_is_threshold`: Enables/disables entire system (null = off)
- `rollout_is`: Controls weight application (true = apply, false =
metrics only)
   - Allows flexible monitoring and gradual rollout

4. **Metrics Consolidation**: Mismatch metrics computed within IS weight
computation
   - Eliminates duplicate computation
   - Reduces memory overhead
   - Maintains metric accuracy

5. **Universal PPO Support**: Single interface for all PPO variants
   - Minimal code changes required
   - Consistent behavior across algorithms
   - Easy to add new variants

---

## Migration Guide

### For Users of Legacy TIS

**Step 1: Update your configuration file**

```yaml
# OLD (remove this)
actor_rollout_ref:
  actor:
    tis_imp_ratio_cap: 2.0

# NEW (add this)
algorithm:
  rollout_is_threshold: 2.0  # Use same value as old tis_imp_ratio_cap
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

# REQUIRED (add if not present)
actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

**Step 2: Monitor metrics**

The first time you run with the new configuration, check these metrics:
- `mismatch/rollout_is_eff_sample_size`: Should be > 80% of batch size
- `mismatch/rollout_is_veto_fraction`: Should be < 5%
- `mismatch/rollout_is_mean`: Should be close to 1.0

**Step 3: Tune if needed**

If effective sample size is too low:
- Increase `rollout_is_threshold`
- Try `rollout_is_mode: clip` with appropriate lower bound
- Consider `rollout_is_level: sequence` for more aggressive correction

For detailed guidance, see `docs/advance/rollout_is_migration.md`.

### For New Users

Start with recommended defaults:

```yaml
algorithm:
  rollout_is_threshold: 2.0
  rollout_is: true
  rollout_is_level: token
  rollout_is_mode: truncate

actor_rollout_ref:
  rollout:
    calculate_log_probs: true
```

Run the example script to see it in action:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```

---

## Testing

### Unit Tests
- **289 lines** of comprehensive unit tests in `test_rollout_is.py`
- Covers all aggregation levels, bounding modes, and edge cases
- Validates numerical stability and correctness
- Fast execution (~1-2 seconds)

### Integration Tests
- **241 lines** of integration tests in `test_rollout_is_integration.py`
- End-to-end workflow with PPO training loop
- Distributed training compatibility
- Metrics collection validation
- Moderate execution time (~10-20 seconds)

### Running Tests
```bash
# Run all Rollout IS tests
pytest tests/trainer/ppo/test_rollout_is.py -v
pytest tests/trainer/ppo/test_rollout_is_integration.py -v

# Run specific test
pytest tests/trainer/ppo/test_rollout_is.py::test_token_level_truncate -v
```

---

## Metrics Reference

### Rollout IS Metrics (all prefixed with `mismatch/`)

| Metric | Description | Ideal Range |
|--------|-------------|-------------|
| `rollout_is_eff_sample_size` | Effective number of samples after IS |
> 80% of batch |
| `rollout_is_mean` | Mean IS weight | ~1.0 |
| `rollout_is_std` | Standard deviation of IS weights | Low variance |
| `rollout_is_p25` | 25th percentile | ~0.8-1.0 |
| `rollout_is_p50` | Median IS weight | ~1.0 |
| `rollout_is_p75` | 75th percentile | ~1.0-1.2 |
| `rollout_is_p95` | 95th percentile | < threshold |
| `rollout_is_p99` | 99th percentile | < threshold |
| `rollout_is_max` | Maximum weight | ≤ threshold |
| `rollout_is_min` | Minimum weight | ≥ lower threshold (clip mode) |
| `rollout_is_veto_fraction` | % sequences vetoed | < 5% |
| `rollout_is_catastrophic_token_fraction` | % catastrophic tokens | <
1% |
| `rollout_is_clipped_fraction` | % tokens clipped (clip mode) |
Variable |

### Mismatch Metrics (all prefixed with `mismatch/`)

| Metric | Description | What It Means |
|--------|-------------|---------------|
| `mismatch_kl` | Forward KL divergence | Distribution difference
(rollout vs training) |
| `mismatch_k3_kl` | K3 KL estimator | Stable KL estimate for small
divergences |
| `mismatch_training_ppl` | Training policy perplexity | Prediction
difficulty of training policy |
| `mismatch_rollout_ppl` | Rollout policy perplexity | Prediction
difficulty of rollout policy |
| `mismatch_ppl_ratio` | Ratio of training to rollout PPL | Relative
prediction difficulty |
| `mismatch_log_ppl_diff` | Log perplexity difference | Sequence-level
PPL mismatch |
| `mismatch_log_ppl_abs_diff` | Absolute log PPL difference | Magnitude
of mismatch |
| `mismatch_log_ppl_diff_max` | Max log PPL difference | Worst-case
mismatch |
| `mismatch_log_ppl_diff_min` | Min log PPL difference | Best-case
mismatch |
| `mismatch_training_log_ppl` | Log of training PPL | Log-scale training
perplexity |
| `mismatch_rollout_log_ppl` | Log of rollout PPL | Log-scale rollout
perplexity |

---

## Performance Impact

### Memory
- Minimal overhead: ~1-2% increase in peak memory usage
- Efficient log-space computation
- No large intermediate tensors

### Computation
- Negligible impact on training speed: < 1% overhead
- Centralized computation on driver (no per-worker redundancy)
- Optimized tensor operations

### Training Stability
- Significant improvement in stability when distribution mismatch exists
- Faster convergence in many scenarios
- Reduced risk of training collapse

---

## Breaking Changes

> [!IMPORTANT]
> This PR contains **BREAKING CHANGES** to the configuration API.

### Removed
- `actor_rollout_ref.actor.tis_imp_ratio_cap`: No longer supported

### Migration Required
All users of the legacy TIS implementation must update their
configuration files. See the migration guide above or
`docs/advance/rollout_is_migration.md` for detailed instructions.

### Backward Compatibility
- No backward compatibility with legacy TIS
- Configuration files with `tis_imp_ratio_cap` will raise validation
errors
- Affected recipes have been updated in this PR

---

## Pre-Submission Checklist

- [x] Search for similar PRs:
[https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling](https://github.com/volcengine/verl/pulls?q=is%3Apr+importance+sampling)
- [x] Format PR title as `[{modules}] {type}: {description}` (checked by
CI)
- **Suggested title:** `[BREAKING][rollout, trainer, algo] feat:
implement comprehensive Rollout Importance Sampling framework`
- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md)
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting)
- [x] Add/update
[documentation](https://github.com/volcengine/verl/tree/main/docs) (3
new docs, 2 updated)
- [x] Add unit and integration tests (530 lines of tests)
- [x] Once PR is ready for CI, send message in `ci-request` channel

---

## References

- **Blog post:** [When Speed Kills Stability: Demystifying RL Collapse
from the Inference-Training
Mismatch](https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda)
- **Migration guide:** `docs/advance/rollout_is_migration.md`
- **Examples:** `examples/rollout_importance_sampling/`
- **Tests:** `tests/trainer/ppo/test_rollout_is*.py`

---------

Co-authored-by: Yan Bai <bayan@nvidia.com>
2025-10-13 17:05:29 +08:00
7f27789961 [fsdp,doc] refactor: rename warmup_style@FSDPOptimizerConfig -> lr_scheduler_type (#3739)
### What does this PR do?

> Rename `warmup_style` in FSDPOptimizerConfig to `lr_scheduler_type` to
align with Hugging Face Trainer API。

The following pull request is for refactoring the optimizer, however,
the naming issue persists.
https://github.com/volcengine/verl/pull/3656 
### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

---------

Co-authored-by: weiqi.li <weiqi.li@bytedance.com>
2025-10-13 15:58:59 +08:00
e9ee6b39c6 [model] fix: qwen3vl models shape mismatch error with SP (#3735) 2025-10-13 13:09:10 +08:00
9d4554b931 [model] fix: qwen3vl training stuck with mixed text-image data (#3734) 2025-10-13 13:08:13 +08:00
71cf69e7ad [ci] feat: increase sft e2e time (#3738)
### What does this PR do?

- As title

### Checklist Before Starting

- [ ] Search for similar PRs. Paste at least one query link here: ...
- [ ] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
2025-10-13 11:29:39 +08:00
69 changed files with 3399 additions and 175 deletions

View File

@ -91,7 +91,7 @@ jobs:
e2e_sft:
needs: setup
runs-on: ["${{ needs.setup.outputs.runner-label || 'L20x8' }}"]
timeout-minutes: 25 # Increase this timeout value as needed
timeout-minutes: 30 # Increase this timeout value as needed
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}

View File

@ -208,6 +208,7 @@ jobs:
- name: Running mcore engine tests on 8 L20 GPUs
run: |
ray stop --force
pytest -s -x tests/models/test_engine.py
cleanup:

View File

@ -238,6 +238,9 @@ verl is inspired by the design of Nemo-Aligner, Deepspeed-chat and OpenRLHF. The
- [Vision-SR1](https://github.com/zli12321/Vision-SR1): Self-Rewarding Vision-Language Model via Reasoning Decomposition ![GitHub Repo stars](https://img.shields.io/github/stars/zli12321/Vision-SR1)
- [SimpleVLA-RL](https://github.com/PRIME-RL/SimpleVLA-RL): SimpleVLA-RL: A Simple yet Effective Vision-Language Action Model for Reinforcement Learning ![GitHub Repo stars](https://img.shields.io/github/stars/PRIME-RL/SimpleVLA-RL)
- [Table-R1](https://github.com/Table-R1/Table-R1): Table-R1: Inference-Time Scaling for Table Reasoning ![GitHub Repo stars](https://img.shields.io/github/stars/Table-R1/Table-R1)
- [Revisual-R1](https://github.com/CSfufu/Revisual-R1): Revisual-R1: Advancing Multimodal Reasoning From Optimized Cold Start to Staged Reinforcement Learning ![GitHub Repo stars](https://img.shields.io/github/stars/CSfufu/Revisual-R1)
- [ARES](https://github.com/shawn0728/ARES): ARES: Multimodal Adaptive Reasoning via Difficulty-Aware Token-Level Entropy Shaping ![GitHub Repo stars](https://img.shields.io/github/stars/shawn0728/ARES)
- [Meta-Bandit-LLM](https://github.com/sanxing-chen/meta-bandit-llm): Meta-Bandit-LLM: Long-horizon multiturn interactive training for meta-bandit agents ![GitHub Repo stars](https://img.shields.io/github/stars/sanxing-chen/meta-bandit-llm)
and many more awesome work listed in [recipe](recipe/README.md).

View File

@ -36,6 +36,8 @@ For vLLM with FSDP, please refer to [hiyouga/verl](https://hub.docker.com/r/hiyo
For SGLang with FSDP, please refer to [ocss884/verl-sglang](https://hub.docker.com/r/ocss884/verl-sglang) repository and the latest version is ``ocss884/verl-sglang:ngc-th2.6.0-cu126-sglang0.4.6.post5`` which is provided by SGLang RL Group.
For latest vLLM with Megatron, please refer to [iseekyan/verl](https://hub.docker.com/r/iseekyan/verl) repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
See files under ``docker/`` for NGC-based image or if you want to build your own.
Note that For aws instances with EFA net interface (Sagemaker AI Pod), you need to install EFA driver as shown in ``docker/Dockerfile.extenstion.awsefa``

View File

@ -0,0 +1,15 @@
FROM nvcr.io/nvidia/nemo:25.07.gpt_oss
RUN git clone -b v0.11.0 --depth 1 https://github.com/vllm-project/vllm.git /opt/vllm
RUN pip install setuptools_scm
RUN cd /opt/vllm && pip install --no-deps --no-build-isolation --no-cache-dir -e .
RUN pip install cbor2 setproctitle blake3 openai_harmony pybase64 msgspec partial_json_parser py-cpuinfo diskcache gguf
RUN pip install --upgrade transformers tokenizers
RUN pip install codetiming tensordict mathruler pylatexenc
RUN pip3 install --no-cache-dir mbridge

View File

@ -0,0 +1,642 @@
# Rollout Importance Sampling - Migration Guide
Last updated: 10/11/2025.
This document provides a comprehensive overview of the Rollout Importance Sampling (IS) implementation merged from aiic_verl into verl.
## References
- **When Speed Kills Stability**: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- **Off-policy RL**: https://fengyao.notion.site/off-policy-rl
## Overview
Rollout Importance Sampling corrects for distribution mismatch between:
- **Rollout policy**: e.g., vLLM with BFloat16
- **Training policy**: e.g., FSDP with FP32
This mismatch can lead to biased gradient estimates and unstable training. Rollout IS applies importance sampling weights to correct these biases.
## What Changed
### **Removed (Old Implementation)**
```yaml
# Old TIS configuration (REMOVED)
actor:
tis_imp_ratio_cap: 2.0 # ❌ No longer supported
```
The old implementation:
- Only supported token-level truncate mode
- Had no metrics tracking
- Lacked numerical stability safeguards
- No configurability for different scenarios
### **Added (New Implementation)**
```yaml
# New Rollout IS configuration (all in algorithm config)
algorithm:
# Main control: set threshold to enable (null = disabled)
rollout_is_threshold: 2.0
# Whether to apply weights to loss (default: false = metrics only)
rollout_is: true
rollout_is_threshold_lower: null # Auto-reciprocal
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
# REQUIRED: Enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
```
The new implementation:
- ✅ Three aggregation levels: token, sequence, geometric
- ✅ Two bounding modes: truncate, mask
- ✅ Dual threshold support (upper/lower)
- ✅ Veto mechanism for catastrophic outliers
- ✅ 30+ comprehensive metrics
- ✅ Log-space computation for numerical stability
- ✅ Memory-efficient implementation
## Files Modified
### **Core Implementation**
1. **NEW**: `verl/trainer/ppo/mismatch_helper.py`
- Contains `compute_rollout_importance_weights()` - main function
- Contains `compute_is_metrics()` - comprehensive metrics
2. **MODIFIED**: `verl/trainer/ppo/core_algos.py` (lines 962-991)
- Replaced old TIS implementation (lines 962-967)
- Added new rollout IS with metrics support
3. **MODIFIED**: `verl/workers/actor/dp_actor.py`
- Updated to use `rollout_is_threshold` instead of `tis_imp_ratio_cap`
- Collects and logs all rollout IS metrics
### **Configuration Files**
4. **MODIFIED**: `verl/trainer/config/algorithm.py` (lines 95-100)
- Added 6 new rollout IS parameters to `AlgoConfig`
5. **MODIFIED**: `verl/workers/config/actor.py` (lines 110-115)
- Added 6 new rollout IS parameters to `ActorConfig`
6. **MODIFIED**: `verl/trainer/config/actor/actor.yaml` (lines 77-89)
- Added rollout IS configuration section
7. **MODIFIED**: `verl/trainer/config/ppo_trainer.yaml` (lines 116-133)
- Added rollout IS to algorithm config
### **Documentation**
8. **MODIFIED**: `docs/examples/config.rst`
- Updated actor config with rollout IS parameters
- Updated algorithm config with rollout IS parameters
- Added detailed parameter descriptions
### **Example Scripts**
9. **MODIFIED**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
- Updated from `tis_imp_ratio_cap` to rollout IS parameters
- Added comprehensive comments
10. **NEW**: `examples/rollout_importance_sampling/README.md`
- Comprehensive guide with usage patterns
- Troubleshooting section
- Performance considerations
11. **NEW**: `examples/rollout_importance_sampling/run_with_rollout_is.sh`
- Basic example with token-level truncate
### **Tests**
12. **NEW**: `tests/trainer/ppo/test_rollout_is.py`
- Unit tests for rollout IS functionality
13. **NEW**: `tests/trainer/ppo/test_rollout_is_integration.py`
- Integration tests with PPO
## Configuration Parameters
### `algorithm.rollout_is_threshold` (float or null)
**Main on/off switch.** Upper threshold for IS weights.
- `null` = disabled (no computation, no metrics)
- `float` value (e.g., 2.0) = enabled (compute weights and metrics)
### `algorithm.rollout_is` (bool)
Whether to apply IS weights to policy loss. Default: `False`
- `true` = apply weights to loss (full IS correction)
- `false` = compute metrics only (useful for monitoring before enabling)
**Recommended threshold ranges:**
- Token level: 1.5 - 5.0
- Sequence level: 2.0 - 10.0
- Geometric level: 1.0002 - 1.001
### `algorithm.rollout_is_threshold_lower` (float or null)
Lower threshold for IS weights. If `null`, defaults to 1/upper (reciprocal).
### `algorithm.rollout_is_level` (str)
Aggregation level for IS weights:
- `"token"`: Per-token ratios
- `"sequence"`: Product of ratios
- `"geometric"`: Geometric mean (experimental)
### `algorithm.rollout_is_mode` (str)
Bounding mode:
- `"truncate"`: Cap weights at upper threshold only
- `"mask"`: Zero out weights outside [lower, upper]
### `algorithm.rollout_is_veto_threshold` (float)
Per-token veto threshold. If any token ratio < this, entire sequence is rejected.
Default: `1e-4` (ratio 10,000x off)
## Migration Steps
### Step 1: Update Your Configuration
**Before (Old):**
```yaml
actor_rollout_ref:
actor:
tis_imp_ratio_cap: 2.0
rollout:
calculate_log_probs: true
```
**After (New):**
```yaml
algorithm:
rollout_is_threshold: 2.0 # Main control
rollout_is: true # Apply to loss (default: false)
rollout_is_level: token
rollout_is_mode: truncate
actor_rollout_ref:
rollout:
calculate_log_probs: true # Still required!
```
### Step 2: Monitor New Metrics
All metrics are prefixed with `mismatch/`. For example, `rollout_is_mean` appears as `mismatch/rollout_is_mean` in logs.
#### **Core IS Weight Metrics**
- **`rollout_is_mean`**: Mean importance sampling weight across all valid tokens
- **Ideal value**: Close to 1.0 (indicates minimal distribution mismatch)
- **Warning**: < 0.5 or > 2.0 suggests significant policy mismatch
- **`rollout_is_std`**: Standard deviation of IS weights
- **Ideal value**: < 0.5 for stable training
- **Warning**: > 1.0 indicates high variance, may need tighter thresholds
- **`rollout_is_min`**: Minimum IS weight observed
- Shows the most underweighted token/sequence
- **`rollout_is_max`**: Maximum IS weight observed (before truncation/masking)
- Shows the most overweighted token/sequence
- Compare with `rollout_is_threshold` to see truncation impact
#### **Percentile Metrics**
- **`rollout_is_p25`**: 25th percentile of IS weights
- **`rollout_is_p50`**: Median IS weight (50th percentile)
- Should be close to `rollout_is_mean` if distribution is symmetric
- **`rollout_is_p75`**: 75th percentile of IS weights
- **`rollout_is_p95`**: 95th percentile of IS weights
- Use to detect outliers
- **`rollout_is_p99`**: 99th percentile of IS weights
- Should be close to `rollout_is_threshold` if truncation is working
#### **Effective Sample Size**
- **`rollout_is_eff_sample_size`**: Effective sample size after IS weighting
- **Formula**: `1 / mean(weights²)` where weights are normalized
- **Range**: 0.0 to 1.0 (as fraction of original batch)
- **Ideal value**: > 0.5 (retaining at least 50% effective samples)
- **Warning**: < 0.3 means high variance, losing too many effective samples
#### **Veto Mechanism Metrics**
- **`rollout_is_veto_fraction`**: Fraction of sequences rejected by veto mechanism
- **Ideal value**: < 0.05 (less than 5% vetoed)
- **Warning**: > 0.1 suggests policies are too different or numerical issues
- **`rollout_is_catastrophic_token_fraction`**: Fraction of tokens below veto threshold
- Identifies problematic tokens before sequence-level veto
- **Warning**: > 0.01 indicates widespread distribution issues
#### **Threshold Exceedance Metrics**
- **`rollout_is_ratio_fraction_high`**: Fraction of weights exceeding upper threshold
- Shows how often truncation/masking occurs on high end
- **Ideal value**: < 0.1 (most weights within bounds)
- **`rollout_is_ratio_fraction_low`**: Fraction of weights below lower threshold
- Shows how often masking occurs on low end (mask mode only)
- **Ideal value**: < 0.1
#### **Sequence-Level Metrics** (for sequence/geometric modes)
- **`rollout_is_seq_mean`**: Mean IS weight at sequence level
- Should match `rollout_is_mean` for sequence-level aggregation
- **`rollout_is_seq_std`**: Standard deviation of sequence-level IS weights
- **`rollout_is_seq_min`**: Minimum sequence-level IS weight
- **`rollout_is_seq_max`**: Maximum sequence-level IS weight
- **`rollout_is_seq_max_deviation`**: Maximum absolute deviation from 1.0 at sequence level
- **Ideal value**: < 1.0
- Shows worst-case sequence mismatch
- **`rollout_is_seq_fraction_high`**: Fraction of sequences exceeding upper threshold
- **`rollout_is_seq_fraction_low`**: Fraction of sequences below lower threshold
#### **Masking Metrics** (mask mode only)
- **`rollout_is_masked_fraction`**: Fraction of tokens masked (set to zero)
- **Ideal value**: < 0.1
- **Warning**: > 0.3 means losing too much data
- **`rollout_is_seq_masked_fraction`**: Fraction of sequences with at least one masked token
- Shows sequence-level impact of masking
#### **Distribution Mismatch Metrics** (Training vs Rollout Policy)
- **`mismatch_training_ppl`**: Perplexity of training policy (e.g., FSDP FP32)
- **Formula**: `exp(-mean(log_probs))`
- Lower is better (model is more confident)
- **`mismatch_rollout_ppl`**: Perplexity of rollout policy (e.g., vLLM BF16)
- Should be close to `mismatch_training_ppl` if policies match well
- **`mismatch_ppl_ratio`**: Ratio of training PPL to rollout PPL
- **Formula**: `exp(mean(log(training_ppl / rollout_ppl)))`
- **Ideal value**: Close to 1.0
- **Meaning**: > 1.0 means training is less confident than rollout
- **`mismatch_training_log_ppl`**: Log perplexity of training policy
- Useful for identifying trends (linear scale)
- **`mismatch_rollout_log_ppl`**: Log perplexity of rollout policy
- **`mismatch_log_ppl_diff`**: Mean difference in log perplexities
- **Formula**: `mean(log_ppl_rollout - log_ppl_training)`
- **Ideal value**: Close to 0.0
- Sign indicates which policy is more confident
- **`mismatch_log_ppl_abs_diff`**: Mean absolute log perplexity difference
- Magnitude of mismatch regardless of direction
- **`mismatch_log_ppl_diff_max`**: Maximum log perplexity difference across sequences
- Identifies worst-case sequence
- **`mismatch_log_ppl_diff_min`**: Minimum log perplexity difference across sequences
- **`mismatch_kl`**: KL divergence KL(π_rollout || π_training)
- **Formula**: `mean(log_prob_rollout - log_prob_training)`
- **Ideal value**: Close to 0.0 (policies match)
- **Warning**: > 0.1 indicates significant mismatch
- **Note**: Can be negative (rollout is less confident)
- **`mismatch_k3_kl`**: K3 KL estimator
- **Formula**: `mean(exp(log_ratio) - log_ratio - 1)`
- More stable for small KL values
- Always non-negative
#### **Example: Accessing Metrics in Code**
```python
# Metrics are returned from compute_rollout_importance_weights
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=training_log_probs, # from training policy
rollout_log_prob=rollout_log_probs, # from rollout policy
response_mask=response_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# All metrics have 'mismatch/' prefix
print(f"Mean IS weight: {metrics['mismatch/rollout_is_mean']:.3f}")
print(f"Effective sample size: {metrics['mismatch/rollout_is_eff_sample_size']:.3f}")
print(f"Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.3f}")
print(f"KL divergence: {metrics['mismatch/mismatch_kl']:.3f}")
# Check for warning conditions
if metrics['mismatch/rollout_is_mean'] < 0.5 or metrics['mismatch/rollout_is_mean'] > 2.0:
print("⚠️ Warning: Mean IS weight far from 1.0, significant policy mismatch detected")
if metrics['mismatch/rollout_is_eff_sample_size'] < 0.3:
print("⚠️ Warning: Low effective sample size, high variance in IS weights")
if metrics['mismatch/rollout_is_veto_fraction'] > 0.1:
print("⚠️ Warning: High veto fraction, policies may be too different")
```
#### **Example: Monitoring Metrics During Training**
```python
# In your training loop
for epoch in range(num_epochs):
for batch_idx, batch in enumerate(dataloader):
# ... rollout phase ...
# Compute IS weights and get metrics
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=batch.old_log_prob,
rollout_log_prob=batch.rollout_log_prob,
response_mask=batch.response_mask,
rollout_is_level=config.rollout_is_level,
rollout_is_mode=config.rollout_is_mode,
rollout_is_threshold=config.rollout_is_threshold,
rollout_is_veto_threshold=config.rollout_is_veto_threshold,
)
# Log to tensorboard/wandb
for metric_name, metric_value in metrics.items():
logger.log_scalar(metric_name, metric_value, step=global_step)
# Use IS weights in training
is_weights = weights_proto.batch["rollout_is_weights"]
# ... apply weights to policy gradient ...
```
#### **Example: Conditional Alerting Based on Metrics**
```python
def check_rollout_is_health(metrics, config):
"""Check if rollout IS metrics indicate healthy training."""
warnings = []
# Check mean IS weight
mean_weight = metrics['mismatch/rollout_is_mean']
if mean_weight < 0.5 or mean_weight > 2.0:
warnings.append(f"Mean IS weight {mean_weight:.3f} is far from 1.0")
# Check effective sample size
ess = metrics['mismatch/rollout_is_eff_sample_size']
if ess < 0.3:
warnings.append(f"Effective sample size {ess:.3f} is too low")
# Check veto fraction
veto_frac = metrics['mismatch/rollout_is_veto_fraction']
if veto_frac > 0.1:
warnings.append(f"Veto fraction {veto_frac:.3f} is too high")
# Check variance
std = metrics['mismatch/rollout_is_std']
if std > 1.0:
warnings.append(f"IS weight std {std:.3f} is too high")
# Check KL divergence
kl = metrics['mismatch/mismatch_kl']
if abs(kl) > 0.1:
warnings.append(f"KL divergence {kl:.3f} indicates significant mismatch")
if warnings:
print("⚠️ Rollout IS Health Warnings:")
for warning in warnings:
print(f" - {warning}")
return False
else:
print("✅ Rollout IS metrics look healthy")
return True
# Use in training
_, metrics = compute_rollout_importance_weights(...)
is_healthy = check_rollout_is_health(metrics, config)
if not is_healthy:
# Consider adjusting config or investigating issues
print("Consider:")
print(" - Tightening rollout_is_threshold")
print(" - Switching to geometric aggregation level")
print(" - Checking if rollout and training policies are too different")
```
### Step 3: Test Your Training
Start with the basic token-level truncate configuration:
```bash
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```
Monitor metrics for 1-2 epochs before adjusting parameters.
## Configuration Examples
### Example 1: Full IS Correction
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: true # Apply weights to loss
rollout_is_level: token
rollout_is_mode: truncate
```
### Example 2: Metrics Only (Monitoring Mode)
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: false # Compute metrics, don't apply weights
rollout_is_level: token
rollout_is_mode: truncate
```
### Example 3: Geometric Mean with Mask
```yaml
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: mask
```
### Example 4: Asymmetric Thresholds
```yaml
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: mask
```
## Troubleshooting
### Issue: High variance in IS weights
**Symptoms:** `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
**Solutions:**
1. Switch from `sequence` to `geometric` level
2. Tighten thresholds
3. Verify rollout and training aren't too different
### Issue: Too many sequences vetoed
**Symptoms:** `rollout_is_veto_fraction` > 0.1
**Solutions:**
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
2. Check for numerical issues in log prob computation
3. Verify policies aren't completely different
### Issue: Mean IS weight far from 1.0
**Symptoms:** `rollout_is_mean` < 0.5 or > 2.0
**Solutions:**
1. Verify `calculate_log_probs=True` is set
2. Check rollout_log_probs are correctly passed
3. Check for systematic bias
### Debugging: Visualizing Metrics
**Example: Plot IS weight distribution**
```python
import matplotlib.pyplot as plt
import numpy as np
def plot_is_metrics(metrics_history):
"""Plot rollout IS metrics over training steps."""
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
# Plot 1: Mean IS weight over time
axes[0, 0].plot(metrics_history['mismatch/rollout_is_mean'])
axes[0, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[0, 0].set_title('Mean IS Weight')
axes[0, 0].set_xlabel('Step')
axes[0, 0].legend()
# Plot 2: Effective sample size
axes[0, 1].plot(metrics_history['mismatch/rollout_is_eff_sample_size'])
axes[0, 1].axhline(y=0.5, color='g', linestyle='--', label='Good')
axes[0, 1].axhline(y=0.3, color='r', linestyle='--', label='Warning')
axes[0, 1].set_title('Effective Sample Size')
axes[0, 1].set_xlabel('Step')
axes[0, 1].legend()
# Plot 3: Veto fraction
axes[0, 2].plot(metrics_history['mismatch/rollout_is_veto_fraction'])
axes[0, 2].axhline(y=0.1, color='r', linestyle='--', label='Warning')
axes[0, 2].set_title('Veto Fraction')
axes[0, 2].set_xlabel('Step')
axes[0, 2].legend()
# Plot 4: IS weight distribution (latest step)
latest_idx = -1
percentiles = [25, 50, 75, 95, 99]
values = [metrics_history[f'mismatch/rollout_is_p{p}'][latest_idx] for p in percentiles]
axes[1, 0].bar([f'p{p}' for p in percentiles], values)
axes[1, 0].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 0].set_title('IS Weight Percentiles (Latest)')
axes[1, 0].legend()
# Plot 5: KL divergence over time
axes[1, 1].plot(metrics_history['mismatch/mismatch_kl'], label='KL')
axes[1, 1].plot(metrics_history['mismatch/mismatch_k3_kl'], label='K3 KL')
axes[1, 1].axhline(y=0, color='g', linestyle='--', alpha=0.3)
axes[1, 1].set_title('KL Divergence')
axes[1, 1].set_xlabel('Step')
axes[1, 1].legend()
# Plot 6: PPL ratio over time
axes[1, 2].plot(metrics_history['mismatch/mismatch_ppl_ratio'])
axes[1, 2].axhline(y=1.0, color='r', linestyle='--', label='Ideal')
axes[1, 2].set_title('PPL Ratio (Training/Rollout)')
axes[1, 2].set_xlabel('Step')
axes[1, 2].legend()
plt.tight_layout()
plt.savefig('rollout_is_metrics.png', dpi=150)
print("Saved plot to rollout_is_metrics.png")
```
**Example: Metric collection during training**
```python
# Collect metrics over time
metrics_history = {
'mismatch/rollout_is_mean': [],
'mismatch/rollout_is_eff_sample_size': [],
'mismatch/rollout_is_veto_fraction': [],
'mismatch/rollout_is_p25': [],
'mismatch/rollout_is_p50': [],
'mismatch/rollout_is_p75': [],
'mismatch/rollout_is_p95': [],
'mismatch/rollout_is_p99': [],
'mismatch/mismatch_kl': [],
'mismatch/mismatch_k3_kl': [],
'mismatch/mismatch_ppl_ratio': [],
}
# In training loop
for step in range(num_steps):
# ... compute IS weights ...
_, metrics = compute_rollout_importance_weights(...)
# Store metrics
for key in metrics_history.keys():
if key in metrics:
metrics_history[key].append(metrics[key])
# Plot every 100 steps
if step % 100 == 0:
plot_is_metrics(metrics_history)
```
## Performance Impact
- **Memory overhead**: ~1% of model memory
- **Computational overhead**: 1-3% depending on level
- **Training stability**: Significantly improved when mismatch exists
## Backward Compatibility
**The old `tis_imp_ratio_cap` parameter is completely removed.** There is no backward compatibility mode.
All scripts and configurations must be updated to use the new rollout IS parameters.
## Testing
Run the test suite to verify everything works:
```bash
# Basic unit tests
python test_rollout_is.py
# Integration tests (if pytest is available)
pytest tests/trainer/ppo/test_rollout_is_integration.py -v
```
Expected output: All tests pass ✓
## Additional Resources
- **Implementation**: `verl/trainer/ppo/mismatch_helper.py`
- **Examples**: `examples/rollout_importance_sampling/`
- **DAPO Example**: `recipe/dapo/run_dapo_qwen2.5_32b_tis.sh`
## Summary
The new Rollout Importance Sampling implementation provides:
- ✅ More robust handling of distribution mismatch
- ✅ Better numerical stability
- ✅ Comprehensive metrics for monitoring
- ✅ Flexibility for different scenarios
- ✅ Memory-efficient computation
Migration is straightforward: replace `tis_imp_ratio_cap` with the new `rollout_is_*` parameters in the `algorithm` config section.

View File

@ -118,7 +118,13 @@ Actor/Rollout/Reference Policy
clip_ratio: 0.2
entropy_coeff: 0.0
use_kl_loss: False # True for GRPO
tis_imp_ratio_cap: -1 # set to positive values for Truncated Importance Sampling (requires setting `rollout.calculate_log_probs` as True)
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
rollout_is: False # Enable IS correction
rollout_is_threshold: null # Upper threshold for IS weights (null to disable)
rollout_is_threshold_lower: null # Lower threshold (null = auto 1/upper)
rollout_is_level: token # Aggregation: token/sequence/geometric
rollout_is_mode: truncate # Bounding: truncate/mask
rollout_is_veto_threshold: 1e-4 # Catastrophic outlier threshold
use_torch_compile: True # False to disable torch compile
kl_loss_coef: 0.001 # for grpo
kl_loss_type: low_var_kl # for grpo
@ -132,7 +138,7 @@ Actor/Rollout/Reference Policy
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: 0.0 # only used with cosine lr scheduler, default to 0.0
num_cycles: 0.5 # only used with cosine lr scheduler, default to 0.5
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
@ -415,7 +421,7 @@ ____________________________________________________
Notice that there are some differences in APIs between Megatron optimizer and FSDP optimizer.
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``warmup_style`` actually means the style of lr decay after warmup.
- Megatron optimizer scheduler names the period after lr_warmup as lr_decay_steps, so the ``lr_scheduler_type`` actually means the style of lr decay after warmup.
- Megatron optimizer also support weight decay decay mechanism
- ``use_checkpoint_opt_param_scheduler`` determines whether to use the checkpoint optimizer parameter scheduler. If set to True, the optimizer parameter scheduler will be saved in the checkpoint and loaded from the checkpoint during resuming training.
@ -498,6 +504,13 @@ Algorithm
kl_coef: 0.005
horizon: 10000
target_kl: 0.1
# Rollout Importance Sampling
rollout_is: False
rollout_is_threshold: null
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
- ``gamma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
@ -510,6 +523,13 @@ Algorithm
- ``kl_coef``: The (initial) coefficient of in-reward kl_penalty. Default is 0.001.
- ``type``: 'fixed' for FixedKLController and 'adaptive' for AdaptiveKLController.
- ``horizon`` and ``target_kl``: See source code of AdaptiveKLController for details.
- ``rollout_is``: Whether to enable rollout importance sampling correction. Default is False.
- ``rollout_is_threshold``: Upper threshold for IS weights. Set to ``null`` to disable IS completely.
- ``rollout_is_threshold_lower``: Lower threshold for IS weights. If ``null``, defaults to reciprocal of upper (1/upper).
- ``rollout_is_level``: Aggregation level: ``token`` (biased), ``sequence`` (unbiased), or ``geometric`` (experimental).
- ``rollout_is_mode``: Bounding mode: ``truncate`` (cap upper only) or ``mask`` (zero outside bounds).
- ``rollout_is_veto_threshold``: Per-token veto threshold for catastrophic outliers. Default is 1e-4.
Note: Rollout IS requires setting ``actor_rollout_ref.rollout.calculate_log_probs=True``.
Trainer
~~~~~~~

View File

@ -121,6 +121,7 @@ verl is fast with:
examples/sandbox_fusion_example
advance/rollout_trace.rst
advance/rollout_skip.rst
advance/rollout_is_migration.md
advance/one_step_off
advance/agent_loop

View File

@ -79,7 +79,7 @@ For latest vLLM with FSDP, please refer to `hiyouga/verl <https://hub.docker.com
For latest SGLang with FSDP, please refer to `hebiaobuaa/verl <https://hub.docker.com/r/hebiaobuaa/verl>`_ repository and the latest version is ``hebiaobuaa/verl:app-verl0.5-sglang0.4.9.post6-mcore0.12.2-te2.2`` which is provided by SGLang RL Group.
For latest vLLM with Megatron, please refer to `iseekyan/verl:app-verl0.5-transformers4.55.4-vllm0.10.0-mcore0.15.0-te2.7`
For latest vLLM with Megatron, please refer to `iseekyan/verl <https://hub.docker.com/r/iseekyan/verl>`_ repository and the latest version is ``iseekyan/verl:nemo.gptoss_vllm0.11.0``.
See files under ``docker/`` for NGC-based image or if you want to build your own.

View File

@ -6,21 +6,20 @@
This is the official implementaion of paper [***Geometric-Mean Policy Optimization***](https://arxiv.org/abs/2507.20673).
<div align=center>
<img width="3092" height="864" alt="image" src="https://github.com/user-attachments/assets/af4c7e0f-923a-45ef-9bcf-57109b8ee61e" />
<img width="3092" height="864" alt="image" src="https://github.com/user-attachments/assets/20b04c4e-7ee8-4775-9af8-33c0158336e2" />
</div>
## 1. Contents
- Geometric-Mean Policy Optimization
- [1. Contents](#1-contents)
- [2. Introduction](#2-introduction)
- [3. Code Usage](#4-code-usage)
- [4. Contacts](#5-contacts)
- [5. Citation](#7-citation)
- [3. Code Usage](#3-code-usage)
- [4. Contacts](#4-contacts)
- [5. Citation](#5-citation)
## 2. Introduction
Recent advancements, such as Group Relative Policy Optimization (GRPO), have enhanced the reasoning capabilities of large language models by optimizing the arithmetic mean of token-level rewards. However, GRPO suffers from unstable policy updates when processing tokens with outlier importance-weighted rewards, which manifests as extreme importance sampling ratios during training, i.e., the ratio between the sampling probabilities assigned to a token by the current and old policies. In this work, we propose Geometric-Mean Policy Optimization (GMPO), a stabilized variant of GRPO. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. In addition, we provide comprehensive theoretical and experimental analysis to justify the design and stability benefits of GMPO. Beyond improved stability, GMPO-7B outperforms GRPO by an average of 4.1% on multiple mathematical benchmarks and 1.4% on multimodal reasoning benchmark, including AIME24, AMC, MATH500, OlympiadBench, Minerva, and Geometry3K.
Group Relative Policy Optimization (GRPO) has significantly enhanced the reasoning capability of large language models by optimizing the arithmetic mean of token-level rewards. Unfortunately, GRPO is observed to suffer from unstable policy updates when facing tokens with outlier importance-weighted rewards, which manifest as extreme importance sampling ratios during training. In this study, we propose Geometric-Mean Policy Optimization (GMPO), with the aim to improve the stability of GRPO through suppressing token reward outliers. Instead of optimizing the arithmetic mean, GMPO maximizes the geometric mean of token-level rewards, which is inherently less sensitive to outliers and maintains a more stable range of importance sampling ratio. GMPO is plug-and-play—simply replacing GRPO's arithmetic mean with the geometric mean of token-level rewards, as the latter is inherently less sensitive to outliers. GMPO is theoretically plausible—analysis reveals that both GMPO and GRPO are weighted forms of the policy gradient while the former enjoys more stable weights, which consequently benefits policy optimization and performance. Experiments on multiple mathematical reasoning benchmarks show that GMPO-7B improves the average Pass@1 of GRPO by up to 4.1%, outperforming many state-of-the-art approaches.
## 3. Code Usage
@ -30,7 +29,7 @@ clip_ratio_low=0.4
clip_ratio_high=0.4
loss_mode=geo_mean
```
We observed that using a large clip ratio during Mixture-of-Experts (MoE) model training often leads to optimization instability. When training MoE models, consider lowering the clip ratio to achieve more stable convergence.
To get started quickly, run:
```
bash examples/gmpo_trainer/run_qwen2_5-7b_math.sh
@ -51,13 +50,10 @@ If you have any question about our work or this repository, please don't hesitat
## 5. Citation
```
@misc{zhao2025geometricmeanpolicyoptimization,
title={Geometric-Mean Policy Optimization},
author={Yuzhong Zhao and Yue Liu and Junpeng Liu and Jingye Chen and Xun Wu and Yaru Hao and Tengchao Lv and Shaohan Huang and Lei Cui and Qixiang Ye and Fang Wan and Furu Wei},
year={2025},
eprint={2507.20673},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2507.20673},
@article{zhao2025geometric,
title={Geometric-mean policy optimization},
author={Zhao, Yuzhong and Liu, Yue and Liu, Junpeng and Chen, Jingye and Wu, Xun and Hao, Yaru and Lv, Tengchao and Huang, Shaohan and Cui, Lei and Ye, Qixiang and others},
journal={arXiv preprint arXiv:2507.20673},
year={2025}
}
```

View File

@ -0,0 +1,79 @@
set -x
ENGINE=${1:-vllm}
export CUDA_DEVICE_MAX_CONNECTIONS=1 # For megatron communication/computation overlapping
# VLLM version >= 0.11.0 for qwen3-vl support, recommend to use container docker://iseekyan/verl:nemo.gptoss_vllm0.11.0
# pip install -U git+https://github.com/ISEEKYAN/mbridge.git # for latest mbridge
# pip install -U transformers # for qwen3-vl support
# pip install --no-deps --no-cache-dir git+https://github.com/NVIDIA/Megatron-LM.git@core_v0.13.1 # for megatron-lm0.13.1
export VLLM_ALLREDUCE_USE_SYMM_MEM=0 # for vllm0.11.0 with TP
HF_MODEL_PATH=${HF_MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-VL-30B-A3B-Instruct"}
train_path=$HOME/data/geo3k/train.parquet
test_path=$HOME/data/geo3k/test.parquet
python3 -m verl.trainer.main_ppo --config-path=config \
--config-name='ppo_megatron_trainer.yaml'\
algorithm.adv_estimator=grpo \
data.train_files="$train_path" \
data.val_files="$test_path" \
data.train_batch_size=512 \
data.max_prompt_length=1024 \
data.max_response_length=2048 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.model.path=$HF_MODEL_PATH \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.megatron.expert_model_parallel_size=8 \
actor_rollout_ref.actor.megatron.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.01 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.rollout.tensor_model_parallel_size=4 \
actor_rollout_ref.actor.use_dynamic_bsz=True \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=5120 \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=20480 \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=20480 \
actor_rollout_ref.rollout.name=$ENGINE \
+actor_rollout_ref.rollout.engine_kwargs.vllm.disable_mm_preprocessor_cache=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.7 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 \
actor_rollout_ref.actor.megatron.use_mbridge=True \
actor_rollout_ref.actor.megatron.param_offload=True \
actor_rollout_ref.actor.megatron.optimizer_offload=True \
actor_rollout_ref.actor.megatron.grad_offload=True \
actor_rollout_ref.ref.megatron.param_offload=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=1 \
+actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True \
+actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_router_dtype=fp32 \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_enable_deepep=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_token_dispatcher_type=flex \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full \
+actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 \
+actor_rollout_ref.actor.megatron.override_transformer_config.gradient_accumulation_fusion=True \
+actor_rollout_ref.actor.megatron.override_transformer_config.moe_permute_fusion=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.logger='["console","wandb"]' \
trainer.project_name='verl_grpo_example_geo3k' \
trainer.experiment_name='qwen3_vl_30b_megatron' \
trainer.n_gpus_per_node=8 \
trainer.nnodes=1 \
trainer.save_freq=20 \
trainer.test_freq=5 \
trainer.total_epochs=15 $@

View File

@ -0,0 +1,242 @@
# Rollout Importance Sampling (IS) Examples
This directory contains examples and documentation for using Rollout Importance Sampling to correct distribution mismatch between rollout and training policies.
**References:**
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
## Overview
Rollout Importance Sampling corrects for distribution mismatch when:
1. **Rollout generation** uses one policy (e.g., vLLM with BFloat16)
2. **Training** uses another policy (e.g., FSDP with FP32)
3. This mismatch leads to biased gradient estimates
## Quick Start
### Basic Configuration
```yaml
algorithm:
# Main control: set threshold to enable (null = disabled)
rollout_is_threshold: 2.0
# Whether to apply weights to policy loss (true) or just compute metrics (false)
rollout_is: true
rollout_is_level: token
rollout_is_mode: truncate
# IMPORTANT: Must enable log prob calculation
actor_rollout_ref:
rollout:
calculate_log_probs: true
```
### Running the Example
```bash
# Basic example with token-level truncate
bash examples/rollout_importance_sampling/run_with_rollout_is.sh
```
## Configuration Options
### Aggregation Levels (`rollout_is_level`)
| Level | Properties | Threshold Range |
|-------|-----------|-----------------|
| **token** | Per-token | 1.5 - 5.0 |
| **sequence** | Per-sequence | 2.0 - 10.0 |
| **geometric** | Geometric mean | 1.0002 - 1.001 |
### Bounding Modes (`rollout_is_mode`)
| Mode | Behavior |
|------|----------|
| **truncate** | Cap weights at upper threshold only |
| **clip** | Zero out weights outside [lower, upper] |
### Key Parameters
- `rollout_is_threshold`: Upper threshold for IS weights (null = disabled, float = enabled). **Main on/off switch.**
- `rollout_is`: Whether to apply weights to loss (true) or just compute metrics (false). Default: false.
- `rollout_is_threshold_lower`: Lower threshold (null = auto 1/upper)
- `rollout_is_veto_threshold`: Catastrophic outlier threshold (default: 1e-4)
## Configuration Examples
### Example 1: Full IS Correction (Apply Weights)
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: true # Apply to loss
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
```
### Example 2: Metrics Only (No Weight Application)
```yaml
algorithm:
rollout_is_threshold: 2.0
rollout_is: false # Compute metrics only, don't apply to loss
rollout_is_level: token
rollout_is_mode: truncate
```
### Example 3: Geometric Mean with Mask
```yaml
algorithm:
rollout_is_threshold: 1.0002
rollout_is: true
rollout_is_threshold_lower: 0.9998
rollout_is_level: geometric
rollout_is_mode: mask
rollout_is_veto_threshold: 1e-4
```
### Example 4: Sequence-level with Truncate
```yaml
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: null # Auto-reciprocal: 0.2
rollout_is_level: sequence
rollout_is_mode: truncate
rollout_is_veto_threshold: 1e-4
```
### Example 5: Asymmetric Thresholds
```yaml
algorithm:
rollout_is_threshold: 5.0
rollout_is: true
rollout_is_threshold_lower: 0.8
rollout_is_level: token
rollout_is_mode: mask
```
## Monitoring Metrics
Key metrics to watch (all prefixed with `mismatch/` in logs):
### Health Indicators
- `rollout_is_mean`: Mean IS weight across sequences
- `rollout_is_eff_sample_size`: Effective sample size after weighting
- `rollout_is_veto_fraction`: Fraction of sequences vetoed
### Distribution Metrics
- `rollout_is_max`, `rollout_is_min`: Weight extremes
- `rollout_is_std`: Standard deviation
- `rollout_is_p50`, `rollout_is_p95`, `rollout_is_p99`: Percentiles
### Diagnostic Metrics
- `rollout_is_ratio_fraction_high`: Fraction exceeding upper threshold
- `rollout_is_ratio_fraction_low`: Fraction below lower threshold
- `rollout_is_catastrophic_token_fraction`: Catastrophic tokens detected
### Mismatch Metrics (Training vs Rollout Policy)
These metrics help diagnose the distribution mismatch between rollout and training policies:
**Perplexity Metrics:**
- `mismatch_training_ppl`: Perplexity of training policy
- `mismatch_rollout_ppl`: Perplexity of rollout policy
- `mismatch_ppl_ratio`: Ratio of training PPL to rollout PPL
- `mismatch_log_ppl_diff`: Log perplexity difference
**KL Divergence Metrics:**
- `mismatch_kl`: KL divergence KL(π_rollout || π_training)
- `mismatch_k3_kl`: K3 KL estimator
## Troubleshooting
### Issue: High Variance in IS Weights
**Symptoms**: `rollout_is_std` > 1.0, `rollout_is_eff_sample_size` < 0.3
**Solutions**:
1. Switch from `sequence` to `geometric` level
2. Tighten thresholds
3. Check if rollout and training are too different
### Issue: Too Many Sequences Vetoed
**Symptoms**: `rollout_is_veto_fraction` > 0.1
**Solutions**:
1. Relax veto threshold: `rollout_is_veto_threshold: 1e-3`
2. Check for numerical issues in log prob computation
3. Verify rollout and training policies aren't completely different
### Issue: Mean IS Weight Far from 1.0
**Symptoms**: `rollout_is_mean` < 0.5 or > 2.0
**Solutions**:
1. Check that `calculate_log_probs=True` is set
2. Verify rollout_log_probs are correctly passed
3. Check for systematic bias in rollout vs training
### Issue: Too Much Data Discarded (Mask Mode)
**Symptoms**: `rollout_is_masked_fraction` > 0.5
**Solutions**:
1. Widen thresholds
2. Switch to `truncate` mode
3. Use `geometric` level for better stability
## Performance Considerations
### Memory Usage
- Rollout IS adds minimal memory overhead (~1% of model memory)
- Log-space computation prevents numerical overflow
### Computational Cost
- Token-level: ~1-2% overhead
- Sequence-level: ~2-3% overhead
- Geometric: ~2-3% overhead
## Advanced Topics
### Dual Thresholds
Specify both upper and lower explicitly:
```yaml
rollout_is_threshold: 2.0 # Upper
rollout_is_threshold_lower: 0.5 # Lower (not 1/2.0 = 0.5)
```
Or use auto-reciprocal:
```yaml
rollout_is_threshold: 2.0 # Upper = 2.0, Lower = 0.5 (auto)
rollout_is_threshold_lower: null
```
### Veto Mechanism
The veto mechanism zeros out entire sequences containing catastrophic outliers:
- If any token has ratio < `rollout_is_veto_threshold`, the entire sequence is rejected
- This prevents extreme outliers from dominating training
- Default threshold: 1e-4 (ratio 10,000x off)
- Set to `null` to disable: `rollout_is_veto_threshold: null`
## Examples
See the script in this directory:
- `run_with_rollout_is.sh`: Basic example with token-level truncate mode
## References
- Implementation: `verl/trainer/ppo/mismatch_helper.py`
- Core algorithm: `verl/trainer/ppo/core_algos.py`
- Paper: "Your Efficient RL Framework Secretly Brings You Off-Policy RL Training"

View File

@ -0,0 +1,99 @@
#!/usr/bin/env bash
# Example: Basic PPO training with Rollout Importance Sampling
# This demonstrates the standard setup for correcting distribution mismatch
set -xeuo pipefail
# ==============================================================================
# Rollout Importance Sampling Configuration
# ==============================================================================
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
rollout_is_threshold=2.0
# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only
rollout_is=true
# Lower threshold (null = auto-reciprocal, i.e., 1/upper = 0.5)
rollout_is_threshold_lower=null
# Aggregation level: token | sequence | geometric (experimental)
rollout_is_level=token
# Bounding mode: truncate (cap upper) | mask (zero outside bounds)
rollout_is_mode=truncate
# Catastrophic outlier veto threshold
rollout_is_veto_threshold=1e-4
# ==============================================================================
# Model and Data Configuration
# ==============================================================================
MODEL_PATH=${MODEL_PATH:-"Qwen/Qwen2.5-7B"}
TRAIN_FILE=${TRAIN_FILE:-"data/train.parquet"}
TEST_FILE=${TEST_FILE:-"data/test.parquet"}
max_prompt_length=512
max_response_length=1024
# ==============================================================================
# Training Configuration
# ==============================================================================
train_batch_size=128
ppo_mini_batch_size=32
ppo_epochs=1
learning_rate=5e-7
# ==============================================================================
# Algorithm Configuration
# ==============================================================================
adv_estimator=gae
gamma=1.0
lam=0.95
# ==============================================================================
# Launch Training
# ==============================================================================
python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_batch_size} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.gamma=${gamma} \
algorithm.lam=${lam} \
algorithm.rollout_is=${rollout_is} \
algorithm.rollout_is_threshold=${rollout_is_threshold} \
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
algorithm.rollout_is_level=${rollout_is_level} \
algorithm.rollout_is_mode=${rollout_is_mode} \
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=${learning_rate} \
actor_rollout_ref.actor.ppo_mini_batch_size=${ppo_mini_batch_size} \
actor_rollout_ref.actor.ppo_epochs=${ppo_epochs} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.name=vllm \
trainer.logger='["console","wandb"]' \
trainer.project_name="rollout_is_example" \
trainer.experiment_name="basic_token_truncate" \
trainer.total_epochs=10
echo "Training completed!"
echo ""
echo "Rollout IS Configuration:"
echo " - Threshold: ${rollout_is_threshold}"
echo " - Apply to loss: ${rollout_is}"
echo " - Level: ${rollout_is_level}"
echo " - Mode: ${rollout_is_mode}"
echo ""
echo "Monitor these key metrics in wandb:"
echo " - mismatch/rollout_is_mean (should be ~1.0)"
echo " - mismatch/rollout_is_eff_sample_size (should be >0.5)"
echo " - mismatch/rollout_is_veto_fraction (should be <0.1)"

View File

@ -0,0 +1,23 @@
hydra:
searchpath:
- file://verl/trainer/config
defaults:
- ppo_trainer
- _self_
data:
max_prompt_length: 1024
max_response_length: 1024
train_batch_size: 256
return_raw_chat: True
shuffle: False
actor_rollout_ref:
hybrid_engine: True
rollout:
name: sglang
multi_turn:
enable: True
max_assistant_turns: 2
format: qwen

View File

@ -51,7 +51,7 @@ actor_rollout_ref:
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
fsdp_config:
wrap_policy:
@ -105,7 +105,7 @@ critic:
lr: 1e-5
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null # only useful for warmup with cosine
warmup_style: constant # select from constant/cosine
lr_scheduler_type: constant # select from constant/cosine
total_training_steps: -1 # must be override by program
model:
path: ~/models/deepseek-llm-7b-chat

View File

@ -304,6 +304,11 @@ class RayDAPOTrainer(RayPPOTrainer):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
with marked_timer("adv", timing_raw, "brown"):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)

View File

@ -1,8 +1,13 @@
#!/usr/bin/env bash
set -xeuo pipefail
# Rollout Importance Sampling Example
# References:
# - When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
# - Off-policy RL: https://fengyao.notion.site/off-policy-rl
project_name='DAPO'
exp_name='DAPO-Qwen2.5-32B-TIS' # Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
exp_name='DAPO-Qwen2.5-32B-RolloutIS' # Rollout Importance Sampling
adv_estimator=grpo
@ -10,7 +15,14 @@ use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
tis_imp_ratio_cap=2.0
# Rollout Importance Sampling parameters (matches original TIS with threshold=2)
rollout_is=True
rollout_is_threshold=2.0
rollout_is_threshold_lower=null # No lower bound (original TIS behavior)
rollout_is_level=token # token-level (original TIS behavior)
rollout_is_mode=truncate # truncate mode (original TIS behavior)
rollout_is_veto_threshold=null # No veto (original TIS behavior)
clip_ratio_low=0.2
clip_ratio_high=0.28
@ -58,14 +70,17 @@ offload=True
gen_tp=4
# Truncated Importance Sampling (TIS) -> https://fengyao.notion.site/off-policy-rl
# Please note that server mode(agent loop) hasn't return rollout_log_probs for now.
# so currently, server mode is not supported for TIS.
# To turn on TIS, you need to set the following parameters. Note 2.0 is a hyper-parameter and can be tuned.
# actor_rollout_ref.actor.tis_imp_ratio_cap=2.0
# actor_rollout_ref.rollout.calculate_log_probs=True
# Rollout Importance Sampling (corrects distribution mismatch between rollout and training)
#
# Please note that server mode (agent loop) hasn't returned rollout_log_probs for now,
# so currently server mode is not supported for Rollout IS.
#
# Rollout IS parameters (configured at top of script):
# algorithm.rollout_is=True
# algorithm.rollout_is_threshold=2.0 # Upper threshold (can be tuned)
# algorithm.rollout_is_level=token # Aggregation level
# algorithm.rollout_is_mode=truncate # Bounding mode
# actor_rollout_ref.rollout.calculate_log_probs=True # Required!
ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
--working-dir "${WORKING_DIR}" \
@ -109,7 +124,12 @@ ray job submit --no-wait --runtime-env="${RUNTIME_ENV}" \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.tis_imp_ratio_cap=${tis_imp_ratio_cap} \
algorithm.rollout_is=${rollout_is} \
algorithm.rollout_is_threshold=${rollout_is_threshold} \
algorithm.rollout_is_threshold_lower=${rollout_is_threshold_lower} \
algorithm.rollout_is_level=${rollout_is_level} \
algorithm.rollout_is_mode=${rollout_is_mode} \
algorithm.rollout_is_veto_threshold=${rollout_is_veto_threshold} \
actor_rollout_ref.rollout.calculate_log_probs=True \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \

View File

@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -100,7 +100,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -103,7 +103,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -99,7 +99,7 @@ HYDRA_FULL_ERROR=1 python -m recipe.entropy.main_entropy \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.weight_decay=0 \
actor_rollout_ref.actor.optim.warmup_style=constant \
actor_rollout_ref.actor.optim.lr_scheduler_type=constant \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.ppo_micro_batch_size=${train_micro_batch_size} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \

View File

@ -293,6 +293,6 @@ python3 -m recipe.one_step_off_policy.async_main_ppo \
| Category | Support Situation |
|--------------------|-----------------------------------------------------------------------------------------------------------------|
| train engine | FSDP2 <br/> Megatron |
| rollout engine | vLLM |
| rollout engine | vLLM <br/> SGLang |
| AdvantageEstimator | GRPO <br/> GRPO_PASSK <br/> REINFORCE_PLUS_PLUS <br/> RLOO <br/> OPO <br/> REINFORCE_PLUS_PLUS_BASELINE<br/>GPG |
| Reward | all |

View File

@ -0,0 +1,140 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-one-step-off-4-12'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
ref_offload=True
actor_offload=False
gen_tp=2
sp_size=4
fsdp_size=2
python3 -m recipe.one_step_off_policy.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.hybrid_engine=False \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${actor_offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${actor_offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.ref.fsdp_config.param_offload=${ref_offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}"

View File

@ -0,0 +1,133 @@
#!/usr/bin/env bash
set -xeuo pipefail
project_name='DAPO'
exp_name='DAPO-Qwen2.5-7b-MATH-0527a1-fsdp2-sglang-colocate'
adv_estimator=grpo
use_kl_in_reward=False
kl_coef=0.0
use_kl_loss=False
kl_loss_coef=0.0
clip_ratio_low=0.2
clip_ratio_high=0.28
max_prompt_length=$((1024 * 2))
max_response_length=$((1024 * 8))
enable_overlong_buffer=True
overlong_buffer_len=$((1024 * 4))
overlong_penalty_factor=1.0
loss_agg_mode="token-mean"
train_prompt_bsz=512
n_resp_per_prompt=12
train_prompt_mini_bsz=32
# Ray
# RAY_ADDRESS=${RAY_ADDRESS:-"http://localhost:8265"}
# WORKING_DIR=${WORKING_DIR:-"${PWD}"}
# RUNTIME_ENV=${RUNTIME_ENV:-"${WORKING_DIR}/verl/trainer/runtime_env.yaml"}
NNODES=${NNODES:-2}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
# very important! please modify the max_position_embeddings in config.json to 32768 after downloading from huggingface
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen2.5-Math-7B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/dapo-math-17k.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/aime-2024.parquet"}
# Algorithm
temperature=1.0
top_p=1.0
top_k=-1 # 0 for HF rollout, -1 for vLLM rollout
val_top_p=0.7
# Performance Related Parameter
use_dynamic_bsz=True
actor_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 2))
infer_ppo_max_token_len=$(((max_prompt_length + max_response_length) * 3))
offload=True
gen_tp=2
sp_size=4
fsdp_size=2
# reference run wandb: https://wandb.ai/verl-org/DAPO%20Reproduction%20on%20verl/runs/ow47vvon?nw=nwusertongyuxuan361
python3 -m verl.trainer.main_ppo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.prompt_key=prompt \
data.truncation='left' \
data.max_prompt_length=${max_prompt_length} \
data.max_response_length=${max_response_length} \
data.train_batch_size=${train_prompt_bsz} \
actor_rollout_ref.rollout.n=${n_resp_per_prompt} \
algorithm.adv_estimator=${adv_estimator} \
algorithm.use_kl_in_reward=${use_kl_in_reward} \
algorithm.kl_ctrl.kl_coef=${kl_coef} \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.actor.use_kl_loss=${use_kl_loss} \
actor_rollout_ref.actor.kl_loss_coef=${kl_loss_coef} \
actor_rollout_ref.actor.clip_ratio_low=${clip_ratio_low} \
actor_rollout_ref.actor.clip_ratio_high=${clip_ratio_high} \
actor_rollout_ref.actor.clip_ratio_c=10.0 \
actor_rollout_ref.model.use_remove_padding=True \
+actor_rollout_ref.model.override_config.max_position_embeddings=32768 \
actor_rollout_ref.actor.use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.ref.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=${use_dynamic_bsz} \
actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${actor_ppo_max_token_len} \
actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${infer_ppo_max_token_len} \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.actor.optim.lr_warmup_steps=10 \
actor_rollout_ref.actor.optim.weight_decay=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=${train_prompt_mini_bsz} \
actor_rollout_ref.actor.fsdp_config.param_offload=${offload} \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=${offload} \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.actor.grad_clip=1.0 \
actor_rollout_ref.actor.loss_agg_mode=${loss_agg_mode} \
actor_rollout_ref.actor.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.rollout.gpu_memory_utilization=0.80 \
actor_rollout_ref.rollout.tensor_model_parallel_size=${gen_tp} \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.max_num_batched_tokens=$((max_prompt_length + max_response_length)) \
actor_rollout_ref.rollout.temperature=${temperature} \
actor_rollout_ref.rollout.top_p=${top_p} \
actor_rollout_ref.rollout.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.temperature=${temperature} \
actor_rollout_ref.rollout.val_kwargs.top_p=${val_top_p} \
actor_rollout_ref.rollout.val_kwargs.top_k=${top_k} \
actor_rollout_ref.rollout.val_kwargs.do_sample=True \
actor_rollout_ref.rollout.val_kwargs.n=1 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.ref.fsdp_config.param_offload=${offload} \
actor_rollout_ref.ref.ulysses_sequence_parallel_size=${sp_size} \
actor_rollout_ref.actor.fsdp_config.fsdp_size=${fsdp_size} \
reward_model.reward_manager=dapo \
+reward_model.reward_kwargs.overlong_buffer_cfg.enable=${enable_overlong_buffer} \
+reward_model.reward_kwargs.overlong_buffer_cfg.len=${overlong_buffer_len} \
+reward_model.reward_kwargs.overlong_buffer_cfg.penalty_factor=${overlong_penalty_factor} \
+reward_model.reward_kwargs.overlong_buffer_cfg.log=False \
+reward_model.reward_kwargs.max_resp_len=${max_response_length} \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
trainer.nnodes="${NNODES}" \
trainer.val_before_train=True \
trainer.test_freq=10 \
trainer.save_freq=-1 \
trainer.total_epochs=10 \
trainer.total_training_steps=100 \
trainer.default_local_dir="${CKPTS_DIR}" \
trainer.resume_mode=auto \
trainer.log_val_generations=10

View File

@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import logging
import os
@ -83,13 +84,20 @@ class ActorRolloutRefWorker(ARRWorker):
assert hasattr(self, "_weights_info") and self._weights_info is not None
params = self._get_actor_params() if self._is_actor else None
rollout_name = self.config.rollout.name
if self._is_rollout:
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
if rollout_name == "vllm":
inference_model = (
self.rollout.inference_engine.llm_engine.model_executor.driver_worker.worker.model_runner.model
)
from verl.utils.vllm.patch import patch_vllm_moe_model_weight_loader
patch_vllm_moe_model_weight_loader(inference_model)
patch_vllm_moe_model_weight_loader(inference_model)
elif rollout_name == "sglang":
inference_model = self.rollout._engine
else:
raise NotImplementedError(f"Unknown rollout name: {rollout_name}")
loop = asyncio.get_event_loop()
for key, shape, dtype in self._weights_info:
tensor = torch.empty(shape, dtype=dtype, device=get_torch_device().current_device())
if self._is_actor:
@ -102,7 +110,23 @@ class ActorRolloutRefWorker(ARRWorker):
self._weight_sync_group.broadcast(tensor, src=0, stream=get_torch_device().current_stream())
if self._is_rollout:
inference_model.load_weights([(key, tensor)])
if rollout_name == "vllm":
inference_model.load_weights([(key, tensor)])
elif rollout_name == "sglang":
loop.run_until_complete(self.update_weights(inference_model, [(key, tensor)]))
async def update_weights(self, inference_engine, params):
from sglang.srt.weight_sync.utils import update_weights as sgl_update_weights
await sgl_update_weights(
engine=inference_engine,
params_batch=params,
device_mesh_key="infer_tp",
device_mesh=self.rollout_device_mesh,
)
if self.rollout_device_mesh["infer_tp"].get_local_rank() == 0:
await inference_engine.flush_cache()
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def get_actor_weights_info(self):
@ -209,6 +233,7 @@ class RolloutWorker(ActorRolloutRefWorker):
rollout_device_mesh = init_device_mesh(
device_name, mesh_shape=(dp, infer_tp), mesh_dim_names=["dp", "infer_tp"]
)
self.rollout_device_mesh = rollout_device_mesh
is_collect = rollout_device_mesh["infer_tp"].get_local_rank() == 0
self._register_dispatch_collect_info(
@ -216,7 +241,8 @@ class RolloutWorker(ActorRolloutRefWorker):
)
rollout_name = self.config.rollout.name
assert rollout_name == "vllm"
if rollout_name not in ("vllm", "sglang"):
raise NotImplementedError(f"rollout_name: {rollout_name} is not supported")
rollout_config: RolloutConfig = omega_conf_to_dataclass(self.config.rollout)
model_config: HFModelConfig = omega_conf_to_dataclass(self.config.model, dataclass_type=HFModelConfig)
@ -227,14 +253,23 @@ class RolloutWorker(ActorRolloutRefWorker):
config=rollout_config, model_config=model_config, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage(f"After building {rollout_name} rollout", logger=logger)
from .vllm_sharding_manager import VLLMShardingManager
rollout_sharding_manager = VLLMShardingManager(
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
)
if rollout_name == "vllm":
from .vllm_sharding_manager import VLLMShardingManager
log_gpu_memory_usage("After building sharding manager", logger=logger)
rollout_sharding_manager = VLLMShardingManager(
inference_engine=rollout.inference_engine, device_mesh=rollout_device_mesh
)
log_gpu_memory_usage("After building sharding manager", logger=logger)
elif rollout_name == "sglang":
from .sglang_sharding_manager import SGLangShardingManager
rollout_sharding_manager = SGLangShardingManager(device_mesh=rollout_device_mesh)
log_gpu_memory_usage("After building sharding manager", logger=logger)
self.model_config = model_config
self.rollout = rollout
self.rollout_sharding_manager = rollout_sharding_manager

View File

@ -0,0 +1,65 @@
set -x
project_name='GRPO'
exp_name='GRPO-Qwen3-0.6b-gsm8k-fsdp2-sglang-one-step-off-2-6'
# Paths
RAY_DATA_HOME=${RAY_DATA_HOME:-"${HOME}/verl"}
MODEL_PATH=${MODEL_PATH:-"${RAY_DATA_HOME}/models/Qwen3-0.6B"}
CKPTS_DIR=${CKPTS_DIR:-"${RAY_DATA_HOME}/ckpts/${project_name}/${exp_name}"}
TRAIN_FILE=${TRAIN_FILE:-"${RAY_DATA_HOME}/data/gsm8k/train.parquet"}
TEST_FILE=${TEST_FILE:-"${RAY_DATA_HOME}/data/gsm8k/test.parquet"}
NNODES=${NNODES:-1}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
n_gpus_rollout=2
n_gpus_training=$((NGPUS_PER_NODE - n_gpus_rollout))
python3 -m recipe.one_step_off_policy.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files="${TRAIN_FILE}" \
data.val_files="${TEST_FILE}" \
data.train_batch_size=1152 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
data.filter_overlong_prompts=True \
data.truncation='error' \
actor_rollout_ref.actor.strategy=fsdp2 \
critic.strategy=fsdp2 \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.hybrid_engine=False \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=192 \
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=32 \
actor_rollout_ref.actor.use_kl_loss=True \
actor_rollout_ref.actor.kl_loss_coef=0.001 \
actor_rollout_ref.actor.kl_loss_type=low_var_kl \
actor_rollout_ref.actor.entropy_coeff=0 \
actor_rollout_ref.model.enable_gradient_checkpointing=True \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=sglang \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.rollout.load_format=safetensors \
actor_rollout_ref.rollout.layered_summon=True \
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=32 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
algorithm.use_kl_in_reward=False \
trainer.critic_warmup=0 \
trainer.val_before_train=True \
trainer.logger=['console','tensorboard'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.save_freq=-1 \
trainer.test_freq=5 \
trainer.total_epochs=2 \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${n_gpus_training}" \
rollout.nnodes="${NNODES}" \
rollout.n_gpus_per_node="${n_gpus_rollout}" $@

View File

@ -577,6 +577,11 @@ class OneStepOffRayTrainer(RayPPOTrainer):
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# Compute rollout IS weights and mismatch metrics (inherited from RayPPOTrainer)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(

View File

@ -0,0 +1,70 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
# Copyright 2025 Meituan Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from torch.distributed.device_mesh import DeviceMesh
from verl import DataProto
from verl.protocol import all_gather_data_proto
from verl.utils.debug import GPUMemoryLogger
from verl.utils.device import get_torch_device
from verl.utils.torch_functional import check_device_is_available
from verl.workers.sharding_manager.base import BaseShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class SGLangShardingManager(BaseShardingManager):
@check_device_is_available()
def __init__(self, device_mesh: DeviceMesh):
self.device_mesh = device_mesh
self.tp_size = self.device_mesh["infer_tp"].size()
self.tp_rank = self.device_mesh["infer_tp"].get_local_rank()
self.timing = {}
gen_dp_rank = self.device_mesh["dp"].get_local_rank()
get_torch_device().manual_seed(gen_dp_rank + 1000)
self.gen_random_states = get_torch_device().get_rng_state()
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def __enter__(self):
get_torch_device().set_rng_state(self.gen_random_states)
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().empty_cache()
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def preprocess_data(self, data: DataProto) -> DataProto:
"""All gather across tp group to make each rank has identical input."""
if self.tp_size == 1:
return data
# TODO: Current impl doesn't consider FSDP with torch micro-dp
group = self.device_mesh["infer_tp"].get_group()
all_gather_data_proto(data=data, process_group=group)
return data
@GPUMemoryLogger(role="vllm sharding_manager", logger=logger)
def postprocess_data(self, data: DataProto) -> DataProto:
"""Get chunk data of this tp rank since we do all gather in preprocess."""
if self.tp_size == 1:
return data
return data.chunk(chunks=self.tp_size)[self.tp_rank]

View File

@ -0,0 +1,55 @@
# Open math reasoning
## Introduction
In this recipe, we perform SFT on the [open math reasoning](https://huggingface.co/datasets/nvidia/OpenMathReasoning) dataset using the new SFT trainer with backend agostic model engine. Note that our goal is not to replicate the [AIMO-2 Winning Solution](https://arxiv.org/abs/2504.16891) work, but to demonstrate a SFT demo from end to end.
Note that you may need to modify the path as needed in the following scripts.
## Dataset Preprocessing
### Download Dataset
```bash
hf download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* --local-dir /path/to/dataset/nvidia/OpenMathReasoning
hf download math-ai/aime24 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime24
hf download math-ai/aime25 --repo-type dataset --local-dir /path/to/dataset/math-ai/aime25
```
### Preprocess the dataset
```bash
python3 recipe/open_math_reasoning/prepare_nvidia-OpenMathReasoning_sft.py --local_dataset_path /path/to/nvidia/OpenMathReasoning --local_save_dir /path/to/open_math_reasoning
```
### Prepare the eval dataset
```bash
python3 recipe/open_math_reasoning/prepare_eval_dataset.py --local_dataset_path /path/to/dataset --local_save_dir /path/to/eval_dataset
```
## Train the model using SFT
### FSDP backend
export CKPT_HOME=/path/to/ckpt
export BACKEND=fsdp2
export MODEL_ID=Qwen/Qwen3-8B-Base
export TRAIN_FILES=/path/to/open_math_reasoning/cot_dataset.parquet
bash recipe/open_math_reasoning/run_sft_qwen3_8b.sh
### Megatron backend
TODO
## Eval the model
### Merge checkpoint into huggingface format
```bash
python -m verl.model_merger merge --backend fsdp --local_dir /path/to/ckpt/global_step_19751 --target_dir /path/to/ckpt/global_step_19751/huggingface
```
### Generate the responses
```bash
export MODEL_PATH=/path/to/ckpt/global_step_19751/huggingface
bash recipe/open_math_reasoning/run_generation.sh
```
### Evaluate the responses
```bash
bash recipe/open_math_reasoning/run_eval.sh
```
You should see the results like:
```python
{'test_score/aime24': 0.584375, 'test_score/aime25': 0.43333333333333335}
```

View File

@ -0,0 +1,22 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def compute_score_data_source(data_source, response, ground_truth):
from verl.utils.reward_score.math_reward import compute_score
if data_source in ["aime24", "aime25"]:
return compute_score(response, ground_truth)
else:
raise ValueError(f"Unknown data source: {data_source}")

View File

@ -0,0 +1,96 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# prepare eval dataset including AIME'24, AIME'25
# hf download math-ai/aime24 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime24
# hf download math-ai/aime25 --repo-type dataset --local-dir /opt/tiger/datasets/math-ai/aime25
import os
import datasets
from verl.utils.reward_score.math_reward import remove_boxed
instruction_following = "Please reason step by step, and put your final answer within \\boxed{}."
def make_map_fn(data_source):
def process_fn(example, idx):
question_raw = example.pop("problem")
question = question_raw + " " + instruction_following
if "solution" not in example:
example["solution"] = example["answer"]
answer_raw = example.pop("solution")
example.clear()
try:
solution = remove_boxed(answer_raw)
except Exception:
solution = answer_raw
data = {
"data_source": data_source,
"prompt": [
{
"role": "user",
"content": question,
}
],
"ability": "math",
"reward_model": {"style": "rule", "ground_truth": solution},
"extra_info": {
"index": idx,
"answer": answer_raw,
"question": question_raw,
},
}
return data
return process_fn
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
parser.add_argument(
"--local_save_dir", default="~/data/math-ai", help="The save directory for the preprocessed dataset."
)
args = parser.parse_args()
if args.local_dataset_path is not None:
aime24_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime24")
aime25_dataset_path = os.path.join(args.local_dataset_path, "math-ai/aime25")
else:
aime24_dataset_path = "math-ai/aime24"
aime25_dataset_path = "math-ai/aime25"
aime24_dataset = datasets.load_dataset(aime24_dataset_path, split="test")
aime25_dataset = datasets.load_dataset(aime25_dataset_path, split="test")
aime24_dataset = aime24_dataset.map(function=make_map_fn("aime24"), with_indices=True)
aime25_dataset = aime25_dataset.map(function=make_map_fn("aime25"), with_indices=True)
local_save_dir = os.path.expanduser(args.local_save_dir)
os.makedirs(local_save_dir, exist_ok=True)
aime24_dataset.to_parquet(os.path.join(local_save_dir, "aime24_test.parquet"))
aime25_dataset.to_parquet(os.path.join(local_save_dir, "aime25_test.parquet"))

View File

@ -0,0 +1,72 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
--local-dir /path/to/nvidia/OpenMathReasoning
huggingface-cli download nvidia/OpenMathReasoning --repo-type dataset --include data/cot* \
--local-dir /opt/tiger/nvidia/OpenMathReasoning
"""
import argparse
import os
import datasets
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--local_dataset_path", default=None, help="The local path to the raw dataset, if it exists.")
parser.add_argument(
"--local_save_dir",
default="~/data/open_math_reasoning",
help="The save directory for the preprocessed dataset.",
)
args = parser.parse_args()
local_dataset_path = args.local_dataset_path
data_source = "nvidia/OpenMathReasoning"
if local_dataset_path is not None:
dataset = datasets.load_dataset(local_dataset_path, split="cot")
else:
dataset = datasets.load_dataset(data_source, split="cot")
def make_map_fn(split):
def process_fn(example, idx):
question = example.pop("problem")
solution = example.pop("generated_solution")
extra_info = {}
for key, value in example.items():
extra_info[key] = value
example.clear()
data = {
"messages": [
{"role": "user", "content": question, "loss_mask": 0},
{"role": "assistant", "content": solution, "loss_mask": 1},
],
"extra_info": extra_info,
}
return data
return process_fn
# filter out data where the problem_type is not has_answer_extracted
dataset = dataset.filter(lambda example: example["problem_type"] == "has_answer_extracted")
dataset = dataset.map(function=make_map_fn("cot"), with_indices=True)
local_save_dir = os.path.expanduser(args.local_save_dir)
os.makedirs(local_save_dir, exist_ok=True)
dataset.to_parquet(os.path.join(local_save_dir, "cot_dataset.parquet"))

View File

@ -0,0 +1,7 @@
#!/usr/bin/env bash
# Evaluation
python3 -m verl.trainer.main_eval \
data.path=$HOME/data/gen/qwen_8b_gen_test.parquet \
custom_reward_function.path=recipe/open_math_reasoning/compute_score.py \
custom_reward_function.name=compute_score_data_source

View File

@ -0,0 +1,32 @@
#!/usr/bin/env bash
MODEL_PATH=${MODEL_PATH:-/path/to/ckpt/global_step_19751/huggingface}
NGPUS_PER_NODE=${NGPUS_PER_NODE:-8}
NNODES=${NNODES:-1}
OUTPUT_PATH=${OUTPUT_PATH:-$HOME/data/gen/qwen_8b_gen_test.parquet}
GEN_TP=${GEN_TP:-1} # Default tensor parallel size to 2
aime24_test_path=${HOME}/data/math-ai/aime24_test.parquet
aime25_test_path=${HOME}/data/math-ai/aime25_test.parquet
train_files="['$aime24_test_path', '$aime25_test_path']"
python3 -m verl.trainer.main_generation_server \
trainer.nnodes="${NNODES}" \
trainer.n_gpus_per_node="${NGPUS_PER_NODE}" \
actor_rollout_ref.model.path="${MODEL_PATH}" \
actor_rollout_ref.model.trust_remote_code=True \
actor_rollout_ref.rollout.temperature=1.0 \
actor_rollout_ref.rollout.top_p=0.7 \
actor_rollout_ref.rollout.prompt_length=2048 \
actor_rollout_ref.rollout.response_length=20480 \
actor_rollout_ref.rollout.tensor_model_parallel_size="${GEN_TP}" \
actor_rollout_ref.rollout.gpu_memory_utilization=0.9 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.n=32 \
data.train_files="$train_files" \
data.prompt_key=prompt \
+data.output_path="${OUTPUT_PATH}" \

View File

@ -0,0 +1,94 @@
#!/usr/bin/env bash
set -xeuo pipefail
ENTRYPOINT=${ENTRYPOINT:-"-m verl.trainer.sft_trainer"}
TRAIN_FILES=${TRAIN_FILES:-/path/to/cot_dataset.parquet}
backend=${BACKEND:-fsdp}
project_name=verl_sft_test
RESUME_MODE=auto
MODEL_ID=${MODEL_ID:-Qwen/Qwen3-8B-Base}
SP_SIZE=${SP_SIZE:-8}
FSDP_SIZE=${FSDP_SIZE:-16}
FSDP_STRATEGY=${FSDP_STRATEGY:-"fsdp2"}
TP_SIZE=${TP_SIZE:-1}
PP_SIZE=${PP_SIZE:-1}
VPP_SIZE=${VPP_SIZE:-null}
CP_SIZE=${CP_SIZE:-1}
PAD_MODE=${PAD_MODE:-no_padding}
USE_REMOVE_PADDING=${USE_REMOVE_PADDING:-True}
FSDP_ENGINE_CONFIG="\
engine=${backend} \
optim=${backend} \
optim.lr=2e-5 \
optim.lr_warmup_steps_ratio=0.01 \
optim.weight_decay=0.1 \
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.min_lr_ratio=0.1 \
optim.warmup_style=cosine \
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
engine.strategy=${FSDP_STRATEGY} \
engine.fsdp_size=${FSDP_SIZE}"
MEGATRON_ENGINE_CONFIG="\
engine=${backend} \
optim=${backend} \
optim.lr=1e-5 \
optim.lr_warmup_steps_ratio=0.2 \
optim.weight_decay=0.1 \
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.lr_warmup_init=0 \
optim.lr_decay_style=cosine \
optim.min_lr=1e-6 \
engine.tensor_model_parallel_size=${TP_SIZE} \
engine.pipeline_model_parallel_size=${PP_SIZE} \
engine.virtual_pipeline_model_parallel_size=${VPP_SIZE} \
engine.context_parallel_size=${CP_SIZE}"
if [ "$backend" = "fsdp" ]; then
ENGINE_CONFIG="$FSDP_ENGINE_CONFIG"
echo "Using fsdp engine"
exp_name=nvidia-openmathreasoning-qwen3-8b-${backend}-${FSDP_STRATEGY}-sp${SP_SIZE}-fsdp-1008a1
else
ENGINE_CONFIG="$MEGATRON_ENGINE_CONFIG"
echo "Using megatron engine"
exp_name=nvidia-openmathreasoning-${backend}-tp${TP_SIZE}-pp${PP_SIZE}-vpp${VPP_SIZE}-cp${CP_SIZE}-pad-${PAD_MODE}-use_remove_padding-${USE_REMOVE_PADDING}
fi
CKPT_HOME=${CKPT_HOME:-$HOME/open_verl/sft/${project_name}/${exp_name}}
mkdir -p "${CKPT_HOME}"
torchrun --standalone --nnodes=1 --nproc-per-node=${NUM_TRAINERS:-8} \
${ENTRYPOINT} \
data.train_files="${TRAIN_FILES}" \
data.train_batch_size=96 \
data.max_length=32768 \
data.pad_mode=${PAD_MODE} \
data.truncation=error \
data.use_dynamic_bsz=True \
data.max_token_len_per_gpu=65536 \
data.messages_key=messages \
model.path=$MODEL_ID \
model.use_remove_padding=${USE_REMOVE_PADDING} \
${ENGINE_CONFIG} \
trainer.test_freq=-1 \
trainer.save_freq=4000 \
trainer.logger=['console','wandb'] \
trainer.project_name="${project_name}" \
trainer.experiment_name="${exp_name}" \
trainer.total_epochs=1 \
trainer.default_local_dir="${CKPT_HOME}" \
trainer.resume_mode=${RESUME_MODE} \
trainer.max_ckpt_to_keep=5 \
checkpoint.save_contents=[model,optimizer,extra]

View File

@ -48,7 +48,8 @@ reward_model:
lr_warmup_steps: -1 # Prioritized. Negative values mean delegating to lr_warmup_steps_ratio.
lr_warmup_steps_ratio: 0. # the total steps will be injected during runtime
min_lr_ratio: null
warmup_style: constant
warmup_style: null # deprecated
lr_scheduler_type: constant
total_training_steps: -1 # must be overridden by program
weight_decay: 0.
grad_clip: 10.0

View File

@ -24,7 +24,7 @@ import ray
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from transformers import AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForTokenClassification, Qwen3Config, Qwen3MoeConfig
from verl import DataProto
from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup
@ -289,8 +289,9 @@ def _worker(rank: int, world_size: int, rendezvous_file: str, strategy: str, mod
world_size=world_size,
)
ref_model_config = AutoConfig.from_pretrained(model_path)
with torch.device("meta"):
ref_model = AutoModelForCausalLM.from_pretrained(model_path)
ref_model = AutoModelForCausalLM.from_config(ref_model_config)
from verl.workers.engine import BaseEngine, EngineRegistry

View File

@ -42,7 +42,7 @@ FSDP_ENGINE_CONFIG="\
optim.betas="[0.9,0.95]" \
optim.clip_grad=1.0 \
optim.min_lr_ratio=0.1 \
optim.warmup_style=cosine \
optim.lr_scheduler_type=cosine \
engine.ulysses_sequence_parallel_size=${SP_SIZE} \
engine.strategy=${FSDP_STRATEGY} \
engine.fsdp_size=${FSDP_SIZE}"

View File

@ -301,8 +301,8 @@ actor_rollout_ref:
# Number of cosine cycles in LR schedule
num_cycles: 0.5
# LR warmup style: "constant" or "cosine"
warmup_style: constant
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
# Total training steps (must be overridden at runtime)
total_training_steps: -1
@ -605,8 +605,8 @@ critic:
# Minimum LR ratio for cosine schedule
min_lr_ratio: 0.0
# LR warmup style: "constant" or "cosine"
warmup_style: constant
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
# Total training steps (must be overridden at runtime)
total_training_steps: -1

View File

@ -0,0 +1,289 @@
#!/usr/bin/env python3
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Quick Sanity Test for Rollout Importance Sampling
This is a standalone test script that can be run without pytest to quickly verify
the rollout IS implementation is working correctly. For comprehensive integration
tests, see: tests/trainer/ppo/test_rollout_is_integration.py
Usage:
python test_rollout_is.py
This tests:
- Basic rollout IS functionality (3 levels, 2 modes)
- Metrics completeness (32 total: 21 IS + 11 mismatch metrics)
- Veto mechanism
- Edge cases
"""
import torch
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
def test_basic_rollout_is():
"""Test basic rollout IS functionality."""
print("Testing basic rollout IS functionality...")
# Create test data
batch_size, seq_length = 4, 10
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create slightly different log probs (simulating BF16 vs FP32 mismatch)
old_log_prob = torch.randn(batch_size, seq_length, device=device)
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.1
eos_mask = torch.ones(batch_size, seq_length, device=device)
# Test token-level truncate mode (equivalent to old TIS)
print("\n1. Testing token-level truncate mode...")
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
weights = weights_proto.batch["rollout_is_weights"]
print(f" Weights shape: {weights.shape}")
print(f" Mean weight: {metrics['mismatch/rollout_is_mean']:.4f}")
print(f" Max weight: {metrics['mismatch/rollout_is_max']:.4f}")
print(f" Min weight: {metrics['mismatch/rollout_is_min']:.4f}")
print(f" Veto fraction: {metrics['mismatch/rollout_is_veto_fraction']:.4f}")
assert weights.shape == old_log_prob.shape
assert weights.max() <= 2.0, "Weights should be capped at threshold"
print(" ✓ Token-level truncate mode passed")
# Test sequence-level mode
print("\n2. Testing sequence-level mode...")
weights_seq_proto, metrics_seq = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="sequence",
rollout_is_mode="truncate",
rollout_is_threshold=5.0,
rollout_is_veto_threshold=1e-4,
)
weights_seq = weights_seq_proto.batch["rollout_is_weights"]
print(f" Mean weight: {metrics_seq['mismatch/rollout_is_mean']:.4f}")
print(f" Effective sample size: {metrics_seq['mismatch/rollout_is_eff_sample_size']:.4f}")
# Check that all tokens in a sequence have the same weight
for i in range(batch_size):
seq_weights = weights_seq[i, eos_mask[i].bool()]
assert torch.allclose(seq_weights, seq_weights[0]), "All tokens in sequence should have same weight"
print(" ✓ Sequence-level mode passed")
# Test geometric mean mode
print("\n3. Testing geometric mean mode...")
weights_geo_proto, metrics_geo = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="geometric",
rollout_is_mode="mask",
rollout_is_threshold=1.5,
rollout_is_threshold_lower=0.5,
rollout_is_veto_threshold=1e-4,
)
print(f" Mean weight: {metrics_geo['mismatch/rollout_is_mean']:.4f}")
print(f" Masked fraction: {metrics_geo['mismatch/rollout_is_masked_fraction']:.4f}")
print(" ✓ Geometric mean mode passed")
# Test veto mechanism
print("\n4. Testing veto mechanism...")
# Create data with catastrophic outliers
old_log_prob_veto = torch.randn(2, 5, device=device)
rollout_log_prob_veto = old_log_prob_veto.clone()
# Make one token have catastrophically low ratio
rollout_log_prob_veto[0, 2] = old_log_prob_veto[0, 2] + 15.0 # ratio ~= 3e-7
eos_mask_veto = torch.ones(2, 5, device=device)
weights_veto_proto, metrics_veto = compute_rollout_importance_weights(
old_log_prob=old_log_prob_veto,
rollout_log_prob=rollout_log_prob_veto,
response_mask=eos_mask_veto,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
weights_veto = weights_veto_proto.batch["rollout_is_weights"]
print(f" Veto fraction: {metrics_veto['mismatch/rollout_is_veto_fraction']:.4f}")
# Check that the sequence with catastrophic token has all weights zeroed
assert weights_veto[0].sum() == 0, "Sequence with catastrophic token should be vetoed"
assert weights_veto[1].sum() > 0, "Normal sequence should not be vetoed"
print(" ✓ Veto mechanism passed")
# Test disabled IS (threshold=None)
print("\n5. Testing disabled IS...")
weights_disabled, metrics_disabled = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_threshold=None,
)
assert weights_disabled is None, "Should return None when threshold is None"
assert len(metrics_disabled) == 0, "Should return empty metrics when disabled"
print(" ✓ Disabled IS passed")
print("\n✓ All tests passed!")
def test_metrics_completeness():
"""Test that all expected metrics are returned."""
print("\nTesting metrics completeness...")
batch_size, seq_length = 3, 8
device = "cuda" if torch.cuda.is_available() else "cpu"
old_log_prob = torch.randn(batch_size, seq_length, device=device)
rollout_log_prob = old_log_prob + torch.randn(batch_size, seq_length, device=device) * 0.2
eos_mask = torch.ones(batch_size, seq_length, device=device)
_, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=eos_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.5,
)
# Expected IS metrics
expected_is_metrics = [
"mismatch/rollout_is_mean",
"mismatch/rollout_is_max",
"mismatch/rollout_is_min",
"mismatch/rollout_is_std",
"mismatch/rollout_is_eff_sample_size",
"mismatch/rollout_is_veto_fraction",
"mismatch/rollout_is_catastrophic_token_fraction",
"mismatch/rollout_is_ratio_fraction_high",
"mismatch/rollout_is_ratio_fraction_low",
"mismatch/rollout_is_p25",
"mismatch/rollout_is_p50",
"mismatch/rollout_is_p75",
"mismatch/rollout_is_p95",
"mismatch/rollout_is_p99",
]
# Expected mismatch/diagnostic metrics (also included now)
expected_mismatch_metrics = [
"mismatch/mismatch_training_ppl",
"mismatch/mismatch_training_log_ppl",
"mismatch/mismatch_kl",
"mismatch/mismatch_k3_kl",
"mismatch/mismatch_rollout_ppl",
"mismatch/mismatch_rollout_log_ppl",
"mismatch/mismatch_log_ppl_diff",
"mismatch/mismatch_log_ppl_abs_diff",
"mismatch/mismatch_log_ppl_diff_max",
"mismatch/mismatch_log_ppl_diff_min",
"mismatch/mismatch_ppl_ratio",
]
expected_metrics = expected_is_metrics + expected_mismatch_metrics
missing_metrics = [m for m in expected_metrics if m not in metrics]
if missing_metrics:
print(f" ✗ Missing metrics: {missing_metrics}")
return False
print(f" ✓ All {len(expected_metrics)} expected metrics present")
print(f" Total metrics returned: {len(metrics)}")
return True
def test_mismatch_metrics():
"""Test mismatch metrics computation."""
print("\nTesting mismatch metrics computation...")
batch_size, seq_length = 4, 12
device = "cuda" if torch.cuda.is_available() else "cpu"
# Create test data with some mismatch
old_log_prob = torch.randn(batch_size, seq_length, device=device) - 2.0 # training policy
rollout_log_prob = torch.randn(batch_size, seq_length, device=device) - 1.5 # rollout policy (more confident)
response_mask = torch.ones(batch_size, seq_length, device=device)
# Test with rollout log probs
metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
)
expected_metrics = [
"mismatch_training_ppl",
"mismatch_training_log_ppl",
"mismatch_kl",
"mismatch_k3_kl",
"mismatch_rollout_ppl",
"mismatch_rollout_log_ppl",
"mismatch_log_ppl_diff",
"mismatch_log_ppl_abs_diff",
"mismatch_log_ppl_diff_max",
"mismatch_log_ppl_diff_min",
"mismatch_ppl_ratio",
]
for metric in expected_metrics:
assert metric in metrics, f"Missing metric: {metric}"
print(f" Training PPL: {metrics['mismatch_training_ppl']:.4f}")
print(f" Rollout PPL: {metrics['mismatch_rollout_ppl']:.4f}")
print(f" KL divergence: {metrics['mismatch_kl']:.6f}")
print(f" K3 KL: {metrics['mismatch_k3_kl']:.6f}")
print(f" PPL ratio: {metrics['mismatch_ppl_ratio']:.4f}")
print(f" ✓ All {len(expected_metrics)} mismatch metrics present")
# Test without rollout log probs
metrics_no_rollout = compute_mismatch_metrics(
old_log_prob=old_log_prob,
rollout_log_prob=None,
response_mask=response_mask,
)
assert "mismatch_training_ppl" in metrics_no_rollout
assert "mismatch_rollout_ppl" not in metrics_no_rollout
print(" ✓ Mismatch metrics work without rollout log probs")
if __name__ == "__main__":
print("=" * 60)
print("Rollout Importance Sampling Test Suite")
print("=" * 60)
try:
test_basic_rollout_is()
test_metrics_completeness()
test_mismatch_metrics()
print("\n" + "=" * 60)
print("ALL TESTS PASSED ✓")
print("=" * 60)
except Exception as e:
print(f"\n✗ Test failed with error: {e}")
import traceback
traceback.print_exc()
exit(1)

View File

@ -0,0 +1,241 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Integration tests for Rollout Importance Sampling."""
import pytest
import torch
from verl.trainer.ppo.core_algos import compute_policy_loss_vanilla
from verl.trainer.ppo.mismatch_helper import compute_mismatch_metrics, compute_rollout_importance_weights
from verl.workers.config.actor import ActorConfig
class TestRolloutISIntegration:
"""Integration tests for Rollout IS with PPO."""
@pytest.fixture
def sample_data(self):
"""Create sample training data."""
batch_size, seq_length = 4, 16
device = "cuda" if torch.cuda.is_available() else "cpu"
return {
"old_log_prob": torch.randn(batch_size, seq_length, device=device),
"log_prob": torch.randn(batch_size, seq_length, device=device),
"rollout_log_prob": torch.randn(batch_size, seq_length, device=device),
"advantages": torch.randn(batch_size, seq_length, device=device),
"response_mask": torch.ones(batch_size, seq_length, device=device),
}
@pytest.fixture
def config_with_rollout_is(self):
"""Create config for policy loss computation.
Note: rollout_is config has been moved to algorithm config.
This config only needs fields used by policy loss (clip_ratio, etc).
"""
config = ActorConfig(
strategy="fsdp",
rollout_n=1,
ppo_micro_batch_size=2,
clip_ratio=0.2,
)
return config
def test_policy_loss_with_rollout_is(self, sample_data, config_with_rollout_is):
"""Test that policy loss computation works with rollout IS weights.
Note: In production, IS weights are computed centrally in the trainer
(before advantage computation) and passed to policy loss.
This test simulates that workflow.
"""
# First compute IS weights (as trainer would do centrally)
rollout_is_weights_proto, _ = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
# Policy loss function receives pre-computed IS weights
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = compute_policy_loss_vanilla(
old_log_prob=sample_data["old_log_prob"],
log_prob=sample_data["log_prob"],
advantages=sample_data["advantages"],
response_mask=sample_data["response_mask"],
loss_agg_mode="token-mean",
config=config_with_rollout_is,
rollout_is_weights=rollout_is_weights,
)
# Check loss is valid
assert isinstance(pg_loss, torch.Tensor)
assert pg_loss.ndim == 0 # Scalar
assert not torch.isnan(pg_loss)
assert not torch.isinf(pg_loss)
def test_rollout_is_weights_computation(self, sample_data):
"""Test rollout IS weights and metrics computation."""
weights_proto, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# Check weights
from verl.protocol import DataProto
assert isinstance(weights_proto, DataProto)
weights = weights_proto.batch["rollout_is_weights"]
assert isinstance(weights, torch.Tensor)
assert weights.shape == sample_data["old_log_prob"].shape
# Check metrics are returned
assert isinstance(metrics, dict)
assert len(metrics) > 0
assert "mismatch/rollout_is_mean" in metrics
def test_all_aggregation_levels(self, sample_data):
"""Test all three aggregation levels."""
levels = ["token", "sequence", "geometric"]
for level in levels:
_, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level=level,
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
)
assert "mismatch/rollout_is_mean" in metrics
def test_both_bounding_modes(self, sample_data):
"""Test both truncate and mask modes."""
modes = ["truncate", "mask"]
for mode in modes:
_, metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode=mode,
rollout_is_threshold=2.0,
rollout_is_threshold_lower=0.5,
)
assert "mismatch/rollout_is_mean" in metrics
def test_mismatch_metrics(self, sample_data):
"""Test mismatch diagnostic metrics computation."""
metrics = compute_mismatch_metrics(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
)
# Check key metrics are present
assert "mismatch_training_ppl" in metrics
assert "mismatch_rollout_ppl" in metrics
assert "mismatch_kl" in metrics
assert isinstance(metrics["mismatch_kl"], float)
def test_veto_mechanism(self):
"""Test veto mechanism with catastrophic outliers."""
batch_size, seq_length = 2, 5
device = "cuda" if torch.cuda.is_available() else "cpu"
old_log_prob = torch.randn(batch_size, seq_length, device=device)
rollout_log_prob = old_log_prob.clone()
# Create catastrophic outlier in first sequence
rollout_log_prob[0, 2] += 15.0 # Makes ratio ~3e-7
response_mask = torch.ones(batch_size, seq_length, device=device)
_, metrics = compute_rollout_importance_weights(
old_log_prob=old_log_prob,
rollout_log_prob=rollout_log_prob,
response_mask=response_mask,
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
rollout_is_veto_threshold=1e-4,
)
# Should have vetoed one sequence
assert metrics["mismatch/rollout_is_veto_fraction"] > 0
assert metrics["mismatch/rollout_is_veto_fraction"] <= 1.0
def test_metrics_only_mode(self, sample_data, config_with_rollout_is):
"""Test metrics-only mode: compute IS weights/metrics but don't apply to loss.
This tests the use case where rollout_is_threshold is set (enables computation)
but rollout_is=False (disables weight application to policy loss).
"""
# Compute IS weights (as trainer would do)
rollout_is_weights_proto, is_metrics = compute_rollout_importance_weights(
old_log_prob=sample_data["old_log_prob"],
rollout_log_prob=sample_data["rollout_log_prob"],
response_mask=sample_data["response_mask"],
rollout_is_level="token",
rollout_is_mode="truncate",
rollout_is_threshold=2.0,
)
# Metrics should be computed
assert len(is_metrics) > 0
assert "mismatch/rollout_is_mean" in is_metrics
# In metrics-only mode, we compute loss WITHOUT applying weights
# (simulating rollout_is=False)
pg_loss_no_weights, _, _, _ = compute_policy_loss_vanilla(
old_log_prob=sample_data["old_log_prob"],
log_prob=sample_data["log_prob"],
advantages=sample_data["advantages"],
response_mask=sample_data["response_mask"],
loss_agg_mode="token-mean",
config=config_with_rollout_is,
rollout_is_weights=None, # Don't apply weights
)
# Compare to loss WITH weights (rollout_is=True)
rollout_is_weights = rollout_is_weights_proto.batch["rollout_is_weights"]
pg_loss_with_weights, _, _, _ = compute_policy_loss_vanilla(
old_log_prob=sample_data["old_log_prob"],
log_prob=sample_data["log_prob"],
advantages=sample_data["advantages"],
response_mask=sample_data["response_mask"],
loss_agg_mode="token-mean",
config=config_with_rollout_is,
rollout_is_weights=rollout_is_weights,
)
# Losses should be different (weights have an effect)
assert not torch.allclose(pg_loss_no_weights, pg_loss_with_weights)
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])

View File

@ -21,15 +21,24 @@ class TestFSDPOptimizerConfigCPU:
def test_default_configuration(self):
config = FSDPOptimizerConfig(lr=0.1)
assert config.min_lr_ratio is None
assert config.warmup_style == "constant"
assert config.lr_scheduler_type == "constant"
assert config.num_cycles == 0.5
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
def test_valid_warmup_styles(self, warmup_style):
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
assert config.warmup_style == warmup_style
@pytest.mark.parametrize("lr_scheduler_type", ["constant", "cosine"])
def test_valid_lr_scheduler_types(self, lr_scheduler_type):
config = FSDPOptimizerConfig(lr_scheduler_type=lr_scheduler_type, lr=0.1)
assert config.lr_scheduler_type == lr_scheduler_type
def test_invalid_warmup_style(self):
@pytest.mark.parametrize("warmup_style", ["constant", "cosine"])
def test_valid_warmup_style_types(self, warmup_style):
config = FSDPOptimizerConfig(warmup_style=warmup_style, lr=0.1)
assert config.lr_scheduler_type == warmup_style
def test_invalid_lr_scheduler_type(self):
with pytest.raises((ValueError, AssertionError)):
FSDPOptimizerConfig(lr_scheduler_type="invalid_style", lr=0.1)
def test_invalid_warmup_style_type(self):
with pytest.raises((ValueError, AssertionError)):
FSDPOptimizerConfig(warmup_style="invalid_style", lr=0.1)

View File

@ -113,7 +113,7 @@ def gptmodel_forward_qwen2_5_vl(
output_orig = model(
input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids,
position_ids=None, # model will calculate position_ids
packed_seq_params=packed_seq_params,
pixel_values=pixel_values,
image_grid_thw=image_grid_thw,

View File

@ -74,6 +74,7 @@ class SupportedModel(Enum):
GLM4_MOE = "Glm4MoeForCausalLM"
QWEN3_TOKEN_CLASSIFICATION = "Qwen3ForTokenClassification"
QWEN3_MOE_VL = "Qwen3VLMoeForConditionalGeneration"
# Registry for model configuration converters
@ -118,6 +119,7 @@ MODEL_FORWARD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.QWEN3: gptmodel_forward,
SupportedModel.QWEN3_MOE: gptmodel_forward,
SupportedModel.QWEN2_5_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_qwen2_5_vl,
SupportedModel.DEEPSEEK_V3: gptmodel_forward,
SupportedModel.GLM4_MOE: gptmodel_forward,
SupportedModel.QWEN3_TOKEN_CLASSIFICATION: gptmodel_forward,
@ -131,6 +133,7 @@ MODEL_FORWARD_NOPAD_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.MIXTRAL: gptmodel_forward_no_padding,
SupportedModel.DEEPSEEK_V3: gptmodel_forward_no_padding,
SupportedModel.QWEN2_5_VL: gptmodel_forward_no_padding,
SupportedModel.QWEN3_MOE_VL: gptmodel_forward_no_padding,
SupportedModel.LLAMA4: gptmodel_forward_no_padding,
SupportedModel.QWEN3: gptmodel_forward_no_padding,
SupportedModel.QWEN3_MOE: gptmodel_forward_no_padding,
@ -148,6 +151,7 @@ MODEL_FORWARD_FUSED_REGISTRY: dict[SupportedModel, Callable] = {
SupportedModel.MIXTRAL: fused_forward_gptmodel,
SupportedModel.DEEPSEEK_V3: fused_forward_gptmodel,
SupportedModel.QWEN2_5_VL: fused_forward_qwen2_5_vl,
SupportedModel.QWEN3_MOE_VL: fused_forward_qwen2_5_vl,
SupportedModel.LLAMA4: fused_forward_gptmodel,
SupportedModel.QWEN3: fused_forward_gptmodel,
SupportedModel.QWEN3_MOE: fused_forward_gptmodel,

View File

@ -127,6 +127,8 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
def ulysses_wrapped_decoder_forward(self, *args, **kwargs):
inputs_embeds = kwargs.get("inputs_embeds")
position_ids = kwargs.get("position_ids")
visual_pos_masks = kwargs.get("visual_pos_masks")
deepstack_visual_embeds = kwargs.get("deepstack_visual_embeds")
call_kwargs = kwargs.copy()
current_ulysses_sp_size = get_ulysses_sequence_parallel_world_size()
@ -139,6 +141,43 @@ def patch_vlm_for_ulysses_input_slicing(model_class: type):
if slice_now:
call_kwargs["inputs_embeds"] = slice_input_tensor(inputs_embeds, dim=1, padding=False)
call_kwargs["position_ids"] = slice_input_tensor(position_ids, dim=-1, padding=False)
# Also slice visual_pos_masks and deepstack_visual_embeds for Qwen3 VL models
if visual_pos_masks is not None:
original_visual_mask = visual_pos_masks
sliced_visual_mask = slice_input_tensor(visual_pos_masks, dim=1, padding=False)
call_kwargs["visual_pos_masks"] = sliced_visual_mask
if deepstack_visual_embeds is not None:
sliced_embeds = []
num_visual_before = original_visual_mask.sum().item()
num_visual_in_shard = sliced_visual_mask.sum().item()
if num_visual_in_shard > 0 and num_visual_before > 0:
# Calculate which visual embeddings belong to this shard
# We need to find the offset of visual tokens in this shard
from verl.utils.ulysses import get_ulysses_sequence_parallel_rank
rank = get_ulysses_sequence_parallel_rank()
seq_len = original_visual_mask.shape[1]
local_seq_len = seq_len // current_ulysses_sp_size
start_idx = rank * local_seq_len
end_idx = start_idx + local_seq_len
# Get total visual tokens before and up to the end of the shard's sequence slice
# This correctly handles batches by summing across all samples
visual_start = original_visual_mask[:, :start_idx].sum().item() if start_idx > 0 else 0
visual_end = original_visual_mask[:, :end_idx].sum().item()
# Slice each tensor in deepstack_visual_embeds
for embed in deepstack_visual_embeds:
sliced_embeds.append(embed[visual_start:visual_end])
else:
# No visual tokens in this shard, create empty tensors to maintain gradient flow
for embed in deepstack_visual_embeds:
sliced_embeds.append(embed[:0])
call_kwargs["deepstack_visual_embeds"] = sliced_embeds
self._needs_initial_slice = False
try:
return original_forward(self, *args, **call_kwargs)
@ -290,9 +329,7 @@ def apply_monkey_patch(
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
Qwen2_5_VLFlashAttention2 as Qwen2_5_VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import (
Qwen2VLFlashAttention2 as Qwen2VLAttention,
)
from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLFlashAttention2 as Qwen2VLAttention
if use_remove_padding or ulysses_sp_size > 1:
from verl.models.transformers.qwen2_vl import qwen2_vl_attn_forward

View File

@ -209,8 +209,10 @@ def _get_input_embeds(
patch_dim = config.in_channels * config.temporal_patch_size * config.patch_size**2
pixel_values = torch.zeros((16, patch_dim), dtype=inputs_embeds.dtype, device=inputs_embeds.device)
image_grid_thw = torch.tensor([[1, 4, 4]], dtype=torch.long, device=inputs_embeds.device)
image_embeds, _ = model.visual(pixel_values, grid_thw=image_grid_thw)
image_embeds, dummy_deepstack_image_embeds = model.visual(pixel_values, grid_thw=image_grid_thw)
inputs_embeds += 0.0 * image_embeds.mean()
for emb in dummy_deepstack_image_embeds or []:
inputs_embeds += 0.0 * emb.mean()
if attention_mask is not None:
attention_mask = attention_mask.to(inputs_embeds.device)

View File

@ -75,7 +75,6 @@ actor_rollout_ref:
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
tis_imp_ratio_cap: -1
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
@ -484,6 +483,12 @@ algorithm:
pf_ppo:
reweight_method: pow
weight_pow: 2.0
rollout_is_threshold: null
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 0.0001
rollout_is: false
trainer:
balance_batch: true
total_epochs: 30

View File

@ -18,7 +18,8 @@ actor_rollout_ref:
clip_grad: 1.0
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
lr_scheduler_type: constant
warmup_style: null
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
wrap_policy:
@ -59,7 +60,6 @@ actor_rollout_ref:
clip_ratio_c: 3.0
loss_agg_mode: token-mean
entropy_coeff: 0
tis_imp_ratio_cap: -1
use_kl_loss: false
use_torch_compile: true
kl_loss_coef: 0.001
@ -315,7 +315,8 @@ critic:
clip_grad: 1.0
min_lr_ratio: 0.0
num_cycles: 0.5
warmup_style: constant
lr_scheduler_type: constant
warmup_style: null
model:
fsdp_config:
_target_: verl.workers.config.FSDPEngineConfig
@ -462,6 +463,12 @@ algorithm:
pf_ppo:
reweight_method: pow
weight_pow: 2.0
rollout_is_threshold: null
rollout_is_threshold_lower: null
rollout_is_level: token
rollout_is_mode: truncate
rollout_is_veto_threshold: 0.0001
rollout_is: false
trainer:
balance_batch: true
total_epochs: 30

View File

@ -74,10 +74,6 @@ loss_agg_mode: token-mean
# Entropy regularization coefficient in PPO loss
entropy_coeff: 0
# Truncated Importance Sampling (TIS): https://fengyao.notion.site/off-policy-rl
# the truncation value C of truncated Importance Sampling (-1 for disable TIS)
tis_imp_ratio_cap: -1
# Whether to use KL loss instead of KL reward penalty. True for GRPO
use_kl_loss: false

View File

@ -73,6 +73,14 @@ class AlgoConfig(BaseConfig):
use_pf_ppo (bool): Whether to enable preference feedback PPO.
pf_ppo (dict[str, Any]): Preference feedback PPO settings.
filter_groups (Optional[FilterGroupsConfig]): Filter groups configuration, used in DAPO and Entropy
rollout_is_threshold (Optional[float]): Upper threshold for IS weights. null = disabled,
float value = enabled (compute weights and metrics). This is the main on/off switch.
rollout_is_threshold_lower (Optional[float]): Lower threshold for IS weights. If None, defaults to 1/upper.
rollout_is_level (str): Aggregation level: "token", "sequence", or "geometric".
rollout_is_mode (str): Bounding mode: "truncate" (cap upper only) or "mask" (zero outside bounds).
rollout_is_veto_threshold (float): Per-token veto threshold for catastrophic outliers.
rollout_is (bool): Whether to apply IS weights to policy loss. True = apply weights,
False = compute metrics only (useful for monitoring before enabling correction). Default: False.
"""
gamma: float = 1.0
@ -85,3 +93,13 @@ class AlgoConfig(BaseConfig):
use_pf_ppo: bool = False
pf_ppo: dict[str, Any] = field(default_factory=dict)
filter_groups: Optional[FilterGroupsConfig] = None
# Rollout Importance Sampling (replaces legacy tis_imp_ratio_cap)
# Controls computation of IS weights and mismatch metrics
rollout_is_threshold: Optional[float] = None # null = disabled, float = enabled
rollout_is_threshold_lower: Optional[float] = None
rollout_is_level: str = "token"
rollout_is_mode: str = "truncate"
rollout_is_veto_threshold: Optional[float] = 1e-4
# Controls whether to apply IS weights to policy loss (only if rollout_is_threshold is set)
# True = apply weights to loss, False = compute metrics only (no weight application)
rollout_is: bool = False

View File

@ -28,6 +28,8 @@ min_lr_ratio: 0.0
# Number of cosine cycles in LR schedule
num_cycles: 0.5
# LR warmup style: "constant" or "cosine"
warmup_style: constant
# LR scheduler type: "constant" or "cosine"
lr_scheduler_type: constant
# deprecated
warmup_style: null

View File

@ -73,6 +73,28 @@ algorithm:
reweight_method: pow # ["pow", "max_min", "max_random"]
weight_pow: 2.0
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
rollout_is_threshold: null
# Lower threshold for IS weights (null = auto-reciprocal of upper)
rollout_is_threshold_lower: null
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate
# Per-token veto threshold for catastrophic outliers
rollout_is_veto_threshold: 1e-4
# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only (no weight application)
# Useful for monitoring mismatch before enabling correction
rollout_is: false
trainer:
balance_batch: True
total_epochs: 30

View File

@ -113,6 +113,28 @@ algorithm:
# Power used for weight scaling in "pow" method
weight_pow: 2.0
# Rollout Importance Sampling: corrects distribution mismatch between rollout and training policies
# Main control: Upper threshold for IS weights (null = disabled, float = enabled)
# When enabled, computes IS weights and mismatch metrics (KL, PPL, etc.)
rollout_is_threshold: null
# Lower threshold for IS weights (null = auto-reciprocal of upper)
rollout_is_threshold_lower: null
# Aggregation level: "token" (biased), "sequence" (unbiased), "geometric" (experimental)
rollout_is_level: token
# Bounding mode: "truncate" (cap upper only), "mask" (zero outside bounds)
rollout_is_mode: truncate
# Per-token veto threshold for catastrophic outliers
rollout_is_veto_threshold: 1e-4
# Whether to apply IS weights to policy loss
# true = apply weights to loss, false = compute metrics only (no weight application)
# Useful for monitoring mismatch before enabling correction
rollout_is: false
# config for the trainer
trainer:

View File

@ -18,7 +18,7 @@ data:
max_token_len_per_gpu: 8192
use_dynamic_bsz: True
train_files: ~/data/gsm8k/train.parquet
val_files: ~/data/gsm8k/test.parquet
val_files: null
# Multi-turn settings
messages_key: messages # Key for messages list in multi-turn mode
tools_key: tools # Key for tools list in multi-turn mode

View File

@ -31,7 +31,8 @@ from verl.utils.fs import copy_to_local
@ray.remote
def process_item(reward_fn, data_source, response_lst, reward_data):
def process_item(config, data_source, response_lst, reward_data):
reward_fn = get_custom_reward_fn(config)
ground_truth = reward_data["ground_truth"]
score_lst = [reward_fn(data_source, r, ground_truth) for r in response_lst]
return data_source, np.mean(score_lst)
@ -53,11 +54,9 @@ def main(config):
# evaluate test_score based on data source
data_source_reward = defaultdict(list)
compute_score = get_custom_reward_fn(config)
# Create remote tasks
remote_tasks = [
process_item.remote(compute_score, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
process_item.remote(config, data_sources[i], responses[i], reward_model_data[i]) for i in range(total)
]
# Process results as they come in

View File

@ -17,6 +17,7 @@ Generate responses given a dataset of prompts
import os
import aiohttp
import hydra
import numpy as np
import ray
@ -30,31 +31,12 @@ from pprint import pprint
import pandas as pd
from omegaconf import OmegaConf
from openai import AsyncOpenAI
from openai.types.chat import ChatCompletion
from verl.utils.hdfs_io import makedirs
from verl.workers.rollout.replica import get_rollout_replica_class
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
run_generation(config)
def run_generation(config) -> None:
if not ray.is_initialized():
# this is for local ray cluster
default_runtime_env = {"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}}
ray_init_kwargs = config.ray_kwargs.get("ray_init", {})
runtime_env_kwargs = ray_init_kwargs.get("runtime_env", {})
runtime_env = OmegaConf.merge(default_runtime_env, runtime_env_kwargs)
ray_init_kwargs = OmegaConf.create({**ray_init_kwargs, "runtime_env": runtime_env})
print(f"ray init kwargs: {ray_init_kwargs}")
ray.init(**OmegaConf.to_container(ray_init_kwargs))
ray.get(main_task.remote(config))
async def start_server(config):
tp_size = config.actor_rollout_ref.rollout.tensor_model_parallel_size
num_replicas = (config.trainer.n_gpus_per_node * config.trainer.nnodes) // tp_size
@ -81,23 +63,42 @@ async def start_server(config):
return server_handles, server_addresses
async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):
# here we should sample n_samples for each chat_lst
client = AsyncOpenAI(
api_key="123-abc",
base_url=f"http://{server_address}/v1",
)
async def submit_request(server_address, **chat_complete_request):
try:
extra_headers = chat_complete_request.pop("extra_headers", {})
timeout = aiohttp.ClientTimeout(total=None)
session = aiohttp.ClientSession(timeout=timeout)
async with session.post(
url=f"http://{server_address}/v1/chat/completions",
headers={"Authorization": "Bearer token-abc123", **extra_headers},
json=chat_complete_request,
) as resp:
data = await resp.json()
return ChatCompletion(**data)
finally:
await session.close()
tasks = [
client.chat.completions.create(
model=model_path,
messages=messages,
async def generate_per_replica(server_address, model_path: str, n_samples: int, sampling_params: dict, chat_lst: list):
# here we should sample n_samples for each chat_lst.
# we use aiohttp to avoid hang in AsyncOpenAI when the number of requests is large.
# client = AsyncOpenAI(
# api_key="123-abc",
# base_url=f"http://{server_address}/v1",
# )
chat_complete_request = [
{
"model": model_path,
"messages": messages,
**sampling_params,
)
}
for messages in chat_lst
for _ in range(n_samples)
]
tasks = [submit_request(server_address, **req) for req in chat_complete_request]
results = await asyncio.gather(*tasks)
return results
@ -118,8 +119,10 @@ async def generate(
return results
@ray.remote(num_cpus=1)
def main_task(config):
@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
ray.init(runtime_env={"env_vars": {"TOKENIZERS_PARALLELISM": "true", "NCCL_DEBUG": "WARN", "VLLM_USE_V1": "1"}})
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
OmegaConf.resolve(config)
@ -136,8 +139,21 @@ def main_task(config):
"max_tokens": config.actor_rollout_ref.rollout.response_length,
}
from omegaconf import ListConfig
train_files = config.data.train_files
if not isinstance(train_files, list | ListConfig):
train_files = [train_files]
# read dataset. Note that the dataset should directly contain chat template format (e.g., a list of dictionary)
dataset = pd.read_parquet(config.data.train_files)
datasets = []
for train_file in train_files:
dataset = pd.read_parquet(train_file)
datasets.append(dataset)
# concat dataset
dataset = pd.concat(datasets, axis=0, ignore_index=True)
chat_lst = dataset[config.data.prompt_key].tolist()
chat_lst = [chat.tolist() for chat in chat_lst]
chat_numpy = np.array(chat_lst)
@ -151,7 +167,6 @@ def main_task(config):
)
# reshape results into a numpy array
import itertools
results = list(itertools.chain.from_iterable(gen_results))
@ -170,6 +185,7 @@ def main_task(config):
# write to a new parquet
output_dir = os.path.dirname(config.data.output_path)
makedirs(output_dir, exist_ok=True)
print(f"Saving results to {config.data.output_path}")
dataset.to_parquet(config.data.output_path)

View File

@ -881,7 +881,7 @@ def compute_policy_loss(
return pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower
@register_policy_loss("vanilla")
@register_policy_loss("vanilla") # type: ignore[arg-type]
def compute_policy_loss_vanilla(
old_log_prob: torch.Tensor,
log_prob: torch.Tensor,
@ -889,7 +889,7 @@ def compute_policy_loss_vanilla(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for PPO.
@ -959,11 +959,9 @@ def compute_policy_loss_vanilla(
pg_losses = torch.where(advantages < 0, clip_pg_losses2, clip_pg_losses1)
if config.tis_imp_ratio_cap > 0 and rollout_log_probs is not None:
# Apply truncated importance sampling -> https://fengyao.notion.site/off-policy-rl
tis_imp_ratio = torch.exp(old_log_prob - rollout_log_probs)
tis_imp_ratio = torch.clamp(tis_imp_ratio, max=config.tis_imp_ratio_cap)
pg_losses = pg_losses * tis_imp_ratio
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
@ -978,7 +976,7 @@ def compute_policy_loss_gspo(
response_mask: torch.Tensor,
loss_agg_mode: str = "seq-mean-token-mean",
config: Optional[DictConfig | ActorConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GSPO.
@ -1024,6 +1022,10 @@ def compute_policy_loss_gspo(
pg_losses2 = -advantages * torch.clamp(seq_importance_ratio, 1 - clip_ratio_low, 1 + clip_ratio_high)
pg_losses = torch.maximum(pg_losses1, pg_losses2)
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
# for GSPO, we need to aggregate the loss at the sequence level (seq-mean-token-mean)
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode="seq-mean-token-mean")
@ -1044,7 +1046,7 @@ def compute_policy_loss_gpg(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Adapted from
https://github.com/AMAP-ML/GPG/blob/main/VisualThinker-R1-Zero/src/open-r1-multimodal/src/open_r1/trainer/grpo_trainer.py#L495
@ -1061,6 +1063,10 @@ def compute_policy_loss_gpg(
"""
pg_losses = -log_prob * advantages
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), torch.tensor(0.0), torch.tensor(0.0)
@ -1073,7 +1079,7 @@ def compute_policy_loss_clip_cov(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
@ -1155,6 +1161,11 @@ def compute_policy_loss_clip_cov(
pg_clipfrac = verl_F.masked_mean((corr == 0).float(), response_mask)
pg_losses = torch.maximum(pg_losses1, pg_losses2) * corr
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, pg_clipfrac, ppo_kl, torch.tensor(0.0)
@ -1168,7 +1179,7 @@ def compute_policy_loss_kl_cov(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for Clip-Cov.
@ -1227,6 +1238,10 @@ def compute_policy_loss_kl_cov(
large_cov_idxs // advantages.shape[1], large_cov_idxs % advantages.shape[1]
]
# Apply rollout importance sampling weights if provided
if rollout_is_weights is not None:
pg_losses = pg_losses * rollout_is_weights
pg_loss = agg_loss(loss_mat=pg_losses, loss_mask=response_mask, loss_agg_mode=loss_agg_mode)
return pg_loss, torch.tensor(0.0), ppo_kl_abs, torch.tensor(0.0)
@ -1240,7 +1255,7 @@ def compute_policy_loss_geo_mean(
response_mask: torch.Tensor,
loss_agg_mode: str = "token-mean",
config: Optional[DictConfig | AlgoConfig] = None,
rollout_log_probs: torch.Tensor | None = None,
rollout_is_weights: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute the clipped policy objective and related metrics for GMPO.
@ -1293,6 +1308,17 @@ def compute_policy_loss_geo_mean(
# otherwise, below would be not consistent with the paper
advantage = (advantages * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
pg_losses = -advantage * ratio
# Apply rollout importance sampling weights if provided
# For geo_mean, IS weights are 2D (batch_size, seq_length) and need to be aggregated to sequence level
if rollout_is_weights is not None:
# Aggregate token-level weights to sequence level using geometric mean for consistency
# Note: rollout_is_weights is always 2D regardless of rollout_is_level
seq_is_weights = torch.exp(
(torch.log(rollout_is_weights + 1e-10) * response_mask).sum(dim=-1) / (response_mask_sum + 1e-8)
)
pg_losses = pg_losses * seq_is_weights
pg_loss = torch.mean(pg_losses)
# higher: ratio is too large that need clamp to clip_high (when adv > 0)

View File

@ -0,0 +1,459 @@
# Copyright 2025 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Rollout Importance Sampling (IS) Helper Module
This module handles importance sampling weight computation for correcting
distribution mismatch between rollout policy (e.g., vLLM BFloat16) and
training policy (e.g., FSDP FP32).
Key Features:
1. Three aggregation levels: token, sequence, geometric
2. Two handling modes: truncate (TIS), mask (MIS)
3. Per-token veto mechanism for catastrophic outliers
4. Memory-efficient computation to prevent CUDA OOM
5. Comprehensive metrics tracking
Usage Notes:
- compute_rollout_importance_weights() computes both IS weights and mismatch metrics
- Used in ray_trainer.py via compute_rollout_importance_weights_and_add_to_batch()
- Also used in dp_actor.py for distributed worker computations
- compute_mismatch_metrics() is called internally by compute_rollout_importance_weights()
References:
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
- Off-policy RL: https://fengyao.notion.site/off-policy-rl
"""
from typing import Any, Optional
import torch
import verl.utils.torch_functional as verl_F
from verl.protocol import DataProto
def compute_rollout_importance_weights(
old_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
response_mask: torch.Tensor,
rollout_is_level: str = "token",
rollout_is_mode: str = "truncate",
rollout_is_threshold: Optional[float] = None,
rollout_is_threshold_lower: Optional[float] = None,
rollout_is_veto_threshold: Optional[float] = 1e-4,
) -> tuple[Optional[DataProto], dict[str, Any]]:
"""Compute importance sampling weights and metrics for rollout-training mismatch correction.
This function handles the computation of importance sampling (IS) weights to correct
for the distribution mismatch between rollout policy and training policy.
Reference:
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
Memory-efficient implementation that prevents CUDA OOM by:
- Using log-space computation where possible
- Applying safety bounds to prevent numerical overflow
- Computing metrics without creating huge intermediate tensors
Args:
old_log_prob: Log probabilities from training policy (e.g., FSDP), shape (batch_size, seq_length)
rollout_log_prob: Log probabilities from rollout policy (e.g., vLLM), shape (batch_size, seq_length)
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
rollout_is_level: Level of IS aggregation:
- "token": Per-token ratios (biased)
- "sequence": Product of ratios (unbiased)
- "geometric": Geometric mean of ratios (experimental)
rollout_is_mode: How to handle weights exceeding threshold:
- "truncate": Cap weights at upper_threshold only (TIS)
- "mask": Zero out weights outside [lower_threshold, upper_threshold] (MIS)
rollout_is_threshold: Upper threshold for IS weights
rollout_is_threshold_lower: Lower threshold for IS weights (mask mode only; if None, defaults to 1/upper)
rollout_is_veto_threshold: Per-token veto threshold. If any token ratio < this, zero entire sequence.
If None, veto mechanism is disabled.
Returns:
Tuple of (weights_proto, metrics) where:
weights_proto: DataProto containing IS weights with key "rollout_is_weights",
shape (batch_size, seq_length). Returns None if rollout_is_threshold is None.
metrics: Dictionary of IS statistics and mismatch metrics (KL, PPL, etc.),
all converted to scalars and prefixed with "mismatch/"
"""
if rollout_is_threshold is None:
return None, {}
# Parse thresholds: if lower not specified, use 1/upper (reciprocal)
upper_threshold = rollout_is_threshold
if rollout_is_threshold_lower is not None:
lower_threshold = rollout_is_threshold_lower
else:
# Default: lower = 1/upper (reciprocal)
lower_threshold = 1.0 / upper_threshold
# Step 1: Compute raw importance weights based on the specified level
log_ratio = old_log_prob - rollout_log_prob
# Pre-compute log thresholds
device = old_log_prob.device
log_threshold_upper = torch.log(torch.tensor(upper_threshold, device=device))
log_threshold_lower = torch.log(torch.tensor(lower_threshold, device=device))
# Safety bound to prevent numerical overflow (exp(20) ≈ 485M)
SAFETY_BOUND = 20.0
# Store unclamped values in log-space for accurate metrics
if rollout_is_level == "token":
# Token-level IS: π_train(a|s) / π_rollout(a|s) per token
log_ratio_for_metrics = log_ratio
# Apply safety bound to prevent overflow
log_ratio_safe = torch.clamp(log_ratio, min=-SAFETY_BOUND, max=SAFETY_BOUND)
rollout_is_weights = torch.exp(log_ratio_safe)
elif rollout_is_level == "sequence":
# Sequence-level IS: π_train(y|x) / π_rollout(y|x) for entire sequence
# Product of token ratios: exp(Σ log(π_train/π_rollout))
log_ratio_sum = verl_F.masked_sum(log_ratio, response_mask, axis=-1).unsqueeze(-1)
log_ratio_for_metrics = log_ratio_sum # Store for metrics
# Apply safety bound to prevent overflow
log_ratio_sum_safe = torch.clamp(log_ratio_sum, min=-SAFETY_BOUND, max=SAFETY_BOUND)
rollout_is_weights = torch.exp(log_ratio_sum_safe).expand_as(old_log_prob)
elif rollout_is_level == "geometric":
# Geometric mean IS: (∏ π_train/π_rollout)^(1/T)
# Equivalent to exp(mean(log(π_train/π_rollout)))
log_ratio_mean = verl_F.masked_mean(log_ratio, response_mask, axis=-1).unsqueeze(-1)
log_ratio_for_metrics = log_ratio_mean # Store for metrics
# Geometric mean rarely explodes due to averaging, but apply safety bound anyway
log_ratio_mean_safe = torch.clamp(log_ratio_mean, min=-SAFETY_BOUND, max=SAFETY_BOUND)
rollout_is_weights = torch.exp(log_ratio_mean_safe).expand_as(old_log_prob)
else:
raise ValueError(f"Invalid rollout_is_level: {rollout_is_level}. Must be 'token', 'sequence', or 'geometric'.")
# Step 1.5: Apply per-token veto check in log space (memory efficient)
if rollout_is_veto_threshold is not None:
log_veto_threshold = torch.log(torch.tensor(rollout_is_veto_threshold, device=device))
# Check if any token ratio is below veto threshold (in log space)
# log(π_train/π_rollout) < log(veto_threshold) ⟺ π_train/π_rollout < veto_threshold
catastrophic_tokens = (log_ratio < log_veto_threshold) & response_mask.bool()
# For each sequence, check if it has any catastrophic token
# Use broadcasting instead of expand_as to save memory
has_catastrophic = catastrophic_tokens.any(dim=-1, keepdim=True)
# Create veto mask: 0 if sequence has catastrophic token, 1 otherwise
veto_mask = (~has_catastrophic).float()
else:
# No veto mechanism
catastrophic_tokens = torch.zeros_like(response_mask, dtype=torch.bool)
has_catastrophic = torch.zeros((old_log_prob.size(0), 1), dtype=torch.bool, device=device)
veto_mask = torch.ones((old_log_prob.size(0), 1), dtype=torch.float32, device=device)
# Step 2: Compute comprehensive metrics
metrics = compute_is_metrics(
rollout_is_weights=rollout_is_weights,
log_ratio_for_metrics=log_ratio_for_metrics,
response_mask=response_mask,
rollout_is_level=rollout_is_level,
rollout_is_threshold=upper_threshold,
rollout_is_threshold_lower=lower_threshold,
log_threshold_upper=log_threshold_upper,
log_threshold_lower=log_threshold_lower,
has_catastrophic=has_catastrophic,
catastrophic_tokens=catastrophic_tokens,
SAFETY_BOUND=SAFETY_BOUND,
)
# Step 3: Apply truncation or masking based on mode
if rollout_is_mode == "truncate":
# Truncated IS (TIS): only cap upper bound to prevent overweighting
rollout_is_weights = rollout_is_weights.clamp(max=upper_threshold)
elif rollout_is_mode == "mask":
# Masked IS (MIS): zero out weights outside [lower_threshold, upper_threshold]
mask = (rollout_is_weights >= lower_threshold) & (rollout_is_weights <= upper_threshold)
mask = mask.float()
# Track MIS-specific metrics
metrics["rollout_is_masked_fraction"] = verl_F.masked_mean(1 - mask, response_mask)
# Sequence-level masking fraction
if rollout_is_level in ["sequence", "geometric"]:
# All tokens in a sequence have the same weight, so reuse mask
metrics["rollout_is_seq_masked_fraction"] = (1 - mask[:, 0]).mean()
else:
# Check if any token in each sequence is masked
seq_has_masked = verl_F.masked_sum(1 - mask, response_mask, axis=-1) > 0
metrics["rollout_is_seq_masked_fraction"] = seq_has_masked.float().mean()
rollout_is_weights = rollout_is_weights * mask
else:
raise ValueError(f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate' or 'mask'.")
# Apply veto mask AFTER all thresholding
# This zeros out entire sequences that have any catastrophic token
rollout_is_weights = rollout_is_weights * veto_mask
# Apply response_mask to ensure weights are 0 where mask is 0
rollout_is_weights = rollout_is_weights * response_mask
# Wrap in DataProto for consistency with worker methods
rollout_is_weights_proto = DataProto.from_dict(tensors={"rollout_is_weights": rollout_is_weights})
# Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics
mismatch_metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
)
metrics.update(mismatch_metrics)
# Convert all tensor metrics to scalars for logging
# Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()
metrics_scalar = {}
for key, value in metrics.items():
if isinstance(value, torch.Tensor):
metrics_scalar[f"mismatch/{key}"] = value.item()
else:
metrics_scalar[f"mismatch/{key}"] = value
return rollout_is_weights_proto, metrics_scalar
def compute_is_metrics(
rollout_is_weights: torch.Tensor,
log_ratio_for_metrics: torch.Tensor,
response_mask: torch.Tensor,
rollout_is_level: str,
rollout_is_threshold: float,
rollout_is_threshold_lower: float,
log_threshold_upper: torch.Tensor,
log_threshold_lower: torch.Tensor,
has_catastrophic: torch.Tensor,
catastrophic_tokens: torch.Tensor,
SAFETY_BOUND: float,
) -> dict[str, Any]:
"""Compute comprehensive metrics for importance sampling weights.
Reference:
When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
This function computes metrics using a mix of true unclamped values (for max/min/fractions
in sequence/geometric mode via log-space) and safety-clamped values (for mean/std/ESS)
to balance accuracy with numerical stability and avoid overflow.
"""
# Validate that we have at least one valid sample
assert response_mask.any(), "Expected at least one valid sample in response_mask"
metrics = {}
device = rollout_is_weights.device
# Track veto statistics
metrics["rollout_is_veto_fraction"] = has_catastrophic.float().mean()
metrics["rollout_is_catastrophic_token_fraction"] = verl_F.masked_mean(catastrophic_tokens.float(), response_mask)
# Compute metrics based on IS level
if rollout_is_level in ["sequence", "geometric"]:
# For sequence/geometric, compute true statistics from log-space
# This reflects the actual distribution before clamping
# True max/min in log space
log_max = log_ratio_for_metrics.max()
log_min = log_ratio_for_metrics.min()
# Convert to regular space with safety bound
metrics["rollout_is_max"] = torch.exp(torch.clamp(log_max, max=SAFETY_BOUND))
metrics["rollout_is_min"] = torch.exp(log_min)
# Mean uses clamped weights to avoid overflow
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
# Compute fraction exceeding threshold in log space (accurate)
exceeds_upper = log_ratio_for_metrics > log_threshold_upper
below_lower = log_ratio_for_metrics < log_threshold_lower
if rollout_is_level == "sequence":
# For sequence level, all tokens in a sequence have the same weight
metrics["rollout_is_ratio_fraction_high"] = exceeds_upper.float().mean()
metrics["rollout_is_ratio_fraction_low"] = below_lower.float().mean()
else: # geometric
# Need to expand to match token dimensions
exceeds_upper_expanded = exceeds_upper.expand_as(response_mask)
below_lower_expanded = below_lower.expand_as(response_mask)
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
exceeds_upper_expanded.float(), response_mask
)
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(below_lower_expanded.float(), response_mask)
else:
# Token-level: compute directly from weights
metrics["rollout_is_mean"] = verl_F.masked_mean(rollout_is_weights, response_mask)
# Fraction exceeding thresholds
rollout_is_above_threshold = rollout_is_weights > rollout_is_threshold
rollout_is_below_threshold = rollout_is_weights < rollout_is_threshold_lower
metrics["rollout_is_ratio_fraction_high"] = verl_F.masked_mean(
rollout_is_above_threshold.float(), response_mask
)
metrics["rollout_is_ratio_fraction_low"] = verl_F.masked_mean(rollout_is_below_threshold.float(), response_mask)
# Max/min for token level
mask_bool = response_mask.bool()
metrics["rollout_is_max"] = rollout_is_weights.masked_fill(~mask_bool, float("-inf")).max()
metrics["rollout_is_min"] = rollout_is_weights.masked_fill(~mask_bool, float("inf")).min()
# Compute standard deviation using clamped weights to avoid overflow
mask_count = response_mask.sum()
if mask_count > 1:
# Use clamped weights for variance to avoid squaring huge values
weights_for_std = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
# Use mean from clamped weights for consistency
mean_clamped = verl_F.masked_mean(weights_for_std, response_mask)
rollout_is_var = verl_F.masked_mean(weights_for_std.square(), response_mask) - mean_clamped.square()
metrics["rollout_is_std"] = torch.sqrt(torch.clamp(rollout_is_var, min=0.0))
else:
metrics["rollout_is_std"] = torch.tensor(0.0, device=device)
# Effective sample size (use clamped weights to avoid overflow)
weights_for_ess = rollout_is_weights.clamp(min=rollout_is_threshold_lower, max=rollout_is_threshold)
mean_for_ess = verl_F.masked_mean(weights_for_ess, response_mask)
is_weights_normalized = weights_for_ess / (mean_for_ess + 1e-8)
metrics["rollout_is_eff_sample_size"] = 1.0 / verl_F.masked_mean(is_weights_normalized.square(), response_mask)
# Per-sequence breakdown metrics
if rollout_is_weights.dim() > 1:
# Compute mean IS weight per sequence
seq_mean_weights = verl_F.masked_mean(rollout_is_weights, response_mask, axis=-1)
# Per-sequence statistics
metrics["rollout_is_seq_mean"] = seq_mean_weights.mean()
metrics["rollout_is_seq_std"] = (
seq_mean_weights.std() if seq_mean_weights.numel() > 1 else torch.tensor(0.0, device=device)
)
metrics["rollout_is_seq_max"] = seq_mean_weights.max()
metrics["rollout_is_seq_min"] = seq_mean_weights.min()
# Identify most problematic sequences
seq_deviation = (seq_mean_weights - 1.0).abs()
metrics["rollout_is_seq_max_deviation"] = seq_deviation.max()
# Fraction of sequences with high IS weights
metrics["rollout_is_seq_fraction_high"] = (seq_mean_weights > rollout_is_threshold).float().mean()
metrics["rollout_is_seq_fraction_low"] = (seq_mean_weights < rollout_is_threshold_lower).float().mean()
# Percentile metrics for better distribution understanding
# Get all valid IS weights
flat_weights = rollout_is_weights[response_mask.bool()]
# Compute key percentiles (guaranteed to have elements due to assertion at function start)
assert flat_weights.numel() > 0, "flat_weights should not be empty"
metrics["rollout_is_p25"] = torch.quantile(flat_weights, 0.25)
metrics["rollout_is_p50"] = torch.quantile(flat_weights, 0.50) # median
metrics["rollout_is_p75"] = torch.quantile(flat_weights, 0.75)
metrics["rollout_is_p95"] = torch.quantile(flat_weights, 0.95)
metrics["rollout_is_p99"] = torch.quantile(flat_weights, 0.99)
return metrics
def compute_mismatch_metrics(
old_log_prob: torch.Tensor,
rollout_log_prob: Optional[torch.Tensor],
response_mask: torch.Tensor,
) -> dict[str, Any]:
"""Compute training-inference mismatch metrics (helper function).
This helper function operates on raw tensors and is used internally by:
- compute_rollout_importance_weights() in this module (automatically included)
- Tests (test_rollout_is.py, test_rollout_is_integration.py)
These metrics help diagnose the mismatch between the rollout policy (e.g., vLLM)
and the training policy (e.g., FSDP), which can cause training instability.
Key metrics:
- mismatch_kl: Direct KL divergence estimator KL(π_rollout || π_training)
- mismatch_k3_kl: K3 KL estimator for stability (more stable for small KL)
- training_ppl: Perplexity of training policy
- rollout_ppl: Perplexity of rollout policy
- log_ppl_diff: Difference in log perplexities
- ppl_ratio: Ratio of training PPL to rollout PPL
Args:
old_log_prob: Log probabilities from training policy, shape (batch_size, seq_length)
rollout_log_prob: Log probabilities from rollout policy, shape (batch_size, seq_length)
response_mask: Mask for valid tokens, shape (batch_size, seq_length)
Returns:
Dictionary of mismatch metrics (without prefix)
Reference:
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
"""
# Validate that we have at least one valid token
assert response_mask.any(), "Expected at least one valid token in response_mask"
metrics = {}
# 1. Training policy perplexity (always available)
# Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))
# where |T| is the number of tokens generated by the model
mean_log_prob_training = verl_F.masked_mean(old_log_prob, response_mask, axis=-1) # (batch_size,)
training_ppl = torch.exp(-mean_log_prob_training).mean() # Batch mean of per-sequence PPL
metrics["mismatch_training_ppl"] = training_ppl.detach().item()
# Also log log-ppl for easier analysis (avoids exponential scale)
metrics["mismatch_training_log_ppl"] = (-mean_log_prob_training).mean().detach().item()
# 2. Compute rollout mismatch metrics (only if rollout_log_probs available)
if rollout_log_prob is not None:
# 2a. mismatch_kl: Direct estimator for KL(π_rollout || π_training)
# This is the standard KL divergence: E[log(π_rollout) - log(π_training)]
# Positive value means rollout policy is more confident than training policy
metrics["mismatch_kl"] = verl_F.masked_mean(rollout_log_prob - old_log_prob, response_mask).detach().item()
# 2b. mismatch_k3_kl: K3 estimator for KL(π_rollout || π_training)
# More stable for small KL values using: E[exp(log_ratio) - log_ratio - 1]
# Formula: KL ≈ E[r - log(r) - 1] where r = π_training/π_rollout
log_ratio = old_log_prob - rollout_log_prob
mismatch_k3_kl_matrix = torch.exp(log_ratio) - log_ratio - 1
metrics["mismatch_k3_kl"] = verl_F.masked_mean(mismatch_k3_kl_matrix, response_mask).detach().item()
# 2c. Rollout policy perplexity
mean_log_prob_rollout = verl_F.masked_mean(rollout_log_prob, response_mask, axis=-1) # (batch_size,)
rollout_ppl = torch.exp(-mean_log_prob_rollout).mean() # Batch mean of per-sequence PPL
metrics["mismatch_rollout_ppl"] = rollout_ppl.detach().item()
metrics["mismatch_rollout_log_ppl"] = (-mean_log_prob_rollout).mean().detach().item()
# 2d. Log PPL difference (sequence-level perplexity difference)
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
# Since ppl = exp(-log_prob), we have:
# log(ppl_ratio) = log(training_ppl/rollout_ppl) = log_ppl_diff
# Positive value means training assigns lower probability (higher PPL) than rollout
log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
metrics["mismatch_log_ppl_diff"] = log_ppl_diff.mean().detach().item()
metrics["mismatch_log_ppl_abs_diff"] = log_ppl_diff.abs().mean().detach().item()
metrics["mismatch_log_ppl_diff_max"] = log_ppl_diff.max().detach().item()
metrics["mismatch_log_ppl_diff_min"] = log_ppl_diff.min().detach().item()
# 2e. PPL ratio (how much higher is training PPL vs rollout PPL)
# IMPORTANT: Compute per-sequence ratio first, then average
# For numerical stability, compute in log space using log_ppl_diff
# Note: log_ppl_diff = log(ppl_ratio), so ppl_ratio = exp(log_ppl_diff)
# This is the inverse of geometric IS: ppl_ratio_i = 1 / geometric_is_i for each sequence
ppl_ratio = torch.exp(log_ppl_diff).mean() # mean(exp(log_ppl_diff)) = mean(ppl_ratio_i)
metrics["mismatch_ppl_ratio"] = ppl_ratio.detach().item()
return metrics

View File

@ -49,6 +49,7 @@ from verl.trainer.ppo.metric_utils import (
compute_timing_metrics,
process_validation_metrics,
)
from verl.trainer.ppo.mismatch_helper import compute_rollout_importance_weights
from verl.trainer.ppo.reward import compute_reward, compute_reward_async
from verl.trainer.ppo.utils import Role, WorkerType, need_critic, need_reference_policy, need_reward_model
from verl.utils.checkpoint.checkpoint_manager import find_latest_ckpt_path, should_save_ckpt_esi
@ -918,6 +919,49 @@ class RayPPOTrainer:
)
metrics.update(global_balance_stats)
def compute_rollout_importance_weights_and_add_to_batch(self, batch: DataProto) -> tuple[DataProto, dict]:
"""Compute rollout importance sampling weights and mismatch metrics, conditionally add weights to batch.
This method computes IS weights to correct for distribution mismatch between
rollout policy and training policy. It always computes metrics when enabled, but
only adds weights to batch if algorithm.rollout_is is True.
Args:
batch: DataProto containing old_log_probs, rollout_log_probs, response_mask
Returns:
Tuple of (updated_batch, metrics) where:
- updated_batch: Batch with rollout_is_weights added (if rollout_is=True)
- metrics: Dictionary of IS and mismatch metrics (all with mismatch/ prefix)
"""
# Compute rollout IS weights if enabled and data is available
# rollout_is_threshold is the main on/off switch
if self.config.algorithm.rollout_is_threshold is not None and "rollout_log_probs" in batch.batch:
rollout_is_weights, rollout_is_metrics = compute_rollout_importance_weights(
old_log_prob=batch.batch["old_log_probs"],
rollout_log_prob=batch.batch["rollout_log_probs"],
response_mask=batch.batch["response_mask"],
rollout_is_level=self.config.algorithm.rollout_is_level,
rollout_is_mode=self.config.algorithm.rollout_is_mode,
rollout_is_threshold=self.config.algorithm.rollout_is_threshold,
rollout_is_threshold_lower=self.config.algorithm.rollout_is_threshold_lower,
rollout_is_veto_threshold=self.config.algorithm.rollout_is_veto_threshold,
)
# Control: Should we apply weights to policy loss?
# True = add weights to batch (actor will apply them)
# False = don't add weights (metrics only, no loss modification)
apply_weights = self.config.algorithm.get("rollout_is", False)
if apply_weights:
# Add IS weights to batch for distribution to workers
batch = batch.union(rollout_is_weights)
return batch, rollout_is_metrics
# Return unchanged batch and empty metrics if IS is disabled
return batch, {}
def fit(self):
"""
The training loop of PPO.
@ -1107,6 +1151,13 @@ class RayPPOTrainer:
else:
batch.batch["token_level_rewards"] = batch.batch["token_level_scores"]
# Compute rollout importance sampling weights centrally (once per batch)
# This corrects for mismatch between rollout policy and training policy
# Also computes mismatch metrics (KL, PPL, etc.)
batch, is_metrics = self.compute_rollout_importance_weights_and_add_to_batch(batch)
# IS and mismatch metrics already have mismatch/ prefix
metrics.update(is_metrics)
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get(
"norm_adv_by_std_in_grpo", True
@ -1205,6 +1256,7 @@ class RayPPOTrainer:
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
# Note: mismatch metrics (KL, PPL, etc.) are collected at line 1179 after advantage computation
# this is experimental and may be changed/removed in the future in favor of a general-purpose one
if isinstance(self.train_dataloader.sampler, AbstractCurriculumSampler):

View File

@ -146,7 +146,10 @@ class SFTTrainer:
config = self.config
tokenizer = self.model_config.tokenizer
train_dataset = create_sft_dataset(config.data.train_files, config.data, tokenizer)
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
if config.data.val_files:
val_dataset = create_sft_dataset(config.data.val_files, config.data, tokenizer)
else:
val_dataset = None
self.train_dataset, self.val_dataset = train_dataset, val_dataset
@ -181,19 +184,22 @@ class SFTTrainer:
pin_memory_device=device_name,
)
self.val_sampler = DistributedSampler(
self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=self.train_batch_size_per_dp,
sampler=self.val_sampler,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True,
drop_last=True,
pin_memory_device=device_name,
)
if self.val_dataset:
self.val_sampler = DistributedSampler(
self.val_dataset, shuffle=False, num_replicas=dp_size, rank=dp_rank, drop_last=True
)
self.val_dataloader = StatefulDataLoader(
dataset=self.val_dataset,
batch_size=self.train_batch_size_per_dp,
sampler=self.val_sampler,
collate_fn=self.collate_fn,
num_workers=8,
pin_memory=True,
drop_last=True,
pin_memory_device=device_name,
)
else:
self.val_dataloader = None
def fit(self):
is_logging = self.engine.is_mp_src_rank_with_outputs() and self.engine.get_data_parallel_rank() == 0
@ -242,6 +248,7 @@ class SFTTrainer:
}
train_time = 0
total_tokens = 0
for epoch in range(start_epoch, self.config.trainer.total_epochs):
self.train_sampler.set_epoch(epoch=epoch)
@ -302,6 +309,8 @@ class SFTTrainer:
metrics["train/grad_norm"] = metrics.pop("grad_norm")
metrics["train/lr"] = lr
metrics["train/global_tokens"] = output_tensor.sum().item()
total_tokens += metrics["train/global_tokens"]
metrics["train/total_tokens(B)"] = total_tokens / 1e9
# mfu
delta_time = timer.last
estimated_flops, promised_flops = self.flops_counter.estimate_flops(batch_seqlens, delta_time)
@ -315,7 +324,7 @@ class SFTTrainer:
is_save_step = global_step % self.save_freq == 0
# early exit or validation step
if is_last_step or (self.test_freq > 0 and is_valid_step):
if is_last_step and self.val_dataloader is not None or (self.test_freq > 0 and is_valid_step):
# Perform validation
val_losses = []
for val_data in self.val_dataloader:

View File

@ -182,7 +182,8 @@ def find_latest_ckpt_path(path, directory_format="global_step_{}"):
tracker_file = get_checkpoint_tracker_filename(path)
if not os.path.exists(tracker_file):
print(f"Checkpoint tracker file does not exist: {tracker_file}")
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
print(f"Checkpoint tracker file does not exist: {tracker_file}")
return None
with open(tracker_file, "rb") as f:

View File

@ -1 +1 @@
0.5.0.dev
0.6.0

View File

@ -373,13 +373,10 @@ class DataParallelPPOActor(BasePPOActor):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
if self.config.tis_imp_ratio_cap > 0:
assert "rollout_log_probs" in data.batch.keys(), (
"Truncated Importance Sampling (TIS) requires to configure "
"`actor_rollout_ref.rollout.calculate_log_probs=True` "
"and is not currently supported in Server mode (agent loop)."
)
select_keys.append("rollout_log_probs")
# Include pre-computed IS weights if present in batch
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
non_tensor_select_keys = ["multi_modal_inputs"] if has_multi_modal_inputs else []
@ -412,7 +409,6 @@ class DataParallelPPOActor(BasePPOActor):
model_inputs = {**micro_batch.batch, **micro_batch.non_tensor_batch}
response_mask = model_inputs["response_mask"]
old_log_prob = model_inputs["old_log_probs"]
rollout_log_probs = model_inputs["rollout_log_probs"] if self.config.tis_imp_ratio_cap > 0 else None
advantages = model_inputs["advantages"]
entropy_coeff = self.config.entropy_coeff
@ -438,9 +434,21 @@ class DataParallelPPOActor(BasePPOActor):
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
# vanilla -> verl.trainer.ppo.core_algos.compute_policy_loss_vanilla
# Extract pre-computed rollout importance sampling weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = model_inputs.get("rollout_is_weights", None)
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
# are computed centrally in ray_trainer.py for consistency and efficiency.
# This ensures metrics are computed uniformly across all batches at the trainer level
# and avoids redundant computation across workers and micro-batches.
# gpg -> verl.trainer.ppo.core_algos.compute_policy_loss_gpg
# clip_cov -> verl.trainer.ppo.core_algos.compute_policy_loss_clip_cov
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Compute policy loss (all functions return 4 values)
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
@ -448,7 +456,7 @@ class DataParallelPPOActor(BasePPOActor):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_log_probs=rollout_log_probs,
rollout_is_weights=rollout_is_weights,
)
if entropy_coeff != 0:

View File

@ -316,6 +316,10 @@ class MegatronPPOActor(BasePPOActor):
]
if self.config.use_kl_loss:
select_keys.append("ref_log_prob")
# Include pre-computed IS weights if present in batch
# Weights are computed centrally in trainer and added to batch when algorithm.rollout_is=True
if "rollout_is_weights" in data.batch.keys():
select_keys.append("rollout_is_weights")
self.has_multi_modal_inputs = "multi_modal_inputs" in data.non_tensor_batch.keys()
if self.has_multi_modal_inputs:
data = data.select(select_keys, ["multi_modal_inputs"])
@ -419,7 +423,6 @@ class MegatronPPOActor(BasePPOActor):
response_length = responses.size(1)
response_mask = data["response_mask"].to(bool)
loss_agg_mode = self.config.loss_agg_mode
# compute policy loss
log_prob = output["log_probs"][:, -response_length - 1 : -1].contiguous()
ret_entropy = None
@ -434,6 +437,15 @@ class MegatronPPOActor(BasePPOActor):
loss_mode = self.config.policy_loss.get("loss_mode", "vanilla")
policy_loss_fn = get_policy_loss_fn(loss_mode)
# Extract pre-computed rollout importance sampling weights if present
# Weights are computed centrally in trainer and added when algorithm.rollout_is=True
rollout_is_weights = data.get("rollout_is_weights", None)
# NOTE: Both mismatch diagnostic metrics (PPL, KL, etc.) and IS weight metrics
# are computed centrally in ray_trainer.py for consistency and efficiency.
# This ensures metrics are computed uniformly across all batches at the trainer level
# and avoids redundant computation across workers and micro-batches.
pg_loss, pg_clipfrac, ppo_kl, pg_clipfrac_lower = policy_loss_fn(
old_log_prob=old_log_prob,
log_prob=log_prob,
@ -441,6 +453,7 @@ class MegatronPPOActor(BasePPOActor):
response_mask=response_mask,
loss_agg_mode=loss_agg_mode,
config=self.config,
rollout_is_weights=rollout_is_weights,
)
stats.update(

View File

@ -106,7 +106,6 @@ class ActorConfig(BaseConfig):
clip_ratio_c: float = 3.0
loss_agg_mode: str = "token-mean"
entropy_coeff: float = 0
tis_imp_ratio_cap: float = -1
use_kl_loss: bool = False
use_torch_compile: bool = True
kl_loss_coef: float = 0.001

View File

@ -60,16 +60,27 @@ class FSDPOptimizerConfig(OptimizerConfig):
Args:
lr (float): Learning rate.
min_lr_ratio (Optional[float]): Minimum LR ratio for cosine schedule.
warmup_style (str): LR warmup style: "constant" or "cosine".
lr_scheduler_type (str): LR scheduler type: "constant" or "cosine".
num_cycles (float): Number of cosine cycles in LR schedule.
"""
_mutable_fields = OptimizerConfig._mutable_fields.copy()
_mutable_fields.add("lr_scheduler_type")
min_lr_ratio: Optional[float] = None
warmup_style: str = "constant"
# deprecate warmup_style
warmup_style: Optional[str] = None
lr_scheduler_type: str = "constant"
num_cycles: float = 0.5
def __post_init__(self):
assert self.warmup_style in ["constant", "cosine"]
if self.warmup_style is not None:
assert self.warmup_style in ["constant", "cosine"]
warnings.warn(
"`warmup_style` is deprecated, use `lr_scheduler_type` instead.", DeprecationWarning, stacklevel=2
)
self.lr_scheduler_type = self.warmup_style
assert self.lr_scheduler_type in ["constant", "cosine"]
return super().__post_init__()

View File

@ -370,7 +370,7 @@ class FSDPEngine(BaseEngine):
total_steps = optim_config.total_training_steps
num_warmup_steps = optim_config.lr_warmup_steps
warmup_style = optim_config.warmup_style
lr_scheduler_type = optim_config.lr_scheduler_type
min_lr_ratio = optim_config.min_lr_ratio
num_cycles = optim_config.num_cycles
if num_warmup_steps <= 0:
@ -380,9 +380,9 @@ class FSDPEngine(BaseEngine):
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
if lr_scheduler_type == "constant":
lr_scheduler = get_constant_schedule_with_warmup(optimizer=optimizer, num_warmup_steps=num_warmup_steps)
elif warmup_style == "cosine":
elif lr_scheduler_type == "cosine":
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=num_warmup_steps,
@ -391,7 +391,7 @@ class FSDPEngine(BaseEngine):
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
return lr_scheduler
def _build_model_optimizer(self):

View File

@ -529,7 +529,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
total_steps = optim_config.get("total_training_steps", 0)
num_warmup_steps = int(optim_config.get("lr_warmup_steps", -1))
warmup_style = optim_config.get("warmup_style", "constant")
lr_scheduler_type = optim_config.get("lr_scheduler_type", "constant")
min_lr_ratio = optim_config.get("min_lr_ratio", 0.0)
num_cycles = optim_config.get("num_cycles", 0.5)
if num_warmup_steps < 0:
@ -539,11 +539,11 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
if self.rank == 0:
print(f"Total steps: {total_steps}, num_warmup_steps: {num_warmup_steps}")
if warmup_style == "constant":
if lr_scheduler_type == "constant":
actor_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=actor_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
elif lr_scheduler_type == "cosine":
actor_lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=actor_optimizer,
num_warmup_steps=num_warmup_steps,
@ -552,7 +552,7 @@ class ActorRolloutRefWorker(Worker, DistProfilerExtension):
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
log_gpu_memory_usage(f"After {role} optimizer init", logger=logger)
else:
@ -1386,7 +1386,8 @@ class CriticWorker(Worker, DistProfilerExtension):
total_steps = config.optim.get("total_training_steps", 0)
num_warmup_steps = int(config.optim.get("lr_warmup_steps", -1))
warmup_style = config.optim.get("warmup_style", "constant")
lr_scheduler_type = config.optim.get("lr_scheduler_type", "constant")
if num_warmup_steps < 0:
num_warmup_steps_ratio = config.optim.get("lr_warmup_steps_ratio", 0.0)
num_warmup_steps = int(num_warmup_steps_ratio * total_steps)
@ -1396,11 +1397,11 @@ class CriticWorker(Worker, DistProfilerExtension):
from verl.utils.torch_functional import get_constant_schedule_with_warmup, get_cosine_schedule_with_warmup
if warmup_style == "constant":
if lr_scheduler_type == "constant":
critic_lr_scheduler = get_constant_schedule_with_warmup(
optimizer=critic_optimizer, num_warmup_steps=num_warmup_steps
)
elif warmup_style == "cosine":
elif lr_scheduler_type == "cosine":
min_lr_ratio = config.optim.get("min_lr_ratio", 0.0)
num_cycles = config.optim.get("num_cycles", 0.5)
critic_lr_scheduler = get_cosine_schedule_with_warmup(
@ -1411,7 +1412,7 @@ class CriticWorker(Worker, DistProfilerExtension):
num_cycles=num_cycles,
)
else:
raise NotImplementedError(f"Warmup style {warmup_style} is not supported")
raise NotImplementedError(f"LR scheduler type {lr_scheduler_type} is not supported")
return critic_module, critic_optimizer, critic_lr_scheduler