mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model] Add BNB quantization support for Mllama (#9720)
This commit is contained in:
@ -3,6 +3,7 @@ from typing import Any, Dict, List, Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
@ -23,7 +24,7 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
bnb_4bit_use_double_quant: bool = False,
|
||||
llm_int8_enable_fp32_cpu_offload: bool = False,
|
||||
llm_int8_has_fp16_weight: bool = False,
|
||||
llm_int8_skip_modules: Optional[Any] = None,
|
||||
llm_int8_skip_modules: Optional[List[str]] = None,
|
||||
llm_int8_threshold: float = 0.0,
|
||||
) -> None:
|
||||
|
||||
@ -34,11 +35,15 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules or []
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return "BitsAndBytesConfig"
|
||||
return (f"BitsAndBytesConfig(load_in_8bit={self.load_in_8bit}, "
|
||||
f"load_in_4bit={self.load_in_4bit}, "
|
||||
f"bnb_4bit_compute_dtype={self.bnb_4bit_compute_dtype}, "
|
||||
f"bnb_4bit_quant_type={self.bnb_4bit_quant_type}, "
|
||||
f"llm_int8_skip_modules={self.llm_int8_skip_modules})")
|
||||
|
||||
@classmethod
|
||||
def get_name(self) -> str:
|
||||
@ -102,8 +107,10 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
llm_int8_threshold=llm_int8_threshold)
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["BitsAndBytesLinearMethod"]:
|
||||
prefix: str) -> Optional["LinearMethodBase"]:
|
||||
if isinstance(layer, LinearBase):
|
||||
if is_layer_skipped_bnb(prefix, self.llm_int8_skip_modules):
|
||||
return UnquantizedLinearMethod()
|
||||
return BitsAndBytesLinearMethod(self)
|
||||
return None
|
||||
|
||||
@ -111,6 +118,10 @@ class BitsAndBytesConfig(QuantizationConfig):
|
||||
return []
|
||||
|
||||
|
||||
def is_layer_skipped_bnb(prefix: str, llm_int8_skip_modules: List[str]):
|
||||
return any(module_name in prefix for module_name in llm_int8_skip_modules)
|
||||
|
||||
|
||||
class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
"""Linear method for BitsAndBytes.
|
||||
|
||||
@ -211,6 +222,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
from bitsandbytes import MatmulLtState, matmul
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
@ -265,6 +281,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
@ -282,6 +301,11 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
from bitsandbytes import matmul_4bit
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
if x.ndim > 2:
|
||||
x = x.reshape(-1, x.size(-1))
|
||||
reshape_after_matmul = True
|
||||
bf_x = x.to(torch.bfloat16)
|
||||
|
||||
qweight = layer.qweight
|
||||
@ -310,6 +334,9 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
out = out.view(*original_shape[:-1], out.size(-1))
|
||||
|
||||
if bias is not None:
|
||||
out += bias
|
||||
|
||||
|
@ -899,6 +899,19 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
return self._unquantized_generator(hf_weights_files, use_safetensors,
|
||||
quant_state_dict), quant_state_dict
|
||||
|
||||
def _is_8bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {".scb", ".weight_format"}
|
||||
return any(weight_name.lower().endswith(suffix)
|
||||
for suffix in quantized_suffix)
|
||||
|
||||
def _is_4bit_weight_name(self, weight_name: str):
|
||||
quantized_suffix = {
|
||||
"absmax", "quant_map", "nested_absmax", "nested_quant_map",
|
||||
"bitsandbytes"
|
||||
}
|
||||
suffix = weight_name.split(".")[-1]
|
||||
return any(q_suffix in suffix for q_suffix in quantized_suffix)
|
||||
|
||||
def _quantized_8bit_generator(self, hf_weights_files, use_safetensors,
|
||||
quant_state_dict) -> Generator:
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
@ -912,7 +925,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if not weight_name.endswith((".weight", ".bias")):
|
||||
if self._is_8bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
qweight_name = weight_name.replace(".weight", ".qweight")
|
||||
@ -932,7 +945,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
use_safetensors)
|
||||
temp_state_dict = {}
|
||||
for weight_name, weight_tensor in weight_iterator:
|
||||
if weight_name.endswith((".weight", ".bias")):
|
||||
if not self._is_4bit_weight_name(weight_name):
|
||||
continue
|
||||
# bitsandbytes library requires
|
||||
# weight.quant_state.bitsandbytes__* in CPU
|
||||
@ -956,7 +969,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
for weight_name, weight_tensor in self._hf_weight_iter(
|
||||
hf_weights_files, use_safetensors):
|
||||
|
||||
if not weight_name.endswith((".weight", ".bias")):
|
||||
if self._is_4bit_weight_name(weight_name):
|
||||
continue
|
||||
|
||||
if (f"{weight_name}.quant_state.bitsandbytes__nf4" \
|
||||
|
@ -325,7 +325,10 @@ class MllamaPrecomputedPositionEmbedding(nn.Module):
|
||||
# TODO: support other attention backends for attention in vision model
|
||||
class MllamaVisionSdpaAttention(nn.Module):
|
||||
|
||||
def __init__(self, config: config_mllama.MllamaVisionConfig):
|
||||
def __init__(self,
|
||||
config: config_mllama.MllamaVisionConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = ""):
|
||||
super().__init__()
|
||||
|
||||
model_parallel_size = get_tensor_model_parallel_world_size()
|
||||
@ -341,12 +344,16 @@ class MllamaVisionSdpaAttention(nn.Module):
|
||||
self.head_dim,
|
||||
self.num_heads,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.qkv_proj",
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.num_heads * self.head_dim,
|
||||
self.embed_dim,
|
||||
bias=False,
|
||||
input_is_parallel=True,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -393,7 +400,8 @@ class MllamaVisionEncoderLayer(nn.Module):
|
||||
self.is_gated = is_gated
|
||||
self.intermediate_size = config.intermediate_size
|
||||
|
||||
self.self_attn = MllamaVisionSdpaAttention(config)
|
||||
self.self_attn = MllamaVisionSdpaAttention(
|
||||
config, quant_config=quant_config, prefix=f"{prefix}.self_attn")
|
||||
self.mlp = CLIPMLP(config,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp")
|
||||
@ -1002,6 +1010,7 @@ class MllamaForCausalLM(nn.Module):
|
||||
org_num_embeddings=config.vocab_size,
|
||||
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.lm_head",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -1037,6 +1046,26 @@ class MllamaForCausalLM(nn.Module):
|
||||
@INPUT_REGISTRY.register_dummy_encoder_data(dummy_encoder_data_for_mllama)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_mllama)
|
||||
class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
".o_proj.",
|
||||
]
|
||||
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
||||
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
||||
bitsandbytes_stacked_params_mapping = {
|
||||
# shard_name, weight_name, index
|
||||
"q_proj": ("qkv_proj", 0),
|
||||
"k_proj": ("qkv_proj", 1),
|
||||
"v_proj": ("qkv_proj", 2),
|
||||
"gate_proj": ("gate_up_proj", 0),
|
||||
"up_proj": ("gate_up_proj", 1),
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
config: config_mllama.MllamaConfig,
|
||||
@ -1061,10 +1090,13 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
quant_config=quant_config,
|
||||
prefix="language_model",
|
||||
)
|
||||
self.multi_modal_projector = nn.Linear(
|
||||
self.multi_modal_projector = ColumnParallelLinear(
|
||||
config.vision_config.vision_output_dim,
|
||||
config.text_config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config,
|
||||
gather_output=True,
|
||||
prefix="multi_modal_projector",
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config.output_hidden_states,
|
||||
config.text_config.vocab_size)
|
||||
@ -1128,7 +1160,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
raise ValueError("No images provided.")
|
||||
max_num_tiles = max(
|
||||
max([len(x) for x in y[0]]) for y in pixel_values)
|
||||
device = self.multi_modal_projector.weight.device
|
||||
device = next(self.multi_modal_projector.parameters()).device
|
||||
bsz = len(pixel_values)
|
||||
out_num_tiles = []
|
||||
out_images = torch.zeros(
|
||||
@ -1204,7 +1236,7 @@ class MllamaForConditionalGeneration(nn.Module, SupportsMultiModal):
|
||||
cross_attention_states = self.vision_model(pixel_values,
|
||||
aspect_ratio_ids,
|
||||
aspect_ratio_mask)
|
||||
cross_attention_states = self.multi_modal_projector(
|
||||
cross_attention_states, _ = self.multi_modal_projector(
|
||||
cross_attention_states)
|
||||
|
||||
bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape)
|
||||
|
Reference in New Issue
Block a user