mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[algo, perf] feat: Vectorize GRPO Advantage Estimator - 13~26x Speedup (#3635)
### What does this PR do? Implements a vectorized GRPO advantage path for outcome-only RL in core_algos.py, keeping the original implementation intact and selectable. This yields large speedups at medium–large batch sizes by replacing Python-side grouping loops with segment reductions and one-pass gathers. Results (CPU, Apple M-series example; float32): ```shell [CPU] bs= 512 T= 512 G= 10 | orig=5.47ms vec=0.21ms speedup=26.16x [CPU] bs= 1024 T=1024 G= 16 | orig=11.05ms vec=0.54ms speedup=20.60x [CPU] bs= 2048 T=2048 G= 32 | orig=23.20ms vec=1.74ms speedup=13.32x ``` ```shell [GRPO] seed=0 groups=5 shape=torch.Size([64, 128]) mask_tokens=4147 adv_max_diff=2.384e-07 ret_max_diff=2.384e-07 [GRPO] seed=1 groups=8 shape=torch.Size([128, 256]) mask_tokens=16364 adv_max_diff=2.384e-07 ret_max_diff=2.384e-07 [GRPO] seed=2 groups=10 shape=torch.Size([512, 512]) mask_tokens=130968 adv_max_diff=4.768e-07 ret_max_diff=4.768e-07 ``` ### Checklist Before Starting - [x] Search for similar PRs. Paste at least one query link here: #3634 - [x] Format the PR title as `[{modules}] {type}: {description}` (This will be checked by the CI) - `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`, `trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`, `ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`, `env`, `tool`, `ckpt`, `doc`, `data` - If this PR involves multiple modules, separate them with `,` like `[megatron, fsdp, doc]` - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test` - If this PR breaks any API (CLI arguments, config, function signature, etc.), add `[BREAKING]` to the beginning of the title. - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching` ### Test > For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc. ### API and Usage Example > Demonstrate how the API changes if any, and provide usage example(s) if possible. ```python # Add code snippet or script demonstrating how to use this ``` ### Design & Code Changes > Demonstrate the high-level design if this PR is complex, and list the specific changes. ### Checklist Before Submitting > [!IMPORTANT] > Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review. - [x] Read the [Contribute Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md). - [x] Apply [pre-commit checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting): `pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always` - [x] Add / Update [the documentation](https://github.com/volcengine/verl/tree/main/docs). - [x] Add unit or end-to-end test(s) to [the CI workflow](https://github.com/volcengine/verl/tree/main/.github/workflows) to cover all the code. If not feasible, explain why: ... - [x] Once your PR is ready for CI, send a message in [the `ci-request` channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the `verl` Slack workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ). (If not accessible, please try [the Feishu group (飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)
This commit is contained in:
@ -501,7 +501,7 @@ Algorithm
|
||||
|
||||
- ``gamma``: discount factor
|
||||
- ``lam``: Trade-off between bias and variance in the GAE estimator
|
||||
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``
|
||||
- ``adv_estimator``: Support ``gae``, ``grpo``, ``reinforce_plus_plus``, ``reinforce_plus_plus_baseline``, ``rloo``, ``rloo_vectorized``, ``grpo_vectorized``
|
||||
- ``use_kl_in_reward``: Whether to enable in-reward kl penalty. Default is False.
|
||||
- ``kl_penalty``: Support ``kl``, ``abs``, ``mse``, ``low_var_kl`` and ``full``. How to
|
||||
calculate the kl divergence between actor and reference policy. For
|
||||
|
@ -22,6 +22,8 @@ import torch
|
||||
import verl.trainer.ppo.core_algos
|
||||
from verl.trainer.ppo.core_algos import (
|
||||
compute_gae_advantage_return,
|
||||
compute_grpo_outcome_advantage,
|
||||
compute_grpo_vectorized_outcome_advantage,
|
||||
compute_rloo_outcome_advantage,
|
||||
compute_rloo_vectorized_outcome_advantage,
|
||||
get_adv_estimator_fn,
|
||||
@ -257,5 +259,59 @@ def test_rloo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_grou
|
||||
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,seq_len,num_groups,seed",
|
||||
[
|
||||
(64, 128, 5, 0),
|
||||
(128, 256, 8, 1),
|
||||
(512, 512, 10, 2),
|
||||
],
|
||||
)
|
||||
def test_grpo_and_vectorized_equivalence(batch_size: int, seq_len: int, num_groups: int, seed: int):
|
||||
# Set seeds for reproducibility
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
# Generate group indices (numpy array of shape [batch_size])
|
||||
index = _make_group_index(batch_size, num_groups)
|
||||
|
||||
# Generate binary response mask (at least one valid token per row)
|
||||
response_mask = _rand_mask(batch_size, seq_len)
|
||||
|
||||
# Generate token-level rewards and apply mask
|
||||
base_rewards = torch.randn(batch_size, seq_len, dtype=torch.float32)
|
||||
token_level_rewards = base_rewards * response_mask
|
||||
|
||||
# Compute GRPO outcome advantage (original implementation)
|
||||
adv1, ret1 = compute_grpo_outcome_advantage(
|
||||
token_level_rewards=token_level_rewards,
|
||||
response_mask=response_mask,
|
||||
index=index,
|
||||
)
|
||||
|
||||
# Compute GRPO outcome advantage (vectorized implementation)
|
||||
adv2, ret2 = compute_grpo_vectorized_outcome_advantage(
|
||||
token_level_rewards=token_level_rewards,
|
||||
response_mask=response_mask,
|
||||
index=index,
|
||||
)
|
||||
|
||||
# Diagnostic info for visibility (same style as RLOO test)
|
||||
adv_max_diff = (adv1 - adv2).abs().max().item()
|
||||
ret_max_diff = (ret1 - ret2).abs().max().item()
|
||||
total_mask_tokens = int(response_mask.sum().item())
|
||||
print(
|
||||
f"[GRPO] seed={seed} groups={num_groups} shape={adv1.shape} "
|
||||
f"mask_tokens={total_mask_tokens} adv_max_diff={adv_max_diff:.3e} ret_max_diff={ret_max_diff:.3e}"
|
||||
)
|
||||
|
||||
# Assert shape and numerical equivalence
|
||||
assert adv1.shape == adv2.shape == (batch_size, seq_len)
|
||||
assert ret1.shape == ret2.shape == (batch_size, seq_len)
|
||||
assert torch.allclose(adv1, adv2, rtol=1e-5, atol=1e-6)
|
||||
assert torch.allclose(ret1, ret2, rtol=1e-5, atol=1e-6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
72
tests/utils/test_groupwise.py
Normal file
72
tests/utils/test_groupwise.py
Normal file
@ -0,0 +1,72 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
|
||||
os.environ.setdefault("VERL_FORCE_DEVICE", "cpu") # ensure CPU for tests
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from verl.utils import as_torch_index, group_mean_std
|
||||
|
||||
|
||||
def test_as_torch_index_basic_integers():
|
||||
g = as_torch_index([2, 2, 5, 7, 5, 2])
|
||||
assert g.dtype == torch.long
|
||||
assert g.device.type == "cpu"
|
||||
# Values should be contiguous 0..G-1, keeping equal labels equal
|
||||
assert g.tolist()[0] == g.tolist()[1]
|
||||
assert len(torch.unique(g)) == 3 # {2,5,7} -> 3 groups
|
||||
|
||||
|
||||
def test_as_torch_index_near_integer_floats():
|
||||
arr = np.array([1.0000001, 2.0, 1.0, 3.0000000001], dtype=np.float64)
|
||||
g = as_torch_index(arr) # should round to integers then factorize
|
||||
assert g.dtype == torch.long
|
||||
assert len(torch.unique(g)) == 3 # {1,2,3}
|
||||
|
||||
|
||||
def test_as_torch_index_factorization_mixed():
|
||||
labels = ["a", "b", "a", "c", "0042", 42]
|
||||
g = as_torch_index(labels)
|
||||
# "0042" and 42 should NOT be the same group (strings are not coerced here)
|
||||
assert g.tolist()[4] != g.tolist()[5]
|
||||
assert len(torch.unique(g)) == 5
|
||||
|
||||
|
||||
def test_group_mean_std_simple():
|
||||
# groups: 0 -> [1, 3], 1 -> [2]
|
||||
scores = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
|
||||
gidx = as_torch_index([0, 1, 0])
|
||||
|
||||
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
|
||||
# group 0: mean = (1+3)/2 = 2
|
||||
# sample std (unbiased) = sqrt( (sum(x^2) - (sum(x)^2)/n) / (n-1) )
|
||||
# = sqrt( (1^2+3^2) - (1+3)^2/2 ) / (2-1) = sqrt(10 - 16/2) = sqrt(2)
|
||||
assert torch.allclose(mean_g, torch.tensor([2.0, 0.0]))
|
||||
assert torch.allclose(cnt_g, torch.tensor([2.0, 1.0]))
|
||||
# singleton group -> std = 1.0
|
||||
assert mean_g[1].item() == 0.0
|
||||
assert std_g[1].item() == 1.0
|
||||
assert pytest.approx(std_g[0].item(), rel=1e-6) == (2.0**0.5)
|
||||
|
||||
|
||||
def test_group_mean_std_empty():
|
||||
scores = torch.tensor([], dtype=torch.float32)
|
||||
gidx = torch.tensor([], dtype=torch.long)
|
||||
mean_g, std_g, cnt_g = group_mean_std(scores, gidx)
|
||||
assert mean_g.numel() == 0 and std_g.numel() == 0 and cnt_g.numel() == 0
|
@ -30,6 +30,7 @@ from omegaconf import DictConfig
|
||||
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from verl.trainer.config import AlgoConfig
|
||||
from verl.utils import as_torch_index, group_mean_std
|
||||
from verl.utils.import_utils import deprecated
|
||||
from verl.workers.config import ActorConfig
|
||||
|
||||
@ -103,6 +104,7 @@ class AdvantageEstimator(str, Enum):
|
||||
GRPO_PASSK = "grpo_passk"
|
||||
GPG = "gpg"
|
||||
RLOO_VECTORIZED = "rloo_vectorized"
|
||||
GRPO_VECTORIZED = "grpo_vectorized"
|
||||
|
||||
|
||||
ADV_ESTIMATOR_REGISTRY: dict[str, Any] = {}
|
||||
@ -326,6 +328,33 @@ def compute_grpo_outcome_advantage(
|
||||
return scores, scores
|
||||
|
||||
|
||||
@register_adv_est(AdvantageEstimator.GRPO_VECTORIZED)
|
||||
def compute_grpo_vectorized_outcome_advantage(
|
||||
token_level_rewards: torch.Tensor,
|
||||
response_mask: torch.Tensor,
|
||||
index: np.ndarray,
|
||||
epsilon: float = 1e-6,
|
||||
norm_adv_by_std_in_grpo: bool = True,
|
||||
config: Optional[AlgoConfig] = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Vectorized GRPO(outcome-only):
|
||||
For each group g:
|
||||
a_i = \\frac{r_i - \\mu_g}{\\sigma_g} (or without dividing by \\sigma_g),
|
||||
then broadcast the scalar across the token dimension (multiplied by response_mask).。
|
||||
"""
|
||||
with torch.no_grad():
|
||||
scores = token_level_rewards.sum(dim=-1)
|
||||
g = as_torch_index(index, device=scores.device)
|
||||
mean_g, std_g, _ = group_mean_std(scores, g, eps=epsilon)
|
||||
if norm_adv_by_std_in_grpo:
|
||||
scalars = (scores - mean_g[g]) / (std_g[g] + epsilon)
|
||||
else:
|
||||
scalars = scores - mean_g[g]
|
||||
advantages = scalars.unsqueeze(-1) * response_mask
|
||||
return advantages, advantages
|
||||
|
||||
|
||||
@register_adv_est(AdvantageEstimator.GRPO_PASSK) # or simply: @register_adv_est("grpo_passk")
|
||||
def compute_grpo_passk_outcome_advantage(
|
||||
token_level_rewards: torch.Tensor,
|
||||
|
@ -14,8 +14,12 @@
|
||||
|
||||
from . import config, tokenizer
|
||||
from .config import omega_conf_to_dataclass, validate_config
|
||||
from .groupwise import as_torch_index, group_mean_std
|
||||
from .tokenizer import hf_processor, hf_tokenizer
|
||||
|
||||
__all__ = (
|
||||
tokenizer.__all__ + config.__all__ + ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"]
|
||||
tokenizer.__all__
|
||||
+ config.__all__
|
||||
+ ["hf_processor", "hf_tokenizer", "omega_conf_to_dataclass", "validate_config"]
|
||||
+ ["as_torch_index", "group_mean_std"]
|
||||
)
|
||||
|
223
verl/utils/groupwise.py
Normal file
223
verl/utils/groupwise.py
Normal file
@ -0,0 +1,223 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Copyright 2025 ModelBest Inc. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Group-wise helpers for RL training utilities.
|
||||
|
||||
Public API:
|
||||
- as_torch_index(index, device=None) -> torch.LongTensor
|
||||
- group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g)
|
||||
|
||||
Default device policy:
|
||||
- If `device` is None:
|
||||
* In pytest (detected by env "PYTEST_CURRENT_TEST"): use CPU.
|
||||
* Else if CUDA is available: use CUDA.
|
||||
* Else: use CPU.
|
||||
- You can override via env "VERL_FORCE_DEVICE" (e.g., "cuda:0" / "cpu").
|
||||
|
||||
Notes:
|
||||
- as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long
|
||||
tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools,
|
||||
numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6)
|
||||
are rounded; otherwise factorization is applied.
|
||||
- group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance
|
||||
(denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for
|
||||
compatibility with common “native” conventions.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from verl.utils.device import get_torch_device
|
||||
|
||||
__all__ = ["as_torch_index", "group_mean_std"]
|
||||
|
||||
|
||||
def _resolve_device(explicit: Optional[torch.device | str]) -> torch.device:
|
||||
"""
|
||||
Resolve device according to policy described in the module docstring.
|
||||
Priority:
|
||||
1) explicit argument
|
||||
2) VERL_FORCE_DEVICE env
|
||||
3) pytest detection -> cpu
|
||||
4) cuda if available, else cpu
|
||||
"""
|
||||
if explicit is not None:
|
||||
return torch.device(explicit)
|
||||
|
||||
forced = os.getenv("VERL_FORCE_DEVICE")
|
||||
if forced:
|
||||
return torch.device(forced)
|
||||
|
||||
# Heuristic: pytest sets PYTEST_CURRENT_TEST
|
||||
if "PYTEST_CURRENT_TEST" in os.environ:
|
||||
return torch.device("cpu")
|
||||
|
||||
return get_torch_device()
|
||||
|
||||
|
||||
def _to_1d_numpy_object_array(x: Any) -> np.ndarray:
|
||||
"""Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype."""
|
||||
try:
|
||||
arr = np.asarray(x)
|
||||
except Exception:
|
||||
try:
|
||||
arr = np.array(list(x), dtype=object)
|
||||
except Exception:
|
||||
arr = np.array([x], dtype=object)
|
||||
if arr.ndim != 1:
|
||||
arr = arr.reshape(-1)
|
||||
return arr
|
||||
|
||||
|
||||
def as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor:
|
||||
"""
|
||||
Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1).
|
||||
|
||||
Args:
|
||||
index: Any iterable of labels or tensor/ndarray.
|
||||
device: Target device; if None, resolved via _resolve_device().
|
||||
|
||||
Returns:
|
||||
torch.LongTensor with shape (N,)
|
||||
"""
|
||||
target = _resolve_device(device)
|
||||
|
||||
# ---------- Fast path: torch.Tensor ----------
|
||||
if isinstance(index, torch.Tensor):
|
||||
t = index.reshape(-1)
|
||||
if t.dtype in (
|
||||
torch.int64,
|
||||
torch.int32,
|
||||
torch.int16,
|
||||
torch.int8,
|
||||
getattr(torch, "uint8", torch.uint8),
|
||||
torch.bool,
|
||||
):
|
||||
return t.to(device=target, dtype=torch.long)
|
||||
|
||||
if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
|
||||
t64 = t.to(dtype=torch.float64)
|
||||
rounded = torch.round(t64)
|
||||
if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6):
|
||||
return rounded.to(device=target, dtype=torch.long)
|
||||
arr = np.array([str(x.item()) for x in t], dtype=object)
|
||||
else:
|
||||
arr = np.array([str(x.item()) if hasattr(x, "item") else str(x) for x in t], dtype=object)
|
||||
|
||||
else:
|
||||
# ---------- Non-torch: go through numpy ----------
|
||||
arr = _to_1d_numpy_object_array(index)
|
||||
|
||||
# Pure integers (incl. bool)
|
||||
if arr.dtype != object and np.issubdtype(arr.dtype, np.integer):
|
||||
return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target)
|
||||
|
||||
# Floats nearly equal to integers
|
||||
if arr.dtype != object and np.issubdtype(arr.dtype, np.floating):
|
||||
arr64 = arr.astype(np.float64, copy=False)
|
||||
rounded = np.rint(arr64)
|
||||
if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6):
|
||||
return torch.from_numpy(rounded.astype(np.int64)).to(device=target)
|
||||
# fall through
|
||||
|
||||
# Try numeric string coercion
|
||||
try:
|
||||
coerced = arr.astype(np.int64)
|
||||
return torch.from_numpy(coerced).to(device=target)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if arr.dtype != object:
|
||||
arr = arr.astype(object)
|
||||
|
||||
# ---------- Factorization (UUIDs / mixed types / arbitrary labels) ----------
|
||||
try:
|
||||
_, inv = np.unique(arr, return_inverse=True)
|
||||
except Exception:
|
||||
sarr = np.array([str(x) for x in arr], dtype=object)
|
||||
_, inv = np.unique(sarr, return_inverse=True)
|
||||
|
||||
inv = inv.astype(np.int64, copy=False)
|
||||
return torch.from_numpy(inv).to(device=target)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def group_mean_std(
|
||||
scores: torch.Tensor,
|
||||
gidx: torch.Tensor,
|
||||
eps: float = 1e-6,
|
||||
device: torch.device | str | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Compute per-group mean/std/count in pure PyTorch.
|
||||
|
||||
mean_g = sum / count
|
||||
std_g = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) )
|
||||
|
||||
Singleton groups fallback to mean=0, std=1.
|
||||
|
||||
Args:
|
||||
scores: (N,) float tensor.
|
||||
gidx : (N,) long/int tensor with group indices (0..G-1).
|
||||
eps : Numerical floor for variance.
|
||||
device: Target device; if None, resolved via _resolve_device().
|
||||
|
||||
Returns:
|
||||
mean_g: (G,) float32
|
||||
std_g : (G,) float32
|
||||
count : (G,) float32
|
||||
"""
|
||||
target = _resolve_device(device)
|
||||
|
||||
scores = scores.reshape(-1).to(device=target, dtype=torch.float32)
|
||||
gidx = gidx.reshape(-1).to(device=target, dtype=torch.long)
|
||||
|
||||
if scores.numel() != gidx.numel():
|
||||
raise ValueError(f"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}")
|
||||
|
||||
G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0
|
||||
if G == 0:
|
||||
# Return empty tensors on the selected device
|
||||
empty = torch.empty(0, device=target, dtype=torch.float32)
|
||||
return empty, empty, empty
|
||||
|
||||
ones = torch.ones_like(scores, dtype=torch.float32)
|
||||
|
||||
count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones)
|
||||
s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores)
|
||||
s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores)
|
||||
|
||||
mean = s1 / count.clamp_min(1.0)
|
||||
var_num = s2 - (s1 * s1) / count.clamp_min(1.0)
|
||||
denom = (count - 1.0).clamp_min(1.0)
|
||||
var = var_num / denom
|
||||
std = torch.sqrt(torch.clamp(var, min=eps))
|
||||
|
||||
# Singleton groups: mean=0, std=1
|
||||
single = count <= 1.0
|
||||
if torch.any(single):
|
||||
mean = mean.clone()
|
||||
std = std.clone()
|
||||
mean[single] = 0.0
|
||||
std[single] = 1.0
|
||||
|
||||
return mean, std, count
|
Reference in New Issue
Block a user