Files
MindSpeed-RL/docs/features/msprobe.md
LookAround a4f41a0ac2 !459 【修改说明】添加PPO相关说明文档以及一些bug fix
Merge pull request !459 from LookAround/ppo_second_pr
2025-07-08 08:47:03 +00:00

6.0 KiB
Raw Permalink Blame History

精度分析Msprobe

概述

msprobe模块为强化学习训练流程提供了配置采集、关键过程数据采集比对、模型层输入输出数据采集比对的能力帮助精度问题分析和调优。

特性使用

注意:当前 msprobe 数据采集仅支持共卡模式integrated场景。

前置条件

安装msprobe三方库安装指南

配置选项

精度分析工具通过 YAML 配置文件中的 msprobe_config 部分进行配置:

msprobe_config:
  msprobe: false
  dump_path: "./msprobe_dump"
  key_data_dump: false
  configurations_dump: false
  actor_infer_dump: false
  token_range_start: 0
  token_range_end: 0
  actor_train_dump: false
  reference_dump: false
  critic_train_dump: false
  step_start: 0
  step_end: 0

配置参数说明

参数 说明 可选值
msprobe 是否使能msprobe true/false开启后下列的采集项才会生效
dump_path 存盘路径 str默认值"./msprobe_dump"
key_data_dump 关键过程数据采集 true/false默认false是否采集关键过程数据包括prompt、response、ref_log_prob、advantage、log_prob、kl_loss、loss的统计量信息最大值、最小值、均值、L2norm值和真实数据
configurations_dump 训练配置采集 true/false默认false是否采集训练配置
actor_infer_dump actor的推理阶段模型层输入输出 true/false默认false是否采集actor_generate_sequences阶段的模型层数据
token_range_start 采集推理生成token的开始范围 int默认0与token_range_end搭配使用表示采集推理生成的从第几个到第几个范围内的token数据
token_range_end 采集推理生成token的结束范围 int默认0如果只想采某一个token的数据设置为跟token_range_start一样
actor_train_dump actor的训练阶段模型层输入输出 true/false默认false是否采集actor_compute_log_prob、actor_update阶段的模型层数据
reference_dump reference的模型层输入输出 true/false默认false是否采集reference的模型层数据
critic_train_dump critic的训练阶段模型层输入输出 true/false默认false是否采集critic_compute_values、critic_update阶段的模型层数据
step_start 采集开始步数 int默认0只对actor_train_dump、reference_dump、actor_generate_sequences、critic_train_dump生效
step_end 采集结束步数 int默认0只对actor_train_dump、reference_dump、actor_generate_sequences、critic_train_dump生效。如果只想采某一步的数据设置为跟step_start一样

落盘数据说明

msprobe_dump/
├── actor_generate_sequences/  # actor_generate_sequences阶段的模型层数据
├── actor_compute_log_prob/  # actor_compute_log_prob阶段的模型层数据
├── critic_compute_values/  # critic_compute_values阶段的模型层数据
├── actor_update/  # actor_update阶段的模型层数据
├── critic_update/  # critic_update阶段的模型层数据
├── reference_compute_log_prob/  # reference的模型层数据
├── data/  # 训练过程关键数据
│   └── advantages/  
│   └── kl_loss/  
│   └── log_prob/
│   └── old_log_prob/
│   └── loss/  
│   └── prompts/  
│   └── ref_log_prob/  
│   └── responses/
│   └── values/  
├── configurations.json  # 训练配置文件

适用场景

精度对齐(例如确定性问题)

  1. 训练关键数据采集。

按如下配置运行两次模型两次需要设置不同的dump_path

msprobe_config:
  msprobe: true
  dump_path: "./msprobe_dump"
  key_data_dump: true

得到两次训练过程的关键阶段性数据,这个数据我们用来定界到模型或代码块。

  1. 将采集到的数据进行比对。

复制如下训练脚本将dump_path1和dump_path2改为前一步中两次采集设置的两个输出路径output_path改为自己的存盘路径执行该脚本

from msprobe.core import SingleComparator
SingleComparator.compare(
    "dump_path1", 
    "dump_path2", 
    "output_path")

得到一个比对结果目录,里面会包含各项关键数据比对结果表格。

  1. 观察结果表格找到首个出现差异的地方例如responses完全一致ref_log_prob存在差异则可以定界到reference model计算存在确定性问题

  2. 再按如下配置运行两次模型两次需要设置不同的dump_pathstep_start和step_end均设置为问题出现的步数

msprobe_config:
  msprobe: true
  dump_path: "./msprobe_dump"
  reference_dump: true
  step_start: 2
  step_end: 2

得到两次训练过程的reference模型层的输入输出数据这个数据我们用来定位问题点。

  1. 将采集到的模型层输入输出进行比对

复制如下训练脚本,将./dump_path1/step2和./dump_path2/step2改为需要比对的step层级路径比如./msprobe_dump/reference_compute_log_prob/step2output_path改为自己的存盘路径执行该脚本

from msprobe.pytorch import *
compare_distributed(
    './dump_path1/step2', 
    './dump_path2/step2', 
    './output_path')

获得一个比对结果表格compare_result_{timestamp}.xlsx。

  1. 在结果表格中找到首个差异点,这就是问题点

以上是一个简单示例,具体可以根据问题情况不同灵活运用该特性

更多功能

关键数据比对指南key_data_dump

模型层数据比对指南actor_train_dump、reference_dump、critic_train_dump