Files
verl/recipe/spin/fsdp_workers.py
ShareLer 10f4eb8cfc [misc] chore: fix typo in function name (#2525)
### What does this PR do?

fix typo `gather_outpus_and_unpad` -> `gather_outputs_and_unpad`

### Checklist Before Starting

- [x] Search for similar PRs. Paste at least one query link here: ...
- [x] Format the PR title as `[{modules}] {type}: {description}` (This
will be checked by the CI)
- `{modules}` include `fsdp`, `megatron`, `sglang`, `vllm`, `rollout`,
`trainer`, `ci`, `training_utils`, `recipe`, `hardware`, `deployment`,
`ray`, `worker`, `single_controller`, `misc`, `perf`, `model`, `algo`,
`env`, `tool`, `ckpt`, `doc`, `data`
- If this PR involves multiple modules, separate them with `,` like
`[megatron, fsdp, doc]`
  - `{type}` is in `feat`, `fix`, `refactor`, `chore`, `test`
- If this PR breaks any API (CLI arguments, config, function signature,
etc.), add `[BREAKING]` to the beginning of the title.
  - Example: `[BREAKING][fsdp, megatron] feat: dynamic batching`

### Test

> For changes that can not be tested by CI (e.g., algorithm
implementation, new model support), validate by experiment(s) and show
results like training curve plots, evaluation results, etc.

### API and Usage Example

> Demonstrate how the API changes if any, and provide usage example(s)
if possible.

```python
# Add code snippet or script demonstrating how to use this
```

### Design & Code Changes

> Demonstrate the high-level design if this PR is complex, and list the
specific changes.

### Checklist Before Submitting

> [!IMPORTANT]
> Please check all the following items before requesting a review,
otherwise the reviewer might deprioritize this PR for review.

- [ ] Read the [Contribute
Guide](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md).
- [ ] Apply [pre-commit
checks](https://github.com/volcengine/verl/blob/main/CONTRIBUTING.md#code-linting-and-formatting):
`pre-commit install && pre-commit run --all-files --show-diff-on-failure
--color=always`
- [ ] Add / Update [the
documentation](https://github.com/volcengine/verl/tree/main/docs).
- [ ] Add unit or end-to-end test(s) to [the CI
workflow](https://github.com/volcengine/verl/tree/main/.github/workflows)
to cover all the code. If not feasible, explain why: ...
- [ ] Once your PR is ready for CI, send a message in [the `ci-request`
channel](https://verl-project.slack.com/archives/C091TCESWB1) in [the
`verl` Slack
workspace](https://join.slack.com/t/verl-project/shared_invite/zt-3855yhg8g-CTkqXu~hKojPCmo7k_yXTQ).

---------

Signed-off-by: ShareLer <ShareLe@163.com>
2025-07-15 19:06:20 +08:00

600 lines
27 KiB
Python

# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import warnings
import psutil
import torch
import torch.distributed
from codetiming import Timer
from omegaconf import open_dict
from torch.distributed.device_mesh import init_device_mesh
import verl.utils.torch_functional as verl_F
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.utils import hf_tokenizer
from verl.utils.checkpoint.fsdp_checkpoint_manager import FSDPCheckpointManager
from verl.utils.device import get_device_id, get_device_name, get_nccl_backend, get_torch_device
from verl.utils.flops_counter import FlopsCounter
from verl.utils.fs import copy_to_local
from verl.utils.fsdp_utils import (
get_fsdp_wrap_policy,
get_init_weight_context_manager,
init_fn,
load_fsdp_model_to_gpu,
load_fsdp_optimizer,
offload_fsdp_model_to_cpu,
offload_fsdp_optimizer,
)
from verl.utils.import_utils import import_external_libs
from verl.utils.model import compute_position_id_with_mask
from verl.utils.profiler import log_gpu_memory_usage
from verl.workers.fsdp_workers import ActorRolloutRefWorker
from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_PPO_LOGGING_LEVEL", "WARN"))
def create_device_mesh(world_size, fsdp_size):
if fsdp_size < 0 or fsdp_size >= world_size:
device_mesh = init_device_mesh(get_device_name(), mesh_shape=(world_size,), mesh_dim_names=["fsdp"])
else:
device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(world_size // fsdp_size, fsdp_size), mesh_dim_names=["ddp", "fsdp"]
)
return device_mesh
def get_sharding_strategy(device_mesh):
from torch.distributed.fsdp import ShardingStrategy
if device_mesh.ndim == 1:
sharding_strategy = ShardingStrategy.FULL_SHARD
elif device_mesh.ndim == 2:
sharding_strategy = ShardingStrategy.HYBRID_SHARD
else:
raise NotImplementedError(f"Get device mesh ndim={device_mesh.ndim}, but only support 1 or 2")
return sharding_strategy
class SPINRolloutRefWorker(ActorRolloutRefWorker):
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
from recipe.spin.dp_actor import SPINDataParallelPPOActor as DataParallelPPOActor
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get("override_config", OmegaConf.create()))
use_remove_padding = self.config.model.get("use_remove_padding", False)
use_fused_kernels = self.config.model.get("use_fused_kernels", False)
if self._is_actor or self._is_rollout or self._is_ref:
# we need the model for actor and rollout
if self._is_actor or self._is_ref:
optim_config = self.config.actor.optim
fsdp_config = self.config.actor.fsdp_config
else:
optim_config = None
fsdp_config = OmegaConf.create()
self.actor_module_fsdp, self.actor_optimizer, self.actor_lr_scheduler, self.actor_model_config = (
self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
enable_gradient_checkpointing=self.config.model.get("enable_gradient_checkpointing", False),
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="actor",
)
)
# get the original unwrapped module
self.actor_module = self.actor_module_fsdp._fsdp_wrapped_module
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
log_gpu_memory_usage("After offload actor optimizer during init", logger=logger)
# load from checkpoint
if self._is_actor or self._is_ref:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.config.actor.use_fused_kernels = use_fused_kernels
self.actor = DataParallelPPOActor(
config=self.config.actor, actor_module=self.actor_module_fsdp, actor_optimizer=self.actor_optimizer
)
if self._is_rollout:
self.rollout, self.rollout_sharding_manager = self._build_rollout(
trust_remote_code=self.config.model.get("trust_remote_code", False)
)
if self._is_ref:
self.ref_module_fsdp = self._build_model_optimizer(
model_path=self.config.model.path,
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
use_fused_kernels=use_fused_kernels,
trust_remote_code=self.config.model.get("trust_remote_code", False),
use_liger=self.config.model.get("use_liger", False),
role="ref",
)[0]
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.config.ref.use_fused_kernels = use_fused_kernels
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
)
if self._is_actor:
self.flops_counter = FlopsCounter(self.actor_model_config)
self.checkpoint_manager = FSDPCheckpointManager(
model=self.actor_module_fsdp,
optimizer=self.actor.actor_optimizer,
lr_scheduler=self.actor_lr_scheduler,
processing_class=self.processor if self.processor is not None else self.tokenizer,
checkpoint_config=self.config.actor.checkpoint,
)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
# Support all hardwares
data = data.to(get_device_id())
micro_batch_size = self.config.ref.log_prob_micro_batch_size_per_gpu
data.meta_info["micro_batch_size"] = micro_batch_size
data.meta_info["temperature"] = self.config.rollout.temperature
data.meta_info["max_token_len"] = self.config.ref.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.ref.log_prob_use_dynamic_bsz
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.ref_policy.compute_log_prob(data=data)
output = DataProto.from_dict(tensors={"ref_log_prob": output})
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.ref_policy.actor_module._handle.reshard(True)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
# Support all hardwares
data = data.to(get_device_id())
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info["micro_batch_size"] = self.config.rollout.log_prob_micro_batch_size_per_gpu
data.meta_info["max_token_len"] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info["use_dynamic_bsz"] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info["temperature"] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
output = self.actor.compute_log_prob(data=data)
output = DataProto.from_dict(
tensors={"old_log_probs": output}, meta_info={"temperature": self.config.rollout.temperature}
)
output = self.ulysses_sharding_manager.postprocess_data(output)
output = output.to("cpu")
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
log_gpu_memory_usage("After compute_log_prob", logger=logger)
return output
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def update_actor_dpo(self, data: DataProto):
"""
Wrapper for actor update step. Handles FSDP state management.
Calls self.actor.update_policy which now contains DPO logic based
on pre-calculated log probabilities.
"""
# Support all hardwares
data = data.to(get_device_id())
assert self._is_actor # Make sure this worker has the actor role
if self.actor is None:
raise RuntimeError("Actor instance (self.actor) not initialized in worker.")
# --- FSDP State Management ---
if self._is_offload_param:
load_fsdp_model_to_gpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
load_fsdp_optimizer(optimizer=self.actor_optimizer, device_id=get_device_id())
log_gpu_memory_usage("Before update policy (DPO via PPO path)", logger=logger)
# --- Ulysses Sharding (if used) ---
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data=data)
# --- Call the core update method (now containing DPO logic) ---
with Timer(name="update_policy_dpo_via_ppo", logger=None) as timer: # Use a distinct timer name
# Calls the modified update_policy method
metrics = self.actor.update_policy_dpo_with_ref(data=data) # <-- THIS CALLS THE MODIFIED FUNCTION
delta_time = timer.last
# --- Add Performance Metrics ---
# MFU calculation might be less accurate/meaningful here for DPO
metrics["perf/approx_tokens_processed"] = torch.sum(
data.batch.get("attention_mask", torch.tensor(0))
).item() # Approx tokens
metrics["perf/max_memory_allocated_gb"] = get_torch_device().max_memory_allocated() / (1024**3)
metrics["perf/max_memory_reserved_gb"] = get_torch_device().max_memory_reserved() / (1024**3)
metrics["perf/cpu_memory_used_gb"] = psutil.virtual_memory().used / (1024**3)
global_num_tokens = data.meta_info["global_token_num"]
estimated_flops, promised_flops = self.flops_counter.estimate_flops(global_num_tokens, delta_time)
metrics["perf/mfu/actor"] = estimated_flops * self.config.ppo_epochs / promised_flops / self.world_size
# --- LR Scheduler Step ---
lr = self.actor_lr_scheduler.get_last_lr()[0]
metrics["actor/lr"] = lr
self.actor_lr_scheduler.step()
log_gpu_memory_usage("After update policy (DPO via PPO path)", logger=logger)
# --- Prepare Output ---
output = DataProto(meta_info={"metrics": metrics})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
output = output.to("cpu")
# --- FSDP State Management (Offload) ---
if self._is_offload_param:
offload_fsdp_model_to_cpu(self.actor_module_fsdp)
if self._is_offload_optimizer:
offload_fsdp_optimizer(optimizer=self.actor_optimizer)
return output
# TODO(sgm): we may need to extract it to dp_reward_model.py
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
"""
def __init__(self, config):
super().__init__()
import torch.distributed
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend=get_nccl_backend())
self.config = config
# build device mesh for Ulysses Sequence Parallel
world_size = torch.distributed.get_world_size()
from torch.distributed.device_mesh import init_device_mesh
fsdp_size = self.config.model.fsdp_config.fsdp_size
self.device_mesh = create_device_mesh(world_size=world_size, fsdp_size=fsdp_size)
self.ulysses_device_mesh = None
self.ulysses_sequence_parallel_size = self.config.get("ulysses_sequence_parallel_size", 1)
dp = world_size // self.ulysses_sequence_parallel_size
if self.ulysses_sequence_parallel_size > 1:
self.ulysses_device_mesh = init_device_mesh(
get_device_name(), mesh_shape=(dp, self.ulysses_sequence_parallel_size), mesh_dim_names=["dp", "sp"]
)
self.ulysses_sharding_manager = FSDPUlyssesShardingManager(self.ulysses_device_mesh)
self.use_remove_padding = self.config.model.get("use_remove_padding", False)
# normalize config
if self.config.micro_batch_size is not None:
self.config.micro_batch_size //= torch.distributed.get_world_size()
self.config.micro_batch_size_per_gpu = self.config.micro_batch_size
def _build_model(self, config):
# the following line is necessary
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformers import AutoConfig, AutoModelForTokenClassification
# download the checkpoint from hdfs
local_path = copy_to_local(config.model.path)
if self.config.model.input_tokenizer is None:
self._do_switch_chat_template = False
else:
self._do_switch_chat_template = True
input_tokenizer_local_path = copy_to_local(config.model.input_tokenizer)
self.input_tokenizer = hf_tokenizer(
input_tokenizer_local_path, trust_remote_code=config.model.get("trust_remote_code", False)
)
self.tokenizer = hf_tokenizer(local_path, trust_remote_code=config.model.get("trust_remote_code", False))
trust_remote_code = config.model.get("trust_remote_code", False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
model_config.num_labels = 1
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(
use_meta_tensor=not model_config.tie_word_embeddings, mesh=self.device_mesh
)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
model_config.classifier_dropout = 0.0
reward_module = AutoModelForTokenClassification.from_pretrained(
pretrained_model_name_or_path=local_path,
config=model_config,
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
trust_remote_code=trust_remote_code,
)
if config.model.get("use_remove_padding", False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
fsdp_mesh = self.device_mesh
sharding_strategy = get_sharding_strategy(fsdp_mesh)
reward_module = FSDP(
reward_module,
param_init_fn=init_fn,
use_orig_params=False,
auto_wrap_policy=auto_wrap_policy,
device_id=get_device_id(),
sharding_strategy=sharding_strategy, # zero3
sync_module_states=True,
cpu_offload=CPUOffload(offload_params=True),
forward_prefetch=False,
device_mesh=self.device_mesh,
)
return reward_module
@register(dispatch_mode=Dispatch.ONE_TO_ALL)
def init_model(self):
# This is used to import external_lib into the huggingface systems
import_external_libs(self.config.model.get("external_lib", None))
self.reward_module = self._build_model(config=self.config)
def _forward_micro_batch(self, micro_batch):
from flash_attn.bert_padding import index_first_axis, pad_input, rearrange, unpad_input
from verl.utils.ulysses import gather_outputs_and_unpad, ulysses_pad_and_slice_inputs
with torch.no_grad(), torch.autocast(device_type=get_device_name(), dtype=torch.bfloat16):
input_ids = micro_batch["input_ids"]
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch["attention_mask"]
position_ids = micro_batch["position_ids"]
if self.use_remove_padding:
input_ids_rmpad, indices, *_ = unpad_input(
input_ids.unsqueeze(-1), attention_mask
) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(
rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."), indices
).transpose(0, 1)
# pad and slice the inputs if sp > 1
if self.ulysses_sequence_parallel_size > 1:
input_ids_rmpad, position_ids_rmpad, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=self.ulysses_sequence_parallel_size
)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.reward_module(
input_ids=input_ids_rmpad, attention_mask=None, position_ids=position_ids_rmpad, use_cache=False
) # prevent model thinks we are generating
reward_rmpad = output.logits
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
# gather output if sp > 1
if self.ulysses_sequence_parallel_size > 1:
reward_rmpad = gather_outputs_and_unpad(
reward_rmpad, gather_dim=0, unpad_dim=0, padding_size=pad_size
)
# pad it back
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch_size, seqlen=seqlen).squeeze(-1)
else:
output = self.reward_module(
input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, use_cache=False
)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
# extract the result of the last valid token
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
rm_score = rm_score[torch.arange(batch_size), eos_mask_idx]
return rm_score
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
batch_size = data.batch.batch_size[0]
# expand as token_level_reward
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
response_length = data.batch["responses"].shape[-1]
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
token_level_scores = torch.zeros_like(attention_mask, dtype=scores.dtype) # (bsz, seqlen)
token_level_scores[torch.arange(batch_size), eos_mask_idx] = scores
# select the response part
token_level_scores = token_level_scores[:, -response_length:]
return token_level_scores
def _switch_chat_template(self, data: DataProto):
src_max_length = data.batch["attention_mask"].shape[-1]
src_tokenizer = self.input_tokenizer
target_tokenizer = self.tokenizer
rm_input_ids = []
rm_attention_mask = []
for i in range(data.batch.batch_size[0]):
# extract raw prompt
if isinstance(data.non_tensor_batch["raw_prompt"][i], list):
chat: list = data.non_tensor_batch["raw_prompt"][i]
else:
chat: list = data.non_tensor_batch["raw_prompt"][i].tolist()
# extract response
response_ids = data.batch["responses"][i]
response_length = response_ids.shape[-1]
valid_response_length = data.batch["attention_mask"][i][-response_length:].sum()
valid_response_ids = response_ids[:valid_response_length]
# decode
response = src_tokenizer.decode(valid_response_ids)
# remove bos and eos
response = response.replace(src_tokenizer.eos_token, "")
chat.append({"role": "assistant", "content": response})
prompt_with_chat_template = target_tokenizer.apply_chat_template(
chat, add_generation_prompt=False, tokenize=False
)
if self.rank == 0 and i == 0:
# for debugging purpose
print(f"Switch template. chat: {prompt_with_chat_template}")
# the maximum length is actually determined by the reward model itself
max_length = self.config.get("max_length", src_max_length)
if max_length is None:
max_length = src_max_length
model_inputs = target_tokenizer(prompt_with_chat_template, return_tensors="pt", add_special_tokens=False)
input_ids, attention_mask = verl_F.postprocess_data(
input_ids=model_inputs["input_ids"],
attention_mask=model_inputs["attention_mask"],
max_length=max_length,
pad_token_id=target_tokenizer.pad_token_id,
left_pad=False, # right padding
truncation=self.config.get("truncation", "right"),
) # truncate from the right
rm_input_ids.append(input_ids)
rm_attention_mask.append(attention_mask)
rm_input_ids = torch.cat(rm_input_ids, dim=0)
rm_attention_mask = torch.cat(rm_attention_mask, dim=0)
rm_position_ids = compute_position_id_with_mask(rm_attention_mask)
rm_inputs = {"input_ids": rm_input_ids, "attention_mask": rm_attention_mask, "position_ids": rm_position_ids}
return DataProto.from_dict(rm_inputs)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_rm_score(self, data: DataProto):
import itertools
from verl.utils.seqlen_balancing import get_reverse_idx, rearrange_micro_batches
# Support all hardwares
data = data.to(get_device_id())
if self._do_switch_chat_template:
rm_data = self._switch_chat_template(data)
else:
rm_input_ids = data.batch["input_ids"]
rm_attention_mask = data.batch["attention_mask"]
rm_position_ids = data.batch["position_ids"]
rm_inputs = {
"input_ids": rm_input_ids,
"attention_mask": rm_attention_mask,
"position_ids": rm_position_ids,
}
rm_data = DataProto.from_dict(rm_inputs)
# Support all hardwares
rm_data.batch = rm_data.batch.to(get_device_id())
# perform forward computation
with self.ulysses_sharding_manager:
rm_data = self.ulysses_sharding_manager.preprocess_data(data=rm_data)
data = self.ulysses_sharding_manager.preprocess_data(data=data)
use_dynamic_bsz = self.config.use_dynamic_bsz
if use_dynamic_bsz:
max_token_len = self.config.forward_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, indices = rearrange_micro_batches(batch=rm_data.batch, max_token_len=max_token_len)
else:
micro_batches = rm_data.batch.split(self.config.micro_batch_size_per_gpu)
output = []
for micro_batch in micro_batches:
rm_score = self._forward_micro_batch(micro_batch)
output.append(rm_score)
scores = torch.cat(output, dim=0) # (batch_size)
if use_dynamic_bsz:
indices = list(itertools.chain.from_iterable(indices))
assert len(indices) == scores.size(0), f"{len(indices)} vs. {scores.size()}"
revert_indices = torch.tensor(get_reverse_idx(indices), dtype=torch.long)
scores = scores[revert_indices]
token_level_scores = self._expand_to_token_level(data, scores)
# Note that this is only the scores, may not be the final rewards used to train RL
output = DataProto.from_dict(tensors={"rm_scores": token_level_scores})
output = self.ulysses_sharding_manager.postprocess_data(data=output)
# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
self.reward_module._handle.reshard(True)
output = output.to("cpu")
return output