!451 dynamic batch size 支持不同模块独立设置 max_packing_token_size

Merge pull request !451 from Nurxat/master
This commit is contained in:
Nurxat
2025-08-19 12:40:35 +00:00
committed by i-robot
parent 014f3228b2
commit dd73eb3dd4
9 changed files with 74 additions and 52 deletions

View File

@ -68,8 +68,8 @@ rl_config:
## 参数说明:
`max_packing_token_size` 是动态批大小Dynamic Batch Size机制中的核心参数用于限制每个拼接后的 micro batch 中 token 的总数防止因拼接过多序列而导致显存溢出OOM
`dynamic_max_batch_size` 用于限制最大的 micro batch防止在长序列训练场景下有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。
`ref_max_packing_token_size`, `actor_max_packing_token_size`, `update_max_packing_token_size` 是动态批大小Dynamic Batch Size机制中的核心参数用于限制每个拼接后的 micro batch 中 token 的总数防止因拼接过多序列而导致显存溢出OOM
`ref_dynamic_max_batch_size`, `actor_dynamic_max_batch_size`, `update_dynamic_max_batch_size` 是控制 Dynamic batch size 分箱之后每个批次中最大的序列条数 micro batch size,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。
**使用限制**:每条样本的 token 长度必须满足:
```text
@ -80,7 +80,7 @@ prompt_length[i] + response_length[i] <= max_packing_token_size
```text
max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling_config.max_tokens) * 2
```
`dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM且发生在计算得出 logits 之后,可以通过设置减小该值减少显存占用建议最小设置为2若设置为1则 Dynamic Batch Size 无意义。
`*_dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM且发生在计算得出 logits 之后可以通过设置减小该值减少显存占用建议最小设置为2若设置为1则 Dynamic Batch Size 无意义。
二者可以根据实际需求调整。
@ -93,8 +93,12 @@ max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling
```yaml
rl_config:
use_dynamic_bsz: true
max_packing_token_size: 8192
dynamic_max_batch_size: 8 # 可选参数
ref_max_packing_token_size: 8192
ref_dynamic_max_batch_size: 8 # 可选参数
actor_max_packing_token_size: 8192
actor_dynamic_max_batch_size: 8 # 可选参数
update_max_packing_token_size: 8192
update_dynamic_max_batch_size: 8 # 可选参数
```
# 📦 数据并行负载均衡DP Batch Balance特性

View File

@ -130,8 +130,15 @@ class RLConfig(BaseConfig):
self.use_dynamic_bsz = False
self.max_packing_token_size = 4096
self.log_max_throughput = True
self.ref_max_packing_token_size = self.max_packing_token_size
self.actor_max_packing_token_size = self.max_packing_token_size
self.update_max_packing_token_size = self.max_packing_token_size
self.dynamic_max_batch_size = None
self.ref_dynamic_max_batch_size = self.dynamic_max_batch_size
self.actor_dynamic_max_batch_size = self.dynamic_max_batch_size
self.update_dynamic_max_batch_size = self.dynamic_max_batch_size
self.log_max_throughput = True
# token level loss
self.token_level_loss = True

View File

@ -62,16 +62,8 @@ class BaseTrainingEngine(ABC):
temperature: float = 1.0,
role: str = None,
micro_batch_size: int = 1,
use_dynamic_bsz: bool = False,
max_packing_token_size: bool = 4096,
dynamic_max_batch_size: int = None,
use_remove_padding: bool = False,
set_actual_seq_len: Callable = None,
get_actual_seq_len: Callable = None,
set_position_ids: Callable = None,
forward_backward_func: Callable = None,
entropy_coeff: float = 0.0,
context_parallel_size: int = 1,
kl_penalty: str = "low_var_kl",
token_level_loss: bool = False,
clip_higher_enable: bool = False,
@ -81,13 +73,6 @@ class BaseTrainingEngine(ABC):
**kwargs):
self.forward_backward_func = forward_backward_func
self.micro_batch_size = micro_batch_size
self.use_dynamic_bsz = use_dynamic_bsz
self.max_packing_token_size = max_packing_token_size
self.dynamic_max_batch_size = dynamic_max_batch_size
self.use_remove_padding = use_remove_padding
self.set_actual_seq_len = set_actual_seq_len
self.get_actual_seq_len = get_actual_seq_len
self.set_position_ids = set_position_ids
self.model = model
self.megatron_config = megatron_config
self.optimizer = optimizer
@ -102,7 +87,6 @@ class BaseTrainingEngine(ABC):
self.kl_penalty = kl_penalty
self.clip_ratio = clip_ratio
self.entropy_coeff = entropy_coeff
self.context_parallel_size = context_parallel_size
self.temperature = temperature
self.token_level_loss = token_level_loss
self.clip_higher_enable = clip_higher_enable
@ -111,7 +95,21 @@ class BaseTrainingEngine(ABC):
self.cliprange_value = cliprange_value
self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role)
self.kwargs = kwargs
self.use_remove_padding = kwargs.get('use_remove_padding', False)
self.use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False)
self.max_packing_token_size = kwargs.get('ref_max_packing_token_size', None)
self.dynamic_max_batch_size = kwargs.get('ref_dynamic_max_batch_size', None)
if self.max_packing_token_size is None:
self.max_packing_token_size = {'actor': kwargs.get('actor_max_packing_token_size', None),
'update': kwargs.get('update_max_packing_token_size', None)}
self.dynamic_max_batch_size = {'actor': kwargs.get('actor_dynamic_max_batch_size', None),
'update': kwargs.get('update_dynamic_max_batch_size', None)}
self.context_parallel_size = kwargs.get('context_parallel_size', 1)
self.set_actual_seq_len = kwargs.get('set_actual_seq_len', None)
self.get_actual_seq_len = kwargs.get('get_actual_seq_len', None)
self.set_position_ids = kwargs.get('set_position_ids', None)
@staticmethod
def _split_batches(batch: Dict, batch_size: int, shuffle_mini_batch: bool, dim: int = 0, keep_list: bool = False) -> List[Dict]:
batches = []
@ -149,7 +147,13 @@ class BaseTrainingEngine(ABC):
def _forward_backward_batch(self, batch: Dict[str, torch.Tensor], forward_only: bool = False):
if self.use_dynamic_bsz:
batches, indices = self._split_batches_with_dynamic_bsz(batch, self.max_packing_token_size, self.dynamic_max_batch_size)
if isinstance(self.max_packing_token_size, dict):
max_packing_token_size = self.max_packing_token_size['actor'] if forward_only else self.max_packing_token_size['update'] # actor forward or update
dynamic_max_batch_size = self.dynamic_max_batch_size['actor'] if forward_only else self.dynamic_max_batch_size['update']
else:
max_packing_token_size = self.max_packing_token_size # reference forward
dynamic_max_batch_size = self.dynamic_max_batch_size
batches, indices = self._split_batches_with_dynamic_bsz(batch, max_packing_token_size, dynamic_max_batch_size)
else:
batches = self._split_batches(batch, batch_size=self.micro_batch_size,
shuffle_mini_batch=self.shuffle_mini_batch)

View File

@ -130,21 +130,24 @@ class ActorHybridWorkerBase(BaseWorker):
forward_backward_func=self.forward_backward_func,
clip_ratio=self.rl_config.clip_ratio,
micro_batch_size=self.megatron_config.micro_batch_size,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
max_packing_token_size=self.rl_config.max_packing_token_size,
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
use_remove_padding=self.rl_config.use_remove_padding,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size,
entropy_coeff=self.rl_config.entropy_coeff,
kl_penalty=self.rl_config.kl_penalty,
temperature=self.generate_config.sampling_config["temperature"],
token_level_loss=self.rl_config.token_level_loss,
clip_higher_enable=self.rl_config.clip_higher_enable,
clip_ratio_low=self.rl_config.clip_ratio_low,
clip_ratio_high=self.rl_config.clip_ratio_high
clip_ratio_high=self.rl_config.clip_ratio_high,
use_remove_padding=self.rl_config.use_remove_padding,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
actor_max_packing_token_size=self.rl_config.actor_max_packing_token_size,
update_max_packing_token_size=self.rl_config.update_max_packing_token_size,
actor_dynamic_max_batch_size=self.rl_config.actor_dynamic_max_batch_size,
update_dynamic_max_batch_size=self.rl_config.update_dynamic_max_batch_size,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size
)
self.empty_cache()
self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role)

View File

@ -84,7 +84,6 @@ class CriticWorkerBase(BaseWorker):
self.critic_offloader.offload_grad()
self.critic_offloader.offload_param()
megatron_module = self.get_megatron_module()
self.critic = Critic(
self.model,
megatron_config=self.megatron_config,
@ -99,6 +98,9 @@ class CriticWorkerBase(BaseWorker):
forward_backward_func=self.forward_backward_func,
clip_ratio=self.rl_config.clip_ratio,
micro_batch_size=self.megatron_config.micro_batch_size,
entropy_coeff=self.rl_config.entropy_coeff,
cliprange_value=self.rl_config.cliprange_value,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
max_packing_token_size=self.rl_config.max_packing_token_size,
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
@ -106,9 +108,7 @@ class CriticWorkerBase(BaseWorker):
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size,
entropy_coeff=self.rl_config.entropy_coeff,
cliprange_value=self.rl_config.cliprange_value
context_parallel_size=self.megatron_config.context_parallel_size
)
self.empty_cache()
self.critic_profiler = profiler_start(self.profiler_config, self.profiler_config.role)

View File

@ -103,15 +103,16 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
max_packing_token_size=self.rl_config.max_packing_token_size,
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
temperature=self.generate_config.sampling_config["temperature"],
use_remove_padding=self.rl_config.use_remove_padding,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size,
ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size,
temperature=self.generate_config.sampling_config["temperature"]
context_parallel_size=self.megatron_config.context_parallel_size
)
MsProbe.config_init(self.msprobe_config)

View File

@ -82,15 +82,16 @@ class ReferenceWorkerBase(BaseWorker):
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
max_packing_token_size=self.rl_config.max_packing_token_size,
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
temperature=self.generate_config.sampling_config["temperature"],
use_remove_padding=self.rl_config.use_remove_padding,
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size,
ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size,
temperature=self.generate_config.sampling_config["temperature"]
context_parallel_size=self.megatron_config.context_parallel_size
)
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):

View File

@ -73,15 +73,15 @@ class RewardWorkerBase(BaseWorker):
stage=self.megatron_config.stage,
forward_backward_func=self.forward_backward_func,
micro_batch_size=self.megatron_config.micro_batch_size,
temperature=self.generate_config.sampling_config["temperature"],
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
max_packing_token_size=self.rl_config.max_packing_token_size,
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
max_packing_token_size=self.rl_config.ref_max_packing_token_size,
dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
use_remove_padding=self.rl_config.use_remove_padding,
set_actual_seq_len=self.set_actual_seq_len,
get_actual_seq_len=self.get_actual_seq_len,
set_position_ids=self.set_position_ids,
context_parallel_size=self.megatron_config.context_parallel_size,
temperature=self.generate_config.sampling_config["temperature"]
context_parallel_size=self.megatron_config.context_parallel_size
)
def init_transfer_dock(self, td, sampling_transfer_dock=None):

View File

@ -62,7 +62,9 @@ rl_config:
blocking: true
use_dp_batch_balance: true
use_dynamic_bsz: true
max_packing_token_size: 8192
ref_max_packing_token_size: 8192
actor_max_packing_token_size: 8192
update_max_packing_token_size: 8192
actor_forward_micro_batch_size: 8
ref_forward_micro_batch_size: 8
use_remove_padding: true