Files
openmind/docs/zh/api_reference/apis/trainer_api.md
2024-11-26 09:24:53 +08:00

66 KiB
Raw Permalink Blame History

Trainer 模块接口

openmind.TrainingArguments类

TrainingArguments 类用于配置训练任务的参数,包括训练过程中所需的超参数、模型保存路径、日志记录选项、学习率等。

参数列表

  • PyTorch和MindSpore的TrainingArguments类共同支持的参数
参数名 PyTorch类型 MindSpore类型 描述 PyTorch默认值 MindSpore默认值
output_dir str str 输出目录。 "./output"
overwrite_output_dir bool bool 是否覆盖输出目录。 False False
seed int int 随机种子。 42 42
use_cpu bool bool 是否使用CPU。 False False
do_train bool bool 是否进行训练。 False False
do_eval bool bool 是否进行评估。 False False
do_predict bool bool 是否进行推理。 False False
num_train_epochs float float 训练的总轮数。 3.0 3.0
resume_from_checkpoint str str 预加载权重。 None None
evaluation_strategy Union[IntervalStrategy, str] Union[IntervalStrategy, str] 评估策略。 "no" "no"
per_device_train_batch_size int int 每个设备的训练批大小。 8 8
per_device_eval_batch_size int int 每个设备的评估批大小。 8 8
per_gpu_train_batch_size int int 每个GPU的训练批大小。不推荐使用 None None
per_gpu_eval_batch_size int int 每个GPU的评估批大小。不推荐使用 None None
gradient_accumulation_steps int int 梯度累积步数。 1 1
ignore_data_skip bool bool 断点续训是否忽略数据跳过。 False False
dataloader_drop_last bool bool 数据加载器是否丢弃最后一批。 False True
dataloader_num_workers int int 数据加载器进程数。 0 8
optim Union[OptimizerNames, str] Union[OptimizerType, str] 优化器。 "adamw_torch" "fp32_adamw"
adam_beta1 float float Adam优化器的beta1。 0.9 0.9
adam_beta2 float float Adam优化器的beta2。 0.999 0.999
adam_epsilon float float Adam优化器的epsilon。 1e-8 1e-8
weight_decay float float 权重衰减。 0.0 0.0
lr_scheduler_type Union[SchedulerType, str] Union[LrSchedulerType, str] 学习率调度器类型。 "linear" "cosine"
learning_rate float float 学习率。 5e-5 5e-5
warmup_ratio float float 预热比率。 0.0 None
warmup_steps int int 预热步数。 0 0
max_grad_norm float float 梯度裁剪的最大范数。 1.0 1.0
logging_strategy Union[IntervalStrategy, str] Union[LoggingIntervalStrategy, str] 日志保存策略。 "steps" "steps"
logging_steps float float 日志保存步数。 500 1
save_steps float float 权重保存步数。 500 500
save_strategy str Union[SaveIntervalStrategy, str] 权重保存策略。 "steps" "steps"
save_total_limit int int 权重最大保存数量限制。 None 5
save_on_each_node bool bool 是否分片保存权重。 False True
hub_model_id str str Hub模型ID。 None None
hub_strategy Union[HubStrategy, str] Union[HubStrategy, str] Hub推送策略。 "every_save" "every_save"
hub_token str str Hub令牌。 None None
hub_private_repo bool bool Hub私有仓库。 False False
hub_always_push bool bool 是否始终推送到Hub。 False False
data_seed int int 数据采样器随机种子数。 None None
eval_steps float float 评估阶段的步骤数。 None None
push_to_hub bool bool 是否推送到Hub。 False False
  • PyTorch独立支持的参数
