mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix get_quant_config when using modelscope (#24421)
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@ -7,20 +7,19 @@ import time
|
||||
from collections.abc import Generator, Iterable
|
||||
from typing import Optional, cast
|
||||
|
||||
import huggingface_hub
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
|
||||
np_cache_weights_iterator, pt_weights_iterator,
|
||||
safetensors_weights_iterator)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -57,35 +56,6 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
def _maybe_download_from_modelscope(
|
||||
self, model: str, revision: Optional[str]) -> Optional[str]:
|
||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model, self.load_config.download_dir):
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(
|
||||
model_id=model,
|
||||
cache_dir=self.load_config.download_dir,
|
||||
local_files_only=huggingface_hub.constants.
|
||||
HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
ignore_file_pattern=self.load_config.ignore_patterns,
|
||||
)
|
||||
else:
|
||||
model_path = model
|
||||
return model_path
|
||||
return None
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
@ -96,7 +66,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
"""Prepare weights for the model.
|
||||
|
||||
If the model is not local, it will be downloaded."""
|
||||
model_name_or_path = (self._maybe_download_from_modelscope(
|
||||
model_name_or_path = (maybe_download_from_modelscope(
|
||||
model_name_or_path, revision) or model_name_or_path)
|
||||
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
|
@ -21,6 +21,7 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file, safe_open, save_file
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_rank
|
||||
from vllm.logger import init_logger
|
||||
@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path],
|
||||
return lock
|
||||
|
||||
|
||||
def maybe_download_from_modelscope(
|
||||
model: str,
|
||||
revision: Optional[str] = None,
|
||||
download_dir: Optional[str] = None,
|
||||
ignore_patterns: Optional[Union[str, list[str]]] = None,
|
||||
allow_patterns: Optional[Union[list[str],
|
||||
str]] = None) -> Optional[str]:
|
||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||
|
||||
Returns the path to the downloaded model, or None if the model is not
|
||||
downloaded from ModelScope."""
|
||||
if envs.VLLM_USE_MODELSCOPE:
|
||||
# download model from ModelScope hub,
|
||||
# lazy import so that modelscope is not required for normal use.
|
||||
# pylint: disable=C.
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model, download_dir):
|
||||
if not os.path.exists(model):
|
||||
model_path = snapshot_download(
|
||||
model_id=model,
|
||||
cache_dir=download_dir,
|
||||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||
revision=revision,
|
||||
ignore_file_pattern=ignore_patterns,
|
||||
allow_patterns=allow_patterns,
|
||||
)
|
||||
else:
|
||||
model_path = model
|
||||
return model_path
|
||||
return None
|
||||
|
||||
|
||||
def _shared_pointers(tensors):
|
||||
ptrs = defaultdict(list)
|
||||
for k, v in tensors.items():
|
||||
@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig,
|
||||
# Inflight BNB quantization
|
||||
if model_config.quantization == "bitsandbytes":
|
||||
return quant_cls.from_config({})
|
||||
is_local = os.path.isdir(model_config.model)
|
||||
model_name_or_path = maybe_download_from_modelscope(
|
||||
model_config.model,
|
||||
revision=model_config.revision,
|
||||
download_dir=load_config.download_dir,
|
||||
allow_patterns=["*.json"],
|
||||
) or model_config.model
|
||||
is_local = os.path.isdir(model_name_or_path)
|
||||
if not is_local:
|
||||
# Download the config files.
|
||||
with get_lock(model_config.model, load_config.download_dir):
|
||||
@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig,
|
||||
tqdm_class=DisabledTqdm,
|
||||
)
|
||||
else:
|
||||
hf_folder = model_config.model
|
||||
hf_folder = model_name_or_path
|
||||
|
||||
possible_config_filenames = quant_cls.get_config_filenames()
|
||||
|
||||
|
Reference in New Issue
Block a user