Compare commits

...

3 Commits

Author SHA1 Message Date
yzb
6cb2e0d61e !555 add config
Merge pull request !555 from yzb/master
2025-08-25 08:10:29 +00:00
d75a6148df !553 [bug_fix]experience_count
Merge pull request !553 from zhoubeirong/grpo_32b
2025-08-25 06:33:24 +00:00
870967acc1 !549 [pytorch][refact]DPO continue pretraining
Merge pull request !549 from shenjiarun/master
2025-08-25 01:09:48 +00:00
7 changed files with 19 additions and 32 deletions

View File

@ -35,12 +35,6 @@ def dpo_train():
from mindspeed_llm.tasks.posttrain.base.base_trainer import BaseTrainer
BaseTrainer.model_provider = model_provider_swap
model, optimizer, opt_param_scheduler = setup_model_and_optimizer(
gpt_model_provider, ModelType.encoder_or_decoder)
logger.info('after model, optimizer and learning rate scheduler are built')
model_arch_config = get_model_config(model[0])
# build tokenizer
tokenizer = get_tokenizer(args.tokenizer_name_or_path,
prompt_type=args.prompt_type, prompt_type_path=args.prompt_type_path)
@ -72,16 +66,6 @@ def dpo_train():
)
logger.info('after datasets are built')
# Backward compatibility, assume fixed batch size.
if args.iteration > 0 and args.consumed_train_samples == 0:
if args.train_samples is not None:
raise ValueError('only backward compatiblity support for iteration-based training')
args.consumed_train_samples = args.iteration * args.global_batch_size
if args.iteration > 0 and args.consumed_valid_samples == 0:
if args.train_samples is None:
args.consumed_valid_samples = (args.iteration // args.eval_interval) * \
args.eval_iters * args.global_batch_size
data_loader = PromptDataLoader(
train_dataset, args.global_batch_size,
args.num_workers, args.seed, args.dataset_additional_keys,
@ -303,7 +287,7 @@ def separate_config_and_parse_args(config):
return megatron_config
@hydra.main(config_path='../configs', config_name='dpo_qwen3_30b_a3b', version_base=None)
@hydra.main(config_path='../configs', config_name='dpo_qwen3_30b_a3b_A3', version_base=None)
def main(config):
megatron_config = separate_config_and_parse_args(config)
initialize_megatron(config=megatron_config)

View File

@ -97,6 +97,7 @@ generate_config:
# vllm 模型相关设置
max_num_seqs: 64
max_model_len: 3072
max_num_batched_tokens: 8192
dtype: "bfloat16"
gpu_memory_utilization: 0.6

View File

@ -65,17 +65,6 @@ megatron_training:
reset_attention_mask: true
```
对于直接偏好对齐DPO算法通过如下配置可以使能
```yaml
# 填写在megatron_training
megatron_training:
variable_seq_lengths: true
context_parallel_size: 2
context_parallel_algo: megatron_cp_algo
cp_attention_mask_type: causal
```
其中:

View File

@ -4,7 +4,7 @@ export PYTORCH_NPU_ALLOC_CONF=expandable_segments:True
export HYDRA_FULL_ERROR=1
export CLOSE_MATMUL_K_SHIFT=1
GPUS_PER_NODE=8
GPUS_PER_NODE=16
MASTER_ADDR=localhost
MASTER_PORT=6005
NNODES=2
@ -20,5 +20,5 @@ DISTRIBUTED_ARGS="
"
torchrun $DISTRIBUTED_ARGS cli/train_dpo.py \
--config-name dpo_qwen3_30b_a3b_A2 \
--config-name dpo_qwen3_30b_a3b_A3 \
| tee logs/RL_dpo_qwen3_30b_a3b_rank${NODE_RANK}.log

View File

@ -381,6 +381,11 @@ class MegatronConfig(BaseConfig):
self.use_cp_send_recv_overlap = False
self.use_fused_ring_attention_update = False
self.dpo_loss_type = 'sigmoid'
self.ref_model = ''
self.refer_model_iter = 1
self.wandb_exp_name = ''
self.wandb_project = ''
self.wandb_save_dir = ''
self.use_ascend_coc = False
self.coc_mode = -1

View File

@ -215,8 +215,16 @@ class ActorHybridWorkerBase(BaseWorker):
if is_multimodal():
experience_columns.extend(['attention_mask', 'position_ids'])
experience_count = self.rl_config.actor_update_dispatch_size
experience_count = (
self.megatron_config.global_batch_size // self.parallel_state.get_data_parallel_world_size()
)
if self.rl_config.filter_groups_enable:
experience_count = (
self.rl_config.filter_groups_train_batch_size * self.rl_config.n_samples_per_prompt //
self.parallel_state.get_data_parallel_world_size()
)
if skip_actor_log_prob:
experience_columns.remove('old_log_prob')