参数名 类型 描述 默认值
optim_args str 优化器参数。 None
label_names List[str] 标签名称。 None
load_best_model_at_end bool 是否在最后加载最佳模型。 False
metric_for_best_model str 用于最佳模型的指标。 None
greater_is_better bool 指标是否越大越好。 None
label_smoothing_factor float 标签平滑因子。 0.0
include_inputs_for_metrics bool 指标中是否包含输入。 False
prediction_loss_only bool 是否在执行评估和生成预测时,仅返回损失。 False
eval_accumulation_steps int 需要累积输出张量的预测步骤数。 None
eval_delay float 第一次评估需要等待的步骤数。 None
max_steps int 最大训练步数。 -1
lr_scheduler_kwargs dict 调度器的额外参数。 {}
log_level str 日志等级。 "passive"
log_level_replica str 在副本上使用的日志等级。 "warning"
log_on_each_node bool 分布式训练是否只在主节点记录日志。 True
logging_dir str 日志保存目录。 None
logging_first_step bool 是否记录第一个“global_step”。 False
logging_nan_inf_filter bool 是否过滤'nan'和'inf'损失以进行日志记录。 True
save_safetensors bool 是否以safetensor格式保存权重。 True
save_only_model bool 在checkpointing时是否只保存模型状态。 False
jit_mode_eval bool 是否使用PyTorch jit跟踪进行推理。 False
use_ipex bool 是否使用Intel扩展。不支持 False
bf16 bool 是否使用bf16格式。 False
fp16 bool 是否使用fp16格式。 False
tf32 bool 是否使用tf32格式。不支持 None
fp16_opt_level str 权重保存策略。 "O1"
fp16_backend str 指定fp16所使用的后端。 "auto"
half_precision_backend str 定义混精训练所使用的设备。 "auto"
bf16_full_eval bool 是否在评估阶段使用bf16。 False
fp16_full_eval bool 是否在评估阶段使用fp16。 False
disable_tqdm bool 是否禁用进度条工具。 None
remove_unused_columns bool 是否自动删除模型forward方法未使用的列。 True
fsdp Union[List[FSDPOption, str]] 是否使用fsdp。 None
fsdp_config Union[dict, str] fsdp配置。 None
local_rank int 分布式训练的进程号。 -1
tpu_num_cores int 使用TPU训练时使用的内核数。不支持 None
past_index int 使用hidden states用作预测时的index。 -1
ddp_backend str ddp分布式训练所使用的后端。 None
run_name str 运行描述符。 None
deepspeed str deepspeed配置。 None
accelerator_config str accelerate配置。 None
debug Union[str, List[DebugOption]] 启用一个或多个调试功能。 None
length_column_name str 预先计算长度的列名。 "length"
group_by_length bool 是否将训练数据集中长度大致相同的样本组合在一起。 False
ddp_find_unused_parameters bool 'find_unused_parameters'是否传递给'DistributedDataParallel'。 None
report_to List[str] 要报告结果和日志的集成列表。 "all"
ddp_bucket_cap_mb int 'bucket_cap_mb'传递给'DistributedDataParallel'的值。 None
ddp_broadcast_buffers bool 'ddp_broadcast_buffers'的值是否传递给'DistributedDataParallel'。 None
ddp_timeout int ddp调用的超时。 1800
dataloader_pin_memory bool 是否要在数据加载器中固定内存。 True
dataloader_persistent_workers bool 是否保持工作线程数据集实例处于活动状态。 False
dataloader_prefetch_factor int 每个线程提前装载的Batch数。 None
skip_memory_metrics bool 是否跳过将内存探查器报告添加到指标。 True
gradient_checkpointing bool 是否使用梯度检查点来节省内存。 False
gradient_checkpointing_kwargs dict 梯度检查点相关参数。 None
auto_find_batch_size bool 是否通过指数衰减自动找到适合内存的批处理大小。 False
full_determinism bool 调用'enable_full_determinism'而不是'set_seed',以确保分布式训练中的可重复结果。 False
torchdynamo str 指定TorchDynamo的后端编译器。不支持 None
ray_scope str 使用Ray进行超参数搜索时要使用的范围。 "last"
use_mps_device bool 是否使用mps设备。不支持 False
torch_compile bool 是否使用PyTorch 2.0编译模型。(不支持) False
torch_compile_backend str torch.compile所使用的后端。不支持 None
torch_compile_mode str torch.compile模式。不支持 None
split_batches bool 是否将数据加载器生成的批次拆分到设备之间。 None
include_tokens_per_second bool 是否计算每个设备每秒的tokens。 None
include_num_input_tokens_seen bool 是否跟踪在整个训练过程中看到的输入tokens。 None
neftune_noise_alpha float 是否激活NEFTune噪声嵌入。 None
optim_target_modules Union[str, List[str]] 要优化的目标模块。 None
  • MindSpore独立支持的参数
