diff --git a/vllm/model_executor/model_loader/default_loader.py b/vllm/model_executor/model_loader/default_loader.py index 1e5aa9e571..c8ad3a55d9 100644 --- a/vllm/model_executor/model_loader/default_loader.py +++ b/vllm/model_executor/model_loader/default_loader.py @@ -7,20 +7,19 @@ import time from collections.abc import Generator, Iterable from typing import Optional, cast -import huggingface_hub import torch from torch import nn from transformers.utils import SAFE_WEIGHTS_INDEX_NAME -from vllm import envs from vllm.config import LoadConfig, ModelConfig from vllm.logger import init_logger from vllm.model_executor.model_loader.base_loader import BaseModelLoader from vllm.model_executor.model_loader.weight_utils import ( download_safetensors_index_file_from_hf, download_weights_from_hf, fastsafetensors_weights_iterator, filter_duplicate_safetensors_files, - filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator, - pt_weights_iterator, safetensors_weights_iterator) + filter_files_not_needed_for_inference, maybe_download_from_modelscope, + np_cache_weights_iterator, pt_weights_iterator, + safetensors_weights_iterator) from vllm.platforms import current_platform logger = init_logger(__name__) @@ -57,35 +56,6 @@ class DefaultModelLoader(BaseModelLoader): raise ValueError(f"Model loader extra config is not supported for " f"load format {load_config.load_format}") - def _maybe_download_from_modelscope( - self, model: str, revision: Optional[str]) -> Optional[str]: - """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. - - Returns the path to the downloaded model, or None if the model is not - downloaded from ModelScope.""" - if envs.VLLM_USE_MODELSCOPE: - # download model from ModelScope hub, - # lazy import so that modelscope is not required for normal use. - # pylint: disable=C. - from modelscope.hub.snapshot_download import snapshot_download - - # Use file lock to prevent multiple processes from - # downloading the same model weights at the same time. - with get_lock(model, self.load_config.download_dir): - if not os.path.exists(model): - model_path = snapshot_download( - model_id=model, - cache_dir=self.load_config.download_dir, - local_files_only=huggingface_hub.constants. - HF_HUB_OFFLINE, - revision=revision, - ignore_file_pattern=self.load_config.ignore_patterns, - ) - else: - model_path = model - return model_path - return None - def _prepare_weights( self, model_name_or_path: str, @@ -96,7 +66,7 @@ class DefaultModelLoader(BaseModelLoader): """Prepare weights for the model. If the model is not local, it will be downloaded.""" - model_name_or_path = (self._maybe_download_from_modelscope( + model_name_or_path = (maybe_download_from_modelscope( model_name_or_path, revision) or model_name_or_path) is_local = os.path.isdir(model_name_or_path) diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index f87eeaa456..50056038b6 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -21,6 +21,7 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download from safetensors.torch import load_file, safe_open, save_file from tqdm.auto import tqdm +from vllm import envs from vllm.config import LoadConfig, ModelConfig from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger @@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path], return lock +def maybe_download_from_modelscope( + model: str, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + ignore_patterns: Optional[Union[str, list[str]]] = None, + allow_patterns: Optional[Union[list[str], + str]] = None) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if envs.VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + # Use file lock to prevent multiple processes from + # downloading the same model weights at the same time. + with get_lock(model, download_dir): + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=ignore_patterns, + allow_patterns=allow_patterns, + ) + else: + model_path = model + return model_path + return None + + def _shared_pointers(tensors): ptrs = defaultdict(list) for k, v in tensors.items(): @@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig, # Inflight BNB quantization if model_config.quantization == "bitsandbytes": return quant_cls.from_config({}) - is_local = os.path.isdir(model_config.model) + model_name_or_path = maybe_download_from_modelscope( + model_config.model, + revision=model_config.revision, + download_dir=load_config.download_dir, + allow_patterns=["*.json"], + ) or model_config.model + is_local = os.path.isdir(model_name_or_path) if not is_local: # Download the config files. with get_lock(model_config.model, load_config.download_dir): @@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig, tqdm_class=DisabledTqdm, ) else: - hf_folder = model_config.model + hf_folder = model_name_or_path possible_config_filenames = quant_cls.get_config_filenames()