mirror of
https://github.com/huggingface/trl.git
synced 2025-10-20 18:43:52 +08:00
Fix CI dev test TypeError: unexpected keyword argument 'load_in_4bit' (#4262)
This commit is contained in:
committed by
GitHub
parent
773afd9314
commit
7e0adbc552
@ -90,7 +90,7 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
model_name,
|
||||
peft_config=lora_config,
|
||||
reward_adapter=rm_adapter_id,
|
||||
load_in_8bit=True,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
|
||||
...
|
||||
|
@ -642,6 +642,7 @@ class TestDPOTrainer(TrlTestCase):
|
||||
def test_dpo_lora_bf16_autocast_llama(self):
|
||||
# Note this test only works on compute capability > 7 GPU devices
|
||||
from peft import LoraConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
@ -655,7 +656,9 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
@ -725,6 +728,7 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval):
|
||||
from peft import LoraConfig
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
lora_config = LoraConfig(
|
||||
r=16,
|
||||
@ -735,7 +739,9 @@ class TestDPOTrainer(TrlTestCase):
|
||||
)
|
||||
|
||||
# lora model
|
||||
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
self.model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
|
||||
)
|
||||
|
||||
training_args = DPOConfig(
|
||||
output_dir=self.tmp_dir,
|
||||
|
@ -101,9 +101,12 @@ class TestPeftModel(TrlTestCase):
|
||||
Simply creates a peft model and checks that it can be loaded.
|
||||
"""
|
||||
from bitsandbytes.nn import Linear8bitLt
|
||||
from transformers import BitsAndBytesConfig
|
||||
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||
self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True
|
||||
self.causal_lm_model_id,
|
||||
peft_config=self.lora_config,
|
||||
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
|
||||
)
|
||||
# Check that the number of trainable parameters is correct
|
||||
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
|
||||
@ -111,7 +114,7 @@ class TestPeftModel(TrlTestCase):
|
||||
assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
|
||||
|
||||
causal_lm_model = AutoModelForCausalLM.from_pretrained(
|
||||
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
|
||||
self.causal_lm_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
|
||||
)
|
||||
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
|
||||
# Check that the number of trainable parameters is correct
|
||||
|
@ -153,8 +153,13 @@ class PreTrainedModelWrapper(nn.Module):
|
||||
|
||||
current_device = cls._get_current_device()
|
||||
if isinstance(pretrained_model_name_or_path, str):
|
||||
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
|
||||
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
|
||||
quantization_config = pretrained_kwargs.get("quantization_config", None)
|
||||
if quantization_config is not None:
|
||||
is_loaded_in_8bit = getattr(quantization_config, "load_in_8bit", False)
|
||||
is_loaded_in_4bit = getattr(quantization_config, "load_in_4bit", False)
|
||||
else:
|
||||
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
|
||||
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
|
||||
else:
|
||||
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)
|
||||
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False)
|
||||
|
Reference in New Issue
Block a user