Enabled BnB NF4 inference on Gaudi (#20172)

Signed-off-by: Ruheena Suhani Shaik <rsshaik@habana.ai>
This commit is contained in:
Ruheena Suhani Shaik
2025-07-15 08:56:08 +05:30
committed by GitHub
parent 80305c1b24
commit 016b8d1b7f
2 changed files with 18 additions and 8 deletions

View File

@ -13,6 +13,7 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
@ -390,12 +391,11 @@ def _apply_bnb_4bit_fake(
try:
direct_register_custom_op(
op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
)
direct_register_custom_op(op_name="apply_bnb_4bit",
op_func=_apply_bnb_4bit,
mutates_args=["out"],
fake_impl=_apply_bnb_4bit_fake,
dispatch_key=current_platform.dispatch_key)
apply_bnb_4bit = torch.ops.vllm.apply_bnb_4bit
except AttributeError as error:

View File

@ -199,6 +199,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if self.pre_quant:
if self.load_8bit:
if current_platform.is_hpu():
raise ValueError(
"currently hpu supports 4bit quantization only")
return self._quantized_8bit_generator(
hf_weights_files, use_safetensors,
quant_state_dict), quant_state_dict
@ -302,6 +306,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
in temp_state_dict):
quant_state = _parse_quant_state(mapped_weight_name,
temp_state_dict)
if current_platform.is_hpu():
assert quant_state.quant_type == "nf4", (
"currently hpu supports nf4 quant_type only")
quant_state_dict[mapped_weight_name] = quant_state
yield org_weight_name, weight_tensor
else:
@ -372,10 +380,12 @@ class BitsAndBytesModelLoader(BaseModelLoader):
...]
# bitsandbytes requires data in GPU
if weight_sub_tensor.is_cuda:
if (weight_sub_tensor.is_cuda
or weight_sub_tensor.device.type == "hpu"):
loaded_weight = weight_sub_tensor
else:
loaded_weight = weight_sub_tensor.cuda()
loaded_weight = weight_sub_tensor.to(
device=current_platform.device_type)
# remove the following after the issue is fixed:
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1342