mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix] Fix get_quant_config when using modelscope (#24421)
Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
@ -7,20 +7,19 @@ import time
|
|||||||
from collections.abc import Generator, Iterable
|
from collections.abc import Generator, Iterable
|
||||||
from typing import Optional, cast
|
from typing import Optional, cast
|
||||||
|
|
||||||
import huggingface_hub
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||||
|
|
||||||
from vllm import envs
|
|
||||||
from vllm.config import LoadConfig, ModelConfig
|
from vllm.config import LoadConfig, ModelConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||||
from vllm.model_executor.model_loader.weight_utils import (
|
from vllm.model_executor.model_loader.weight_utils import (
|
||||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||||
filter_files_not_needed_for_inference, get_lock, np_cache_weights_iterator,
|
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
|
||||||
pt_weights_iterator, safetensors_weights_iterator)
|
np_cache_weights_iterator, pt_weights_iterator,
|
||||||
|
safetensors_weights_iterator)
|
||||||
from vllm.platforms import current_platform
|
from vllm.platforms import current_platform
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
@ -57,35 +56,6 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
raise ValueError(f"Model loader extra config is not supported for "
|
raise ValueError(f"Model loader extra config is not supported for "
|
||||||
f"load format {load_config.load_format}")
|
f"load format {load_config.load_format}")
|
||||||
|
|
||||||
def _maybe_download_from_modelscope(
|
|
||||||
self, model: str, revision: Optional[str]) -> Optional[str]:
|
|
||||||
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
|
||||||
|
|
||||||
Returns the path to the downloaded model, or None if the model is not
|
|
||||||
downloaded from ModelScope."""
|
|
||||||
if envs.VLLM_USE_MODELSCOPE:
|
|
||||||
# download model from ModelScope hub,
|
|
||||||
# lazy import so that modelscope is not required for normal use.
|
|
||||||
# pylint: disable=C.
|
|
||||||
from modelscope.hub.snapshot_download import snapshot_download
|
|
||||||
|
|
||||||
# Use file lock to prevent multiple processes from
|
|
||||||
# downloading the same model weights at the same time.
|
|
||||||
with get_lock(model, self.load_config.download_dir):
|
|
||||||
if not os.path.exists(model):
|
|
||||||
model_path = snapshot_download(
|
|
||||||
model_id=model,
|
|
||||||
cache_dir=self.load_config.download_dir,
|
|
||||||
local_files_only=huggingface_hub.constants.
|
|
||||||
HF_HUB_OFFLINE,
|
|
||||||
revision=revision,
|
|
||||||
ignore_file_pattern=self.load_config.ignore_patterns,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_path = model
|
|
||||||
return model_path
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _prepare_weights(
|
def _prepare_weights(
|
||||||
self,
|
self,
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
@ -96,7 +66,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|||||||
"""Prepare weights for the model.
|
"""Prepare weights for the model.
|
||||||
|
|
||||||
If the model is not local, it will be downloaded."""
|
If the model is not local, it will be downloaded."""
|
||||||
model_name_or_path = (self._maybe_download_from_modelscope(
|
model_name_or_path = (maybe_download_from_modelscope(
|
||||||
model_name_or_path, revision) or model_name_or_path)
|
model_name_or_path, revision) or model_name_or_path)
|
||||||
|
|
||||||
is_local = os.path.isdir(model_name_or_path)
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
|
@ -21,6 +21,7 @@ from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
|||||||
from safetensors.torch import load_file, safe_open, save_file
|
from safetensors.torch import load_file, safe_open, save_file
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
|
||||||
|
from vllm import envs
|
||||||
from vllm.config import LoadConfig, ModelConfig
|
from vllm.config import LoadConfig, ModelConfig
|
||||||
from vllm.distributed import get_tensor_model_parallel_rank
|
from vllm.distributed import get_tensor_model_parallel_rank
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
@ -95,6 +96,41 @@ def get_lock(model_name_or_path: Union[str, Path],
|
|||||||
return lock
|
return lock
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_download_from_modelscope(
|
||||||
|
model: str,
|
||||||
|
revision: Optional[str] = None,
|
||||||
|
download_dir: Optional[str] = None,
|
||||||
|
ignore_patterns: Optional[Union[str, list[str]]] = None,
|
||||||
|
allow_patterns: Optional[Union[list[str],
|
||||||
|
str]] = None) -> Optional[str]:
|
||||||
|
"""Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
|
||||||
|
|
||||||
|
Returns the path to the downloaded model, or None if the model is not
|
||||||
|
downloaded from ModelScope."""
|
||||||
|
if envs.VLLM_USE_MODELSCOPE:
|
||||||
|
# download model from ModelScope hub,
|
||||||
|
# lazy import so that modelscope is not required for normal use.
|
||||||
|
# pylint: disable=C.
|
||||||
|
from modelscope.hub.snapshot_download import snapshot_download
|
||||||
|
|
||||||
|
# Use file lock to prevent multiple processes from
|
||||||
|
# downloading the same model weights at the same time.
|
||||||
|
with get_lock(model, download_dir):
|
||||||
|
if not os.path.exists(model):
|
||||||
|
model_path = snapshot_download(
|
||||||
|
model_id=model,
|
||||||
|
cache_dir=download_dir,
|
||||||
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||||||
|
revision=revision,
|
||||||
|
ignore_file_pattern=ignore_patterns,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_path = model
|
||||||
|
return model_path
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _shared_pointers(tensors):
|
def _shared_pointers(tensors):
|
||||||
ptrs = defaultdict(list)
|
ptrs = defaultdict(list)
|
||||||
for k, v in tensors.items():
|
for k, v in tensors.items():
|
||||||
@ -169,7 +205,13 @@ def get_quant_config(model_config: ModelConfig,
|
|||||||
# Inflight BNB quantization
|
# Inflight BNB quantization
|
||||||
if model_config.quantization == "bitsandbytes":
|
if model_config.quantization == "bitsandbytes":
|
||||||
return quant_cls.from_config({})
|
return quant_cls.from_config({})
|
||||||
is_local = os.path.isdir(model_config.model)
|
model_name_or_path = maybe_download_from_modelscope(
|
||||||
|
model_config.model,
|
||||||
|
revision=model_config.revision,
|
||||||
|
download_dir=load_config.download_dir,
|
||||||
|
allow_patterns=["*.json"],
|
||||||
|
) or model_config.model
|
||||||
|
is_local = os.path.isdir(model_name_or_path)
|
||||||
if not is_local:
|
if not is_local:
|
||||||
# Download the config files.
|
# Download the config files.
|
||||||
with get_lock(model_config.model, load_config.download_dir):
|
with get_lock(model_config.model, load_config.download_dir):
|
||||||
@ -182,7 +224,7 @@ def get_quant_config(model_config: ModelConfig,
|
|||||||
tqdm_class=DisabledTqdm,
|
tqdm_class=DisabledTqdm,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
hf_folder = model_config.model
|
hf_folder = model_name_or_path
|
||||||
|
|
||||||
possible_config_filenames = quant_cls.get_config_filenames()
|
possible_config_filenames = quant_cls.get_config_filenames()
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user