[Bugfix] Fix get_quant_config when using modelscope (#24421)

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-09-08 19:03:02 +08:00
committed by GitHub
parent c2a8b08fcd
commit 5e537f45b4
2 changed files with 48 additions and 36 deletions

View File

@ -7,20 +7,19 @@ import time
from collections.abc import Generator, Iterable
from typing import Optional, cast
import huggingface_hub
import torch
from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -57,35 +56,6 @@ class DefaultModelLoader(BaseModelLoader):
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model, self.load_config.download_dir):
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.
HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=self.load_config.ignore_patterns,
)
else:
model_path = model
return model_path
return None
def _prepare_weights(
self,
model_name_or_path: str,
@ -96,7 +66,7 @@ class DefaultModelLoader(BaseModelLoader):
"""Prepare weights for the model.
If the model is not local, it will be downloaded."""
model_name_or_path = (self._maybe_download_from_modelscope(
model_name_or_path = (maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path)

View File

@ -21,6 +21,7 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
from vllm import envs
from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger
@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path],
return lock
def maybe_download_from_modelscope(
model: str,
revision: Optional[str] = None,
download_dir: Optional[str] = None,
ignore_patterns: Optional[Union[str, list[str]]] = None,
allow_patterns: Optional[Union[list[str],
str]] = None) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model, download_dir):
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=ignore_patterns,
allow_patterns=allow_patterns,
)
else:
model_path = model
return model_path
return None
def _shared_pointers(tensors):
ptrs = defaultdict(list)
for k, v in tensors.items():
@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig,
# Inflight BNB quantization
if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({})
is_local = os.path.isdir(model_config.model)
model_name_or_path = maybe_download_from_modelscope(
model_config.model,
revision=model_config.revision,
download_dir=load_config.download_dir,
allow_patterns=["*.json"],
) or model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local:
# Download the config files.
with get_lock(model_config.model, load_config.download_dir):
@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm,
)
else:
hf_folder = model_config.model
hf_folder = model_name_or_path
possible_config_filenames = quant_cls.get_config_filenames()