参数名 类型 描述 默认值
only_save_strategy bool 任务是否保存策略文件后直接退出。 False
auto_trans_ckpt bool 是否开启权重自动转换。 False
src_strategy str 预加载权重的分布式策略文件。 None
batch_size int 每个设备的训练批大小。会覆盖per_device_train_batch_size None
sink_mode bool 是否通过通道将数据直接下沉到设备 True
sink_size int 每步训练或评估的数据下沉数量 2
mode int 指示运行在 GRAPH_MODE0 或 PYNATIVE_MODE1 0
resume_training bool 是否开启断点续训。 False
remote_save_url str OBS保存路径。 None
device_id int 设备号。 0
device_target str 执行的目标设备,支持 'Ascend'、'GPU' 和 'CPU'。 "Ascend"
enable_graph_kernel bool 是否启用图融合。 False
graph_kernel_flags str 图融合级别。 "--opt_level=0"
save_graphs bool 是否保存计算图。 False
save_graphs_path str 保存计算图路径。 "./graph"
max_call_depth int 函数调用的最大深度。 10000
max_device_memory str 设备的最大可用内存。 "1024GB"
use_parallel bool 是否开启并行模式。 False
parallel_mode int 指示是否运行于数据并行0、半自动并行1、自动并行2或混合并行3模式。 1
gradients_mean bool 是否在梯度AllReduce后执行平均算子。 False
loss_repeated_mean bool 在重复计算时,是否向后执行均值操作符。 False
enable_alltoall bool 是否允许在通信过程中生成 AllToAll 通信操作符。 False
enable_parallel_optimizer bool 是否开启优化器并行。 False
full_batch bool 如果在自动并行模式下加载整个批处理数据集,则应将 full_batch 设置为 True。当前不建议使用此接口请将其替换为 dataset_strategy。 True
dataset_strategy Union[str, tuple] 数据集分片策略。 "full_batch"
search_mode str 策略搜索模式,仅在自动并行模式下有效,实验性接口,请谨慎使用。 "sharding_propagation"
data_parallel int 数据并行。 1
gradient_accumulation_shard bool 累积梯度变量是否沿着数据并行维度进行分割。 False
parallel_optimizer_threshold int 设置参数分割的阈值。 64
optimizer_weight_shard_size int 设置指定优化器权重分割的通信域大小。 -1
strategy_ckpt_save_file str 保存分布式策略文件的路径。 "./ckpt_strategy.ckpt"
model_parallel int 模型并行。 1
expert_parallel int 专家并行。 1
pipeline_stage int 流水线并行。 1
gradient_aggregation_group int 梯度通信操作融合组的大小。 4
micro_batch_num int 流水线计算最小批次数量。 1
micro_batch_interleave_num int 多副本并行数量。 1
use_seq_parallel bool 是否启用序列并行。 False
vocab_emb_dp bool 是否仅沿着数据并行维度分割词汇表。 True
expert_num int 专家的数量。 1
capacity_factor float 专家因子。 1.05
aux_loss_factor float 损失贡献因子。 0.05
num_experts_chosen int 每个标记选择的专家数量。 1
recompute bool 重计算。 False
select_recompute bool 选择重计算。 False
parallel_optimizer_comm_recompute bool 是否重新计算由优化器并行引入的 AllGather 通信。 False
mp_comm_recompute bool 是否重新计算模型并行引入的通信操作。 True
recompute_slice_activation bool 是否对保留在内存中的 Cell 输出进行切片。 False
layer_scale bool 是否启用层衰减。 False
layer_decay float 层衰减系数。 0.65
lr_end float 最终学习率。 1e-6
warmup_lr_init float 预热阶段的初始学习率。 0.0
warmup_epochs int 在总步数的 warmup_epochs 部分进行线性预热。 None
lr_scale bool 是否启用学习率缩放。 False
lr_scale_factor int 学习率缩放因子。 256
python_multiprocessing bool 是否启动 Python 多进程模式。 False
numa_enable bool 将 NUMA 的默认状态设置为启用状态。 False
prefetch_size int 设置管道中线程的队列容量。 1
wrapper_type str 包装器的类名。 "MFTrainOneStepCell"
scale_sense Union[str, float] scale sense的值或类名。 "DynamicLossScaleUpdateCell"
loss_scale_value int 初始损失缩放因子。 65536
loss_scale_factor int 损失缩放系数的增量和减量因子。 2
loss_scale_window int 增加损失缩放系数的最大连续训练步数。 1000
use_clip_grad bool 是否启用梯度裁剪。 True
train_dataset str 训练集路径。 None
eval_dataset str 评估集路径。 None
dataset_task str 数据集对应的任务类型。 None
dataset_type str 数据集类型。 None
train_dataset_in_columns list[str] 训练集输入标签名称。 None
train_dataset_out_columns list[str] 训练集输出标签名称。 None
eval_dataset_in_columns list[str] 评估集输入标签名称。 None
eval_dataset_out_columns list[str] 评估集输出标签名称。 None
shuffle bool 训练集是否乱序。 True
repeat int 训练集重复次数。 1
metric_type Union[List[str], str] 矩阵类型。 None
save_seconds int 每隔 X 秒保存一次检查点。 None
integrated_save bool 在自动并行场景中是否合并并保存分割的张量。 None
eval_epochs int 每次评估之间的纪元间隔数1 表示每个纪元结束时进行评估。 None
profile bool 是否开启性能分析收集。 False
profile_start_step int 性能分析起始step。 1
profile_end_step int 性能分析结束step。 10
init_start_profile bool 是否在 Profiler 初始化时启用数据收集。 False
profile_communication bool 是否在多设备训练中收集通信性能数据。 False
profile_memory bool 是否收集张量内存数据。 True
auto_tune bool 是否启用自动数据加速。 False
filepath_prefix str 优化后的全局配置的保存路径和文件前缀。 "./autotune"
autotune_per_step int 设置调整自动数据加速配置的步长间隔。 10

