mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
143 lines
4.9 KiB
Python
143 lines
4.9 KiB
Python
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
# Full training:
|
|
python examples/scripts/gkd.py \
|
|
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
|
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
|
|
--dataset_name trl-lib/chatbot_arena_completions \
|
|
--learning_rate 2e-5 \
|
|
--per_device_train_batch_size 4 \
|
|
--gradient_accumulation_steps 8 \
|
|
--output_dir gkd-model \
|
|
--logging_steps 10 \
|
|
--num_train_epochs 1 \
|
|
--push_to_hub \
|
|
--gradient_checkpointing
|
|
|
|
# LoRA:
|
|
python examples/scripts/gkd.py \
|
|
--model_name_or_path Qwen/Qwen2-0.5B-Instruct \
|
|
--teacher_model_name_or_path Qwen/Qwen2-1.5B-Instruct \
|
|
--dataset_name trl-lib/chatbot_arena_completions \
|
|
--learning_rate 2e-4 \
|
|
--per_device_train_batch_size 4 \
|
|
--gradient_accumulation_steps 8 \
|
|
--output_dir gkd-model \
|
|
--logging_steps 10 \
|
|
--num_train_epochs 1 \
|
|
--push_to_hub \
|
|
--gradient_checkpointing \
|
|
--use_peft \
|
|
--lora_r 64 \
|
|
--lora_alpha 16
|
|
"""
|
|
|
|
from accelerate import PartialState
|
|
from datasets import load_dataset
|
|
from transformers import AutoTokenizer, GenerationConfig
|
|
|
|
from trl import (
|
|
GKDConfig,
|
|
GKDTrainer,
|
|
LogCompletionsCallback,
|
|
ModelConfig,
|
|
ScriptArguments,
|
|
TrlParser,
|
|
get_kbit_device_map,
|
|
get_peft_config,
|
|
get_quantization_config,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = TrlParser((ScriptArguments, GKDConfig, ModelConfig))
|
|
script_args, training_args, model_args = parser.parse_args_and_config()
|
|
|
|
################
|
|
# Model & Tokenizer
|
|
################
|
|
quantization_config = get_quantization_config(model_args)
|
|
model_kwargs = dict(
|
|
revision=model_args.model_revision,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
attn_implementation=model_args.attn_implementation,
|
|
torch_dtype=model_args.torch_dtype,
|
|
use_cache=False if training_args.gradient_checkpointing else True,
|
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
|
quantization_config=quantization_config,
|
|
)
|
|
training_args.model_init_kwargs = model_kwargs
|
|
|
|
teacher_model_kwargs = dict(
|
|
revision=model_args.model_revision,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
attn_implementation=model_args.attn_implementation,
|
|
torch_dtype=model_args.torch_dtype,
|
|
use_cache=True,
|
|
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
|
quantization_config=quantization_config,
|
|
)
|
|
training_args.teacher_model_init_kwargs = teacher_model_kwargs
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
model_args.model_name_or_path,
|
|
revision=model_args.model_revision,
|
|
trust_remote_code=model_args.trust_remote_code,
|
|
padding_side="left",
|
|
)
|
|
if tokenizer.pad_token is None:
|
|
tokenizer.pad_token = tokenizer.eos_token
|
|
|
|
################
|
|
# Dataset
|
|
################
|
|
dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
|
|
|
|
with PartialState().local_main_process_first():
|
|
dataset = dataset.map(
|
|
lambda x: {
|
|
"prompt": tokenizer.apply_chat_template(x["prompt"], tokenize=False, add_generation_prompt=True)
|
|
},
|
|
num_proc=training_args.dataset_num_proc,
|
|
)
|
|
|
|
################
|
|
# Training
|
|
################
|
|
trainer = GKDTrainer(
|
|
model=model_args.model_name_or_path,
|
|
teacher_model=training_args.teacher_model_name_or_path,
|
|
args=training_args,
|
|
train_dataset=dataset[script_args.dataset_train_split],
|
|
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
|
processing_class=tokenizer,
|
|
peft_config=get_peft_config(model_args),
|
|
)
|
|
|
|
if training_args.eval_strategy != "no":
|
|
generation_config = GenerationConfig(
|
|
max_new_tokens=training_args.max_new_tokens, do_sample=True, temperature=training_args.temperature
|
|
)
|
|
completions_callback = LogCompletionsCallback(trainer, generation_config, num_prompts=8)
|
|
trainer.add_callback(completions_callback)
|
|
|
|
trainer.train()
|
|
|
|
# Save and push to hub
|
|
trainer.save_model(training_args.output_dir)
|
|
if training_args.push_to_hub:
|
|
trainer.push_to_hub(dataset_name=script_args.dataset_name)
|