mirror of
https://github.com/volcengine/verl.git
synced 2025-10-20 13:43:50 +08:00
[misc] feat: spport rmpad/data-packing in FSDP with transformers (#91)
* init commit of rmpad * add rmpad test * support rmpad in actor model * add test for value model * support rmpad in critic and rm * fix actor return and fix num_labels and clean not used rmpad * fix critic and benchmark * update script * fix critic * lint * fix util issue * fix unnecessary unpad * address issues * fix args * update test and update rmpad support model list * fix typo * fix typo and fix name * rename rmpad to rename padding * fix arch to model_type * add ci for e2e rmpad and fix typo * lint * fix ci * fix typo * update tests for customize tokenizer in actor * fix rmpad test * update requirement of transformers as hf_rollout may have issue
This commit is contained in:
6
.github/workflows/e2e_gpu.yml
vendored
6
.github/workflows/e2e_gpu.yml
vendored
@ -23,6 +23,7 @@ jobs:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
NO_PROXY: "localhost,127.0.0.1"
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
container:
|
||||
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
|
||||
options: --gpus all --shm-size=10g
|
||||
@ -32,7 +33,12 @@ jobs:
|
||||
fetch-depth: 0
|
||||
- name: Install the current repository
|
||||
run: |
|
||||
pip3 install hf_transfer
|
||||
pip3 install -e .[test]
|
||||
- name: Running digit completon e2e training tests on 8 L20 GPUs
|
||||
run: |
|
||||
bash tests/e2e/run_ray_trainer.sh
|
||||
- name: Running digit completon e2e training tests on 8 L20 GPUs (with rmpad)
|
||||
run: |
|
||||
pip3 install --upgrade transformers
|
||||
bash tests/e2e/run_ray_trainer_rmpad.sh
|
||||
|
39
.github/workflows/model.yml
vendored
Normal file
39
.github/workflows/model.yml
vendored
Normal file
@ -0,0 +1,39 @@
|
||||
name: model_rmpad
|
||||
|
||||
on:
|
||||
# Trigger the workflow on push or pull request,
|
||||
# but only for the main branch
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/model.yml
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
paths:
|
||||
- "**/*.py"
|
||||
- .github/workflows/model.yml
|
||||
|
||||
jobs:
|
||||
e2e_gpu:
|
||||
runs-on: [self-hosted, l20-1]
|
||||
env:
|
||||
HTTP_PROXY: ${{ secrets.PROXY_HTTP }}
|
||||
HTTPS_PROXY: ${{ secrets.PROXY_HTTPS }}
|
||||
NO_PROXY: "localhost,127.0.0.1"
|
||||
container:
|
||||
image: verlai/verl:vemlp-th2.4.0-cu124-vllm0.6.3-ray2.10-te1.7-v0.0.3
|
||||
options: --gpus all --shm-size=10g
|
||||
steps:
|
||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
- name: Install the current repository and upgrade to latest transformers
|
||||
run: |
|
||||
pip3 install -e .[test]
|
||||
pip3 install --upgrade transformers
|
||||
- name: Running digit completon e2e training tests on 8 L20 GPUs
|
||||
run: |
|
||||
pytest -s tests/model/test_transformer.py
|
@ -9,6 +9,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.max_response_length=512 \
|
||||
actor_rollout_ref.model.path=deepseek-ai/deepseek-llm-7b-chat \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=32 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
@ -21,6 +22,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
critic.optim.lr=1e-5 \
|
||||
critic.model.use_remove_padding=True \
|
||||
critic.model.path=deepseek-ai/deepseek-llm-7b-chat \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
critic.ppo_micro_batch_size=32 \
|
||||
|
@ -9,6 +9,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.max_response_length=512 \
|
||||
actor_rollout_ref.model.path=google/gemma-2-2b-it \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=128 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=4 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
@ -21,6 +22,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=4 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
critic.optim.lr=1e-5 \
|
||||
critic.model.use_remove_padding=True \
|
||||
critic.model.path=google/gemma-2-2b-it \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
critic.ppo_micro_batch_size=4 \
|
||||
|
@ -17,6 +17,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.max_response_length=512 \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
@ -29,6 +30,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
critic.optim.lr=1e-5 \
|
||||
critic.model.use_remove_padding=True \
|
||||
critic.model.path=Qwen/Qwen2-7B-Instruct \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
critic.ppo_micro_batch_size=16 \
|
||||
|
@ -18,6 +18,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
data.return_raw_chat=True \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.optim.lr_warmup_steps_ratio=0.1 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
|
||||
@ -31,6 +32,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=16 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
critic.optim.lr=1e-5 \
|
||||
critic.model.use_remove_padding=True \
|
||||
critic.optim.lr_warmup_steps_ratio=0.05 \
|
||||
critic.model.path=Qwen/Qwen2-7B-Instruct \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
@ -40,6 +42,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
critic.model.fsdp_config.optimizer_offload=False \
|
||||
reward_model.enable=True \
|
||||
reward_model.model.path=sfairXC/FsfairX-Gemma2-RM-v0.1\
|
||||
reward_model.model.use_remove_padding=True \
|
||||
reward_model.model.fsdp_config.param_offload=True \
|
||||
reward_model.micro_batch_size=16 \
|
||||
algorithm.kl_ctrl.kl_coef=0.001 \
|
||||
|
@ -18,6 +18,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.model.path=Qwen/Qwen2.5-32B-Instruct \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=False \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.model.use_remove_padding=True \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=256 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size=16 \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
@ -30,6 +31,7 @@ python3 -m verl.trainer.main_ppo \
|
||||
actor_rollout_ref.ref.log_prob_micro_batch_size=128 \
|
||||
actor_rollout_ref.ref.fsdp_config.param_offload=True \
|
||||
critic.optim.lr=1e-5 \
|
||||
critic.model.use_remove_padding=True \
|
||||
critic.model.path=Qwen/Qwen2.5-32B-Instruct \
|
||||
critic.model.enable_gradient_checkpointing=False \
|
||||
critic.ppo_micro_batch_size=32 \
|
||||
|
@ -39,7 +39,7 @@ dependencies = [
|
||||
"pybind11",
|
||||
"ray",
|
||||
"tensordict",
|
||||
"transformers",
|
||||
"transformers<4.48",
|
||||
"vllm<=0.6.3",
|
||||
]
|
||||
|
||||
|
@ -8,6 +8,6 @@ pandas
|
||||
pybind11
|
||||
ray
|
||||
tensordict<0.6
|
||||
transformers
|
||||
transformers<4.48
|
||||
vllm<=0.6.3
|
||||
wandb
|
||||
wandb
|
||||
|
@ -14,9 +14,11 @@ actor_rollout_ref:
|
||||
hybrid_engine: True
|
||||
model:
|
||||
path: ~/verl/tests/e2e/arithmetic_sequence/model
|
||||
tokenizer_path: ${actor_rollout_ref.model.path}
|
||||
external_lib: tests.e2e.envs.digit_completion
|
||||
override_config: {}
|
||||
enable_gradient_checkpointing: False
|
||||
use_remove_padding: False
|
||||
actor:
|
||||
strategy: fsdp # This is for backward-compatibility
|
||||
ppo_mini_batch_size: 200
|
||||
@ -76,6 +78,7 @@ critic:
|
||||
override_config: {}
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
enable_gradient_checkpointing: False
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
param_offload: False
|
||||
grad_offload: False
|
||||
@ -104,6 +107,7 @@ reward_model:
|
||||
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
offload: False
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
min_num_params: 0
|
||||
micro_batch_size: 8
|
||||
|
@ -119,7 +119,7 @@ def main(config):
|
||||
pprint(OmegaConf.to_container(config, resolve=True)) # resolve=True will eval symbol values
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.path)
|
||||
local_path = copy_local_path_from_hdfs(config.actor_rollout_ref.model.tokenizer_path)
|
||||
local_path = os.path.expanduser(local_path)
|
||||
# instantiate tokenizern
|
||||
tokenizer = AutoTokenizer.from_pretrained(local_path)
|
||||
|
14
tests/e2e/run_ray_trainer_rmpad.sh
Normal file
14
tests/e2e/run_ray_trainer_rmpad.sh
Normal file
@ -0,0 +1,14 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -e -x
|
||||
|
||||
python3 tests/e2e/arithmetic_sequence/rl/main_trainer.py \
|
||||
data.train_files=tests/e2e/arithmetic_sequence/data/train.parquet \
|
||||
data.val_files=tests/e2e/arithmetic_sequence/data/test.parquet \
|
||||
actor_rollout_ref.model.path=tests/e2e/arithmetic_sequence/model \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.model.tokenizer_path=tests/e2e/arithmetic_sequence/model \
|
||||
critic.model.path=Qwen/Qwen2.5-0.5B \
|
||||
critic.model.use_remove_padding=True \
|
||||
trainer.total_epochs=1
|
129
tests/model/test_transformer.py
Normal file
129
tests/model/test_transformer.py
Normal file
@ -0,0 +1,129 @@
|
||||
from transformers import AutoModelForCausalLM, AutoConfig, AutoModelForTokenClassification, AutoTokenizer
|
||||
|
||||
import torch
|
||||
from verl.utils.model import create_random_mask, compute_position_id_with_mask
|
||||
from verl.utils.torch_functional import masked_mean, log_probs_from_logits_all_rmpad, logprobs_from_logits
|
||||
from flash_attn.bert_padding import unpad_input, pad_input, index_first_axis, rearrange
|
||||
|
||||
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
|
||||
# TODO(sgm): add more models for test
|
||||
# we only need one scale for each model
|
||||
test_configs = [
|
||||
LlamaConfig(num_hidden_layers=1),
|
||||
MistralConfig(num_hidden_layers=1),
|
||||
GemmaConfig(num_hidden_layers=1),
|
||||
Qwen2Config(num_hidden_layers=1)
|
||||
]
|
||||
# test_cases = ['deepseek-ai/deepseek-llm-7b-chat', 'Qwen/Qwen2-7B-Instruct']
|
||||
|
||||
|
||||
def test_hf_casual_models():
|
||||
batch_size = 4
|
||||
seqlen = 128
|
||||
response_length = 127
|
||||
|
||||
for config in test_configs:
|
||||
# config = AutoConfig.from_pretrained(test_case)
|
||||
with torch.device('cuda'):
|
||||
model = AutoModelForCausalLM.from_config(config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2')
|
||||
model = model.to(device='cuda')
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
|
||||
attention_mask = create_random_mask(input_ids=input_ids,
|
||||
max_ratio_of_left_padding=0.1,
|
||||
max_ratio_of_valid_token=0.8,
|
||||
min_ratio_of_valid_token=0.5)
|
||||
position_ids = compute_position_id_with_mask(
|
||||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
|
||||
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
|
||||
logits_rmpad = model(input_ids_rmpad, position_ids=position_ids_rmpad,
|
||||
use_cache=False).logits # (1, total_nnz, vocab_size)
|
||||
|
||||
origin_logits = model(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False).logits
|
||||
origin_logits_rmpad, origin_logits_indices, _, _ = unpad_input(origin_logits, attention_mask)
|
||||
|
||||
logits_rmpad = logits_rmpad.squeeze(0)
|
||||
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
|
||||
logits_rmpad=logits_rmpad,
|
||||
indices=indices,
|
||||
batch_size=batch_size,
|
||||
seqlen=seqlen,
|
||||
response_length=response_length) # (batch, seqlen)
|
||||
origin_log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
|
||||
logits_rmpad=origin_logits_rmpad,
|
||||
indices=origin_logits_indices,
|
||||
batch_size=batch_size,
|
||||
seqlen=seqlen,
|
||||
response_length=response_length) # (batch, seqlen)
|
||||
|
||||
torch.testing.assert_close(masked_mean(log_probs, attention_mask[:, -response_length - 1:-1]),
|
||||
masked_mean(origin_log_probs, attention_mask[:, -response_length - 1:-1]),
|
||||
atol=1e-2,
|
||||
rtol=1e-5)
|
||||
print(f'Check pass')
|
||||
|
||||
|
||||
def test_hf_value_models():
|
||||
batch_size = 4
|
||||
seqlen = 128
|
||||
|
||||
for config in test_configs:
|
||||
# config = AutoConfig.from_pretrained(test_case)
|
||||
config.num_labels = 1
|
||||
setattr(config, 'classifier_dropout', 0)
|
||||
setattr(config, 'hidden_dropout', 0)
|
||||
with torch.device('cuda'):
|
||||
model = AutoModelForTokenClassification.from_config(config=config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2')
|
||||
model = model.to(device='cuda')
|
||||
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
|
||||
attention_mask = create_random_mask(input_ids=input_ids,
|
||||
max_ratio_of_left_padding=0.1,
|
||||
max_ratio_of_valid_token=0.8,
|
||||
min_ratio_of_valid_token=0.5)
|
||||
position_ids = compute_position_id_with_mask(
|
||||
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
|
||||
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
|
||||
origin_logits = model(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False).logits
|
||||
|
||||
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
|
||||
rmpad_logits = model(input_ids_rmpad, position_ids=position_ids_rmpad,
|
||||
use_cache=False).logits # (1, total_nnz, 1)
|
||||
rmpad_logits = rmpad_logits.squeeze(0)
|
||||
pad_logits = pad_input(rmpad_logits, indices, batch_size, seqlen=seqlen)
|
||||
|
||||
torch.testing.assert_close(masked_mean(pad_logits, attention_mask[:, :, None]),
|
||||
masked_mean(origin_logits, attention_mask[:, :, None]),
|
||||
atol=1e-2,
|
||||
rtol=1e-5)
|
||||
print('Value model check pass')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_hf_casual_models()
|
||||
test_hf_value_models()
|
@ -17,6 +17,21 @@ from typing import List, Optional, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
# Supported models using HF Rmpad
|
||||
# TODO(sgm): HF may supported more than listed here, we should add more after testing
|
||||
from transformers import LlamaConfig, MistralConfig, GemmaConfig, Qwen2Config
|
||||
|
||||
_REOVEPAD_MODELS = {'llama': LlamaConfig, 'mistral': MistralConfig, 'gemma': GemmaConfig, 'qwen2': Qwen2Config}
|
||||
|
||||
|
||||
def check_model_support_rmpad(model_type: str):
|
||||
assert isinstance(model_type, str)
|
||||
if not model_type in _REOVEPAD_MODELS.keys():
|
||||
raise ValueError(f"Model architecture {model_type} is not supported for now. "
|
||||
f"RMPad supported architectures: {_REOVEPAD_MODELS.keys()}")
|
||||
|
||||
|
||||
# Supported models in Megatron-LM
|
||||
# Architecture -> (module, class).
|
||||
_MODELS = {
|
||||
"LlamaForCausalLM":
|
||||
|
@ -17,6 +17,7 @@ actor_rollout_ref:
|
||||
external_lib: null
|
||||
override_config: { }
|
||||
enable_gradient_checkpointing: False
|
||||
use_remove_padding: False
|
||||
actor:
|
||||
strategy: fsdp # This is for backward-compatibility
|
||||
ppo_mini_batch_size: 256
|
||||
@ -83,6 +84,7 @@ critic:
|
||||
override_config: { }
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
enable_gradient_checkpointing: False
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
param_offload: False
|
||||
grad_offload: False
|
||||
@ -105,6 +107,7 @@ reward_model:
|
||||
input_tokenizer: ${actor_rollout_ref.model.path} # set this to null if the chat template is identical
|
||||
path: ~/models/FsfairX-LLaMA3-RM-v0.1
|
||||
external_lib: ${actor_rollout_ref.model.external_lib}
|
||||
use_remove_padding: False
|
||||
fsdp_config:
|
||||
min_num_params: 0
|
||||
param_offload: False
|
||||
|
@ -298,7 +298,7 @@ def log_probs_from_logits_response(input_ids, logits, response_length):
|
||||
|
||||
|
||||
def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad, response_length):
|
||||
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
|
||||
"""Compute the log_probs from logits with rmpad logits and pad input. Note that
|
||||
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
|
||||
logits and input_ids.
|
||||
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
|
||||
@ -326,6 +326,34 @@ def log_probs_from_logits_response_rmpad(input_ids, attention_mask, logits_rmpad
|
||||
return output
|
||||
|
||||
|
||||
def log_probs_from_logits_all_rmpad(input_ids_rmpad, logits_rmpad, indices, batch_size, seqlen, response_length):
|
||||
"""Compute the log_probs from logits with rmpad input_ids and logits. Note that
|
||||
logits_rmpad = model(input_ids_rmpad). For each sentences, there is a shift between
|
||||
logits and input_ids.
|
||||
The reason for this function to is to compute logprobs_from_logits in rmpad mode because it is memory-intensive
|
||||
for large vocab_size
|
||||
|
||||
Args:
|
||||
input_ids_rmpad: [1, total_nnz]
|
||||
logits_rmpad: [total_nnz, vocab_size]
|
||||
indices: [total_nnz]
|
||||
batch_size: int
|
||||
seqlen: int
|
||||
response_length: int
|
||||
"""
|
||||
from flash_attn.bert_padding import pad_input
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # transpose back to [total_nnz, 1]
|
||||
input_ids_rmpad = input_ids_rmpad.squeeze(-1)
|
||||
input_ids_rmpad_rolled = torch.roll(input_ids_rmpad, shifts=-1, dims=0)
|
||||
full_log_probs_rmpad = logprobs_from_logits(logits=logits_rmpad, labels=input_ids_rmpad_rolled) # (total_nnz,)
|
||||
full_output = pad_input(hidden_states=full_log_probs_rmpad.unsqueeze(-1),
|
||||
indices=indices,
|
||||
batch=batch_size,
|
||||
seqlen=seqlen)
|
||||
output = full_output.squeeze(-1)[:, -response_length - 1:-1] # [batch_size, response_length]
|
||||
return output
|
||||
|
||||
|
||||
from transformers.generation.logits_process import (TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper)
|
||||
|
||||
|
||||
|
@ -14,7 +14,7 @@
|
||||
"""
|
||||
Single Process Actor
|
||||
"""
|
||||
from typing import Iterable
|
||||
from typing import Iterable, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -24,7 +24,9 @@ from verl import DataProto
|
||||
from verl.trainer.ppo import core_algos
|
||||
from verl.workers.actor import BasePPOActor
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import logprobs_from_logits
|
||||
from verl.utils.torch_functional import logprobs_from_logits, log_probs_from_logits_all_rmpad
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
|
||||
|
||||
__all__ = ['DataParallelPPOActor']
|
||||
|
||||
@ -41,17 +43,47 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
super().__init__(config)
|
||||
self.actor_module = actor_module
|
||||
self.actor_optimizer = actor_optimizer
|
||||
self.use_remove_padding = self.config.get('use_remove_padding', False)
|
||||
print(f'Actor use_remove_padding={self.use_remove_padding}')
|
||||
|
||||
def _forward_micro_batch(self, micro_batch, temperature):
|
||||
def _forward_micro_batch(self, micro_batch, temperature) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
response_length = micro_batch['responses'].size(-1)
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
output = self.actor_module(input_ids=micro_batch['input_ids'],
|
||||
attention_mask=micro_batch['attention_mask'],
|
||||
position_ids=micro_batch['position_ids'],
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
logits = output.logits / temperature
|
||||
logits = logits[:, -response_length - 1:-1]
|
||||
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
|
||||
input_ids = micro_batch['input_ids']
|
||||
batch_size, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch['attention_mask']
|
||||
position_ids = micro_batch['position_ids']
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
output = self.actor_module(input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
logits_rmpad = output.logits.squeeze(0) # (total_nnz, vocab_size)
|
||||
logits_rmpad /= temperature
|
||||
log_probs = log_probs_from_logits_all_rmpad(input_ids_rmpad=input_ids_rmpad,
|
||||
logits_rmpad=logits_rmpad,
|
||||
indices=indices,
|
||||
batch_size=batch_size,
|
||||
seqlen=seqlen,
|
||||
response_length=response_length) # (batch, seqlen)
|
||||
logits = logits_rmpad
|
||||
else:
|
||||
output = self.actor_module(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
logits = output.logits / temperature
|
||||
logits = logits[:, -response_length - 1:-1]
|
||||
log_probs = logprobs_from_logits(logits, micro_batch['responses'])
|
||||
return logits, log_probs
|
||||
|
||||
def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
|
||||
@ -145,8 +177,17 @@ class DataParallelPPOActor(BasePPOActor):
|
||||
advantages=advantages,
|
||||
eos_mask=response_mask,
|
||||
cliprange=clip_ratio)
|
||||
|
||||
entropy_loss = core_algos.compute_entropy_loss(logits, response_mask)
|
||||
# compute entropy loss
|
||||
if self.use_remove_padding:
|
||||
full_response_mask = attention_mask.clone()
|
||||
full_response_mask[:, :-response_length] = 0 # set the prompt part to zero
|
||||
full_response_mask_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
full_response_mask.unsqueeze(-1), attention_mask=attention_mask)
|
||||
full_response_mask_rmpad = full_response_mask_rmpad.squeeze(-1) # (total_nnz)
|
||||
entropy_loss = core_algos.compute_entropy_loss(logits, full_response_mask_rmpad) # (total_nnz,)
|
||||
else:
|
||||
entropy_loss = core_algos.compute_entropy_loss(logits, response_mask)
|
||||
# compute policy loss
|
||||
policy_loss = pg_loss - entropy_loss * entropy_coeff
|
||||
|
||||
loss = policy_loss / self.gradient_accumulation
|
||||
|
@ -29,6 +29,8 @@ from verl.workers.critic import BasePPOCritic
|
||||
from verl.utils.py_functional import append_to_dict
|
||||
from verl.utils.torch_functional import masked_mean
|
||||
|
||||
from flash_attn.bert_padding import pad_input, unpad_input, rearrange, index_first_axis
|
||||
|
||||
__all__ = ['DataParallelPPOCritic']
|
||||
|
||||
|
||||
@ -38,6 +40,8 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
super().__init__(config=config)
|
||||
self.critic_module = critic_module
|
||||
self.critic_optimizer = critic_optimizer
|
||||
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
|
||||
print(f'Critic use_remove_padding={self.use_remove_padding}')
|
||||
|
||||
assert self.config.ppo_mini_batch_size % self.config.ppo_micro_batch_size == 0
|
||||
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size
|
||||
@ -45,12 +49,37 @@ class DataParallelPPOCritic(BasePPOCritic):
|
||||
def _forward_micro_batch(self, micro_batch):
|
||||
response_length = micro_batch['responses'].size(-1)
|
||||
with torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
output = self.critic_module(input_ids=micro_batch['input_ids'],
|
||||
attention_mask=micro_batch['attention_mask'],
|
||||
position_ids=micro_batch['position_ids'],
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values = output.logits
|
||||
values = values[:, -response_length - 1:-1]
|
||||
input_ids = micro_batch['input_ids']
|
||||
batch, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch['attention_mask']
|
||||
position_ids = micro_batch['position_ids']
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
output = self.critic_module(input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values_rmpad = output.logits
|
||||
values_rmpad = values_rmpad.squeeze(0) # (total_nnz)
|
||||
|
||||
# pad it back
|
||||
values = pad_input(values_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
|
||||
values = values[:, -response_length - 1:-1]
|
||||
else:
|
||||
output = self.critic_module(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
values = output.logits
|
||||
values = values[:, -response_length - 1:-1].squeeze(-1)
|
||||
return values
|
||||
|
||||
def _make_minibatch_iterator(self, data: DataProto) -> Iterable[DataProto]:
|
||||
|
@ -23,7 +23,7 @@ import torch
|
||||
import torch.distributed
|
||||
import verl.utils.hdfs_io as hdfs_io
|
||||
import verl.utils.torch_functional as verl_F
|
||||
from omegaconf import DictConfig
|
||||
from omegaconf import DictConfig, open_dict
|
||||
from verl import DataProto
|
||||
from verl.single_controller.base import Worker
|
||||
from verl.single_controller.base.decorator import register, Dispatch
|
||||
@ -95,6 +95,7 @@ class ActorRolloutRefWorker(Worker):
|
||||
fsdp_config,
|
||||
optim_config,
|
||||
override_model_config,
|
||||
use_remove_padding=False,
|
||||
enable_gradient_checkpointing=False,
|
||||
trust_remote_code=False):
|
||||
from verl.utils.model import print_model_size, update_model_config
|
||||
@ -119,6 +120,10 @@ class ActorRolloutRefWorker(Worker):
|
||||
# override model kwargs
|
||||
actor_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
|
||||
|
||||
if use_remove_padding:
|
||||
from verl.models.registry import check_model_support_rmpad
|
||||
check_model_support_rmpad(actor_model_config.model_type)
|
||||
|
||||
override_config_kwargs = {
|
||||
'bos_token_id': self.tokenizer.bos_token_id,
|
||||
'eos_token_id': self.tokenizer.eos_token_id,
|
||||
@ -254,6 +259,8 @@ class ActorRolloutRefWorker(Worker):
|
||||
from omegaconf import OmegaConf
|
||||
override_model_config = OmegaConf.to_container(self.config.model.get('override_config', OmegaConf.create()))
|
||||
|
||||
use_remove_padding = self.config.model.get('use_remove_padding', False)
|
||||
|
||||
if self._is_actor or self._is_rollout:
|
||||
# we need the model for actor and rollout
|
||||
if self._is_actor:
|
||||
@ -267,6 +274,7 @@ class ActorRolloutRefWorker(Worker):
|
||||
fsdp_config=fsdp_config,
|
||||
optim_config=optim_config,
|
||||
override_model_config=override_model_config,
|
||||
use_remove_padding=use_remove_padding,
|
||||
enable_gradient_checkpointing=self.config.model.get('enable_gradient_checkpointing', False),
|
||||
trust_remote_code=self.config.model.get('trust_remote_code', False))
|
||||
|
||||
@ -283,6 +291,8 @@ class ActorRolloutRefWorker(Worker):
|
||||
# load from checkpoint
|
||||
if self._is_actor:
|
||||
OmegaConf.set_struct(self.config.actor, True)
|
||||
with open_dict(self.config.actor):
|
||||
self.config.actor.use_remove_padding = use_remove_padding
|
||||
self.actor = DataParallelPPOActor(config=self.config.actor,
|
||||
actor_module=self.actor_module_fsdp,
|
||||
actor_optimizer=self.actor_optimizer)
|
||||
@ -295,12 +305,15 @@ class ActorRolloutRefWorker(Worker):
|
||||
fsdp_config=self.config.ref.fsdp_config,
|
||||
optim_config=None,
|
||||
override_model_config=override_model_config,
|
||||
use_remove_padding=use_remove_padding,
|
||||
trust_remote_code=self.config.model.get(
|
||||
'trust_remote_code', False))[0]
|
||||
if self._is_offload_param:
|
||||
offload_fsdp_param_and_grad(module=self.ref_module_fsdp, offload_grad=self._is_offload_grad)
|
||||
|
||||
OmegaConf.set_struct(self.config.ref, True)
|
||||
with open_dict(self.config.ref):
|
||||
self.config.ref.use_remove_padding = use_remove_padding
|
||||
self.ref_policy = DataParallelPPOActor(config=self.config.ref, actor_module=self.ref_module_fsdp)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
@ -463,7 +476,6 @@ class CriticWorker(Worker):
|
||||
local_path = copy_local_path_from_hdfs(config.model.path)
|
||||
# note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info
|
||||
# using random initialized model from any architecture. May not be the same as Actor.
|
||||
# TODO: support loading critic weights from RM. Support using AutoModelForTokenClassification
|
||||
|
||||
tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path)
|
||||
self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False))
|
||||
@ -482,22 +494,28 @@ class CriticWorker(Worker):
|
||||
torch_dtype = self.config.model.fsdp_config.get('model_dtype', 'fp32')
|
||||
torch_dtype = PrecisionType.to_dtype(torch_dtype)
|
||||
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers import AutoConfig, AutoModelForTokenClassification
|
||||
from torch import nn
|
||||
|
||||
trust_remote_code = False
|
||||
critic_model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
|
||||
critic_model_config.num_labels = 1
|
||||
|
||||
use_remove_padding = config.model.get('use_remove_padding', False)
|
||||
if use_remove_padding:
|
||||
from verl.models.registry import check_model_support_rmpad
|
||||
check_model_support_rmpad(critic_model_config.model_type)
|
||||
|
||||
init_context = get_init_weight_context_manager()
|
||||
with init_context(), warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
critic_module = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path=local_path,
|
||||
torch_dtype=torch_dtype,
|
||||
config=critic_model_config,
|
||||
attn_implementation='flash_attention_2',
|
||||
trust_remote_code=trust_remote_code)
|
||||
critic_module.lm_head = nn.Sequential(nn.Linear(critic_model_config.hidden_size, 1, dtype=torch_dtype),
|
||||
LambdaLayer(fn=squeeze))
|
||||
setattr(critic_model_config, 'classifier_dropout', 0.)
|
||||
setattr(critic_model_config, 'hidden_dropout', '0')
|
||||
critic_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path,
|
||||
torch_dtype=torch_dtype,
|
||||
config=critic_model_config,
|
||||
attn_implementation='flash_attention_2',
|
||||
trust_remote_code=trust_remote_code)
|
||||
|
||||
# some parameters may not in torch_dtype
|
||||
critic_module.to(torch_dtype)
|
||||
@ -642,9 +660,10 @@ class CriticWorker(Worker):
|
||||
offload_fsdp_param_and_grad(module=self.critic_module, offload_grad=self._is_offload_grad)
|
||||
|
||||
|
||||
# TODO(sgm): we may need to extract it to dp_reward_model.py
|
||||
class RewardModelWorker(Worker):
|
||||
"""
|
||||
Note that we only implement the reward model that is subclass of AutoModelForSequenceClassification.
|
||||
Note that we only implement the reward model that is subclass of AutoModelForTokenClassification.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
@ -653,12 +672,12 @@ class RewardModelWorker(Worker):
|
||||
if not torch.distributed.is_initialized():
|
||||
torch.distributed.init_process_group(backend="nccl")
|
||||
self.config = config
|
||||
|
||||
self.use_remove_padding = self.config.model.get('use_remove_padding', False)
|
||||
self.config.micro_batch_size //= torch.distributed.get_world_size()
|
||||
|
||||
def _build_model(self, config):
|
||||
# the following line is necessary
|
||||
from transformers import AutoModelForSequenceClassification, AutoConfig
|
||||
from transformers import AutoModelForTokenClassification, AutoConfig
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload
|
||||
|
||||
# download the checkpoint from hdfs
|
||||
@ -675,15 +694,18 @@ class RewardModelWorker(Worker):
|
||||
|
||||
trust_remote_code = config.model.get('trust_remote_code', False)
|
||||
model_config = AutoConfig.from_pretrained(local_path, trust_remote_code=trust_remote_code)
|
||||
model_config.num_labels = 1
|
||||
# note that we have to create model in fp32. Otherwise, the optimizer is in bf16, which is incorrect
|
||||
init_context = get_init_weight_context_manager(use_meta_tensor=not model_config.tie_word_embeddings)
|
||||
|
||||
with init_context(), warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore")
|
||||
reward_module = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=local_path,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2',
|
||||
trust_remote_code=trust_remote_code)
|
||||
setattr(model_config, 'classifier_dropout', 0.)
|
||||
reward_module = AutoModelForTokenClassification.from_pretrained(pretrained_model_name_or_path=local_path,
|
||||
config=model_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation='flash_attention_2',
|
||||
trust_remote_code=trust_remote_code)
|
||||
reward_module.to(torch.bfloat16)
|
||||
auto_wrap_policy = get_fsdp_wrap_policy(module=reward_module, config=self.config.model.fsdp_config)
|
||||
|
||||
@ -707,12 +729,39 @@ class RewardModelWorker(Worker):
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def _forward_micro_batch(self, micro_batch):
|
||||
from verl.utils.torch_functional import prepare_input_for_rmpad
|
||||
from flash_attn.bert_padding import pad_input, unpad_input, index_first_axis, rearrange
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type='cuda', dtype=torch.bfloat16):
|
||||
output = self.reward_module(input_ids=micro_batch['input_ids'],
|
||||
attention_mask=micro_batch['attention_mask'],
|
||||
position_ids=micro_batch['position_ids'])
|
||||
rm_score = output.logits # (batch_size,)
|
||||
rm_score = rm_score.squeeze(-1)
|
||||
input_ids = micro_batch['input_ids']
|
||||
batch, seqlen = input_ids.shape
|
||||
attention_mask = micro_batch['attention_mask']
|
||||
position_ids = micro_batch['position_ids']
|
||||
|
||||
if self.use_remove_padding:
|
||||
input_ids_rmpad, indices, cu_seqlens, max_seqlen_in_batch = unpad_input(
|
||||
input_ids.unsqueeze(-1), attention_mask) # input_ids_rmpad (total_nnz, ...)
|
||||
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
|
||||
|
||||
# unpad the position_ids to align the rotary
|
||||
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
|
||||
indices).transpose(0, 1)
|
||||
# only pass input_ids and position_ids to enable flash_attn_varlen
|
||||
output = self.reward_module(input_ids=input_ids_rmpad,
|
||||
attention_mask=None,
|
||||
position_ids=position_ids_rmpad,
|
||||
use_cache=False) # prevent model thinks we are generating
|
||||
reward_rmpad = output.logits
|
||||
reward_rmpad = reward_rmpad.squeeze(0) # (total_nnz)
|
||||
|
||||
# pad it back
|
||||
rm_score = pad_input(reward_rmpad, indices=indices, batch=batch, seqlen=seqlen).squeeze(-1)
|
||||
else:
|
||||
output = self.reward_module(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids)
|
||||
rm_score = output.logits # (batch_size, seq_len, 1)
|
||||
rm_score = rm_score.squeeze(-1)
|
||||
return rm_score
|
||||
|
||||
def _expand_to_token_level(self, data: DataProto, scores: torch.Tensor):
|
||||
|
Reference in New Issue
Block a user