[Core] Support model loader plugins (#21067)

Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
This commit is contained in:
22quinn
2025-07-24 01:49:44 -07:00
committed by GitHub
parent f0f4de8f26
commit 610852a423
9 changed files with 159 additions and 86 deletions

View File

@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import SamplingParams
from vllm.config import LoadFormat
test_model = "openai-community/gpt2"
@ -17,7 +16,6 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def test_model_loader_download_files(vllm_runner):
with vllm_runner(test_model,
load_format=LoadFormat.FASTSAFETENSORS) as llm:
with vllm_runner(test_model, load_format="fastsafetensors") as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs

View File

@ -0,0 +1,37 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import pytest
from torch import nn
from vllm.config import LoadConfig, ModelConfig
from vllm.model_executor.model_loader import (get_model_loader,
register_model_loader)
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
@register_model_loader("custom_load_format")
class CustomModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig) -> None:
super().__init__(load_config)
def download_model(self, model_config: ModelConfig) -> None:
pass
def load_weights(self, model: nn.Module,
model_config: ModelConfig) -> None:
pass
def test_register_model_loader():
load_config = LoadConfig(load_format="custom_load_format")
assert isinstance(get_model_loader(load_config), CustomModelLoader)
def test_invalid_model_loader():
with pytest.raises(ValueError):
@register_model_loader("invalid_load_format")
class InValidModelLoader:
pass

View File

@ -2,9 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm import SamplingParams
from vllm.config import LoadConfig, LoadFormat
from vllm.config import LoadConfig
from vllm.model_executor.model_loader import get_model_loader
load_format = "runai_streamer"
test_model = "openai-community/gpt2"
prompts = [
@ -18,7 +19,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, seed=0)
def get_runai_model_loader():
load_config = LoadConfig(load_format=LoadFormat.RUNAI_STREAMER)
load_config = LoadConfig(load_format=load_format)
return get_model_loader(load_config)
@ -28,6 +29,6 @@ def test_get_model_loader_with_runai_flag():
def test_runai_model_loader_download_files(vllm_runner):
with vllm_runner(test_model, load_format=LoadFormat.RUNAI_STREAMER) as llm:
with vllm_runner(test_model, load_format=load_format) as llm:
deserialized_outputs = llm.generate(prompts, sampling_params)
assert deserialized_outputs

View File

@ -65,7 +65,7 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader import LoadFormats
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
ConfigType = type[DataclassInstance]
@ -78,6 +78,7 @@ else:
QuantizationConfig = Any
QuantizationMethods = Any
BaseModelLoader = Any
LoadFormats = Any
TensorizerConfig = Any
ConfigType = type
HfOverrides = Union[dict[str, Any], Callable[[type], type]]
@ -1773,29 +1774,12 @@ class CacheConfig:
logger.warning("Possibly too large swap space. %s", msg)
class LoadFormat(str, enum.Enum):
AUTO = "auto"
PT = "pt"
SAFETENSORS = "safetensors"
NPCACHE = "npcache"
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"
MISTRAL = "mistral"
RUNAI_STREAMER = "runai_streamer"
RUNAI_STREAMER_SHARDED = "runai_streamer_sharded"
FASTSAFETENSORS = "fastsafetensors"
@config
@dataclass
class LoadConfig:
"""Configuration for loading the model weights."""
load_format: Union[str, LoadFormat,
"BaseModelLoader"] = LoadFormat.AUTO.value
load_format: Union[str, LoadFormats] = "auto"
"""The format of the model weights to load:\n
- "auto" will try to load the weights in the safetensors format and fall
back to the pytorch bin format if safetensors format is not available.\n
@ -1816,7 +1800,8 @@ class LoadConfig:
- "gguf" will load weights from GGUF format files (details specified in
https://github.com/ggml-org/ggml/blob/master/docs/gguf.md).\n
- "mistral" will load weights from consolidated safetensors files used by
Mistral models."""
Mistral models.
- Other custom values can be supported via plugins."""
download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default
cache directory of Hugging Face."""
@ -1864,10 +1849,7 @@ class LoadConfig:
return hash_str
def __post_init__(self):
if isinstance(self.load_format, str):
load_format = self.load_format.lower()
self.load_format = LoadFormat(load_format)
self.load_format = self.load_format.lower()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",

View File

