[misc] feat: spport rmpad/data-packing in FSDP with transformers (#91)

* init commit of rmpad

* add rmpad test

* support rmpad in actor model

* add test for value model

* support rmpad in critic and rm

* fix actor return and fix num_labels and clean not used rmpad

* fix critic and benchmark

* update script

* fix critic

* lint

* fix util issue

* fix unnecessary unpad

* address issues

* fix args

* update test and update rmpad support model list

* fix typo

* fix typo and fix name

* rename rmpad to rename padding

* fix arch to model_type

* add ci for e2e rmpad and fix typo

* lint

* fix ci

* fix typo

* update tests for customize tokenizer in actor

* fix rmpad test

* update requirement of transformers as hf_rollout may have issue
This commit is contained in:
Guangming Sheng
2025-01-11 16:50:15 +08:00
committed by GitHub
parent e88cf81ae8
commit 569210e06c
19 changed files with 413 additions and 45 deletions

View File

@ -23,6 +23,7 @@ jobs:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
HF_HUB_ENABLE_HF_TRANSFER: 1
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
@ -32,7 +33,12 @@ jobs:
fetch-depth: 0
- name: Install the current repository
run: |
pip3 install hf_transfer
pip3 install -e .[test]
- name: Running digit completon e2e training tests on 8 L20 GPUs
run: |
bash tests/e2e/run_ray_trainer.sh
- name: Running digit completon e2e training tests on 8 L20 GPUs (with rmpad)
run: |
pip3 install --upgrade transformers
bash tests/e2e/run_ray_trainer_rmpad.sh

39
.github/workflows/model.yml vendored Normal file
View File

@ -0,0 +1,39 @@
name: model_rmpad
on:
# Trigger the workflow on push or pull request,
# but only for the main branch
push:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/model.yml
pull_request:
branches:
- main
paths:
- "**/*.py"
- .github/workflows/model.yml
jobs:
e2e_gpu:
runs-on: [self-hosted, l20-1]
env:
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
NO_PROXY: "localhost,127.0.0.1"
container:
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
options: --gpus all --shm-size=10g
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
with:
fetch-depth: 0
- name: Install the current repository and upgrade to latest transformers
run: |
pip3 install -e .[test]
pip3 install --upgrade transformers
- name: Running digit completon e2e training tests on 8 L20 GPUs
run: |
pytest -s tests/model/test_transformer.py

View File

@ -9,6 +9,7 @@ python3 -m verl.trainer.main_ppo \
data.max_response_length=512 \
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
@ -21,6 +22,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \

View File

@ -9,6 +9,7 @@ python3 -m verl.trainer.main_ppo \
data.max_response_length=512 \
actor_rollout_ref.model.path=google/gemma-2-2b-it \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
actor_rollout_ref.actor.ppo_micro_batch_size=4 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
@ -21,6 +22,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=google/gemma-2-2b-it \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=4 \

View File

@ -17,6 +17,7 @@ python3 -m verl.trainer.main_ppo \
data.max_response_length=512 \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
@ -29,6 +30,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=16 \

View File

@ -18,6 +18,7 @@ python3 -m verl.trainer.main_ppo \
data.return_raw_chat=True \
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
@ -31,6 +32,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.optim.lr_warmup_steps_ratio=0.05 \
critic.model.path=Qwen/Qwen2-7B-Instruct \
critic.model.enable_gradient_checkpointing=False \
@ -40,6 +42,7 @@ python3 -m verl.trainer.main_ppo \
critic.model.fsdp_config.optimizer_offload=False \
reward_model.enable=True \
reward_model.model.path=sfairXC/FsfairX-Gemma2-RM-v0.1\
reward_model.model.use_remove_padding=True \
reward_model.model.fsdp_config.param_offload=True \
reward_model.micro_batch_size=16 \
algorithm.kl_ctrl.kl_coef=0.001 \

View File

@ -18,6 +18,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
actor_rollout_ref.model.enable_gradient_checkpointing=False \
actor_rollout_ref.actor.optim.lr=1e-6 \
actor_rollout_ref.model.use_remove_padding=True \
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
actor_rollout_ref.actor.fsdp_config.param_offload=False \
@ -30,6 +31,7 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
critic.optim.lr=1e-5 \
critic.model.use_remove_padding=True \
critic.model.path=Qwen/Qwen2.5-32B-Instruct \
critic.model.enable_gradient_checkpointing=False \
critic.ppo_micro_batch_size=32 \

View File

@ -39,7 +39,7 @@ dependencies = [
"pybind11",
"ray",
"tensordict",
"transformers",
"transformers<4.48",
"vllm<=0.6.3",
]

View File

@ -8,6 +8,6 @@ pandas
pybind11
ray
tensordict<0.6
transformers
transformers<4.48
vllm<=0.6.3
wandb
wandb

View File

@ -14,9 +14,11 @@ actor_rollout_ref:
hybrid_engine: True
model:
path: ~/verl/tests/e2e/arithmetic_sequence/model
tokenizer_path: ${actor_rollout_ref.model.path}
external_lib: tests.e2e.envs.digit_completion
override_config: {}
enable_gradient_checkpointing: False
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 200
@ -76,6 +78,7 @@ critic:
override_config: {}
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
use_remove_padding: False
fsdp_config:
param_offload: False
grad_offload: False
@ -104,6 +107,7 @@ reward_model:
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
offload: False
use_remove_padding: False
fsdp_config:
min_num_params: 0
micro_batch_size: 8

View File

@ -119,7 +119,7 @@ def main(config):
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
# download the checkpoint from hdfs
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.tokenizer_path)
local_path = os.path.expanduser(local_path)
# instantiate tokenizern
tokenizer = AutoTokenizer.from_pretrained(local_path)