train_batch_size

获取训练批大小

接口原型

def train_batch_size()

eval_batch_size

获取评估批大小

接口原型

def eval_batch_size()

world_size

获取并行的进程数量

接口原型

def world_size()

process_index

获取当前进程的索引

接口原型

def process_index()

local_process_index

获取当前本地进程的索引

接口原型

def local_process_index()

should_log

获取当前进程是否应生成日志,当前仅支持PyTorch

接口原型

def should_log()

should_save

获取当前进程是否应写入磁盘,当前仅支持PyTorch

接口原型

def should_save()

_setup_devices

设置设备,当前仅支持PyTorch

接口原型

def _setup_devices()

device

获取当前进程使用的设备,当前仅支持PyTorch

接口原型

def device()

get_process_log_level

获取进程日志级别,当前仅支持PyTorch

接口原型

def get_process_log_level()

main_process_first

主进程优先,当前仅支持PyTorch

接口原型

def main_process_first(local: bool = True, desc: str = "work")

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
local 是否本地。 bool 不支持
desc 工作描述。 str 不支持

get_warmup_steps

获取预热迭代步数

接口原型

def get_warmup_steps(num_training_steps: int)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
num_training_steps 训练迭代步数。 int int

to_dict

将实例序列化为字典

接口原型

def to_dict()

to_json_string

将实例序列化为JSON字符串

接口原型

def to_json_string()

to_sanitized_dict

将实例序列化为可用于TensorBoard的参数字典当前仅支持PyTorch

接口原型

def to_sanitized_dict()

set_training

设置训练参数

接口原型

