[algo, perf] feat: Vectorize GRPO Advantage Estimator - 13~26x Speedup (#3635)

### What does this PR do?  

Implements a vectorized GRPO advantage path for outcome-only RL in
core_algos.py, keeping the original implementation intact and
selectable. This yields large speedups at medium–large batch sizes by
replacing Python-side grouping loops with segment reductions and
one-pass gathers.


Results (CPU, Apple M-series example; float32):
```shell
[CPU] bs=  512 T= 512 G= 10 | orig=5.47ms vec=0.21ms speedup=26.16x
[CPU] bs= 1024 T=1024 G= 16 | orig=11.05ms vec=0.54ms speedup=20.60x
[CPU] bs= 2048 T=2048 G= 32 | orig=23.20ms vec=1.74ms speedup=13.32x
```

```shell
[GRPO] seed=0 groups=5 shape=torch.Size([64, 128]) mask_tokens=4147 adv_max_diff=2.384e-07 ret_max_diff=2.384e-07
[GRPO] seed=1 groups=8 shape=torch.Size([128, 256]) mask_tokens=16364 adv_max_diff=2.384e-07 ret_max_diff=2.384e-07
[GRPO] seed=2 groups=10 shape=torch.Size([512, 512]) mask_tokens=130968 adv_max_diff=4.768e-07 ret_max_diff=4.768e-07
```


### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: #3634
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [x] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [x] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [x] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [x] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [x] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
CedricHuang
2025-09-27 17:21:08 +08:00
committed by GitHub
parent c03dcb0f8f
commit 4ff3ce2fed
6 changed files with 386 additions and 2 deletions

View File

@ -501,7 +501,7 @@ Algorithm
- ``gamma``: discount factor
- ``lam``: Trade-off between bias and variance in the GAE estimator
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``, ``grpo_vectorized``
- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False.
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to
calculate the kl divergence between actor and reference policy. For

View File

@ -22,6 +22,8 @@ import torch
import verl.trainer.ppo.core_algos
from verl.trainer.ppo.core_algos import (
compute_gae_advantage_return,
compute_grpo_outcome_advantage,
compute_grpo_vectorized_outcome_advantage,
compute_rloo_outcome_advantage,
compute_rloo_vectorized_outcome_advantage,
get_adv_estimator_fn,
@ -257,5 +259,59 @@ def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_grou
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
@pytest.mark.parametrize(
"batch_size,seq_len,num_groups,seed",
[
(64, 128, 5, 0),
(128, 256, 8, 1),
(512, 512, 10, 2),
],
)
def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
# Set seeds for reproducibility
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
# Generate group indices (numpy array of shape [batch_size])
index = _make_group_index(batch_size, num_groups)
# Generate binary response mask (at least one valid token per row)
response_mask = _rand_mask(batch_size, seq_len)
# Generate token-level rewards and apply mask
base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
token_level_rewards = base_rewards * response_mask
# Compute GRPO outcome advantage (original implementation)
adv1, ret1 = compute_grpo_outcome_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)
# Compute GRPO outcome advantage (vectorized implementation)
adv2, ret2 = compute_grpo_vectorized_outcome_advantage(
token_level_rewards=token_level_rewards,
response_mask=response_mask,
index=index,
)
# Diagnostic info for visibility (same style as RLOO test)
adv_max_diff = (adv1 - adv2).abs().max().item()
ret_max_diff = (ret1 - ret2).abs().max().item()
total_mask_tokens = int(response_mask.sum().item())
print(
f"[GRPO] seed={seed} groups={num_groups} shape={adv1.shape} "
f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
)
# Assert shape and numerical equivalence
assert adv1.shape == adv2.shape == (batch_size, seq_len)
assert ret1.shape == ret2.shape == (batch_size, seq_len)
assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,72 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
os.environ.setdefault("VERL_FORCE_DEVICE", "cpu") # ensure CPU for tests
import numpy as np
import pytest
import torch
from verl.utils import as_torch_index, group_mean_std
def test_as_torch_index_basic_integers():
g = as_torch_index([2, 2, 5, 7, 5, 2])
assert g.dtype == torch.long
assert g.device.type == "cpu"
# Values should be contiguous 0..G-1, keeping equal labels equal
assert g.tolist()[0] == g.tolist()[1]
assert len(torch.unique(g)) == 3 # {2,5,7} -> 3 groups
def test_as_torch_index_near_integer_floats():
arr = np.array([1.0000001, 2.0, 1.0, 3.0000000001], dtype=np.float64)
g = as_torch_index(arr) # should round to integers then factorize
assert g.dtype == torch.long
assert len(torch.unique(g)) == 3 # {1,2,3}
def test_as_torch_index_factorization_mixed():
labels = ["a", "b", "a", "c", "0042", 42]
g = as_torch_index(labels)
# "0042" and 42 should NOT be the same group (strings are not coerced here)
assert g.tolist()[4] != g.tolist()[5]
assert len(torch.unique(g)) == 5
def test_group_mean_std_simple():
# groups: 0 -> [1, 3], 1 -> [2]
scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
gidx = as_torch_index([0, 1, 0])
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
# group 0: mean = (1+3)/2 = 2
# sample std (unbiased) = sqrt( (sum(x^2) - (sum(x)^2)/n) / (n-1) )
# = sqrt( (1^2+3^2) - (1+3)^2/2 ) / (2-1) = sqrt(10 - 16/2) = sqrt(2)
assert torch.allclose(mean_g, torch.tensor([2.0, 0.0]))
assert torch.allclose(cnt_g, torch.tensor([2.0, 1.0]))
# singleton group -> std = 1.0
assert mean_g[1].item() == 0.0
assert std_g[1].item() == 1.0
assert pytest.approx(std_g[0].item(), rel=1e-6) == (2.0**0.5)
def test_group_mean_std_empty():
scores = torch.tensor([], dtype=torch.float32)
gidx = torch.tensor([], dtype=torch.long)
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
assert mean_g.numel() == 0 and std_g.numel() == 0 and cnt_g.numel() == 0

View File

@ -30,6 +30,7 @@ from omegaconf import DictConfig
import verl.utils.torch_functional as verl_F
from verl.trainer.config import AlgoConfig
from verl.utils import as_torch_index, group_mean_std
from verl.utils.import_utils import deprecated
from verl.workers.config import ActorConfig
@ -103,6 +104,7 @@ class AdvantageEstimator(str, Enum):
GRPO_PASSK = "grpo_passk"
GPG = "gpg"
RLOO_VECTORIZED = "rloo_vectorized"
GRPO_VECTORIZED = "grpo_vectorized"
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
@ -326,6 +328,33 @@ def compute_grpo_outcome_advantage(
return scores, scores
@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED)
def compute_grpo_vectorized_outcome_advantage(
token_level_rewards: torch.Tensor,
response_mask: torch.Tensor,
index: np.ndarray,
epsilon: float = 1e-6,
norm_adv_by_std_in_grpo: bool = True,
config: Optional[AlgoConfig] = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Vectorized GRPOoutcome-only:
For each group g:
a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g),
then broadcast the scalar across the token dimension (multiplied by response_mask).。
"""
with torch.no_grad():
scores = token_level_rewards.sum(dim=-1)
g = as_torch_index(index, device=scores.device)
mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon)
if norm_adv_by_std_in_grpo:
scalars = (scores - mean_g[g]) / (std_g[g] + epsilon)
else:
scalars = scores - mean_g[g]
advantages = scalars.unsqueeze(-1) * response_mask
return advantages, advantages
@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
def compute_grpo_passk_outcome_advantage(
token_level_rewards: torch.Tensor,

View File

@ -14,8 +14,12 @@
from . import config, tokenizer
from .config import omega_conf_to_dataclass, validate_config
from .groupwise import as_torch_index, group_mean_std
from .tokenizer import hf_processor, hf_tokenizer
__all__ = (
tokenizer.__all__ + config.__all__ + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"]
tokenizer.__all__
+ config.__all__
+ ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"]
+ ["as_torch_index", "group_mean_std"]
)

223
verl/utils/groupwise.py Normal file
View File

@ -0,0 +1,223 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Group-wise helpers for RL training utilities.
Public API:
- as_torch_index(index, device=None) -> torch.LongTensor
- group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g)
Default device policy:
- If `device` is None:
* In pytest (detected by env "PYTEST_CURRENT_TEST"): use CPU.
* Else if CUDA is available: use CUDA.
* Else: use CPU.
- You can override via env "VERL_FORCE_DEVICE" (e.g., "cuda:0" / "cpu").
Notes:
- as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long
tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools,
numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6)
are rounded; otherwise factorization is applied.
- group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance
(denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for
compatibility with common “native” conventions.
"""
from __future__ import annotations
import os
from typing import Any, Optional
import numpy as np
import torch
from verl.utils.device import get_torch_device
__all__ = ["as_torch_index", "group_mean_std"]
def _resolve_device(explicit: Optional[torch.device | str]) -> torch.device:
"""
Resolve device according to policy described in the module docstring.
Priority:
1) explicit argument
2) VERL_FORCE_DEVICE env
3) pytest detection -> cpu
4) cuda if available, else cpu
"""
if explicit is not None:
return torch.device(explicit)
forced = os.getenv("VERL_FORCE_DEVICE")
if forced:
return torch.device(forced)
# Heuristic: pytest sets PYTEST_CURRENT_TEST
if "PYTEST_CURRENT_TEST" in os.environ:
return torch.device("cpu")
return get_torch_device()
def _to_1d_numpy_object_array(x: Any) -> np.ndarray:
"""Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype."""
try:
arr = np.asarray(x)
except Exception:
try:
arr = np.array(list(x), dtype=object)
except Exception:
arr = np.array([x], dtype=object)
if arr.ndim != 1:
arr = arr.reshape(-1)
return arr
def as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor:
"""
Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1).
Args:
index: Any iterable of labels or tensor/ndarray.
device: Target device; if None, resolved via _resolve_device().
Returns:
torch.LongTensor with shape (N,)
"""
target = _resolve_device(device)
# ---------- Fast path: torch.Tensor ----------
if isinstance(index, torch.Tensor):
t = index.reshape(-1)
if t.dtype in (
torch.int64,
torch.int32,
torch.int16,
torch.int8,
getattr(torch, "uint8", torch.uint8),
torch.bool,
):
return t.to(device=target, dtype=torch.long)
if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
t64 = t.to(dtype=torch.float64)
rounded = torch.round(t64)
if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6):
return rounded.to(device=target, dtype=torch.long)
arr = np.array([str(x.item()) for x in t], dtype=object)
else:
arr = np.array([str(x.item()) if hasattr(x, "item") else str(x) for x in t], dtype=object)
else:
# ---------- Non-torch: go through numpy ----------
arr = _to_1d_numpy_object_array(index)
# Pure integers (incl. bool)
if arr.dtype != object and np.issubdtype(arr.dtype, np.integer):
return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target)
# Floats nearly equal to integers
if arr.dtype != object and np.issubdtype(arr.dtype, np.floating):
arr64 = arr.astype(np.float64, copy=False)
rounded = np.rint(arr64)
if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6):
return torch.from_numpy(rounded.astype(np.int64)).to(device=target)
# fall through
# Try numeric string coercion
try:
coerced = arr.astype(np.int64)
return torch.from_numpy(coerced).to(device=target)
except Exception:
pass
if arr.dtype != object:
arr = arr.astype(object)
# ---------- Factorization (UUIDs / mixed types / arbitrary labels) ----------
try:
_, inv = np.unique(arr, return_inverse=True)
except Exception:
sarr = np.array([str(x) for x in arr], dtype=object)
_, inv = np.unique(sarr, return_inverse=True)
inv = inv.astype(np.int64, copy=False)
return torch.from_numpy(inv).to(device=target)
@torch.no_grad()
def group_mean_std(
scores: torch.Tensor,
gidx: torch.Tensor,
eps: float = 1e-6,
device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Compute per-group mean/std/count in pure PyTorch.
mean_g = sum / count
std_g = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) )
Singleton groups fallback to mean=0, std=1.
Args:
scores: (N,) float tensor.
gidx : (N,) long/int tensor with group indices (0..G-1).
eps : Numerical floor for variance.
device: Target device; if None, resolved via _resolve_device().
Returns:
mean_g: (G,) float32
std_g : (G,) float32
count : (G,) float32
"""
target = _resolve_device(device)
scores = scores.reshape(-1).to(device=target, dtype=torch.float32)
gidx = gidx.reshape(-1).to(device=target, dtype=torch.long)
if scores.numel() != gidx.numel():
raise ValueError(f"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}")
G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0
if G == 0:
# Return empty tensors on the selected device
empty = torch.empty(0, device=target, dtype=torch.float32)
return empty, empty, empty
ones = torch.ones_like(scores, dtype=torch.float32)
count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones)
s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores)
s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores)
mean = s1 / count.clamp_min(1.0)
var_num = s2 - (s1 * s1) / count.clamp_min(1.0)
denom = (count - 1.0).clamp_min(1.0)
var = var_num / denom
std = torch.sqrt(torch.clamp(var, min=eps))
# Singleton groups: mean=0, std=1
single = count <= 1.0
if torch.any(single):
mean = mean.clone()
std = std.clone()
mean[single] = 0.0
std[single] = 1.0
return mean, std, count