View File

@ -0,0 +1,14 @@
#!/usr/bin/env bash
set -e -x
python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
actor_rollout_ref.model.tokenizer_path=tests/e2e/arithmetic_sequence/model \
critic.model.path=Qwen/Qwen2.5-0.5B \
critic.model.use_remove_padding=True \
trainer.total_epochs=1

View File

@ -0,0 +1,129 @@
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForTokenClassification, AutoTokenizer
import torch
from verl.utils.model import create_random_mask, compute_position_id_with_mask
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = [
LlamaConfig(num_hidden_layers=1),
MistralConfig(num_hidden_layers=1),
GemmaConfig(num_hidden_layers=1),
Qwen2Config(num_hidden_layers=1)
]
# test_cases = ['deepseek-ai/deepseek-llm-7b-chat', 'Qwen/Qwen2-7B-Instruct']
def test_hf_casual_models():
batch_size = 4
seqlen = 128
response_length = 127
for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, vocab_size)
origin_logits = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
origin_logits_rmpad, origin_logits_indices, _, _ = unpad_input(origin_logits, attention_mask)
logits_rmpad = logits_rmpad.squeeze(0)
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)
origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=origin_logits_rmpad,
indices=origin_logits_indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)
torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]),
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]),
atol=1e-2,
rtol=1e-5)
print(f'Check pass')
def test_hf_value_models():
batch_size = 4
seqlen = 128
for config in test_configs:
# config = AutoConfig.from_pretrained(test_case)
config.num_labels = 1
setattr(config, 'classifier_dropout', 0)
setattr(config, 'hidden_dropout', 0)
with torch.device('cuda'):
model = AutoModelForTokenClassification.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
model = model.to(device='cuda')
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0.1,
max_ratio_of_valid_token=0.8,
min_ratio_of_valid_token=0.5)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
origin_logits = model(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False).logits
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, 1)
rmpad_logits = rmpad_logits.squeeze(0)
pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)
torch.testing.assert_close(masked_mean(pad_logits, attention_mask[:, :, None]),
masked_mean(origin_logits, attention_mask[:, :, None]),
atol=1e-2,
rtol=1e-5)
print('Value model check pass')
if __name__ == '__main__':
test_hf_casual_models()
test_hf_value_models()

View File

@ -17,6 +17,21 @@ from typing import List, Optional, Type
import torch.nn as nn
# Supported models using HF Rmpad
# TODO(sgm): HF may supported more than listed here, we should add more after testing
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}
def check_model_support_rmpad(model_type: str):
assert isinstance(model_type, str)
if not model_type in _REOVEPAD_MODELS.keys():
raise ValueError(f"Model architecture {model_type} is not supported for now. "
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}")
# Supported models in Megatron-LM
# Architecture -> (module, class).
_MODELS = {
"LlamaForCausalLM":

View File

@ -17,6 +17,7 @@ actor_rollout_ref:
external_lib: null
override_config: { }
enable_gradient_checkpointing: False
use_remove_padding: False
actor:
strategy: fsdp # This is for backward-compatibility
ppo_mini_batch_size: 256
@ -83,6 +84,7 @@ critic:
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
use_remove_padding: False
fsdp_config:
param_offload: False
grad_offload: False
@ -105,6 +107,7 @@ reward_model:
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
path: ~/models/FsfairX-LLaMA3-RM-v0.1
external_lib: ${actor_rollout_ref.model.external_lib}
use_remove_padding: False
fsdp_config:
min_num_params: 0
param_offload: False

View File

@ -298,7 +298,7 @@ def log_probs_from_logits_response(input_ids, logits, response_length):
def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
"""Compute the log_probs from logits with rmpad logits and pad input. Note that
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
logits and input_ids.
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
@ -326,6 +326,34 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad
return output
def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
logits and input_ids.
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
for large vocab_size
Args:
input_ids_rmpad: [1, total_nnz]
logits_rmpad: [total_nnz, vocab_size]
indices: [total_nnz]
batch_size: int
seqlen: int
response_length: int
"""
from flash_attn.bert_padding import pad_input
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1]
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
indices=indices,
batch=batch_size,
seqlen=seqlen)
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
return output
from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)