def set_training(
    learning_rate: float = 5e-5,
    batch_size: int = 8,
    weight_decay: float = 0,
    num_epochs: float = 3,
    max_steps: int = -1,
    gradient_accumulation_steps: int = 1,
    seed: int = 42,
    **kwargs,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
learning_rate 初始学习率。 float float
batch_size 每个设备训练的批量大小。 int int
weight_decay 权重衰减。 float float
num_epochs 执行的总训练周期数。 float float
max_steps 最大训练步数。 int 不支持
gradient_accumulation_steps 累积梯度的更新步数。 int int
seed 在训练开始时设置的随机种子。 int int
kwargs["gradient_checkpointing"] 如果为True则使用梯度检查点来节省内存但反向传播速度会减慢。 bool 不支持

set_evaluate

设置评估参数

接口原型

def set_evaluate(
    strategy: Union[str, IntervalStrategy] = "no",
    steps: int = 500,
    batch_size: int = 8,
    **kwargs,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
strategy 训练过程中采用的评估策略,支持以下取值:
- "no": 在训练过程中不进行评估。
- "steps": 每steps步进行一次评估(并记录日志)。
- "epoch": 在每个周期结束时进行一次评估。
Union[str, IntervalStrategy] Union[str, IntervalStrategy]
steps 如果strategy="steps",则在两次评估之间的更新步数。 int int
batch_size 用于评估的每个设备的批大小。 int int
kwargs["accumulation_steps"] 在将输出张量移动到CPU之前累积预测步骤的输出张量的数量。 int 不支持
kwargs["delay"] 在进行第一次评估之前等待的周期数或步数具体取决于evaluation_strategy。 float 不支持
kwargs["loss_only"] 仅忽略除损失之外的所有输出。 bool 不支持
kwargs["jit_mode"] 是否在推理中使用PyTorch jit。 bool 不支持

set_testing

设置测试参数

接口原型

def set_testing(
    batch_size: int = 8,
    **kwargs,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
batch_size 用于测试的每个设备的批大小。 int int
kwargs["loss_only"] 仅忽略除损失之外的所有输出。 bool bool
kwargs["jit_mode"] 是否在推理中使用PyTorch jit。 bool 不支持

set_save

设置与保存相关的所有参数

接口原型

def set_save(
    strategy: Union[str, IntervalStrategy] = "steps",
    steps: int = 500,
    total_limit: Optional[int] = None,
    on_each_node: bool = False,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
strategy 权重保存策略,支持以下取值:
- "no": 训练期间不保存检查点。
- "epoch": 每个周期结束时保存检查点。
- "steps": 每save_steps步保存检查点。
Union[str, IntervalStrategy] Union[str, IntervalStrategy]
steps 如果strategy="steps",该参数表示两次检查点保存之间的更新步数。 int int
total_limit 限制检查点的总数量。删除output_dir中较旧的检查点。 int int
on_each_node 在进行多节点分布式训练时,是否在每个节点上保存模型和检查点,还是只在主节点上保存。 bool bool

set_logging

设置与日志记录相关的所有参数

接口原型

def set_logging(
    strategy: Union[str, IntervalStrategy] = "steps",
    steps: int = 500,
    report_to: Union[str, List[str]] = "none",
    level: str = "passive",
    first_step: bool = False,
    nan_inf_filter: bool = False,
    on_each_node: bool = False,
    replica_level: str = "passive",
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
strategy 训练日志保存策略,支持以下取值:
- "no": 训练期间不进行日志记录。
- "epoch": 每个周期结束时进行日志记录。
- "steps": 每save_steps步进行日志记录。
Union[str, IntervalStrategy] Union[str, IntervalStrategy]
steps 如果strategy="steps",则在两次日志记录之间的更新步数。 int int
report_to 用于将模型推送到Hub的令牌。 str 不支持
level 主进程上要使用的记录器日志级别。包括:"debug""info""warning""error""critical",以及"passive" str 不支持
first_step 是否记录和评估第一个global_step bool 不支持
nan_inf_filter 是否过滤日志中的naninf损失。 bool 不支持
on_each_node 在分布式训练中,是否在每个节点上使用log_level记录,还是只在主节点上记录。 bool 不支持
replica_level 在副本上使用的记录器日志级别。 str 不支持

set_push_to_hub

设置与Hub同步检查点相关的所有参数

接口原型

def set_push_to_hub(
    model_id: str,
    strategy: Union[str, HubStrategy] = "every_save",
    token: Optional[str] = None,
    private_repo: bool = False,
    always_push: bool = False,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
model_id 与本地output_dir同步的存储库的名称。它可以是一个简单的模型ID在这种情况下模型将被推送到您的命名空间。也可以是整个存储库名称例如"user_name/model" str str
strategy 定义推送到Hub的策略支持以下取值
- "end": 当调用Trainer.save_model方法时,推送模型、其配置、分词器和模型卡片。
- "every_save": 每次保存模型时,推送模型、配置、分词器和模型卡片。推送是异步的,不会阻塞训练,如果保存非常频繁,只有在前一个推送完成后才会尝试新的推送。在训练结束时,最终模型会进行最后一次推送。
- "checkpoint": 类似于 "every_save"但最新的检查点也被推送到名为last-checkpoint的子文件夹中使您可以轻松地使用trainer.train(resume_from_checkpoint="last-checkpoint")恢复训练。
- "all_checkpoints": 类似于"checkpoint",但所有检查点都被推送(因此您将在最终存储库中获得每个文件夹的一个检查点文件夹。)
Union[str, HubStrategy] Union[str, HubStrategy]
token 用于将模型推送到Hub的令牌。 str str
private_repo 如果为True则Hub存储库将设置为私有。 bool bool
always_push 如果为False,当前一个推送未完成时,Trainer将跳过推送检查点。 bool bool

set_optimizer

设置与优化器及其超参数相关的所有参数

接口原型

def set_optimizer(
    name: Union[str, OptimizerNames],
    learning_rate: float = 5e-5,
    weight_decay: float = 0,
    beta1: float = 0.9,
    beta2: float = 0.999,
    epsilon: float = 1e-8,
    **kwargs,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
name 优化器类型。 Union[str, OptimizerNames] Union[str, OptimizerType]
learning_rate 初始学习率。 float float
lr_end 最终学习率。 不支持t float
weight_decay 权重衰减。 float float
beta1 Adam优化器或其变体的beta1超参数。 float float
beta2 Adam优化器或其变体的beta2超参数。 float float
epsilon Adam优化器或其变体的epsilon超参数。 float float
kwargs["args"] 传递给AnyPrecisionAdamW的可选参数仅当optim="adamw_anyprecision"时有用默认None。 str 不支持

set_lr_scheduler

设置与学习率调度器及其超参数相关的所有参数

接口原型

def set_lr_scheduler(
    name: Union[str, SchedulerType] = "linear",
    num_epochs: float = 3.0,
    max_steps: int = -1,
    warmup_ratio: float = 0,
    warmup_steps: int = 0,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
name 学习率调度器类型。 Union[str, SchedulerType] Union[str, LrSchedulerType]
num_epochs 训练的总轮数。 float float
max_steps 最大训练步数。 int 不支持
warmup_ratio 用于从0到 learning_rate 进行线性预热的总训练步数的比率。 float float
warmup_steps 用于从0到learning_rate进行线性预热的步数。 int int

set_dataloader

设置数据加载器

接口原型

def set_dataloader(
    train_batch_size: int = 8,
    eval_batch_size: int = 8,
    drop_last: bool = False,
    num_workers: int = 0,
    ignore_data_skip: bool = False,
    sampler_seed: Optional[int] = None,
    **kwargs,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
train_batch_size 训练的批大小。 int int
eval_batch_size 评估的批大小。 int int
drop_last 是否丢弃最后一个不完整的批次。 bool bool
num_workers 用于数据加载的子进程数量。 int int
ignore_data_skip 在恢复训练时,是否跳过批次和周期以使数据加载处于与上一次训练相同的阶段。 bool bool
sampler_seed 用于数据采样器的随机种子。 int int
kwargs["pin_memory"] 是否要在数据加载器中固定内存默认True。 bool 不支持
kwargs["persistent_workers"] 如果为True则数据加载器在数据集被消耗一次后不会关闭工作进程。这允许保持工作进程的数据集实例处于活动状态。可能会加速训练但会增加RAM使用量默认False。 bool 不支持
kwargs["auto_find_batch_size"] 自动找到适合内存的批大小需要安装accelerate默认False。 bool 不支持
kwargs["prefetch_factor"] 每个进程预先会加载的批次数量。 int 不支持

openmind.Trainer类

Trainer类用于实现模型的训练、评估和推理等功能。它是训练的核心组件,提供了许多方法和功能来管理整个训练过程,包括数据加载、模型前向传播、损失计算、梯度更新等。

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
args 用于配置数据集、超参数、优化器等的任务配置。 TrainingArguments TrainingArguments
task 任务类型。 不支持 str
model 用于训练、评估或进行预测的模型实例。 Union[PreTrainedModel, torch.nn.Module] Union[mindformers.models.PreTrainedModel, str]
model_name 模型名称。 不支持 str
pet_method PET方法名称。 不支持 str
tokenizer 分词器。 PreTrainedTokenizerBase mindformers.models.PreTrainedTokenizerBase
train_dataset 训练数据集。 Dataset Union[str, mindspore.dataset.BaseDataset]
eval_dataset 评估数据集。 Union[Dataset, Dict[str, Dataset]] Union[str, mindspore.dataset.BaseDataset]
data_collator 数据批处理的函数。 DataCollator 不支持
image_processor 图像预处理的处理器。 不支持 mindformers.models.BaseImageProcessor
audio_processor 音频预处理的处理器。 不支持 mindformers.models.BaseAudioProcessor
optimizers 优化器。 Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] mindspore.nn.Optimizer
compute_metrics 评估时计算指标的函数。 Callable[[EvalPrediction], Dict] Union[dict, set]
callbacks 回调函数列表。 List[TrainerCallback] Union[List[mindspore.train.Callback], mindspore.train.Callback]
eval_callbacks 评估回调函数列表。 不支持 Union[List[mindspore.train.Callback], mindspore.train.Callback]
model_init 实例化要使用的模型的函数。 Callable[[], PreTrainedModel] 不支持
preprocess_logits_for_metrics 计算指标前对输出结果预处理函数。 Callable[[torch.Tensor, torch.Tensor], torch.Tensor] 不支持
save_config 保存当前任务的配置。 不支持 bool

train

执行训练步骤

接口原型

def train(*args, **kwargs)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
train_checkpoint 恢复网络的训练权重。 不支持 Union[str, bool]
resume_from_checkpoint 预加载权重。 Union[str, bool] Union[str, bool]
trial 运行的试验或用于超参数搜索的超参数字典。 Union["optuna.Trial", Dict[str, Any]] 不支持
ignore_keys_for_eval 在训练期间用于收集评估预测时,应该忽略的模型输出中的键列表。 List[str] 不支持
resume_training 断点续训开关。 不支持 bool
auto_trans_ckpt 权重自动转换开关。 不支持 bool
src_strategy 预加载权重的分布式策略文件。 不支持 str
do_eval 是否在训练过程中进行评估。 不支持 bool

evaluate

运行评估

接口原型

def evaluate(*args, **kwargs)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
eval_dataset 评估数据集。 Union[Dataset, Dict[str, Dataset]] Union[str, mindspore.dataset.BaseDataset, mindspore.dataset.Dataset, Iterable]
eval_checkpoint 评估网络的权重。 不支持 Union[str, bool]
ignore_keys 在收集预测时,应该忽略模型输出中的键列表。 List[str] 不支持
metric_key_prefix 指标名前缀。 str 不支持
auto_trans_ckpt 权重自动转换开关。 不支持 bool
src_strategy 预加载权重的分布式策略文件。 不支持 str

predict

运行推理

接口原型

def predict(*args, **kwargs)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
predict_checkpoint 推理网络的权重。 不支持 Union[str, bool]
test_dataset 推理数据集。 Dataset 不支持
ignore_keys 在收集预测时,应该忽略模型输出中的键列表。 List[str] 不支持
metric_key_prefix 指标名前缀。 str 不支持
input_data 推理输入数据。 不支持 Union[GeneratorDataset,Tensor, np.ndarray, Image, str, list]
batch_size 批处理大小。 不支持 int
auto_trans_ckpt 权重自动转换开关。 不支持 bool
src_strategy 预加载权重的分布式策略文件。 不支持 str

add_callback

向当前回调列表中添加一个回调函数

接口原型

def add_callback(callback)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
callback 回调函数。 Union[type, TrainerCallback] Union[type, mindspore.train.Callback]

pop_callback

从当前回调列表中删除一个回调并将其返回。如果找不到回调,则返回None(不会引发错误)

接口原型

def pop_callback(callback)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
callback 回调函数。 Union[type, TrainerCallback] Union[type, mindspore.train.Callback]

remove_callback

从当前回调列表中删除一个回调函数

接口原型

def remove_callback(callback)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
callback 回调函数。 Union[type, TrainerCallback] Union[type, mindspore.train.Callback]

save_model

保存模型,以便您可以使用from_pretrained()方法重新加载它

接口原型

def save_model(*args, **kwargs)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
output_dir 模型保存路径。 str str
_internal_call args.push_to_hub为True的情况下当用户调用save_model方法时是否将模型上传到存储库Hub中。默认值为False表示进行推送。 bool bool

init_hf_repo

创建并初始化args.hub_model_id中的git repo。

接口原型

def init_hf_repo()

push_to_hub

modeltokenizer上传到模型存储库Hub中的仓库args.hub_model_id

接口原型

def push_to_hub(
    commit_message: Optional[str] = "End of training",
    blocking: bool = True,
    **kwargs,
)

参数列表

参数名 描述 PyTorch支持类型 MindSpore支持类型
commit_message 推送时的提交消息,默认为"End of training"。 str str
blocking 函数是否应该仅在git push完成时返回默认为True。 bool bool