[Bugfix] Check bnb_4bit_quant_storage for bitsandbytes (#10642)

This commit is contained in:
Michael Goin
2024-11-26 15:29:00 -05:00
committed by GitHub
parent 9a99273b48
commit 7576cd38df

View File

@ -20,6 +20,7 @@ class BitsAndBytesConfig(QuantizationConfig):
load_in_8bit: bool = False,
load_in_4bit: bool = True,
bnb_4bit_compute_dtype: str = "float32",
bnb_4bit_quant_storage: str = "uint8",
bnb_4bit_quant_type: str = "fp4",
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
@ -31,6 +32,7 @@ class BitsAndBytesConfig(QuantizationConfig):
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
self.bnb_4bit_quant_storage = bnb_4bit_quant_storage
self.bnb_4bit_quant_type = bnb_4bit_quant_type
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
@ -38,10 +40,15 @@ class BitsAndBytesConfig(QuantizationConfig):
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold
if self.bnb_4bit_quant_storage not in ["uint8"]:
raise ValueError("Unsupported bnb_4bit_quant_storage: "
f"{self.bnb_4bit_quant_storage}")
def __repr__(self) -> str:
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_storage={self.bnb_4bit_quant_storage}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@ -80,6 +87,9 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_compute_dtype = get_safe_value(config,
["bnb_4bit_compute_dtype"],
default_value="float32")
bnb_4bit_quant_storage = get_safe_value(config,
["bnb_4bit_quant_storage"],
default_value="uint8")
bnb_4bit_quant_type = get_safe_value(config, ["bnb_4bit_quant_type"],
default_value="fp4")
bnb_4bit_use_double_quant = get_safe_value(
@ -99,6 +109,7 @@ class BitsAndBytesConfig(QuantizationConfig):
load_in_8bit=load_in_8bit,
load_in_4bit=load_in_4bit,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
bnb_4bit_quant_storage=bnb_4bit_quant_storage,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
llm_int8_enable_fp32_cpu_offload=llm_int8_enable_fp32_cpu_offload,