mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
👯 Standardize model_args
(#2442)
* `model_config` -> `model_args` * sort
This commit is contained in:
committed by
GitHub
parent
7ba118a229
commit
460e780265
@ -65,30 +65,28 @@ from trl import (
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = HfArgumentParser((ScriptArguments, RewardConfig, ModelConfig))
|
||||
script_args, training_args, model_config = parser.parse_args_into_dataclasses()
|
||||
script_args, training_args, model_args = parser.parse_args_into_dataclasses()
|
||||
training_args.gradient_checkpointing_kwargs = dict(use_reentrant=False)
|
||||
|
||||
################
|
||||
# Model & Tokenizer
|
||||
################
|
||||
torch_dtype = (
|
||||
model_config.torch_dtype
|
||||
if model_config.torch_dtype in ["auto", None]
|
||||
else getattr(torch, model_config.torch_dtype)
|
||||
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
|
||||
)
|
||||
quantization_config = get_quantization_config(model_config)
|
||||
quantization_config = get_quantization_config(model_args)
|
||||
model_kwargs = dict(
|
||||
revision=model_config.model_revision,
|
||||
revision=model_args.model_revision,
|
||||
device_map=get_kbit_device_map() if quantization_config is not None else None,
|
||||
quantization_config=quantization_config,
|
||||
use_cache=False if training_args.gradient_checkpointing else True,
|
||||
torch_dtype=torch_dtype,
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True
|
||||
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
|
||||
)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_config.model_name_or_path, num_labels=1, trust_remote_code=model_config.trust_remote_code, **model_kwargs
|
||||
model_args.model_name_or_path, num_labels=1, trust_remote_code=model_args.trust_remote_code, **model_kwargs
|
||||
)
|
||||
# Align padding tokens between tokenizer and model
|
||||
model.config.pad_token_id = tokenizer.pad_token_id
|
||||
@ -97,7 +95,7 @@ if __name__ == "__main__":
|
||||
if tokenizer.chat_template is None:
|
||||
model, tokenizer = setup_chat_format(model, tokenizer)
|
||||
|
||||
if model_config.use_peft and model_config.lora_task_type != "SEQ_CLS":
|
||||
if model_args.use_peft and model_args.lora_task_type != "SEQ_CLS":
|
||||
warnings.warn(
|
||||
"You are using a `task_type` that is different than `SEQ_CLS` for PEFT. This will lead to silent bugs"
|
||||
" Make sure to pass --lora_task_type SEQ_CLS when using this script with PEFT.",
|
||||
@ -118,7 +116,7 @@ if __name__ == "__main__":
|
||||
args=training_args,
|
||||
train_dataset=dataset[script_args.dataset_train_split],
|
||||
eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
|
||||
peft_config=get_peft_config(model_config),
|
||||
peft_config=get_peft_config(model_args),
|
||||
)
|
||||
trainer.train()
|
||||
|
||||
|
Reference in New Issue
Block a user