Fix CI dev test TypeError: unexpected keyword argument 'load_in_4bit' (#4262)

This commit is contained in:
Albert Villanova del Moral
2025-10-15 18:14:49 +02:00
committed by GitHub
parent 773afd9314
commit 7e0adbc552
4 changed files with 21 additions and 7 deletions

View File

@ -90,7 +90,7 @@ model = AutoModelForCausalLMWithValueHead.from_pretrained(
model_name,
peft_config=lora_config,
reward_adapter=rm_adapter_id,
load_in_8bit=True,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
...

View File

@ -642,6 +642,7 @@ class TestDPOTrainer(TrlTestCase):
def test_dpo_lora_bf16_autocast_llama(self):
# Note this test only works on compute capability > 7 GPU devices
from peft import LoraConfig
from transformers import BitsAndBytesConfig
model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5"
tokenizer = AutoTokenizer.from_pretrained(model_id)
@ -655,7 +656,9 @@ class TestDPOTrainer(TrlTestCase):
)
# lora model
model = AutoModelForCausalLM.from_pretrained(model_id, load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
training_args = DPOConfig(
output_dir=self.tmp_dir,
@ -725,6 +728,7 @@ class TestDPOTrainer(TrlTestCase):
)
def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval):
from peft import LoraConfig
from transformers import BitsAndBytesConfig
lora_config = LoraConfig(
r=16,
@ -735,7 +739,9 @@ class TestDPOTrainer(TrlTestCase):
)
# lora model
model = AutoModelForCausalLM.from_pretrained(self.model_id, load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
self.model_id, quantization_config=BitsAndBytesConfig(load_in_4bit=True)
)
training_args = DPOConfig(
output_dir=self.tmp_dir,

View File

@ -101,9 +101,12 @@ class TestPeftModel(TrlTestCase):
Simply creates a peft model and checks that it can be loaded.
"""
from bitsandbytes.nn import Linear8bitLt
from transformers import BitsAndBytesConfig
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(
self.causal_lm_model_id, peft_config=self.lora_config, load_in_8bit=True
self.causal_lm_model_id,
peft_config=self.lora_config,
quantization_config=BitsAndBytesConfig(load_in_8bit=True),
)
# Check that the number of trainable parameters is correct
nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad)
@ -111,7 +114,7 @@ class TestPeftModel(TrlTestCase):
assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt)
causal_lm_model = AutoModelForCausalLM.from_pretrained(
self.causal_lm_model_id, load_in_8bit=True, device_map="auto"
self.causal_lm_model_id, quantization_config=BitsAndBytesConfig(load_in_8bit=True), device_map="auto"
)
trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config)
# Check that the number of trainable parameters is correct

View File

@ -153,8 +153,13 @@ class PreTrainedModelWrapper(nn.Module):
current_device = cls._get_current_device()
if isinstance(pretrained_model_name_or_path, str):
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
quantization_config = pretrained_kwargs.get("quantization_config", None)
if quantization_config is not None:
is_loaded_in_8bit = getattr(quantization_config, "load_in_8bit", False)
is_loaded_in_4bit = getattr(quantization_config, "load_in_4bit", False)
else:
is_loaded_in_8bit = pretrained_kwargs["load_in_8bit"] if "load_in_8bit" in pretrained_kwargs else False
is_loaded_in_4bit = pretrained_kwargs["load_in_4bit"] if "load_in_4bit" in pretrained_kwargs else False
else:
is_loaded_in_8bit = getattr(pretrained_model_name_or_path, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(pretrained_model_name_or_path, "is_loaded_in_4bit", False)