Files
verl/recipe/entropy/entropy_ray_trainer.py
ℍ𝕠𝕝𝕝𝕠𝕨 𝕄𝕒𝕟 ae5d8504d4 [trainer] feat: ReMax support using reward model for baseline (#3780)
### What does this PR do?

> Add **concise** overview of what this PR aims to achieve or
accomplish. Reference related GitHub issues and PRs that help with the
review.

Not only limited to reward functions, we should also support using rm to
calculate the reward baseline.

### Checklist Before Starting

- [X] Search for similar PRs. Paste at least one query link here: ...
- [X] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [X] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [X] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [X] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [X] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [X] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).
(If not accessible, please try [the Feishu group
(飞书群)](https://applink.larkoffice.com/client/chat/chatter/add_by_link?link_token=772jd4f1-cd91-441e-a820-498c6614126a).)

Signed-off-by: Hollow Man <hollowman@opensuse.org>
2025-10-17 12:07:05 +08:00

349 lines
17 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
FSDP PPO Trainer with Ray-based single controller.
This trainer supports model-agonistic model initialization with huggingface
"""
import uuid
from collections import defaultdict
from copy import deepcopy
from pprint import pprint
import numpy as np
import torch
from tqdm import tqdm
from verl import DataProto
from verl.trainer.ppo.metric_utils import (
compute_data_metrics,
compute_throughout_metrics,
compute_timing_metrics,
reduce_metrics,
)
from verl.trainer.ppo.ray_trainer import (
AdvantageEstimator,
RayPPOTrainer,
apply_kl_penalty,
compute_advantage,
compute_response_mask,
)
from verl.trainer.ppo.reward import compute_reward
from verl.utils.profiler import simple_timer
class RayEntropyTrainer(RayPPOTrainer):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""
def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
to construct the PPO dataflow.
The light-weight advantage computation is done on the driver process.
"""
from omegaconf import OmegaConf
from verl.utils.tracking import Tracking
logger = Tracking(
project_name=self.config.trainer.project_name,
experiment_name=self.config.trainer.experiment_name,
default_backend=self.config.trainer.logger,
config=OmegaConf.to_container(self.config, resolve=True),
)
self.global_steps = 0
# load checkpoint before doing anything
self._load_checkpoint()
# perform validation before training
# currently, we only support validation using the reward_function.
if self.val_reward_fn is not None and self.config.trainer.get("val_before_train", True):
val_metrics = self._validate()
assert val_metrics, f"{val_metrics=}"
pprint(f"Initial validation metrics: {val_metrics}")
logger.log(data=val_metrics, step=self.global_steps)
if self.config.trainer.get("val_only", False):
return
# add tqdm
progress_bar = tqdm(total=self.total_training_steps, initial=self.global_steps, desc="Training Progress")
# we start from step 1
self.global_steps += 1
last_val_metrics = None
timing_raw = defaultdict(float)
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
new_batch: DataProto = DataProto.from_single_dict(batch_dict)
num_gen_batches += 1
# pop those keys for generation
if "multi_modal_inputs" in new_batch.non_tensor_batch.keys():
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids", "multi_modal_data", "multi_modal_inputs"],
)
else:
gen_batch = new_batch.pop(
batch_keys=["input_ids", "attention_mask", "position_ids"],
non_tensor_batch_keys=["raw_prompt_ids"],
)
gen_batch_output = gen_batch.repeat(
repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True
)
is_last_step = self.global_steps >= self.total_training_steps
with simple_timer("step", timing_raw):
# generate a batch
with simple_timer("gen", timing_raw):
if not self.async_rollout_mode:
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch_output)
else:
gen_batch_output = self.async_rollout_manager.generate_sequences(gen_batch_output)
if self.config.algorithm.adv_estimator == AdvantageEstimator.REMAX:
with simple_timer("gen_max", timing_raw):
gen_baseline_batch = deepcopy(gen_batch)
gen_baseline_batch.meta_info["do_sample"] = False
gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_batch)
new_batch = new_batch.union(gen_baseline_output)
# compute reward model score on new_batch
rm_scores = None
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
rm_scores = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(rm_scores)
reward_baseline_tensor, _ = compute_reward(new_batch, self.reward_fn)
reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1)
keys_to_pop = set(gen_baseline_output.batch.keys())
if rm_scores is not None:
keys_to_pop.update(rm_scores.batch.keys())
new_batch.pop(batch_keys=list(keys_to_pop))
new_batch.batch["reward_baselines"] = reward_baseline_tensor
del rm_scores, gen_baseline_batch, gen_baseline_output
new_batch.non_tensor_batch["uid"] = np.array(
[str(uuid.uuid4()) for _ in range(len(new_batch.batch))], dtype=object
)
# repeat to align with repeated responses in rollout
new_batch = new_batch.repeat(repeat_times=self.config.actor_rollout_ref.rollout.n, interleave=True)
new_batch = new_batch.union(gen_batch_output)
with simple_timer("reward", timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
if self.use_rm and "rm_scores" not in new_batch.batch.keys():
# we first compute reward model score
reward_tensor = self.rm_wg.compute_rm_score(new_batch)
new_batch = new_batch.union(reward_tensor)
# we combine with rule-based rm
reward_tensor, reward_extra_infos_dict = compute_reward(new_batch, self.reward_fn)
new_batch.batch["token_level_scores"] = reward_tensor
print(f"{list(reward_extra_infos_dict.keys())=}")
if reward_extra_infos_dict:
new_batch.non_tensor_batch.update(
{k: np.array(v) for k, v in reward_extra_infos_dict.items()}
)
# compute rewards. apply_kl_penalty if available
if self.config.algorithm.use_kl_in_reward:
new_batch, kl_metrics = apply_kl_penalty(
new_batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty
)
metrics.update(
kl_metrics
) # TODO: This will be cleared if we use multiple genenration batches
else:
new_batch.batch["token_level_rewards"] = new_batch.batch["token_level_scores"]
if not self.config.algorithm.filter_groups.enable:
batch = new_batch
else: # NOTE: When prompts after filtering is less than train batch size,
# we skip to the next generation batch
metric_name = self.config.algorithm.filter_groups.metric
if metric_name == "seq_final_reward":
# Turn to numpy for easier filtering
new_batch.non_tensor_batch["seq_final_reward"] = (
new_batch.batch["token_level_rewards"].sum(dim=-1).numpy()
)
elif metric_name == "seq_reward":
new_batch.non_tensor_batch["seq_reward"] = (
new_batch.batch["token_level_scores"].sum(dim=-1).numpy()
)
# Collect the sequence reward for each trajectory
prompt_uid2metric_vals = defaultdict(list)
for uid, metric_val in zip(
new_batch.non_tensor_batch["uid"], new_batch.non_tensor_batch[metric_name], strict=True
):
prompt_uid2metric_vals[uid].append(metric_val)
prompt_uid2metric_std = {}
for prompt_uid, metric_vals in prompt_uid2metric_vals.items():
prompt_uid2metric_std[prompt_uid] = np.std(metric_vals)
kept_prompt_uids = [
uid
for uid, std in prompt_uid2metric_std.items()
if std > 0 or len(prompt_uid2metric_vals[uid]) == 1
]
num_prompt_in_batch += len(kept_prompt_uids)
kept_traj_idxs = []
for idx, traj_from_prompt_uid in enumerate(new_batch.non_tensor_batch["uid"]):
if traj_from_prompt_uid in kept_prompt_uids:
kept_traj_idxs.append(idx)
new_batch = new_batch[kept_traj_idxs]
batch = new_batch if batch is None else DataProto.concat([batch, new_batch])
prompt_bsz = self.config.data.train_batch_size
if num_prompt_in_batch < prompt_bsz:
print(f"{num_prompt_in_batch=} < {prompt_bsz=}")
max_num_gen_batches = self.config.algorithm.filter_groups.max_num_gen_batches
if max_num_gen_batches <= 0 or num_gen_batches < max_num_gen_batches:
print(f"{num_gen_batches=}. Keep generating...")
continue
else:
raise ValueError(
f"{num_gen_batches=} >= {max_num_gen_batches=}."
+ " Generated too many. Please check if your data are too difficult."
+ " You could also try set max_num_gen_batches=0 to enable endless trials."
)
else:
# Align the batch
traj_bsz = self.config.data.train_batch_size * self.config.actor_rollout_ref.rollout.n
print(
f"Collected {num_prompt_in_batch} / {self.config.data.train_batch_size} prompt. "
f"Collecting finished."
)
batch = batch[:traj_bsz]
# === Updating ===
batch.batch["response_mask"] = compute_response_mask(batch)
# balance the number of valid tokens on each dp rank.
# Note that this breaks the order of data inside the batch.
# Please take care when you implement group based adv computation such as GRPO and rloo
if self.config.trainer.balance_batch:
self._balance_batch(batch, metrics=metrics)
# compute global_valid tokens
batch.meta_info["global_token_num"] = torch.sum(batch.batch["attention_mask"], dim=-1).tolist()
# recompute old_log_probs
with simple_timer("old_log_prob", timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)
if self.use_reference_policy:
# compute reference log_prob
with simple_timer("ref", timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
# compute values
if self.use_critic:
with simple_timer("values", timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
with simple_timer("adv", timing_raw):
# compute advantages, executed on the driver process
norm_adv_by_std_in_grpo = self.config.algorithm.get("norm_adv_by_std_in_grpo", True)
batch = compute_advantage(
batch,
adv_estimator=self.config.algorithm.adv_estimator,
gamma=self.config.algorithm.gamma,
lam=self.config.algorithm.lam,
num_repeat=self.config.actor_rollout_ref.rollout.n,
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo,
)
# update critic
if self.use_critic:
with simple_timer("update_critic", timing_raw):
critic_output = self.critic_wg.update_critic(batch)
critic_output_metrics = reduce_metrics(critic_output.meta_info["metrics"])
metrics.update(critic_output_metrics)
# implement critic warmup
if self.config.trainer.critic_warmup <= self.global_steps:
# update actor
with simple_timer("update_actor", timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"])
metrics.update(actor_output_metrics)
# validate
if (
self.val_reward_fn is not None
and self.config.trainer.test_freq > 0
and (is_last_step or self.global_steps % self.config.trainer.test_freq == 0)
):
with simple_timer("testing", timing_raw):
val_metrics: dict = self._validate()
if is_last_step:
last_val_metrics = val_metrics
metrics.update(val_metrics)
if self.config.trainer.save_freq > 0 and (
is_last_step or self.global_steps % self.config.trainer.save_freq == 0
):
with simple_timer("save_checkpoint", timing_raw):
self._save_checkpoint()
# collect metrics
metrics.update(compute_data_metrics(batch=batch, use_critic=self.use_critic))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))
# TODO: implement actual tflpo and theoretical tflpo
n_gpus = self.resource_pool_manager.get_n_gpus()
metrics.update(compute_throughout_metrics(batch=batch, timing_raw=timing_raw, n_gpus=n_gpus))
timing_raw = defaultdict(float) # clear timing
metrics["train/num_gen_batches"] = num_gen_batches
batch = None
num_prompt_in_batch = 0
num_gen_batches = 0
# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=self.global_steps)
if is_last_step:
pprint(f"Final validation metrics: {last_val_metrics}")
progress_bar.close()
return
progress_bar.update(1)
self.global_steps += 1