[Model] Add BNB quantization support for Mllama (#9720)

This commit is contained in:
Isotr0py
2024-10-29 20:20:02 +08:00
committed by GitHub
parent ef7865b4f9
commit 09500f7dde
3 changed files with 84 additions and 12 deletions

View File

@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
import torch
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -23,7 +24,7 @@ class BitsAndBytesConfig(QuantizationConfig):
bnb_4bit_use_double_quant: bool = False,
llm_int8_enable_fp32_cpu_offload: bool = False,
llm_int8_has_fp16_weight: bool = False,
llm_int8_skip_modules: Optional[Any] = None,
llm_int8_skip_modules: Optional[List[str]] = None,
llm_int8_threshold: float = 0.0,
) -> None:
@ -34,11 +35,15 @@ class BitsAndBytesConfig(QuantizationConfig):
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_skip_modules = llm_int8_skip_modules or []
self.llm_int8_threshold = llm_int8_threshold
def __repr__(self) -> str:
return "BitsAndBytesConfig"
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
f"load_in_4bit={self.load_in_4bit}, "
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
@classmethod
def get_name(self) -> str:
@ -102,8 +107,10 @@ class BitsAndBytesConfig(QuantizationConfig):
llm_int8_threshold=llm_int8_threshold)
def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
prefix: str) -> Optional["LinearMethodBase"]:
if isinstance(layer, LinearBase):
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
return UnquantizedLinearMethod()
return BitsAndBytesLinearMethod(self)
return None
@ -111,6 +118,10 @@ class BitsAndBytesConfig(QuantizationConfig):
return []
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
return any(module_name in prefix for module_name in llm_int8_skip_modules)
class BitsAndBytesLinearMethod(LinearMethodBase):
"""Linear method for BitsAndBytes.
@ -211,6 +222,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
from bitsandbytes import MatmulLtState, matmul
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
@ -265,6 +281,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out = out.to(original_type)
if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))
if bias is not None:
out += bias
@ -282,6 +301,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
from bitsandbytes import matmul_4bit
original_type = x.dtype
original_shape = x.shape
reshape_after_matmul = False
if x.ndim > 2:
x = x.reshape(-1, x.size(-1))
reshape_after_matmul = True
bf_x = x.to(torch.bfloat16)
qweight = layer.qweight
@ -310,6 +334,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
out = out.to(original_type)
if reshape_after_matmul:
out = out.view(*original_shape[:-1], out.size(-1))
if bias is not None:
out += bias

View File

@ -899,6 +899,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
return self._unquantized_generator(hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix)
for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
"bitsandbytes"
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
quant_state_dict) -> Generator:
for weight_name, weight_tensor in self._hf_weight_iter(
@ -912,7 +925,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.endswith((".weight", ".bias")):
if self._is_8bit_weight_name(weight_name):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
@ -932,7 +945,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")):
if not self._is_4bit_weight_name(weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
@ -956,7 +969,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors):
if not weight_name.endswith((".weight", ".bias")):
if self._is_4bit_weight_name(weight_name):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \

View File

@ -325,7 +325,10 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
# TODO: support other attention backends for attention in vision model
class MllamaVisionSdpaAttention(nn.Module):
def __init__(self, config: config_mllama.MllamaVisionConfig):
def __init__(self,
config: config_mllama.MllamaVisionConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
super().__init__()
model_parallel_size = get_tensor_model_parallel_world_size()
@ -341,12 +344,16 @@ class MllamaVisionSdpaAttention(nn.Module):
self.head_dim,
self.num_heads,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.o_proj = RowParallelLinear(
self.num_heads * self.head_dim,
self.embed_dim,
bias=False,
input_is_parallel=True,
quant_config=quant_config,
prefix=f"{prefix}.o_proj",
)
def forward(
@ -393,7 +400,8 @@ class MllamaVisionEncoderLayer(nn.Module):
self.is_gated = is_gated
self.intermediate_size = config.intermediate_size
self.self_attn = MllamaVisionSdpaAttention(config)
self.self_attn = MllamaVisionSdpaAttention(
config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
self.mlp = CLIPMLP(config,
quant_config=quant_config,
prefix=f"{prefix}.mlp")
@ -1002,6 +1010,7 @@ class MllamaForCausalLM(nn.Module):
org_num_embeddings=config.vocab_size,
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
quant_config=quant_config,
prefix=f"{prefix}.lm_head",
)
def forward(
@ -1037,6 +1046,26 @@ class MllamaForCausalLM(nn.Module):
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
# in TP, these weights are partitioned along the column dimension (dim=-1)
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
def __init__(self,
config: config_mllama.MllamaConfig,
@ -1061,10 +1090,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
quant_config=quant_config,
prefix="language_model",
)
self.multi_modal_projector = nn.Linear(
self.multi_modal_projector = ColumnParallelLinear(
config.vision_config.vision_output_dim,
config.text_config.hidden_size,
bias=True,
quant_config=quant_config,
gather_output=True,
prefix="multi_modal_projector",
)
self.logits_processor = LogitsProcessor(config.output_hidden_states,
config.text_config.vocab_size)
@ -1128,7 +1160,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
raise ValueError("No images provided.")
max_num_tiles = max(
max([len(x) for x in y[0]]) for y in pixel_values)
device = self.multi_modal_projector.weight.device
device = next(self.multi_modal_projector.parameters()).device
bsz = len(pixel_values)
out_num_tiles = []
out_images = torch.zeros(
@ -1204,7 +1236,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
cross_attention_states = self.vision_model(pixel_values,
aspect_ratio_ids,
aspect_ratio_mask)
cross_attention_states = self.multi_modal_projector(
cross_attention_states, _ = self.multi_modal_projector(
cross_attention_states)
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)