[ulysses] fix: repeat kv heads by sp_size//nheads_k if nheads_k is less than sp_size (#850)

This commit is contained in:
Joel
2025-04-01 07:25:53 +08:00
committed by haibin.lin
parent b70981bdb9
commit de9e01b847
5 changed files with 201 additions and 135 deletions

View File

@ -127,7 +127,7 @@ class PRIMERewardModelWorker(Worker):
if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module)
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# some parameters may not in torch_dtype
reward_module.to(torch_dtype)

View File

@ -11,7 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
import contextlib
from dataclasses import dataclass
import pytest
import torch
import copy
import torch.distributed
@ -23,16 +27,45 @@ from verl.utils.ulysses import get_ulysses_sequence_parallel_world_size, set_uly
from verl.workers.sharding_manager import FSDPUlyssesShardingManager
from verl.protocol import DataProto
from flash_attn.bert_padding import unpad_input, index_first_axis, rearrange
from transformers import LlamaConfig, Qwen2Config
from transformers import LlamaConfig, Qwen2Config, PretrainedConfig
from transformers import AutoModelForCausalLM
from verl.models.transformers.monkey_patch import apply_monkey_patch
# TODO(sgm): add more models for test
# we only need one scale for each model
test_configs = {
'llama': (LlamaConfig(num_hidden_layers=2), apply_monkey_patch),
'qwen2': (Qwen2Config(num_hidden_layers=2), apply_monkey_patch)
}
@dataclass
class SequenceParallelConfig:
config: PretrainedConfig
sp_size: int
is_valid: bool
def test_configs():
return [
SequenceParallelConfig(LlamaConfig(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=32),
sp_size=8,
is_valid=True),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2,
num_attention_heads=28,
num_key_value_heads=4,
hidden_size=3584),
sp_size=4,
is_valid=True),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2,
num_attention_heads=28,
num_key_value_heads=4,
hidden_size=3584),
sp_size=8,
is_valid=False),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4),
sp_size=4,
is_valid=True),
SequenceParallelConfig(Qwen2Config(num_hidden_layers=2, num_attention_heads=32, num_key_value_heads=4),
sp_size=8,
is_valid=True),
]
def sync_model_parameters_global(layer):
@ -41,11 +74,23 @@ def sync_model_parameters_global(layer):
torch.distributed.broadcast(tensor=p.data, src=0)
def test_hf_casual_fwd():
@pytest.mark.parametrize("test_config", test_configs())
def test_hf_casual_fwd_bwd(test_config):
if not torch.distributed.is_initialized():
initialize_global_process_group()
context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError)
with context:
world_size = torch.distributed.get_world_size()
_hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size)
# TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort`
# torch.distributed.destroy_process_group()
def _hf_casual_fwd(config, sp_size, dp_size):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
sp_size = 8
dp_size = 1
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, sp_size),
mesh_dim_names=('dp', 'sp'))
@ -55,75 +100,71 @@ def test_hf_casual_fwd():
seqlen = 128
response_length = 127
for model_name, (config, apply_monkey_patch) in test_configs.items():
# patch before load
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
apply_monkey_patch(model)
model = model.to(device='cuda')
sync_model_parameters_global(model)
# patch before load
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
apply_monkey_patch(model, sp_size)
model = model.to(device='cuda')
sync_model_parameters_global(model)
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.8)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.8)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
model_inputs = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.int().cuda()
}
model_inputs = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.int().cuda()
}
model_inputs = DataProto.from_dict(model_inputs)
model_inputs = DataProto.from_dict(model_inputs)
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch['input_ids']
attention_mask = model_inputs.batch['attention_mask']
position_ids = model_inputs.batch['position_ids']
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch['input_ids']
attention_mask = model_inputs.batch['attention_mask']
position_ids = model_inputs.batch['position_ids']
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
use_cache=False).logits # (1, total_nnz/n, vocab_size)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
use_cache=False).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, vocab_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
logits_rmpad_local = model(input_ids_rmpad, position_ids=position_ids_rmpad,
use_cache=False).logits # (1, total_nnz, vocab_size)
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
print(f'Fwd Check pass')
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
def test_hf_casual_fwd_bwd():
def _hf_casual_fwd_bwd(config, sp_size, dp_size):
assert torch.cuda.device_count() >= 2, "need at least 2 gpus for test"
sp_size = 8
dp_size = 1
ulysses_device_mesh = init_device_mesh(device_type='cuda',
mesh_shape=(dp_size, sp_size),
mesh_dim_names=('dp', 'sp'))
@ -133,82 +174,78 @@ def test_hf_casual_fwd_bwd():
seqlen = 128
response_length = 127
for model_name, (config, apply_monkey_patch) in test_configs.items():
# patch before load
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
apply_monkey_patch(model)
model = model.to(device='cuda')
sync_model_parameters_global(model)
# patch before load
with torch.device('cuda'):
model = AutoModelForCausalLM.from_config(config=config,
torch_dtype=torch.bfloat16,
attn_implementation='flash_attention_2')
apply_monkey_patch(model, sp_size)
model = model.to(device='cuda')
sync_model_parameters_global(model)
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.8)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
# different rank will generate different input_ids following fsdp
input_ids = torch.randint(low=0, high=config.vocab_size, size=(batch_size, seqlen), device='cuda')
attention_mask = create_random_mask(input_ids=input_ids,
max_ratio_of_left_padding=0,
max_ratio_of_valid_token=0.9,
min_ratio_of_valid_token=0.8)
position_ids = compute_position_id_with_mask(
attention_mask) # TODO(sgm): we can construct the position_ids_rmpad here
model_inputs = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.int().cuda()
}
model_inputs = {
'input_ids': input_ids.cuda(),
'attention_mask': attention_mask.cuda(),
'position_ids': position_ids.int().cuda()
}
model_inputs = DataProto.from_dict(model_inputs)
model_inputs = DataProto.from_dict(model_inputs)
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch['input_ids']
attention_mask = model_inputs.batch['attention_mask']
position_ids = model_inputs.batch['position_ids']
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# 1. perform ulysses forward
with sharding_manager:
model_inputs = sharding_manager.preprocess_data(model_inputs)
input_ids = model_inputs.batch['input_ids']
attention_mask = model_inputs.batch['attention_mask']
position_ids = model_inputs.batch['position_ids']
input_ids_rmpad, indices, *_ = unpad_input(input_ids.unsqueeze(-1),
attention_mask) # input_ids_rmpad (total_nnz, ...)
input_ids_rmpad = input_ids_rmpad.transpose(0, 1) # (1, total_nnz)
# unpad the position_ids to align the rotary
position_ids_rmpad = index_first_axis(rearrange(position_ids.unsqueeze(-1), "b s ... -> (b s) ..."),
indices).transpose(0, 1)
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
# slice input tensor for ulysses
# input_ids are padded and sliced
# postition_ids are only padded but not sliced
input_ids_rmpad_sliced, position_ids_rmpad_padded, pad_size = ulysses_pad_and_slice_inputs(
input_ids_rmpad, position_ids_rmpad, sp_size=get_ulysses_sequence_parallel_world_size())
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
use_cache=False).logits # (1, total_nnz/n, vocab_size)
# input with input_ids_rmpad and postition_ids to enable flash attention varlen
logits_split_in_seq = model(input_ids_rmpad_sliced, position_ids=position_ids_rmpad_padded,
use_cache=False).logits # (1, total_nnz/n, vocab_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# all_gather output
logits_full = gather_outpus_and_unpad(logits_split_in_seq, gather_dim=1, unpad_dim=1, padding_size=pad_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
input_ids_full = copy.deepcopy(input_ids_rmpad)
position_ids_full = copy.deepcopy(position_ids_rmpad)
model_no_sp = copy.deepcopy(model)
logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full,
use_cache=False).logits # (1, total_nnz, vocab_size)
# 2. perform normal forward
set_ulysses_sequence_parallel_group(None)
input_ids_full = copy.deepcopy(input_ids_rmpad)
position_ids_full = copy.deepcopy(position_ids_rmpad)
model_no_sp = copy.deepcopy(model)
logits_rmpad_local = model_no_sp(input_ids_full, position_ids=position_ids_full,
use_cache=False).logits # (1, total_nnz, vocab_size)
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
mean_local = logits_rmpad_local.mean()
mean_full = logits_full.mean()
mean_full.backward()
mean_local.backward()
mean_full.backward()
mean_local.backward()
# 3. check the gradients
grad = model.model.layers[0].self_attn.q_proj.weight.grad
grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad
torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)
print(f'Fwd + BWD Check pass')
# 3. check the gradients
grad = model.model.layers[0].self_attn.q_proj.weight.grad
grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad
torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5)
torch.testing.assert_close(grad, grad_full, atol=1e-2, rtol=1e-5)
if __name__ == '__main__':
local_rank, rank, world_size = initialize_global_process_group()
test_hf_casual_fwd()
test_hf_casual_fwd_bwd()
pytest.main([__file__, "-svv"])

