[Core] Support model loader plugins (#21067)
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
@ -2,7 +2,6 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LoadFormat
|
||||
|
||||
test_model = "openai-community/gpt2"
|
||||
|
||||
@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
|
||||
def test_model_loader_download_files(vllm_runner):
|
||||
with vllm_runner(test_model,
|
||||
load_format=LoadFormat.FASTSAFETENSORS) as llm:
|
||||
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
|
||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||
assert deserialized_outputs
|
||||
|
0
tests/model_executor/model_loader/__init__.py
Normal file
0
tests/model_executor/model_loader/__init__.py
Normal file
37
tests/model_executor/model_loader/test_registry.py
Normal file
37
tests/model_executor/model_loader/test_registry.py
Normal file
@ -0,0 +1,37 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pytest
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.model_executor.model_loader import (get_model_loader,
|
||||
register_model_loader)
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
|
||||
|
||||
@register_model_loader("custom_load_format")
|
||||
class CustomModelLoader(BaseModelLoader):
|
||||
|
||||
def __init__(self, load_config: LoadConfig) -> None:
|
||||
super().__init__(load_config)
|
||||
|
||||
def download_model(self, model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
def load_weights(self, model: nn.Module,
|
||||
model_config: ModelConfig) -> None:
|
||||
pass
|
||||
|
||||
|
||||
def test_register_model_loader():
|
||||
load_config = LoadConfig(load_format="custom_load_format")
|
||||
assert isinstance(get_model_loader(load_config), CustomModelLoader)
|
||||
|
||||
|
||||
def test_invalid_model_loader():
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
@register_model_loader("invalid_load_format")
|
||||
class InValidModelLoader:
|
||||
pass
|
@ -2,9 +2,10 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import LoadConfig, LoadFormat
|
||||
from vllm.config import LoadConfig
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
|
||||
load_format = "runai_streamer"
|
||||
test_model = "openai-community/gpt2"
|
||||
|
||||
prompts = [
|
||||
@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
|
||||
|
||||
|
||||
def get_runai_model_loader():
|
||||
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
|
||||
load_config = LoadConfig(load_format=load_format)
|
||||
return get_model_loader(load_config)
|
||||
|
||||
|
||||
@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
|
||||
|
||||
|
||||
def test_runai_model_loader_download_files(vllm_runner):
|
||||
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
|
||||
with vllm_runner(test_model, load_format=load_format) as llm:
|
||||
deserialized_outputs = llm.generate(prompts, sampling_params)
|
||||
assert deserialized_outputs
|
||||
|
@ -65,7 +65,7 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
@ -78,6 +78,7 @@ else:
|
||||
QuantizationConfig = Any
|
||||
QuantizationMethods = Any
|
||||
BaseModelLoader = Any
|
||||
LoadFormats = Any
|
||||
TensorizerConfig = Any
|
||||
ConfigType = type
|
||||
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
|
||||
@ -1773,29 +1774,12 @@ class CacheConfig:
|
||||
logger.warning("Possibly too large swap space. %s", msg)
|
||||
|
||||
|
||||
class LoadFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
PT = "pt"
|
||||
SAFETENSORS = "safetensors"
|
||||
NPCACHE = "npcache"
|
||||
DUMMY = "dummy"
|
||||
TENSORIZER = "tensorizer"
|
||||
SHARDED_STATE = "sharded_state"
|
||||
GGUF = "gguf"
|
||||
BITSANDBYTES = "bitsandbytes"
|
||||
MISTRAL = "mistral"
|
||||
RUNAI_STREAMER = "runai_streamer"
|
||||
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
|
||||
FASTSAFETENSORS = "fastsafetensors"
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class LoadConfig:
|
||||
"""Configuration for loading the model weights."""
|
||||
|
||||
load_format: Union[str, LoadFormat,
|
||||
"BaseModelLoader"] = LoadFormat.AUTO.value
|
||||
load_format: Union[str, LoadFormats] = "auto"
|
||||
"""The format of the model weights to load:\n
|
||||
- "auto" will try to load the weights in the safetensors format and fall
|
||||
back to the pytorch bin format if safetensors format is not available.\n
|
||||
@ -1816,7 +1800,8 @@ class LoadConfig:
|
||||
- "gguf" will load weights from GGUF format files (details specified in
|
||||
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
|
||||
- "mistral" will load weights from consolidated safetensors files used by
|
||||
Mistral models."""
|
||||
Mistral models.
|
||||
- Other custom values can be supported via plugins."""
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
@ -1864,10 +1849,7 @@ class LoadConfig:
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.load_format, str):
|
||||
load_format = self.load_format.lower()
|
||||
self.load_format = LoadFormat(load_format)
|
||||
|
||||
self.load_format = self.load_format.lower()
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
|
@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DetailedTraceModules, Device, DeviceConfig,
|
||||
DistributedExecutorBackend, GuidedDecodingBackend,
|
||||
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LoadFormat,
|
||||
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
|
||||
ModelImpl, MultiModalConfig, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
TokenizerMode, VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
from vllm.plugins import load_general_plugins
|
||||
@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
if TYPE_CHECKING:
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
from vllm.model_executor.model_loader import LoadFormats
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
else:
|
||||
ExecutorBase = Any
|
||||
QuantizationMethods = Any
|
||||
LoadFormats = Any
|
||||
UsageContext = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -276,7 +277,7 @@ class EngineArgs:
|
||||
trust_remote_code: bool = ModelConfig.trust_remote_code
|
||||
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
|
||||
download_dir: Optional[str] = LoadConfig.download_dir
|
||||
load_format: str = LoadConfig.load_format
|
||||
load_format: Union[str, LoadFormats] = LoadConfig.load_format
|
||||
config_format: str = ModelConfig.config_format
|
||||
dtype: ModelDType = ModelConfig.dtype
|
||||
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
|
||||
@ -547,9 +548,7 @@ class EngineArgs:
|
||||
title="LoadConfig",
|
||||
description=LoadConfig.__doc__,
|
||||
)
|
||||
load_group.add_argument("--load-format",
|
||||
choices=[f.value for f in LoadFormat],
|
||||
**load_kwargs["load_format"])
|
||||
load_group.add_argument("--load-format", **load_kwargs["load_format"])
|
||||
load_group.add_argument("--download-dir",
|
||||
**load_kwargs["download_dir"])
|
||||
load_group.add_argument("--model-loader-extra-config",
|
||||
@ -864,10 +863,9 @@ class EngineArgs:
|
||||
|
||||
# NOTE: This is to allow model loading from S3 in CI
|
||||
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
|
||||
and self.model in MODELS_ON_S3
|
||||
and self.load_format == LoadFormat.AUTO): # noqa: E501
|
||||
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
|
||||
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
|
||||
self.load_format = LoadFormat.RUNAI_STREAMER
|
||||
self.load_format = "runai_streamer"
|
||||
|
||||
return ModelConfig(
|
||||
model=self.model,
|
||||
@ -1299,7 +1297,7 @@ class EngineArgs:
|
||||
#############################################################
|
||||
# Unsupported Feature Flags on V1.
|
||||
|
||||
if self.load_format == LoadFormat.SHARDED_STATE.value:
|
||||
if self.load_format == "sharded_state":
|
||||
_raise_or_fallback(
|
||||
feature_name=f"--load_format {self.load_format}",
|
||||
recommend_to_remove=False)
|
||||
|
@ -1,11 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
from typing import Literal, Optional
|
||||
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
|
||||
from vllm.config import LoadConfig, ModelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.bitsandbytes_loader import (
|
||||
BitsAndBytesModelLoader)
|
||||
@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
|
||||
from vllm.model_executor.model_loader.utils import (
|
||||
get_architecture_class_name, get_model_architecture, get_model_cls)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Reminder: Please update docstring in `LoadConfig`
|
||||
# if a new load format is added here
|
||||
LoadFormats = Literal[
|
||||
"auto",
|
||||
"bitsandbytes",
|
||||
"dummy",
|
||||
"fastsafetensors",
|
||||
"gguf",
|
||||
"mistral",
|
||||
"npcache",
|
||||
"pt",
|
||||
"runai_streamer",
|
||||
"runai_streamer_sharded",
|
||||
"safetensors",
|
||||
"sharded_state",
|
||||
"tensorizer",
|
||||
]
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
|
||||
"auto": DefaultModelLoader,
|
||||
"bitsandbytes": BitsAndBytesModelLoader,
|
||||
"dummy": DummyModelLoader,
|
||||
"fastsafetensors": DefaultModelLoader,
|
||||
"gguf": GGUFModelLoader,
|
||||
"mistral": DefaultModelLoader,
|
||||
"npcache": DefaultModelLoader,
|
||||
"pt": DefaultModelLoader,
|
||||
"runai_streamer": RunaiModelStreamerLoader,
|
||||
"runai_streamer_sharded": ShardedStateLoader,
|
||||
"safetensors": DefaultModelLoader,
|
||||
"sharded_state": ShardedStateLoader,
|
||||
"tensorizer": TensorizerLoader,
|
||||
}
|
||||
|
||||
|
||||
def register_model_loader(load_format: str):
|
||||
"""Register a customized vllm model loader.
|
||||
|
||||
When a load format is not supported by vllm, you can register a customized
|
||||
model loader to support it.
|
||||
|
||||
Args:
|
||||
load_format (str): The model loader format name.
|
||||
|
||||
Examples:
|
||||
>>> from vllm.config import LoadConfig
|
||||
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
|
||||
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
>>>
|
||||
>>> @register_model_loader("my_loader")
|
||||
... class MyModelLoader(BaseModelLoader):
|
||||
... def download_model(self):
|
||||
... pass
|
||||
...
|
||||
... def load_weights(self):
|
||||
... pass
|
||||
>>>
|
||||
>>> load_config = LoadConfig(load_format="my_loader")
|
||||
>>> type(get_model_loader(load_config))
|
||||
<class 'MyModelLoader'>
|
||||
""" # noqa: E501
|
||||
|
||||
def _wrapper(model_loader_cls):
|
||||
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
logger.warning(
|
||||
"Load format `%s` is already registered, and will be "
|
||||
"overwritten by the new loader class `%s`.", load_format,
|
||||
model_loader_cls)
|
||||
if not issubclass(model_loader_cls, BaseModelLoader):
|
||||
raise ValueError("The model loader must be a subclass of "
|
||||
"`BaseModelLoader`.")
|
||||
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
|
||||
logger.info("Registered model loader `%s` with load format `%s`",
|
||||
model_loader_cls, load_format)
|
||||
return model_loader_cls
|
||||
|
||||
return _wrapper
|
||||
|
||||
|
||||
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
||||
"""Get a model loader based on the load format."""
|
||||
if isinstance(load_config.load_format, type):
|
||||
return load_config.load_format(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.DUMMY:
|
||||
return DummyModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.TENSORIZER:
|
||||
return TensorizerLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.SHARDED_STATE:
|
||||
return ShardedStateLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.BITSANDBYTES:
|
||||
return BitsAndBytesModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.GGUF:
|
||||
return GGUFModelLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
|
||||
return RunaiModelStreamerLoader(load_config)
|
||||
|
||||
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
|
||||
return ShardedStateLoader(load_config, runai_model_streamer=True)
|
||||
|
||||
return DefaultModelLoader(load_config)
|
||||
load_format = load_config.load_format
|
||||
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
|
||||
raise ValueError(f"Load format `{load_format}` is not supported")
|
||||
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
|
||||
|
||||
|
||||
def get_model(*,
|
||||
@ -66,6 +125,7 @@ __all__ = [
|
||||
"get_architecture_class_name",
|
||||
"get_model_architecture",
|
||||
"get_model_cls",
|
||||
"register_model_loader",
|
||||
"BaseModelLoader",
|
||||
"BitsAndBytesModelLoader",
|
||||
"GGUFModelLoader",
|
||||
|
@ -13,7 +13,7 @@ from torch import nn
|
||||
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config import LoadConfig, LoadFormat, ModelConfig
|
||||
from vllm.config import LoadConfig, ModelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.weight_utils import (
|
||||
@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
use_safetensors = False
|
||||
index_file = SAFE_WEIGHTS_INDEX_NAME
|
||||
# Some quantized models use .pt files for storing the weights.
|
||||
if load_format == LoadFormat.AUTO:
|
||||
if load_format == "auto":
|
||||
allow_patterns = ["*.safetensors", "*.bin"]
|
||||
elif (load_format == LoadFormat.SAFETENSORS
|
||||
or load_format == LoadFormat.FASTSAFETENSORS):
|
||||
elif (load_format == "safetensors"
|
||||
or load_format == "fastsafetensors"):
|
||||
use_safetensors = True
|
||||
allow_patterns = ["*.safetensors"]
|
||||
elif load_format == LoadFormat.MISTRAL:
|
||||
elif load_format == "mistral":
|
||||
use_safetensors = True
|
||||
allow_patterns = ["consolidated*.safetensors"]
|
||||
index_file = "consolidated.safetensors.index.json"
|
||||
elif load_format == LoadFormat.PT:
|
||||
elif load_format == "pt":
|
||||
allow_patterns = ["*.pt"]
|
||||
elif load_format == LoadFormat.NPCACHE:
|
||||
elif load_format == "npcache":
|
||||
allow_patterns = ["*.bin"]
|
||||
else:
|
||||
raise ValueError(f"Unknown load_format: {load_format}")
|
||||
@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides)
|
||||
if self.load_config.load_format == LoadFormat.NPCACHE:
|
||||
if self.load_config.load_format == "npcache":
|
||||
# Currently np_cache only support *.bin checkpoints
|
||||
assert use_safetensors is False
|
||||
weights_iterator = np_cache_weights_iterator(
|
||||
@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
elif use_safetensors:
|
||||
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
|
||||
if self.load_config.load_format == "fastsafetensors":
|
||||
weights_iterator = fastsafetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
|
@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
||||
|
||||
def __init__(self,
|
||||
load_config: LoadConfig,
|
||||
runai_model_streamer: bool = False):
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
|
||||
self.runai_model_streamer = runai_model_streamer
|
||||
extra_config = ({} if load_config.model_loader_extra_config is None
|
||||
else load_config.model_loader_extra_config.copy())
|
||||
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
|
||||
@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
|
||||
|
||||
def iterate_over_files(
|
||||
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
if self.runai_model_streamer:
|
||||
if self.load_config.load_format == "runai_streamer_sharded":
|
||||
yield from runai_safetensors_weights_iterator(paths, True)
|
||||
else:
|
||||
from safetensors.torch import safe_open
|
||||
|
Reference in New Issue
Block a user