@ -26,13 +26,12 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
DetailedTraceModules, Device, DeviceConfig,
DistributedExecutorBackend, GuidedDecodingBackend,
GuidedDecodingBackendV1, HfOverrides, KVEventsConfig,
KVTransferConfig, LoadConfig, LoadFormat,
LogprobsMode, LoRAConfig, ModelConfig, ModelDType,
ModelImpl, MultiModalConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, VllmConfig, get_attr_docs,
get_field)
KVTransferConfig, LoadConfig, LogprobsMode,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo, SchedulerConfig,
SchedulerPolicy, SpeculativeConfig, TaskOption,
TokenizerMode, VllmConfig, get_attr_docs, get_field)
from vllm.logger import init_logger
from vllm.platforms import CpuArchEnum, current_platform
from vllm.plugins import load_general_plugins
@ -47,10 +46,12 @@ from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
if TYPE_CHECKING:
from vllm.executor.executor_base import ExecutorBase
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.model_loader import LoadFormats
from vllm.usage.usage_lib import UsageContext
else:
ExecutorBase = Any
QuantizationMethods = Any
LoadFormats = Any
UsageContext = Any
logger = init_logger(__name__)
@ -276,7 +277,7 @@ class EngineArgs:
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
download_dir: Optional[str] = LoadConfig.download_dir
load_format: str = LoadConfig.load_format
load_format: Union[str, LoadFormats] = LoadConfig.load_format
config_format: str = ModelConfig.config_format
dtype: ModelDType = ModelConfig.dtype
kv_cache_dtype: CacheDType = CacheConfig.cache_dtype
@ -547,9 +548,7 @@ class EngineArgs:
title="LoadConfig",
description=LoadConfig.__doc__,
)
load_group.add_argument("--load-format",
choices=[f.value for f in LoadFormat],
**load_kwargs["load_format"])
load_group.add_argument("--load-format", **load_kwargs["load_format"])
load_group.add_argument("--download-dir",
**load_kwargs["download_dir"])
load_group.add_argument("--model-loader-extra-config",
@ -864,10 +863,9 @@ class EngineArgs:
# NOTE: This is to allow model loading from S3 in CI
if (not isinstance(self, AsyncEngineArgs) and envs.VLLM_CI_USE_S3
and self.model in MODELS_ON_S3
and self.load_format == LoadFormat.AUTO): # noqa: E501
and self.model in MODELS_ON_S3 and self.load_format == "auto"):
self.model = f"{MODEL_WEIGHTS_S3_BUCKET}/{self.model}"
self.load_format = LoadFormat.RUNAI_STREAMER
self.load_format = "runai_streamer"
return ModelConfig(
model=self.model,
@ -1299,7 +1297,7 @@ class EngineArgs:
#############################################################
# Unsupported Feature Flags on V1.
if self.load_format == LoadFormat.SHARDED_STATE.value:
if self.load_format == "sharded_state":
_raise_or_fallback(
feature_name=f"--load_format {self.load_format}",
recommend_to_remove=False)

View File

@ -1,11 +1,12 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from typing import Literal, Optional
from torch import nn
from vllm.config import LoadConfig, LoadFormat, ModelConfig, VllmConfig
from vllm.config import LoadConfig, ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.bitsandbytes_loader import (
BitsAndBytesModelLoader)
@ -20,34 +21,92 @@ from vllm.model_executor.model_loader.tensorizer_loader import TensorizerLoader
from vllm.model_executor.model_loader.utils import (
get_architecture_class_name, get_model_architecture, get_model_cls)
logger = init_logger(__name__)
# Reminder: Please update docstring in `LoadConfig`
# if a new load format is added here
LoadFormats = Literal[
"auto",
"bitsandbytes",
"dummy",
"fastsafetensors",
"gguf",
"mistral",
"npcache",
"pt",
"runai_streamer",
"runai_streamer_sharded",
"safetensors",
"sharded_state",
"tensorizer",
]
_LOAD_FORMAT_TO_MODEL_LOADER: dict[str, type[BaseModelLoader]] = {
"auto": DefaultModelLoader,
"bitsandbytes": BitsAndBytesModelLoader,
"dummy": DummyModelLoader,
"fastsafetensors": DefaultModelLoader,
"gguf": GGUFModelLoader,
"mistral": DefaultModelLoader,
"npcache": DefaultModelLoader,
"pt": DefaultModelLoader,
"runai_streamer": RunaiModelStreamerLoader,
"runai_streamer_sharded": ShardedStateLoader,
"safetensors": DefaultModelLoader,
"sharded_state": ShardedStateLoader,
"tensorizer": TensorizerLoader,
}
def register_model_loader(load_format: str):
"""Register a customized vllm model loader.
When a load format is not supported by vllm, you can register a customized
model loader to support it.
Args:
load_format (str): The model loader format name.
Examples:
>>> from vllm.config import LoadConfig
>>> from vllm.model_executor.model_loader import get_model_loader, register_model_loader
>>> from vllm.model_executor.model_loader.base_loader import BaseModelLoader
>>>
>>> @register_model_loader("my_loader")
... class MyModelLoader(BaseModelLoader):
... def download_model(self):
... pass
...
... def load_weights(self):
... pass
>>>
>>> load_config = LoadConfig(load_format="my_loader")
>>> type(get_model_loader(load_config))
<class 'MyModelLoader'>
""" # noqa: E501
def _wrapper(model_loader_cls):
if load_format in _LOAD_FORMAT_TO_MODEL_LOADER:
logger.warning(
"Load format `%s` is already registered, and will be "
"overwritten by the new loader class `%s`.", load_format,
model_loader_cls)
if not issubclass(model_loader_cls, BaseModelLoader):
raise ValueError("The model loader must be a subclass of "
"`BaseModelLoader`.")
_LOAD_FORMAT_TO_MODEL_LOADER[load_format] = model_loader_cls
logger.info("Registered model loader `%s` with load format `%s`",
model_loader_cls, load_format)
return model_loader_cls
return _wrapper
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
"""Get a model loader based on the load format."""
if isinstance(load_config.load_format, type):
return load_config.load_format(load_config)
if load_config.load_format == LoadFormat.DUMMY:
return DummyModelLoader(load_config)
if load_config.load_format == LoadFormat.TENSORIZER:
return TensorizerLoader(load_config)
if load_config.load_format == LoadFormat.SHARDED_STATE:
return ShardedStateLoader(load_config)
if load_config.load_format == LoadFormat.BITSANDBYTES:
return BitsAndBytesModelLoader(load_config)
if load_config.load_format == LoadFormat.GGUF:
return GGUFModelLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER:
return RunaiModelStreamerLoader(load_config)
if load_config.load_format == LoadFormat.RUNAI_STREAMER_SHARDED:
return ShardedStateLoader(load_config, runai_model_streamer=True)
return DefaultModelLoader(load_config)
load_format = load_config.load_format
if load_format not in _LOAD_FORMAT_TO_MODEL_LOADER:
raise ValueError(f"Load format `{load_format}` is not supported")
return _LOAD_FORMAT_TO_MODEL_LOADER[load_format](load_config)
def get_model(*,
@ -66,6 +125,7 @@ __all__ = [
"get_architecture_class_name",
"get_model_architecture",
"get_model_cls",
"register_model_loader",
"BaseModelLoader",
"BitsAndBytesModelLoader",
"GGUFModelLoader",

View File

@ -13,7 +13,7 @@ from torch import nn
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm import envs
from vllm.config import LoadConfig, LoadFormat, ModelConfig
from vllm.config import LoadConfig, ModelConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.base_loader import BaseModelLoader
from vllm.model_executor.model_loader.weight_utils import (
@ -104,19 +104,19 @@ class DefaultModelLoader(BaseModelLoader):
use_safetensors = False
index_file = SAFE_WEIGHTS_INDEX_NAME
# Some quantized models use .pt files for storing the weights.
if load_format == LoadFormat.AUTO:
if load_format == "auto":
allow_patterns = ["*.safetensors", "*.bin"]
elif (load_format == LoadFormat.SAFETENSORS
or load_format == LoadFormat.FASTSAFETENSORS):
elif (load_format == "safetensors"
or load_format == "fastsafetensors"):
use_safetensors = True
allow_patterns = ["*.safetensors"]
elif load_format == LoadFormat.MISTRAL:
elif load_format == "mistral":
use_safetensors = True
allow_patterns = ["consolidated*.safetensors"]
index_file = "consolidated.safetensors.index.json"
elif load_format == LoadFormat.PT:
elif load_format == "pt":
allow_patterns = ["*.pt"]
elif load_format == LoadFormat.NPCACHE:
elif load_format == "npcache":
allow_patterns = ["*.bin"]
else:
raise ValueError(f"Unknown load_format: {load_format}")
@ -178,7 +178,7 @@ class DefaultModelLoader(BaseModelLoader):
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
if self.load_config.load_format == LoadFormat.NPCACHE:
if self.load_config.load_format == "npcache":
# Currently np_cache only support *.bin checkpoints
assert use_safetensors is False
weights_iterator = np_cache_weights_iterator(
@ -189,7 +189,7 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
if self.load_config.load_format == LoadFormat.FASTSAFETENSORS:
if self.load_config.load_format == "fastsafetensors":
weights_iterator = fastsafetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,

View File

@ -32,12 +32,9 @@ class ShardedStateLoader(BaseModelLoader):
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
def __init__(self,
load_config: LoadConfig,
runai_model_streamer: bool = False):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
self.runai_model_streamer = runai_model_streamer
extra_config = ({} if load_config.model_loader_extra_config is None
else load_config.model_loader_extra_config.copy())
self.pattern = extra_config.pop("pattern", self.DEFAULT_PATTERN)
@ -152,7 +149,7 @@ class ShardedStateLoader(BaseModelLoader):
def iterate_over_files(
self, paths) -> Generator[tuple[str, torch.Tensor], None, None]:
if self.runai_model_streamer:
if self.load_config.load_format == "runai_streamer_sharded":
yield from runai_safetensors_weights_iterator(paths, True)
else:
from safetensors.torch import safe_open