[Bugfix] Fix get_quant_config when using modelscope (#24421)

Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
Li Wang
2025-09-08 19:03:02 +08:00
committed by GitHub
parent c2a8b08fcd
commit 5e537f45b4
2 changed files with 48 additions and 36 deletions

View File

@ -7,20 +7,19 @@ import time
from collections.abc import Generator, Iterable from collections.abc import Generator, Iterable
from typing import Optional, cast from typing import Optional, cast
import huggingface_hub
import torch import torch
from torch import nn from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import LoadConfig, ModelConfig from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import ( from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf, download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, filter_files_not_needed_for_inference, maybe_download_from_modelscope,
pt_weights_iterator, safetensors_weights_iterator) np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.platforms import current_platform from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
@ -57,35 +56,6 @@ class DefaultModelLoader(BaseModelLoader):
raise ValueError(f"Model loader extra config is not supported for " raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}") f"load format {load_config.load_format}")
def _maybe_download_from_modelscope(
self, model: str, revision: Optional[str]) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model, self.load_config.download_dir):
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=self.load_config.download_dir,
local_files_only=huggingface_hub.constants.
HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=self.load_config.ignore_patterns,
)
else:
model_path = model
return model_path
return None
def _prepare_weights( def _prepare_weights(
self, self,
model_name_or_path: str, model_name_or_path: str,
@ -96,7 +66,7 @@ class DefaultModelLoader(BaseModelLoader):
"""Prepare weights for the model. """Prepare weights for the model.
If the model is not local, it will be downloaded.""" If the model is not local, it will be downloaded."""
model_name_or_path = (self._maybe_download_from_modelscope( model_name_or_path = (maybe_download_from_modelscope(
model_name_or_path, revision) or model_name_or_path) model_name_or_path, revision) or model_name_or_path)
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)

View File

@ -21,6 +21,7 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
from safetensors.torch import load_file, safe_open, save_file from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm from tqdm.auto import tqdm
from vllm import envs
from vllm.config import LoadConfig, ModelConfig from vllm.config import LoadConfig, ModelConfig
from vllm.distributed import get_tensor_model_parallel_rank from vllm.distributed import get_tensor_model_parallel_rank
from vllm.logger import init_logger from vllm.logger import init_logger
@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path],
return lock return lock
def maybe_download_from_modelscope(
model: str,
revision: Optional[str] = None,
download_dir: Optional[str] = None,
ignore_patterns: Optional[Union[str, list[str]]] = None,
allow_patterns: Optional[Union[list[str],
str]] = None) -> Optional[str]:
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
Returns the path to the downloaded model, or None if the model is not
downloaded from ModelScope."""
if envs.VLLM_USE_MODELSCOPE:
# download model from ModelScope hub,
# lazy import so that modelscope is not required for normal use.
# pylint: disable=C.
from modelscope.hub.snapshot_download import snapshot_download
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model, download_dir):
if not os.path.exists(model):
model_path = snapshot_download(
model_id=model,
cache_dir=download_dir,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
revision=revision,
ignore_file_pattern=ignore_patterns,
allow_patterns=allow_patterns,
)
else:
model_path = model
return model_path
return None
def _shared_pointers(tensors): def _shared_pointers(tensors):
ptrs = defaultdict(list) ptrs = defaultdict(list)
for k, v in tensors.items(): for k, v in tensors.items():
@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig,
# Inflight BNB quantization # Inflight BNB quantization
if model_config.quantization == "bitsandbytes": if model_config.quantization == "bitsandbytes":
return quant_cls.from_config({}) return quant_cls.from_config({})
is_local = os.path.isdir(model_config.model) model_name_or_path = maybe_download_from_modelscope(
model_config.model,
revision=model_config.revision,
download_dir=load_config.download_dir,
allow_patterns=["*.json"],
) or model_config.model
is_local = os.path.isdir(model_name_or_path)
if not is_local: if not is_local:
# Download the config files. # Download the config files.
with get_lock(model_config.model, load_config.download_dir): with get_lock(model_config.model, load_config.download_dir):
@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig,
tqdm_class=DisabledTqdm, tqdm_class=DisabledTqdm,
) )
else: else:
hf_folder = model_config.model hf_folder = model_name_or_path
possible_config_filenames = quant_cls.get_config_filenames() possible_config_filenames = quant_cls.get_config_filenames()