👯 Standardize model_args (#2442)

* `model_config` -> `model_args`

* sort
This commit is contained in:
Quentin Gallouédec
2024-12-10 12:51:20 +01:00
committed by GitHub
parent 7ba118a229
commit 460e780265
20 changed files with 184 additions and 203 deletions

View File

@ -65,30 +65,28 @@ from trl import (
if __name__ == "__main__":
parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig))
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
################
# Model & Tokenizer
################
torch_dtype = (
model_config.torch_dtype
if model_config.torch_dtype in ["auto", None]
else getattr(torch, model_config.torch_dtype)
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_config)
quantization_config = get_quantization_config(model_args)
model_kwargs = dict(
revision=model_config.model_revision,
revision=model_args.model_revision,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
use_cache=False if training_args.gradient_checkpointing else True,
torch_dtype=torch_dtype,
)
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
)
model = AutoModelForSequenceClassification.from_pretrained(
model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs
model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs
)
# Align padding tokens between tokenizer and model
model.config.pad_token_id = tokenizer.pad_token_id
@ -97,7 +95,7 @@ if __name__ == "__main__":
if tokenizer.chat_template is None:
model, tokenizer = setup_chat_format(model, tokenizer)
if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS":
warnings.warn(
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
@ -118,7 +116,7 @@ if __name__ == "__main__":
args=training_args,
train_dataset=dataset[script_args.dataset_train_split],
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
peft_config=get_peft_config(model_config),
peft_config=get_peft_config(model_args),
)
trainer.train()