mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Quantization][V1] BitsAndBytes support V1 (#15611)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -425,7 +425,6 @@ def test_bnb_regression(
|
||||
max_model_len=4096,
|
||||
max_num_seqs=2,
|
||||
quantization="bitsandbytes",
|
||||
load_format="bitsandbytes",
|
||||
)
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0,
|
||||
|
@ -72,7 +72,6 @@ def test_distributed(
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
{
|
||||
"quantization": "bitsandbytes",
|
||||
"load_format": "bitsandbytes",
|
||||
},
|
||||
),
|
||||
])
|
||||
|
@ -101,8 +101,6 @@ def test_load_pp_4bit_bnb_model(model_name, description) -> None:
|
||||
"--enable-prefix-caching",
|
||||
"--quantization",
|
||||
"bitsandbytes",
|
||||
"--load-format",
|
||||
"bitsandbytes",
|
||||
"--gpu-memory-utilization",
|
||||
"0.7",
|
||||
]
|
||||
@ -137,7 +135,6 @@ def validate_generated_texts(hf_runner,
|
||||
# when using distributed inference
|
||||
with vllm_runner(model_name,
|
||||
quantization='bitsandbytes',
|
||||
load_format='bitsandbytes',
|
||||
tensor_parallel_size=vllm_tp_size,
|
||||
enforce_eager=False) as llm:
|
||||
vllm_outputs = llm.generate_greedy(prompts, 8)
|
||||
|
@ -682,8 +682,9 @@ class ModelConfig:
|
||||
|
||||
def _verify_bnb_config(self) -> None:
|
||||
"""
|
||||
The current version of bitsandbytes (0.44.0) with 8-bit models does not
|
||||
The current version of bitsandbytes (0.45.3) with 8-bit models does not
|
||||
yet support CUDA graph.
|
||||
# TODO Remove this when bitsandbytes supports.
|
||||
"""
|
||||
is_bitsandbytes = self.quantization == "bitsandbytes"
|
||||
has_quantization_config = (getattr(self.hf_config,
|
||||
@ -698,8 +699,9 @@ class ModelConfig:
|
||||
not self.enforce_eager,
|
||||
]):
|
||||
logger.warning(
|
||||
"CUDA graph is not supported on BitAndBytes 8bit yet, "
|
||||
"CUDA graph is not supported on BitsAndBytes 8bit yet, "
|
||||
"fallback to the eager mode.")
|
||||
|
||||
self.enforce_eager = True
|
||||
|
||||
def _verify_with_expert_parallelism(self) -> None:
|
||||
|
@ -1616,7 +1616,7 @@ class EngineArgs:
|
||||
return False
|
||||
|
||||
# Some quantization is not compatible with torch.compile.
|
||||
V1_UNSUPPORTED_QUANT = ["bitsandbytes", "gguf"]
|
||||
V1_UNSUPPORTED_QUANT = ["gguf"]
|
||||
if model_config.quantization in V1_UNSUPPORTED_QUANT:
|
||||
_raise_or_fallback(
|
||||
feature_name=f"--quantization {model_config.quantization}",
|
||||
|
@ -9,6 +9,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
|
||||
class BitsAndBytesConfig(QuantizationConfig):
|
||||
@ -321,9 +322,6 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
|
||||
original_type = x.dtype
|
||||
original_shape = x.shape
|
||||
reshape_after_matmul = False
|
||||
@ -343,19 +341,7 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
out_dim_1,
|
||||
dtype=torch.bfloat16,
|
||||
device=x.device)
|
||||
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||
bf_x, qweight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||
|
||||
current_index += output_size
|
||||
|
||||
apply_bnb_4bit(bf_x, qweight, offsets, out)
|
||||
out = out.to(original_type)
|
||||
|
||||
if reshape_after_matmul:
|
||||
@ -365,3 +351,46 @@ class BitsAndBytesLinearMethod(LinearMethodBase):
|
||||
out += bias
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def _apply_bnb_4bit(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
# only load the bitsandbytes module when needed
|
||||
from bitsandbytes import matmul_4bit
|
||||
quant_states = weight.bnb_quant_state
|
||||
current_index = 0
|
||||
for i in range(len(quant_states)):
|
||||
output_size = quant_states[i].shape[0]
|
||||
# It is more efficient to use out kwarg like
|
||||
# matmul_4bit(..., out = ...). Infeasible now due to the bug
|
||||
# https://github.com/TimDettmers/bitsandbytes/issues/1235.
|
||||
# Need to change after the bug is fixed.
|
||||
out[:, current_index:current_index + output_size] = matmul_4bit(
|
||||
x, weight[offsets[i]:offsets[i + 1]].t(), quant_states[i])
|
||||
current_index += output_size
|
||||
|
||||
|
||||
def _apply_bnb_4bit_fake(
|
||||
x: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
offsets: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
try:
|
||||
direct_register_custom_op(
|
||||
op_name="apply_bnb_4bit",
|
||||
op_func=_apply_bnb_4bit,
|
||||
mutates_args=["out"],
|
||||
fake_impl=_apply_bnb_4bit_fake,
|
||||
)
|
||||
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
|
||||
|
||||
except AttributeError as error:
|
||||
raise error
|
||||
|
@ -1259,6 +1259,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
pack_ratio)
|
||||
|
||||
offsets = np.concatenate(([0], np.cumsum(num_elements)))
|
||||
# Make torch infer_schema happy
|
||||
offsets = torch.tensor(offsets).cpu()
|
||||
set_weight_attrs(param, {"bnb_shard_offsets": offsets})
|
||||
|
||||
if load_8bit:
|
||||
|
Reference in New Issue
Block a user