!508 【多模态】【feat.】Supports dynamic MBS

Merge pull request !508 from wanghao/mbs
This commit is contained in:
wanghao
2025-08-28 13:37:29 +00:00
committed by i-robot
parent bdcf688590
commit a9faec2b83
3 changed files with 137 additions and 5 deletions

View File

@ -141,10 +141,16 @@ class BaseTrainingEngine(ABC):
partitions = rearrange_micro_batches(seq_len_list, max_packing_token, dynamic_max_batch_size=dynamic_max_batch_size)
batches = []
for key, tensors in batch.items():
for batch_idx, partition in enumerate(partitions):
if batch_idx >= len(batches):
batches.append({})
batches[batch_idx][key] = tensors[partition]
if isinstance(tensors, torch.Tensor):
for batch_idx, partition in enumerate(partitions):
if batch_idx >= len(batches):
batches.append({})
batches[batch_idx][key] = tensors[partition]
elif isinstance(tensors, List):
for batch_idx, partition in enumerate(partitions):
if batch_idx >= len(batches):
batches.append({})
batches[batch_idx][key] = torch.concat([tensors[p] for p in partition])
return batches, partitions
def _forward_backward_batch(self, batch: Dict[str, torch.Tensor], forward_only: bool = False):

View File

@ -6,6 +6,8 @@ from typing import List, Tuple
import torch
import torch.distributed as dist
from mindspeed_rl.utils import is_multimodal
def karmarkar_karp(seqlen_list: List[int], k_partitions: int, equal_size: bool):
"""Karmarkar-Karp algorithm for partitioning a list of integers into k partitions
@ -186,6 +188,125 @@ def get_seqlen_balanced_partitions(seqlen_list: List[int], k_partitions: int, eq
return _check_and_sort_partitions(partitions)
def balanced_bin_packing(seqlen_list: List[int], max_capacity: int):
"""Balanced bin packing algorithm that ensures each bin doesn't exceed max_capacity
while maintaining load balance across bins.
Parameters:
seqlen_list (List[int]):
seq lengths of each items
max_capacity (int):
maximum capacity for each bin/partition
Returns:
partitions (List[List[int]]):
list of partitions, each containing indices of items
"""
if not seqlen_list:
return []
# Create list of (seqlen, original_index) and sort by seqlen descending
indexed_seqlens = [(seqlen, i) for i, seqlen in enumerate(seqlen_list)]
indexed_seqlens.sort(reverse=True) # Largest first (Best Fit Decreasing)
# Initialize bins with their current capacity usage
bins = [] # Each bin: {'items': [(idx, seqlen), ...], 'capacity_used': int}
for seqlen, original_idx in indexed_seqlens:
if seqlen > max_capacity:
raise ValueError(f"Item with seqlen {seqlen} exceeds max_capacity {max_capacity}")
# Find the best bin that can accommodate this item
best_bin_idx = -1
best_remaining_capacity = max_capacity + 1 # Initialize to impossible value
for bin_idx, bin_info in enumerate(bins):
remaining_capacity = max_capacity - bin_info['capacity_used']
# Check if item fits and this bin has less remaining capacity (Best Fit)
if remaining_capacity >= seqlen and remaining_capacity < best_remaining_capacity:
best_bin_idx = bin_idx
best_remaining_capacity = remaining_capacity
if best_bin_idx != -1:
# Add to existing bin
bins[best_bin_idx]['items'].append((original_idx, seqlen))
bins[best_bin_idx]['capacity_used'] += seqlen
else:
# Create new bin
bins.append({
'items': [(original_idx, seqlen)],
'capacity_used': seqlen
})
# Post-processing: Try to balance the bins by moving items between them
# This helps reduce the variance in bin loads
_balance_bins(bins, max_capacity)
# Convert to partition format (list of indices for each partition)
partitions = []
for bin_info in bins:
partition = [idx for idx, _ in bin_info['items']]
partitions.append(partition)
return partitions
def _balance_bins(bins: List[dict], max_capacity: int):
"""Helper function to balance loads across bins by moving items between bins.
Parameters:
bins: List of bin dictionaries with 'items' and 'capacity_used' keys
max_capacity: Maximum capacity per bin
"""
if len(bins) <= 1:
return
# Perform multiple passes to improve balance
max_iterations = 3
for _ in range(max_iterations):
improved = False
# Sort bins by current load
bins.sort(key=lambda b: b['capacity_used'])
# Try to move items from heaviest bins to lightest bins
for heavy_idx in range(len(bins) - 1, 0, -1):
heavy_bin = bins[heavy_idx]
for light_idx in range(heavy_idx):
light_bin = bins[light_idx]
# Calculate load difference
load_diff = heavy_bin['capacity_used'] - light_bin['capacity_used']
if load_diff <= 1: # Already balanced enough
break
# Find items in heavy bin that can be moved to light bin
for item_idx, (idx, seqlen) in enumerate(heavy_bin['items']):
new_light_load = light_bin['capacity_used'] + seqlen
new_heavy_load = heavy_bin['capacity_used'] - seqlen
# Check if move is beneficial and doesn't violate capacity
if (new_light_load <= max_capacity and
abs(new_heavy_load - new_light_load) < load_diff):
# Move the item
item = heavy_bin['items'].pop(item_idx)
light_bin['items'].append(item)
heavy_bin['capacity_used'] -= seqlen
light_bin['capacity_used'] += seqlen
improved = True
break
if improved:
break
if improved:
break
if not improved:
break
def rearrange_micro_batches(
seqlen_list: List[int],
max_token_len: int,
@ -205,6 +326,11 @@ def rearrange_micro_batches(
Returns:
List[List[int]]: List of partitions, each containing indices of items.
"""
if is_multimodal():
# When multimodal, max_token_len is a list representing the maximum token length for each modality.
# Use balanced bin packing algorithm with capacity constraints
return balanced_bin_packing(seqlen_list=seqlen_list, max_capacity=max_token_len)
if max(seqlen_list) > max_token_len:
raise ValueError(
f"seqlen of items:[{max(seqlen_list)}] must <= max_token_len:[{max_token_len}]"

View File

@ -65,7 +65,7 @@ rl_config:
blocking: true
gamma: 1.0
lam: 0.95
use_dynamic_bsz: false
use_dynamic_bsz: true
max_packing_token_size: [15000, 4000]
actor_forward_micro_batch_size: 1
ref_forward_micro_batch_size: 1