mirror of
https://gitee.com/ascend/MindSpeed-RL.git
synced 2025-10-20 16:23:45 +08:00
Compare commits
3 Commits
d6c6a1ad81
...
6cb2e0d61e
Author | SHA1 | Date | |
---|---|---|---|
6cb2e0d61e | |||
d75a6148df | |||
870967acc1 |
@ -35,12 +35,6 @@ def dpo_train():
|
||||
from mindspeed_llm.tasks.posttrain.base.base_trainer import BaseTrainer
|
||||
BaseTrainer.model_provider = model_provider_swap
|
||||
|
||||
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
|
||||
gpt_model_provider, ModelType.encoder_or_decoder)
|
||||
logger.info('after model, optimizer and learning rate scheduler are built')
|
||||
|
||||
model_arch_config = get_model_config(model[0])
|
||||
|
||||
# build tokenizer
|
||||
tokenizer = get_tokenizer(args.tokenizer_name_or_path,
|
||||
prompt_type=args.prompt_type, prompt_type_path=args.prompt_type_path)
|
||||
@ -72,16 +66,6 @@ def dpo_train():
|
||||
)
|
||||
logger.info('after datasets are built')
|
||||
|
||||
# Backward compatibility, assume fixed batch size.
|
||||
if args.iteration > 0 and args.consumed_train_samples == 0:
|
||||
if args.train_samples is not None:
|
||||
raise ValueError('only backward compatiblity support for iteration-based training')
|
||||
args.consumed_train_samples = args.iteration * args.global_batch_size
|
||||
if args.iteration > 0 and args.consumed_valid_samples == 0:
|
||||
if args.train_samples is None:
|
||||
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
|
||||
args.eval_iters * args.global_batch_size
|
||||
|
||||
data_loader = PromptDataLoader(
|
||||
train_dataset, args.global_batch_size,
|
||||
args.num_workers, args.seed, args.dataset_additional_keys,
|
||||
@ -303,7 +287,7 @@ def separate_config_and_parse_args(config):
|
||||
return megatron_config
|
||||
|
||||
|
||||
@hydra.main(config_path='../configs', config_name='dpo_qwen3_30b_a3b', version_base=None)
|
||||
@hydra.main(config_path='../configs', config_name='dpo_qwen3_30b_a3b_A3', version_base=None)
|
||||
def main(config):
|
||||
megatron_config = separate_config_and_parse_args(config)
|
||||
initialize_megatron(config=megatron_config)
|
||||
|
@ -97,6 +97,7 @@ generate_config:
|
||||
# vllm 模型相关设置
|
||||
max_num_seqs: 64
|
||||
max_model_len: 3072
|
||||
max_num_batched_tokens: 8192
|
||||
dtype: "bfloat16"
|
||||
gpu_memory_utilization: 0.6
|
||||
|
||||
|
@ -65,17 +65,6 @@ megatron_training:
|
||||
reset_attention_mask: true
|
||||
```
|
||||
|
||||
对于直接偏好对齐(DPO)算法,通过如下配置可以使能:
|
||||
|
||||
```yaml
|
||||
# 填写在megatron_training
|
||||
megatron_training:
|
||||
variable_seq_lengths: true
|
||||
context_parallel_size: 2
|
||||
context_parallel_algo: megatron_cp_algo
|
||||
cp_attention_mask_type: causal
|
||||
```
|
||||
|
||||
|
||||
其中:
|
||||
|
||||
|
@ -4,7 +4,7 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
|
||||
export HYDRA_FULL_ERROR=1
|
||||
export CLOSE_MATMUL_K_SHIFT=1
|
||||
|
||||
GPUS_PER_NODE=8
|
||||
GPUS_PER_NODE=16
|
||||
MASTER_ADDR=localhost
|
||||
MASTER_PORT=6005
|
||||
NNODES=2
|
||||
@ -20,5 +20,5 @@ DISTRIBUTED_ARGS="
|
||||
"
|
||||
|
||||
torchrun $DISTRIBUTED_ARGS cli/train_dpo.py \
|
||||
--config-name dpo_qwen3_30b_a3b_A2 \
|
||||
--config-name dpo_qwen3_30b_a3b_A3 \
|
||||
| tee logs/RL_dpo_qwen3_30b_a3b_rank${NODE_RANK}.log
|
@ -381,6 +381,11 @@ class MegatronConfig(BaseConfig):
|
||||
self.use_cp_send_recv_overlap = False
|
||||
self.use_fused_ring_attention_update = False
|
||||
self.dpo_loss_type = 'sigmoid'
|
||||
self.ref_model = ''
|
||||
self.refer_model_iter = 1
|
||||
self.wandb_exp_name = ''
|
||||
self.wandb_project = ''
|
||||
self.wandb_save_dir = ''
|
||||
|
||||
self.use_ascend_coc = False
|
||||
self.coc_mode = -1
|
||||
|
@ -215,8 +215,16 @@ class ActorHybridWorkerBase(BaseWorker):
|
||||
if is_multimodal():
|
||||
experience_columns.extend(['attention_mask', 'position_ids'])
|
||||
|
||||
experience_count = self.rl_config.actor_update_dispatch_size
|
||||
|
||||
experience_count = (
|
||||
self.megatron_config.global_batch_size // self.parallel_state.get_data_parallel_world_size()
|
||||
)
|
||||
|
||||
if self.rl_config.filter_groups_enable:
|
||||
experience_count = (
|
||||
self.rl_config.filter_groups_train_batch_size * self.rl_config.n_samples_per_prompt //
|
||||
self.parallel_state.get_data_parallel_world_size()
|
||||
)
|
||||
|
||||
if skip_actor_log_prob:
|
||||
experience_columns.remove('old_log_prob')
|
||||
|
||||
|
Reference in New Issue
Block a user