View File

@ -14,7 +14,7 @@
"""
Single Process Actor
"""
from typing import Iterable
from typing import Iterable, Tuple
import torch
from torch import nn
@ -24,7 +24,9 @@ from verl import DataProto
from verl.trainer.ppo import core_algos
from verl.workers.actor import BasePPOActor
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import logprobs_from_logits
from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
__all__ = ['DataParallelPPOActor']
@ -41,17 +43,47 @@ class DataParallelPPOActor(BasePPOActor):
super().__init__(config)
self.actor_module = actor_module
self.actor_optimizer = actor_optimizer
self.use_remove_padding = self.config.get('use_remove_padding', False)
print(f'Actor use_remove_padding={self.use_remove_padding}')
def _forward_micro_batch(self, micro_batch, temperature):
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
response_length = micro_batch['responses'].size(-1)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.actor_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'],
use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature
logits = logits[:, -response_length - 1:-1]
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
input_ids = micro_batch['input_ids']
batch_size, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.actor_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
logits_rmpad /= temperature
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
logits_rmpad=logits_rmpad,
indices=indices,
batch_size=batch_size,
seqlen=seqlen,
response_length=response_length) # (batch, seqlen)
logits = logits_rmpad
else:
output = self.actor_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
logits = output.logits / temperature
logits = logits[:, -response_length - 1:-1]
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
return logits, log_probs
def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
@ -145,8 +177,17 @@ class DataParallelPPOActor(BasePPOActor):
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
entropy_loss = core_algos.compute_entropy_loss(logits, response_mask)
# compute entropy loss
if self.use_remove_padding:
full_response_mask = attention_mask.clone()
full_response_mask[:, :-response_length] = 0 # set the prompt part to zero
full_response_mask_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
full_response_mask.unsqueeze(-1), attention_mask=attention_mask)
full_response_mask_rmpad = full_response_mask_rmpad.squeeze(-1) # (total_nnz)
entropy_loss = core_algos.compute_entropy_loss(logits, full_response_mask_rmpad) # (total_nnz,)
else:
entropy_loss = core_algos.compute_entropy_loss(logits, response_mask)
# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff
loss = policy_loss / self.gradient_accumulation

View File

@ -29,6 +29,8 @@ from verl.workers.critic import BasePPOCritic
from verl.utils.py_functional import append_to_dict
from verl.utils.torch_functional import masked_mean
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
__all__ = ['DataParallelPPOCritic']
@ -38,6 +40,8 @@ class DataParallelPPOCritic(BasePPOCritic):
super().__init__(config=config)
self.critic_module = critic_module
self.critic_optimizer = critic_optimizer
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
print(f'Critic use_remove_padding={self.use_remove_padding}')
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
@ -45,12 +49,37 @@ class DataParallelPPOCritic(BasePPOCritic):
def _forward_micro_batch(self, micro_batch):
response_length = micro_batch['responses'].size(-1)
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.critic_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'],
use_cache=False) # prevent model thinks we are generating
values = output.logits
values = values[:, -response_length - 1:-1]
input_ids = micro_batch['input_ids']
batch, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.critic_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
values_rmpad = output.logits
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
# pad it back
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
values = values[:, -response_length - 1:-1]
else:
output = self.critic_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=False) # prevent model thinks we are generating
values = output.logits
values = values[:, -response_length - 1:-1].squeeze(-1)
return values
def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:

View File

