[Quantization] Enable BNB support for InternS1 (#21953)

Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
Jee Jee Li
2025-08-01 19:09:54 +08:00
committed by GitHub
parent 4931486988
commit 28b18cc741
2 changed files with 43 additions and 16 deletions

View File

@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import (
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.model_executor.models import is_pooling_model
from vllm.model_executor.utils import (get_packed_modules_mapping,
from vllm.model_executor.utils import (get_moe_expert_mapping,
get_packed_modules_mapping,
set_weight_attrs)
from vllm.platforms import current_platform
@ -43,6 +44,12 @@ from vllm.platforms import current_platform
logger = init_logger(__name__)
def is_moe_model(model: torch.nn.Module) -> bool:
"""Checks if the model contains FusedMoE layers."""
return bool(any(
isinstance(module, FusedMoE) for module in model.modules()))
class BitsAndBytesModelLoader(BaseModelLoader):
"""Model loader to load model weights with BitAndBytes quantization."""
@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# Store all module names (from transformers) that support
# BNB quantization.
self.target_modules: list[str] = []
# Store the mapping of expert parameters for MoE models.
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
# mapping weight names from transformers to vllm.
self.weight_mapper: Callable = lambda name: name
self.pre_quant: bool = False
@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
# in case model has a mixture of disk-merged and disk-split
# weights with same last name.
self.target_modules.append(name)
elif (isinstance(module, FusedMoE)
and hasattr(module.quant_method, "quant_config")):
if not hasattr(model, "get_expert_mapping"):
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
elif isinstance(module, FusedMoE) and hasattr(
module.quant_method, "quant_config"):
# TODO: support FusedMoE with prequant and 8bit.
if self.pre_quant:
raise ValueError(
@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"BitsAndBytes 8bit quantization with FusedMoE is not "
"supported yet.")
# Get the corresponding weight name using module name and
# get_expert_mapping.
expert_mapping = model.get_expert_mapping()
for exp in expert_mapping:
# expert_params_mapping.
for exp in self.expert_params_mapping:
weight_name = exp[1]
rep_name = name.replace("experts",
"") + weight_name.removesuffix(".")
@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
elif isinstance(module, (RowParallelLinear, )):
self.column_sharded_weights_modules.append(name)
elif isinstance(module, FusedMoE):
expert_mapping = model.get_expert_mapping()
expert_mapping = self.expert_params_mapping
for exp in expert_mapping:
if exp[-1] == "w2":
weight_name = exp[1]
@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
self.is_pool_model = is_pooling_model(model)
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
if is_moe_model(model):
self.expert_params_mapping = get_moe_expert_mapping(model)
if not self.expert_params_mapping:
raise AttributeError(
f"MoE Model {type(model).__name__} does not support "
"BitsAndBytes quantization yet. Ensure this model has "
"'get_expert_mapping' method.")
# For some models like Molmo, we need to use hf_to_vllm_mapper
# to ensure correct loading of weights.
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
"""
from bitsandbytes.functional import QuantState
if not hasattr(model, "get_expert_mapping"):
if not self.expert_params_mapping:
return dict()
expert_mapping = model.get_expert_mapping()
expert_mapping = self.expert_params_mapping
expert_qs_dict = {}
for name, module in model.named_modules():
if not isinstance(module, FusedMoE):

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utils for model executor."""
import copy
from typing import Any, Optional
@ -9,6 +10,7 @@ import torch
def set_random_seed(seed: int) -> None:
from vllm.platforms import current_platform
current_platform.seed_everything(seed)
@ -29,7 +31,7 @@ def set_weight_attrs(
return
for key, value in weight_attrs.items():
assert not hasattr(
weight, key), (f"Overwriting existing tensor attribute: {key}")
weight, key), f"Overwriting existing tensor attribute: {key}"
# NOTE(woosuk): During weight loading, we often do something like:
# narrowed_tensor = param.data.narrow(0, offset, len)
@ -41,6 +43,7 @@ def set_weight_attrs(
# we sync the param tensor after its weight loader is called.
# TODO(woosuk): Remove this hack once we have a better solution.
from vllm.platforms import current_platform
if current_platform.is_tpu() and key == "weight_loader":
value = _make_synced_weight_loader(value)
setattr(weight, key, value)
@ -77,4 +80,17 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
f"safely because of conflicts from {type(child).__name__}.")
else:
parent_map.update(child_map)
return parent_map
return parent_map
def get_moe_expert_mapping(
model: torch.nn.Module, ) -> list[tuple[str, str, int, str]]:
if parent_map := getattr(model, "get_expert_mapping", None):
return parent_map()
else:
# We only check main components instead of whole model submodules
for child in model.children():
child_map = getattr(child, "get_expert_mapping", None)
if child_map is not None:
return child_map()
return []