mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Feature] Use pydantic validation in lora.py and load.py configs (#26413)
Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
@ -2,9 +2,9 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
@ -64,7 +64,7 @@ class LoadConfig:
|
||||
was quantized using torchao and saved using safetensors.
|
||||
Needs torchao >= 0.14.0
|
||||
"""
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = Field(
|
||||
default_factory=dict
|
||||
)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
@ -72,7 +72,9 @@ class LoadConfig:
|
||||
device: Optional[str] = None
|
||||
"""Device to which model weights will be loaded, default to
|
||||
device_config.device"""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
ignore_patterns: Union[list[str], str] = Field(
|
||||
default_factory=lambda: ["original/**/*"]
|
||||
)
|
||||
"""The list of patterns to ignore when loading the model. Default to
|
||||
"original/**/*" to avoid repeated loading of llama's checkpoints."""
|
||||
use_tqdm_on_load: bool = True
|
||||
@ -107,12 +109,18 @@ class LoadConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
self.load_format = self.load_format.lower()
|
||||
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
|
||||
@field_validator("load_format", mode="after")
|
||||
def _lowercase_load_format(cls, load_format: str) -> str:
|
||||
return load_format.lower()
|
||||
|
||||
@field_validator("ignore_patterns", mode="after")
|
||||
def _validate_ignore_patterns(
|
||||
cls, ignore_patterns: Union[list[str], str]
|
||||
) -> Union[list[str], str]:
|
||||
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
|
||||
logger.info(
|
||||
"Ignoring the following patterns when downloading weights: %s",
|
||||
self.ignore_patterns,
|
||||
ignore_patterns,
|
||||
)
|
||||
else:
|
||||
self.ignore_patterns = ["original/**/*"]
|
||||
|
||||
return ignore_patterns
|
||||
|
@ -5,8 +5,9 @@ import hashlib
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import ConfigDict, Field, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
from typing_extensions import Self
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config.utils import config
|
||||
@ -23,6 +24,8 @@ else:
|
||||
logger = init_logger(__name__)
|
||||
|
||||
LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
|
||||
LoRAExtraVocabSize = Literal[256, 512]
|
||||
|
||||
|
||||
@config
|
||||
@ -30,9 +33,9 @@ LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
max_lora_rank: int = 16
|
||||
max_lora_rank: MaxLoRARanks = 16
|
||||
"""Max LoRA rank."""
|
||||
max_loras: int = 1
|
||||
max_loras: int = Field(default=1, ge=1)
|
||||
"""Max number of LoRAs in a single batch."""
|
||||
fully_sharded_loras: bool = False
|
||||
"""By default, only half of the LoRA computation is sharded with tensor
|
||||
@ -44,7 +47,14 @@ class LoRAConfig:
|
||||
`max_loras`."""
|
||||
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
|
||||
"""Data type for LoRA. If auto, will default to base model dtype."""
|
||||
lora_extra_vocab_size: int = 256
|
||||
lora_extra_vocab_size: LoRAExtraVocabSize = Field(
|
||||
default=256,
|
||||
deprecated=(
|
||||
"`lora_extra_vocab_size` is deprecated and will be removed "
|
||||
"in v0.12.0. Additional vocabulary support for "
|
||||
"LoRA adapters is being phased out."
|
||||
),
|
||||
)
|
||||
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
|
||||
LoRA adapter. Will be removed in v0.12.0."""
|
||||
lora_vocab_padding_size: ClassVar[int] = (
|
||||
@ -60,7 +70,10 @@ class LoRAConfig:
|
||||
per prompt. When run in offline mode, the lora IDs for n modalities
|
||||
will be automatically assigned to 1-n with the names of the modalities
|
||||
in alphabetic order."""
|
||||
bias_enabled: bool = False
|
||||
bias_enabled: bool = Field(
|
||||
default=False,
|
||||
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
|
||||
)
|
||||
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
|
||||
removed in v0.12.0."""
|
||||
|
||||
@ -87,36 +100,8 @@ class LoRAConfig:
|
||||
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
||||
|
||||
def __post_init__(self):
|
||||
# Deprecation warning for lora_extra_vocab_size
|
||||
logger.warning(
|
||||
"`lora_extra_vocab_size` is deprecated and will be removed "
|
||||
"in v0.12.0. Additional vocabulary support for "
|
||||
"LoRA adapters is being phased out."
|
||||
)
|
||||
|
||||
# Deprecation warning for enable_lora_bias
|
||||
if self.bias_enabled:
|
||||
logger.warning(
|
||||
"`enable_lora_bias` is deprecated and will be removed in v0.12.0."
|
||||
)
|
||||
|
||||
# Setting the maximum rank to 512 should be able to satisfy the vast
|
||||
# majority of applications.
|
||||
possible_max_ranks = (1, 8, 16, 32, 64, 128, 256, 320, 512)
|
||||
possible_lora_extra_vocab_size = (256, 512)
|
||||
if self.max_lora_rank not in possible_max_ranks:
|
||||
raise ValueError(
|
||||
f"max_lora_rank ({self.max_lora_rank}) must be one of "
|
||||
f"{possible_max_ranks}."
|
||||
)
|
||||
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
|
||||
raise ValueError(
|
||||
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
|
||||
f"must be one of {possible_lora_extra_vocab_size}."
|
||||
)
|
||||
if self.max_loras < 1:
|
||||
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
|
||||
@model_validator(mode="after")
|
||||
def _validate_lora_config(self) -> Self:
|
||||
if self.max_cpu_loras is None:
|
||||
self.max_cpu_loras = self.max_loras
|
||||
elif self.max_cpu_loras < self.max_loras:
|
||||
@ -125,6 +110,8 @@ class LoRAConfig:
|
||||
f"max_loras ({self.max_loras})"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
def verify_with_cache_config(self, cache_config: CacheConfig):
|
||||
if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1:
|
||||
raise ValueError("V0 LoRA does not support CPU offload, please use V1.")
|
||||
|
@ -11,6 +11,7 @@ from itertools import pairwise
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
|
||||
|
||||
import regex as re
|
||||
from pydantic.fields import FieldInfo
|
||||
from typing_extensions import runtime_checkable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -50,7 +51,14 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
if (default_factory := named_field.default_factory) is not MISSING:
|
||||
return field(default_factory=default_factory)
|
||||
if (default := named_field.default) is not MISSING:
|
||||
if isinstance(default, FieldInfo):
|
||||
# Handle pydantic.Field defaults
|
||||
if default.default_factory is not None:
|
||||
return field(default_factory=default.default_factory)
|
||||
else:
|
||||
default = default.default
|
||||
return field(default=default)
|
||||
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory."
|
||||
)
|
||||
|
@ -452,7 +452,7 @@ class EngineArgs:
|
||||
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
|
||||
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
|
||||
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
|
||||
ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
|
||||
ignore_patterns: Union[str, list[str]] = get_field(LoadConfig, "ignore_patterns")
|
||||
|
||||
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
|
||||
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input
|
||||
|
Reference in New Issue
Block a user