@ -23,7 +23,7 @@ import torch
import torch.distributed
import verl.utils.hdfs_io as hdfs_io
import verl.utils.torch_functional as verl_F
from omegaconf import DictConfig
from omegaconf import DictConfig, open_dict
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import register, Dispatch
@ -95,6 +95,7 @@ class ActorRolloutRefWorker(Worker):
fsdp_config,
optim_config,
override_model_config,
use_remove_padding=False,
enable_gradient_checkpointing=False,
trust_remote_code=False):
from verl.utils.model import print_model_size, update_model_config
@ -119,6 +120,10 @@ class ActorRolloutRefWorker(Worker):
# override model kwargs
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
if use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(actor_model_config.model_type)
override_config_kwargs = {
'bos_token_id': self.tokenizer.bos_token_id,
'eos_token_id': self.tokenizer.eos_token_id,
@ -254,6 +259,8 @@ class ActorRolloutRefWorker(Worker):
from omegaconf import OmegaConf
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
use_remove_padding = self.config.model.get('use_remove_padding', False)
if self._is_actor or self._is_rollout:
# we need the model for actor and rollout
if self._is_actor:
@ -267,6 +274,7 @@ class ActorRolloutRefWorker(Worker):
fsdp_config=fsdp_config,
optim_config=optim_config,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
trust_remote_code=self.config.model.get('trust_remote_code', False))
@ -283,6 +291,8 @@ class ActorRolloutRefWorker(Worker):
# load from checkpoint
if self._is_actor:
OmegaConf.set_struct(self.config.actor, True)
with open_dict(self.config.actor):
self.config.actor.use_remove_padding = use_remove_padding
self.actor = DataParallelPPOActor(config=self.config.actor,
actor_module=self.actor_module_fsdp,
actor_optimizer=self.actor_optimizer)
@ -295,12 +305,15 @@ class ActorRolloutRefWorker(Worker):
fsdp_config=self.config.ref.fsdp_config,
optim_config=None,
override_model_config=override_model_config,
use_remove_padding=use_remove_padding,
trust_remote_code=self.config.model.get(
'trust_remote_code', False))[0]
if self._is_offload_param:
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
OmegaConf.set_struct(self.config.ref, True)
with open_dict(self.config.ref):
self.config.ref.use_remove_padding = use_remove_padding
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
torch.cuda.empty_cache()
@ -463,7 +476,6 @@ class CriticWorker(Worker):
local_path = copy_local_path_from_hdfs(config.model.path)
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
# using random initialized model from any architecture. May not be the same as Actor.
# TODO: support loading critic weights from RM. Support using AutoModelForTokenClassification
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
@ -482,22 +494,28 @@ class CriticWorker(Worker):
torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32')
torch_dtype = PrecisionType.to_dtype(torch_dtype)
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForTokenClassification
from torch import nn
trust_remote_code = False
critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
critic_model_config.num_labels = 1
use_remove_padding = config.model.get('use_remove_padding', False)
if use_remove_padding:
from verl.models.registry import check_model_support_rmpad
check_model_support_rmpad(critic_model_config.model_type)
init_context = get_init_weight_context_manager()
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
critic_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
critic_module.lm_head = nn.Sequential(nn.Linear(critic_model_config.hidden_size, 1, dtype=torch_dtype),
LambdaLayer(fn=squeeze))
setattr(critic_model_config, 'classifier_dropout', 0.)
setattr(critic_model_config, 'hidden_dropout', '0')
critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch_dtype,
config=critic_model_config,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
@ -642,9 +660,10 @@ class CriticWorker(Worker):
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
# TODO(sgm): we may need to extract it to dp_reward_model.py
class RewardModelWorker(Worker):
"""
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
"""
def __init__(self, config):
@ -653,12 +672,12 @@ class RewardModelWorker(Worker):
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
self.config = config
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
self.config.micro_batch_size //= torch.distributed.get_world_size()
def _build_model(self, config):
# the following line is necessary
from transformers import AutoModelForSequenceClassification, AutoConfig
from transformers import AutoModelForTokenClassification, AutoConfig
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload
# download the checkpoint from hdfs
@ -675,15 +694,18 @@ class RewardModelWorker(Worker):
trust_remote_code = config.model.get('trust_remote_code', False)
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
model_config.num_labels = 1
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings)
with init_context(), warnings.catch_warnings():
warnings.simplefilter("ignore")
reward_module = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=local_path,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
setattr(model_config, 'classifier_dropout', 0.)
reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path,
config=model_config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2',
trust_remote_code=trust_remote_code)
reward_module.to(torch.bfloat16)
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
@ -707,12 +729,39 @@ class RewardModelWorker(Worker):
torch.cuda.empty_cache()
def _forward_micro_batch(self, micro_batch):
from verl.utils.torch_functional import prepare_input_for_rmpad
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
output = self.reward_module(input_ids=micro_batch['input_ids'],
attention_mask=micro_batch['attention_mask'],
position_ids=micro_batch['position_ids'])
rm_score = output.logits # (batch_size,)
rm_score = rm_score.squeeze(-1)
input_ids = micro_batch['input_ids']
batch, seqlen = input_ids.shape
attention_mask = micro_batch['attention_mask']
position_ids = micro_batch['position_ids']
if self.use_remove_padding:
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# only pass input_ids and position_ids to enable flash_attn_varlen
output = self.reward_module(input_ids=input_ids_rmpad,
attention_mask=None,
position_ids=position_ids_rmpad,
use_cache=False) # prevent model thinks we are generating
reward_rmpad = output.logits
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
# pad it back
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
else:
output = self.reward_module(input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids)
rm_score = output.logits # (batch_size, seq_len, 1)
rm_score = rm_score.squeeze(-1)
return rm_score
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):