View File

@ -29,6 +29,18 @@ from verl.utils.ulysses import (
)
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=2, repeats=n_rep). The hidden states go from (batch,
seqlen, num_key_value_heads, head_dim) to (batch, seqlen, num_attention_heads, head_dim)
"""
batch, slen, num_key_value_heads, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, :, None, :].expand(batch, slen, num_key_value_heads, n_rep, head_dim)
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, head_dim)
def _ulysses_flash_attention_forward(
query_states: torch.Tensor,
key_states: torch.Tensor,
@ -54,6 +66,17 @@ def _ulysses_flash_attention_forward(
########## AlltoAll for Ulysses ##########
if ulysses_sp_size > 1:
assert position_ids is not None, "position_ids is required for Ulysses sequence parallelism"
# NOTE: repeat kv heads to be divided by sequence parallel. Instead of repeating nheads_q//nheads_k,
# we choose to repeat sp_size//nheads_k, since flash_attention supports MQA/GQA.
# For example:
# - nheads_k=4, sp=8, repeats=2
# - nheads_k=8, sp=8, repeats=1
# - nheads_k=16, sp=8, repeats=1
repeats = max(ulysses_sp_size // key_states.size(2), 1)
key_states = repeat_kv(key_states, repeats)
value_states = repeat_kv(value_states, repeats)
# (bsz, seq_len/n, n_head, head_dim) -> (bsz, seq_len, n_head/n, head_dim)
query_states = gather_seq_scatter_heads(query_states, seq_dim=1, head_dim=2)
key_states = gather_seq_scatter_heads(key_states, seq_dim=1, head_dim=2)
@ -84,10 +107,16 @@ def _ulysses_flash_attention_forward(
return attn_output
def apply_monkey_patch(model: PreTrainedModel):
def apply_monkey_patch(model: PreTrainedModel, ulysses_sp_size: int):
"""Replace _flash_attention_forward to _ulysses_flash_attention_forward"""
module = sys.modules[model.__module__]
num_attention_heads, num_key_value_heads = model.config.num_attention_heads, model.config.num_key_value_heads
assert num_attention_heads % ulysses_sp_size == 0, \
f"num_attention_heads {num_attention_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
assert num_key_value_heads % ulysses_sp_size == 0 or ulysses_sp_size % num_key_value_heads == 0, \
f"num_key_value_heads {num_key_value_heads} must be divisible by ulysses_sp_size {ulysses_sp_size}"
# TODO: VLM models only, unify monkey patch to LLM models.
if model.config.model_type in ("qwen2_vl", "qwen2_5_vl"): # patch remove padding for qwen2vl mrope
from verl.models.transformers.qwen2_vl import ulysses_flash_attn_forward

View File

@ -210,7 +210,7 @@ class FSDPSFTTrainer(object):
if self.use_remove_padding or self.config.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=self.model)
apply_monkey_patch(model=self.model, ulysses_sp_size=self.config.ulysses_sequence_parallel_size)
# Apply Liger kernel if use_liger is enabled
if self.config.model.get('use_liger', False):

View File

@ -204,7 +204,7 @@ class ActorRolloutRefWorker(Worker):
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=actor_module)
apply_monkey_patch(model=actor_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# Apply Liger kernel to the model if use_liger is set to True
if use_liger:
@ -709,7 +709,7 @@ class CriticWorker(Worker):
use_remove_padding = config.model.get('use_remove_padding', False)
if use_remove_padding or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=critic_module)
apply_monkey_patch(model=critic_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
# some parameters may not in torch_dtype
critic_module.to(torch_dtype)
@ -967,7 +967,7 @@ class RewardModelWorker(Worker):
if config.model.get('use_remove_padding', False) or self.ulysses_sequence_parallel_size > 1:
from verl.models.transformers.monkey_patch import apply_monkey_patch
apply_monkey_patch(model=reward_module)
apply_monkey_patch(model=reward_module, ulysses_sp_size=self.ulysses_sequence_parallel_size)
reward_module.to(torch.bfloat16)