mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
!508 【多模态】【feat.】Supports dynamic MBS
Merge pull request !508 from wanghao/mbs
This commit is contained in:
@ -141,10 +141,16 @@ class BaseTrainingEngine(ABC):
|
||||
partitions = rearrange_micro_batches(seq_len_list, max_packing_token, dynamic_max_batch_size=dynamic_max_batch_size)
|
||||
batches = []
|
||||
for key, tensors in batch.items():
|
||||
for batch_idx, partition in enumerate(partitions):
|
||||
if batch_idx >= len(batches):
|
||||
batches.append({})
|
||||
batches[batch_idx][key] = tensors[partition]
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
for batch_idx, partition in enumerate(partitions):
|
||||
if batch_idx >= len(batches):
|
||||
batches.append({})
|
||||
batches[batch_idx][key] = tensors[partition]
|
||||
elif isinstance(tensors, List):
|
||||
for batch_idx, partition in enumerate(partitions):
|
||||
if batch_idx >= len(batches):
|
||||
batches.append({})
|
||||
batches[batch_idx][key] = torch.concat([tensors[p] for p in partition])
|
||||
return batches, partitions
|
||||
|
||||
def _forward_backward_batch(self, batch: Dict[str, torch.Tensor], forward_only: bool = False):
|
||||
|
@ -6,6 +6,8 @@ from typing import List, Tuple
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from mindspeed_rl.utils import is_multimodal
|
||||
|
||||
|
||||
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
|
||||
"""Karmarkar-Karp algorithm for partitioning a list of integers into k partitions
|
||||
@ -186,6 +188,125 @@ def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, eq
|
||||
return _check_and_sort_partitions(partitions)
|
||||
|
||||
|
||||
def balanced_bin_packing(seqlen_list: List[int], max_capacity: int):
|
||||
"""Balanced bin packing algorithm that ensures each bin doesn't exceed max_capacity
|
||||
while maintaining load balance across bins.
|
||||
|
||||
Parameters:
|
||||
seqlen_list (List[int]):
|
||||
seq lengths of each items
|
||||
max_capacity (int):
|
||||
maximum capacity for each bin/partition
|
||||
|
||||
Returns:
|
||||
partitions (List[List[int]]):
|
||||
list of partitions, each containing indices of items
|
||||
"""
|
||||
if not seqlen_list:
|
||||
return []
|
||||
|
||||
# Create list of (seqlen, original_index) and sort by seqlen descending
|
||||
indexed_seqlens = [(seqlen, i) for i, seqlen in enumerate(seqlen_list)]
|
||||
indexed_seqlens.sort(reverse=True) # Largest first (Best Fit Decreasing)
|
||||
|
||||
# Initialize bins with their current capacity usage
|
||||
bins = [] # Each bin: {'items': [(idx, seqlen), ...], 'capacity_used': int}
|
||||
|
||||
for seqlen, original_idx in indexed_seqlens:
|
||||
if seqlen > max_capacity:
|
||||
raise ValueError(f"Item with seqlen {seqlen} exceeds max_capacity {max_capacity}")
|
||||
|
||||
# Find the best bin that can accommodate this item
|
||||
best_bin_idx = -1
|
||||
best_remaining_capacity = max_capacity + 1 # Initialize to impossible value
|
||||
|
||||
for bin_idx, bin_info in enumerate(bins):
|
||||
remaining_capacity = max_capacity - bin_info['capacity_used']
|
||||
# Check if item fits and this bin has less remaining capacity (Best Fit)
|
||||
if remaining_capacity >= seqlen and remaining_capacity < best_remaining_capacity:
|
||||
best_bin_idx = bin_idx
|
||||
best_remaining_capacity = remaining_capacity
|
||||
|
||||
if best_bin_idx != -1:
|
||||
# Add to existing bin
|
||||
bins[best_bin_idx]['items'].append((original_idx, seqlen))
|
||||
bins[best_bin_idx]['capacity_used'] += seqlen
|
||||
else:
|
||||
# Create new bin
|
||||
bins.append({
|
||||
'items': [(original_idx, seqlen)],
|
||||
'capacity_used': seqlen
|
||||
})
|
||||
|
||||
# Post-processing: Try to balance the bins by moving items between them
|
||||
# This helps reduce the variance in bin loads
|
||||
_balance_bins(bins, max_capacity)
|
||||
|
||||
# Convert to partition format (list of indices for each partition)
|
||||
partitions = []
|
||||
for bin_info in bins:
|
||||
partition = [idx for idx, _ in bin_info['items']]
|
||||
partitions.append(partition)
|
||||
|
||||
return partitions
|
||||
|
||||
|
||||
def _balance_bins(bins: List[dict], max_capacity: int):
|
||||
"""Helper function to balance loads across bins by moving items between bins.
|
||||
|
||||
Parameters:
|
||||
bins: List of bin dictionaries with 'items' and 'capacity_used' keys
|
||||
max_capacity: Maximum capacity per bin
|
||||
"""
|
||||
if len(bins) <= 1:
|
||||
return
|
||||
|
||||
# Perform multiple passes to improve balance
|
||||
max_iterations = 3
|
||||
for _ in range(max_iterations):
|
||||
improved = False
|
||||
|
||||
# Sort bins by current load
|
||||
bins.sort(key=lambda b: b['capacity_used'])
|
||||
|
||||
# Try to move items from heaviest bins to lightest bins
|
||||
for heavy_idx in range(len(bins) - 1, 0, -1):
|
||||
heavy_bin = bins[heavy_idx]
|
||||
|
||||
for light_idx in range(heavy_idx):
|
||||
light_bin = bins[light_idx]
|
||||
|
||||
# Calculate load difference
|
||||
load_diff = heavy_bin['capacity_used'] - light_bin['capacity_used']
|
||||
if load_diff <= 1: # Already balanced enough
|
||||
break
|
||||
|
||||
# Find items in heavy bin that can be moved to light bin
|
||||
for item_idx, (idx, seqlen) in enumerate(heavy_bin['items']):
|
||||
new_light_load = light_bin['capacity_used'] + seqlen
|
||||
new_heavy_load = heavy_bin['capacity_used'] - seqlen
|
||||
|
||||
# Check if move is beneficial and doesn't violate capacity
|
||||
if (new_light_load <= max_capacity and
|
||||
abs(new_heavy_load - new_light_load) < load_diff):
|
||||
# Move the item
|
||||
item = heavy_bin['items'].pop(item_idx)
|
||||
light_bin['items'].append(item)
|
||||
heavy_bin['capacity_used'] -= seqlen
|
||||
light_bin['capacity_used'] += seqlen
|
||||
improved = True
|
||||
break
|
||||
|
||||
if improved:
|
||||
break
|
||||
|
||||
if improved:
|
||||
break
|
||||
|
||||
if not improved:
|
||||
break
|
||||
|
||||
|
||||
def rearrange_micro_batches(
|
||||
seqlen_list: List[int],
|
||||
max_token_len: int,
|
||||
@ -205,6 +326,11 @@ def rearrange_micro_batches(
|
||||
Returns:
|
||||
List[List[int]]: List of partitions, each containing indices of items.
|
||||
"""
|
||||
if is_multimodal():
|
||||
# When multimodal, max_token_len is a list representing the maximum token length for each modality.
|
||||
# Use balanced bin packing algorithm with capacity constraints
|
||||
return balanced_bin_packing(seqlen_list=seqlen_list, max_capacity=max_token_len)
|
||||
|
||||
if max(seqlen_list) > max_token_len:
|
||||
raise ValueError(
|
||||
f"seqlen of items:[{max(seqlen_list)}] must <= max_token_len:[{max_token_len}]"
|
||||
|
@ -65,7 +65,7 @@ rl_config:
|
||||
blocking: true
|
||||
gamma: 1.0
|
||||
lam: 0.95
|
||||
use_dynamic_bsz: false
|
||||
use_dynamic_bsz: true
|
||||
max_packing_token_size: [15000, 4000]
|
||||
actor_forward_micro_batch_size: 1
|
||||
ref_forward_micro_batch_size: 1
|
||||
|
Reference in New Issue
Block a user