mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Quantization] Enable BNB support for InternS1 (#21953)
Signed-off-by: Jee Jee Li <pandaleefree@gmail.com>
This commit is contained in:
@ -34,7 +34,8 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
filter_duplicate_safetensors_files, filter_files_not_needed_for_inference,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.model_executor.models import is_pooling_model
|
||||
from vllm.model_executor.utils import (get_packed_modules_mapping,
|
||||
from vllm.model_executor.utils import (get_moe_expert_mapping,
|
||||
get_packed_modules_mapping,
|
||||
set_weight_attrs)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
@ -43,6 +44,12 @@ from vllm.platforms import current_platform
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def is_moe_model(model: torch.nn.Module) -> bool:
|
||||
"""Checks if the model contains FusedMoE layers."""
|
||||
return bool(any(
|
||||
isinstance(module, FusedMoE) for module in model.modules()))
|
||||
|
||||
|
||||
class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""Model loader to load model weights with BitAndBytes quantization."""
|
||||
|
||||
@ -61,6 +68,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# Store all module names (from transformers) that support
|
||||
# BNB quantization.
|
||||
self.target_modules: list[str] = []
|
||||
# Store the mapping of expert parameters for MoE models.
|
||||
self.expert_params_mapping: list[tuple[str, str, int, str]] = []
|
||||
# mapping weight names from transformers to vllm.
|
||||
self.weight_mapper: Callable = lambda name: name
|
||||
self.pre_quant: bool = False
|
||||
@ -413,13 +422,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
# in case model has a mixture of disk-merged and disk-split
|
||||
# weights with same last name.
|
||||
self.target_modules.append(name)
|
||||
elif (isinstance(module, FusedMoE)
|
||||
and hasattr(module.quant_method, "quant_config")):
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
elif isinstance(module, FusedMoE) and hasattr(
|
||||
module.quant_method, "quant_config"):
|
||||
# TODO: support FusedMoE with prequant and 8bit.
|
||||
if self.pre_quant:
|
||||
raise ValueError(
|
||||
@ -430,9 +434,9 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"BitsAndBytes 8bit quantization with FusedMoE is not "
|
||||
"supported yet.")
|
||||
# Get the corresponding weight name using module name and
|
||||
# get_expert_mapping.
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
for exp in expert_mapping:
|
||||
# expert_params_mapping.
|
||||
|
||||
for exp in self.expert_params_mapping:
|
||||
weight_name = exp[1]
|
||||
rep_name = name.replace("experts",
|
||||
"") + weight_name.removesuffix(".")
|
||||
@ -464,7 +468,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
elif isinstance(module, (RowParallelLinear, )):
|
||||
self.column_sharded_weights_modules.append(name)
|
||||
elif isinstance(module, FusedMoE):
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
expert_mapping = self.expert_params_mapping
|
||||
for exp in expert_mapping:
|
||||
if exp[-1] == "w2":
|
||||
weight_name = exp[1]
|
||||
@ -516,6 +520,13 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
self.is_pool_model = is_pooling_model(model)
|
||||
self.modules_mapping = ParamMapping(get_packed_modules_mapping(model))
|
||||
|
||||
if is_moe_model(model):
|
||||
self.expert_params_mapping = get_moe_expert_mapping(model)
|
||||
if not self.expert_params_mapping:
|
||||
raise AttributeError(
|
||||
f"MoE Model {type(model).__name__} does not support "
|
||||
"BitsAndBytes quantization yet. Ensure this model has "
|
||||
"'get_expert_mapping' method.")
|
||||
# For some models like Molmo, we need to use hf_to_vllm_mapper
|
||||
# to ensure correct loading of weights.
|
||||
if hf_to_vllm_mapper := getattr(model, "hf_to_vllm_mapper", None):
|
||||
@ -569,10 +580,10 @@ class BitsAndBytesModelLoader(BaseModelLoader):
|
||||
"""
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
if not hasattr(model, "get_expert_mapping"):
|
||||
if not self.expert_params_mapping:
|
||||
return dict()
|
||||
|
||||
expert_mapping = model.get_expert_mapping()
|
||||
expert_mapping = self.expert_params_mapping
|
||||
expert_qs_dict = {}
|
||||
for name, module in model.named_modules():
|
||||
if not isinstance(module, FusedMoE):
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utils for model executor."""
|
||||
|
||||
import copy
|
||||
from typing import Any, Optional
|
||||
|
||||
@ -9,6 +10,7 @@ import torch
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
|
||||
@ -29,7 +31,7 @@ def set_weight_attrs(
|
||||
return
|
||||
for key, value in weight_attrs.items():
|
||||
assert not hasattr(
|
||||
weight, key), (f"Overwriting existing tensor attribute: {key}")
|
||||
weight, key), f"Overwriting existing tensor attribute: {key}"
|
||||
|
||||
# NOTE(woosuk): During weight loading, we often do something like:
|
||||
# narrowed_tensor = param.data.narrow(0, offset, len)
|
||||
@ -41,6 +43,7 @@ def set_weight_attrs(
|
||||
# we sync the param tensor after its weight loader is called.
|
||||
# TODO(woosuk): Remove this hack once we have a better solution.
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu() and key == "weight_loader":
|
||||
value = _make_synced_weight_loader(value)
|
||||
setattr(weight, key, value)
|
||||
@ -77,4 +80,17 @@ def get_packed_modules_mapping(model: torch.nn.Module) -> dict[str, list[str]]:
|
||||
f"safely because of conflicts from {type(child).__name__}.")
|
||||
else:
|
||||
parent_map.update(child_map)
|
||||
return parent_map
|
||||
return parent_map
|
||||
|
||||
|
||||
def get_moe_expert_mapping(
|
||||
model: torch.nn.Module, ) -> list[tuple[str, str, int, str]]:
|
||||
if parent_map := getattr(model, "get_expert_mapping", None):
|
||||
return parent_map()
|
||||
else:
|
||||
# We only check main components instead of whole model submodules
|
||||
for child in model.children():
|
||||
child_map = getattr(child, "get_expert_mapping", None)
|
||||
if child_map is not None:
|
||||
return child_map()
|
||||
return []
|
||||
|
Reference in New Issue
Block a user