mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
!451 dynamic batch size 支持不同模块独立设置 max_packing_token_size
Merge pull request !451 from Nurxat/master
This commit is contained in:
@ -68,8 +68,8 @@ rl_config:
|
||||
|
||||
## 参数说明:
|
||||
|
||||
`max_packing_token_size` 是动态批大小(Dynamic Batch Size)机制中的核心参数,用于限制每个拼接后的 micro batch 中 token 的总数,防止因拼接过多序列而导致显存溢出(OOM)。
|
||||
`dynamic_max_batch_size` 用于限制最大的 micro batch,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。
|
||||
`ref_max_packing_token_size`, `actor_max_packing_token_size`, `update_max_packing_token_size` 是动态批大小(Dynamic Batch Size)机制中的核心参数,用于限制每个拼接后的 micro batch 中 token 的总数,防止因拼接过多序列而导致显存溢出(OOM)。
|
||||
`ref_dynamic_max_batch_size`, `actor_dynamic_max_batch_size`, `update_dynamic_max_batch_size` 是控制 Dynamic batch size 分箱之后每个批次中最大的序列条数 micro batch size,防止在长序列训练场景下,有多个短序列放入同一批次导致 micro batch size 过大,进而导致 OOM。
|
||||
|
||||
**使用限制**:每条样本的 token 长度必须满足:
|
||||
```text
|
||||
@ -80,7 +80,7 @@ prompt_length[i] + response_length[i] <= max_packing_token_size
|
||||
```text
|
||||
max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling_config.max_tokens) * 2
|
||||
```
|
||||
`dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM,且发生在计算得出 logits 之后,可以通过设置或减小该值减少显存占用,建议最小设置为2,若设置为1,则 Dynamic Batch Size 无意义。
|
||||
`*_dynamic_max_batch_size` 是可选参数。如果长序列训练过程发生 OOM,且发生在计算得出 logits 之后,可以通过设置减小该值减少显存占用,建议最小设置为2,若设置为1,则 Dynamic Batch Size 无意义。
|
||||
|
||||
二者可以根据实际需求调整。
|
||||
|
||||
@ -93,8 +93,12 @@ max_packing_token_size = (rl_config.max_prompt_length + generate_config.sampling
|
||||
```yaml
|
||||
rl_config:
|
||||
use_dynamic_bsz: true
|
||||
max_packing_token_size: 8192
|
||||
dynamic_max_batch_size: 8 # 可选参数
|
||||
ref_max_packing_token_size: 8192
|
||||
ref_dynamic_max_batch_size: 8 # 可选参数
|
||||
actor_max_packing_token_size: 8192
|
||||
actor_dynamic_max_batch_size: 8 # 可选参数
|
||||
update_max_packing_token_size: 8192
|
||||
update_dynamic_max_batch_size: 8 # 可选参数
|
||||
```
|
||||
|
||||
# 📦 数据并行负载均衡(DP Batch Balance)特性
|
||||
|
@ -130,8 +130,15 @@ class RLConfig(BaseConfig):
|
||||
|
||||
self.use_dynamic_bsz = False
|
||||
self.max_packing_token_size = 4096
|
||||
self.log_max_throughput = True
|
||||
self.ref_max_packing_token_size = self.max_packing_token_size
|
||||
self.actor_max_packing_token_size = self.max_packing_token_size
|
||||
self.update_max_packing_token_size = self.max_packing_token_size
|
||||
self.dynamic_max_batch_size = None
|
||||
self.ref_dynamic_max_batch_size = self.dynamic_max_batch_size
|
||||
self.actor_dynamic_max_batch_size = self.dynamic_max_batch_size
|
||||
self.update_dynamic_max_batch_size = self.dynamic_max_batch_size
|
||||
|
||||
self.log_max_throughput = True
|
||||
|
||||
# token level loss
|
||||
self.token_level_loss = True
|
||||
|
@ -62,16 +62,8 @@ class BaseTrainingEngine(ABC):
|
||||
temperature: float = 1.0,
|
||||
role: str = None,
|
||||
micro_batch_size: int = 1,
|
||||
use_dynamic_bsz: bool = False,
|
||||
max_packing_token_size: bool = 4096,
|
||||
dynamic_max_batch_size: int = None,
|
||||
use_remove_padding: bool = False,
|
||||
set_actual_seq_len: Callable = None,
|
||||
get_actual_seq_len: Callable = None,
|
||||
set_position_ids: Callable = None,
|
||||
forward_backward_func: Callable = None,
|
||||
entropy_coeff: float = 0.0,
|
||||
context_parallel_size: int = 1,
|
||||
kl_penalty: str = "low_var_kl",
|
||||
token_level_loss: bool = False,
|
||||
clip_higher_enable: bool = False,
|
||||
@ -81,13 +73,6 @@ class BaseTrainingEngine(ABC):
|
||||
**kwargs):
|
||||
self.forward_backward_func = forward_backward_func
|
||||
self.micro_batch_size = micro_batch_size
|
||||
self.use_dynamic_bsz = use_dynamic_bsz
|
||||
self.max_packing_token_size = max_packing_token_size
|
||||
self.dynamic_max_batch_size = dynamic_max_batch_size
|
||||
self.use_remove_padding = use_remove_padding
|
||||
self.set_actual_seq_len = set_actual_seq_len
|
||||
self.get_actual_seq_len = get_actual_seq_len
|
||||
self.set_position_ids = set_position_ids
|
||||
self.model = model
|
||||
self.megatron_config = megatron_config
|
||||
self.optimizer = optimizer
|
||||
@ -102,7 +87,6 @@ class BaseTrainingEngine(ABC):
|
||||
self.kl_penalty = kl_penalty
|
||||
self.clip_ratio = clip_ratio
|
||||
self.entropy_coeff = entropy_coeff
|
||||
self.context_parallel_size = context_parallel_size
|
||||
self.temperature = temperature
|
||||
self.token_level_loss = token_level_loss
|
||||
self.clip_higher_enable = clip_higher_enable
|
||||
@ -111,7 +95,21 @@ class BaseTrainingEngine(ABC):
|
||||
self.cliprange_value = cliprange_value
|
||||
self.loss_func: BaseLossFunc = LossFuncFactory.get_instance(self.stage, self.role)
|
||||
self.kwargs = kwargs
|
||||
|
||||
|
||||
self.use_remove_padding = kwargs.get('use_remove_padding', False)
|
||||
self.use_dynamic_bsz = kwargs.get('use_dynamic_bsz', False)
|
||||
self.max_packing_token_size = kwargs.get('ref_max_packing_token_size', None)
|
||||
self.dynamic_max_batch_size = kwargs.get('ref_dynamic_max_batch_size', None)
|
||||
if self.max_packing_token_size is None:
|
||||
self.max_packing_token_size = {'actor': kwargs.get('actor_max_packing_token_size', None),
|
||||
'update': kwargs.get('update_max_packing_token_size', None)}
|
||||
self.dynamic_max_batch_size = {'actor': kwargs.get('actor_dynamic_max_batch_size', None),
|
||||
'update': kwargs.get('update_dynamic_max_batch_size', None)}
|
||||
self.context_parallel_size = kwargs.get('context_parallel_size', 1)
|
||||
self.set_actual_seq_len = kwargs.get('set_actual_seq_len', None)
|
||||
self.get_actual_seq_len = kwargs.get('get_actual_seq_len', None)
|
||||
self.set_position_ids = kwargs.get('set_position_ids', None)
|
||||
|
||||
@staticmethod
|
||||
def _split_batches(batch: Dict, batch_size: int, shuffle_mini_batch: bool, dim: int = 0, keep_list: bool = False) -> List[Dict]:
|
||||
batches = []
|
||||
@ -149,7 +147,13 @@ class BaseTrainingEngine(ABC):
|
||||
|
||||
def _forward_backward_batch(self, batch: Dict[str, torch.Tensor], forward_only: bool = False):
|
||||
if self.use_dynamic_bsz:
|
||||
batches, indices = self._split_batches_with_dynamic_bsz(batch, self.max_packing_token_size, self.dynamic_max_batch_size)
|
||||
if isinstance(self.max_packing_token_size, dict):
|
||||
max_packing_token_size = self.max_packing_token_size['actor'] if forward_only else self.max_packing_token_size['update'] # actor forward or update
|
||||
dynamic_max_batch_size = self.dynamic_max_batch_size['actor'] if forward_only else self.dynamic_max_batch_size['update']
|
||||
else:
|
||||
max_packing_token_size = self.max_packing_token_size # reference forward
|
||||
dynamic_max_batch_size = self.dynamic_max_batch_size
|
||||
batches, indices = self._split_batches_with_dynamic_bsz(batch, max_packing_token_size, dynamic_max_batch_size)
|
||||
else:
|
||||
batches = self._split_batches(batch, batch_size=self.micro_batch_size,
|
||||
shuffle_mini_batch=self.shuffle_mini_batch)
|
||||
|
@ -130,21 +130,24 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
forward_backward_func=self.forward_backward_func,
|
||||
clip_ratio=self.rl_config.clip_ratio,
|
||||
micro_batch_size=self.megatron_config.micro_batch_size,
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
max_packing_token_size=self.rl_config.max_packing_token_size,
|
||||
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
|
||||
use_remove_padding=self.rl_config.use_remove_padding,
|
||||
set_actual_seq_len=self.set_actual_seq_len,
|
||||
get_actual_seq_len=self.get_actual_seq_len,
|
||||
set_position_ids=self.set_position_ids,
|
||||
context_parallel_size=self.megatron_config.context_parallel_size,
|
||||
entropy_coeff=self.rl_config.entropy_coeff,
|
||||
kl_penalty=self.rl_config.kl_penalty,
|
||||
temperature=self.generate_config.sampling_config["temperature"],
|
||||
token_level_loss=self.rl_config.token_level_loss,
|
||||
clip_higher_enable=self.rl_config.clip_higher_enable,
|
||||
clip_ratio_low=self.rl_config.clip_ratio_low,
|
||||
clip_ratio_high=self.rl_config.clip_ratio_high
|
||||
clip_ratio_high=self.rl_config.clip_ratio_high,
|
||||
|
||||
use_remove_padding=self.rl_config.use_remove_padding,
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
actor_max_packing_token_size=self.rl_config.actor_max_packing_token_size,
|
||||
update_max_packing_token_size=self.rl_config.update_max_packing_token_size,
|
||||
actor_dynamic_max_batch_size=self.rl_config.actor_dynamic_max_batch_size,
|
||||
update_dynamic_max_batch_size=self.rl_config.update_dynamic_max_batch_size,
|
||||
set_actual_seq_len=self.set_actual_seq_len,
|
||||
get_actual_seq_len=self.get_actual_seq_len,
|
||||
set_position_ids=self.set_position_ids,
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
self.empty_cache()
|
||||
self.actor_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
|
||||
|
@ -84,7 +84,6 @@ class CriticWorkerBase(BaseWorker):
|
||||
self.critic_offloader.offload_grad()
|
||||
self.critic_offloader.offload_param()
|
||||
|
||||
megatron_module = self.get_megatron_module()
|
||||
self.critic = Critic(
|
||||
self.model,
|
||||
megatron_config=self.megatron_config,
|
||||
@ -99,6 +98,9 @@ class CriticWorkerBase(BaseWorker):
|
||||
forward_backward_func=self.forward_backward_func,
|
||||
clip_ratio=self.rl_config.clip_ratio,
|
||||
micro_batch_size=self.megatron_config.micro_batch_size,
|
||||
entropy_coeff=self.rl_config.entropy_coeff,
|
||||
cliprange_value=self.rl_config.cliprange_value,
|
||||
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
max_packing_token_size=self.rl_config.max_packing_token_size,
|
||||
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
|
||||
@ -106,9 +108,7 @@ class CriticWorkerBase(BaseWorker):
|
||||
set_actual_seq_len=self.set_actual_seq_len,
|
||||
get_actual_seq_len=self.get_actual_seq_len,
|
||||
set_position_ids=self.set_position_ids,
|
||||
context_parallel_size=self.megatron_config.context_parallel_size,
|
||||
entropy_coeff=self.rl_config.entropy_coeff,
|
||||
cliprange_value=self.rl_config.cliprange_value
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
self.empty_cache()
|
||||
self.critic_profiler = profiler_start(self.profiler_config, self.profiler_config.role)
|
||||
|
@ -103,15 +103,16 @@ class IntegratedWorker(ActorHybridWorkerBase, ReferenceWorkerBase, RewardWorkerB
|
||||
stage=self.megatron_config.stage,
|
||||
forward_backward_func=self.forward_backward_func,
|
||||
micro_batch_size=self.megatron_config.micro_batch_size,
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
max_packing_token_size=self.rl_config.max_packing_token_size,
|
||||
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
|
||||
temperature=self.generate_config.sampling_config["temperature"],
|
||||
|
||||
use_remove_padding=self.rl_config.use_remove_padding,
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size,
|
||||
ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
|
||||
set_actual_seq_len=self.set_actual_seq_len,
|
||||
get_actual_seq_len=self.get_actual_seq_len,
|
||||
set_position_ids=self.set_position_ids,
|
||||
context_parallel_size=self.megatron_config.context_parallel_size,
|
||||
temperature=self.generate_config.sampling_config["temperature"]
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
MsProbe.config_init(self.msprobe_config)
|
||||
|
||||
|
@ -82,15 +82,16 @@ class ReferenceWorkerBase(BaseWorker):
|
||||
stage=self.megatron_config.stage,
|
||||
forward_backward_func=self.forward_backward_func,
|
||||
micro_batch_size=self.megatron_config.micro_batch_size,
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
max_packing_token_size=self.rl_config.max_packing_token_size,
|
||||
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
|
||||
temperature=self.generate_config.sampling_config["temperature"],
|
||||
|
||||
use_remove_padding=self.rl_config.use_remove_padding,
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
ref_max_packing_token_size=self.rl_config.ref_max_packing_token_size,
|
||||
ref_dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
|
||||
set_actual_seq_len=self.set_actual_seq_len,
|
||||
get_actual_seq_len=self.get_actual_seq_len,
|
||||
set_position_ids=self.set_position_ids,
|
||||
context_parallel_size=self.megatron_config.context_parallel_size,
|
||||
temperature=self.generate_config.sampling_config["temperature"]
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
|
||||
def init_transfer_dock(self, td, mm_td, sampling_transfer_dock=None):
|
||||
|
@ -73,15 +73,15 @@ class RewardWorkerBase(BaseWorker):
|
||||
stage=self.megatron_config.stage,
|
||||
forward_backward_func=self.forward_backward_func,
|
||||
micro_batch_size=self.megatron_config.micro_batch_size,
|
||||
temperature=self.generate_config.sampling_config["temperature"],
|
||||
use_dynamic_bsz=self.rl_config.use_dynamic_bsz,
|
||||
max_packing_token_size=self.rl_config.max_packing_token_size,
|
||||
dynamic_max_batch_size=self.rl_config.dynamic_max_batch_size,
|
||||
max_packing_token_size=self.rl_config.ref_max_packing_token_size,
|
||||
dynamic_max_batch_size=self.rl_config.ref_dynamic_max_batch_size,
|
||||
use_remove_padding=self.rl_config.use_remove_padding,
|
||||
set_actual_seq_len=self.set_actual_seq_len,
|
||||
get_actual_seq_len=self.get_actual_seq_len,
|
||||
set_position_ids=self.set_position_ids,
|
||||
context_parallel_size=self.megatron_config.context_parallel_size,
|
||||
temperature=self.generate_config.sampling_config["temperature"]
|
||||
context_parallel_size=self.megatron_config.context_parallel_size
|
||||
)
|
||||
|
||||
def init_transfer_dock(self, td, sampling_transfer_dock=None):
|
||||
|
@ -62,7 +62,9 @@ rl_config:
|
||||
blocking: true
|
||||
use_dp_batch_balance: true
|
||||
use_dynamic_bsz: true
|
||||
max_packing_token_size: 8192
|
||||
ref_max_packing_token_size: 8192
|
||||
actor_max_packing_token_size: 8192
|
||||
update_max_packing_token_size: 8192
|
||||
actor_forward_micro_batch_size: 8
|
||||
ref_forward_micro_batch_size: 8
|
||||
use_remove_padding: true
|
||||
|
Reference in New Issue
Block a user