[Feature] Use pydantic validation in lora.py and load.py configs (#26413)

Signed-off-by: simondanielsson <simon.danielsson99@hotmail.com>
This commit is contained in:
Simon Danielsson
2025-10-09 11:38:33 +02:00
committed by GitHub
parent e6e898f95d
commit e4791438ed
4 changed files with 48 additions and 45 deletions

View File

@ -2,9 +2,9 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
from dataclasses import field
from typing import TYPE_CHECKING, Any, Optional, Union
from pydantic import Field, field_validator
from pydantic.dataclasses import dataclass
from vllm.config.utils import config
@ -64,7 +64,7 @@ class LoadConfig:
was quantized using torchao and saved using safetensors.
Needs torchao >= 0.14.0
"""
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
model_loader_extra_config: Union[dict, TensorizerConfig] = Field(
default_factory=dict
)
"""Extra config for model loader. This will be passed to the model loader
@ -72,7 +72,9 @@ class LoadConfig:
device: Optional[str] = None
"""Device to which model weights will be loaded, default to
device_config.device"""
ignore_patterns: Optional[Union[list[str], str]] = None
ignore_patterns: Union[list[str], str] = Field(
default_factory=lambda: ["original/**/*"]
)
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
use_tqdm_on_load: bool = True
@ -107,12 +109,18 @@ class LoadConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
self.load_format = self.load_format.lower()
if self.ignore_patterns is not None and len(self.ignore_patterns) > 0:
@field_validator("load_format", mode="after")
def _lowercase_load_format(cls, load_format: str) -> str:
return load_format.lower()
@field_validator("ignore_patterns", mode="after")
def _validate_ignore_patterns(
cls, ignore_patterns: Union[list[str], str]
) -> Union[list[str], str]:
if ignore_patterns != ["original/**/*"] and len(ignore_patterns) > 0:
logger.info(
"Ignoring the following patterns when downloading weights: %s",
self.ignore_patterns,
ignore_patterns,
)
else:
self.ignore_patterns = ["original/**/*"]
return ignore_patterns

View File

@ -5,8 +5,9 @@ import hashlib
from typing import TYPE_CHECKING, Any, ClassVar, Literal, Optional, Union
import torch
from pydantic import ConfigDict
from pydantic import ConfigDict, Field, model_validator
from pydantic.dataclasses import dataclass
from typing_extensions import Self
import vllm.envs as envs
from vllm.config.utils import config
@ -23,6 +24,8 @@ else:
logger = init_logger(__name__)
LoRADType = Literal["auto", "float16", "bfloat16"]
MaxLoRARanks = Literal[1, 8, 16, 32, 64, 128, 256, 320, 512]
LoRAExtraVocabSize = Literal[256, 512]
@config
@ -30,9 +33,9 @@ LoRADType = Literal["auto", "float16", "bfloat16"]
class LoRAConfig:
"""Configuration for LoRA."""
max_lora_rank: int = 16
max_lora_rank: MaxLoRARanks = 16
"""Max LoRA rank."""
max_loras: int = 1
max_loras: int = Field(default=1, ge=1)
"""Max number of LoRAs in a single batch."""
fully_sharded_loras: bool = False
"""By default, only half of the LoRA computation is sharded with tensor
@ -44,7 +47,14 @@ class LoRAConfig:
`max_loras`."""
lora_dtype: Union[torch.dtype, LoRADType] = "auto"
"""Data type for LoRA. If auto, will default to base model dtype."""
lora_extra_vocab_size: int = 256
lora_extra_vocab_size: LoRAExtraVocabSize = Field(
default=256,
deprecated=(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out."
),
)
"""(Deprecated) Maximum size of extra vocabulary that can be present in a
LoRA adapter. Will be removed in v0.12.0."""
lora_vocab_padding_size: ClassVar[int] = (
@ -60,7 +70,10 @@ class LoRAConfig:
per prompt. When run in offline mode, the lora IDs for n modalities
will be automatically assigned to 1-n with the names of the modalities
in alphabetic order."""
bias_enabled: bool = False
bias_enabled: bool = Field(
default=False,
deprecated="`bias_enabled` is deprecated and will be removed in v0.12.0.",
)
"""[DEPRECATED] Enable bias for LoRA adapters. This option will be
removed in v0.12.0."""
@ -87,36 +100,8 @@ class LoRAConfig:
hash_str = hashlib.md5(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
def __post_init__(self):
# Deprecation warning for lora_extra_vocab_size
logger.warning(
"`lora_extra_vocab_size` is deprecated and will be removed "
"in v0.12.0. Additional vocabulary support for "
"LoRA adapters is being phased out."
)
# Deprecation warning for enable_lora_bias
if self.bias_enabled:
logger.warning(
"`enable_lora_bias` is deprecated and will be removed in v0.12.0."
)
# Setting the maximum rank to 512 should be able to satisfy the vast
# majority of applications.
possible_max_ranks = (1, 8, 16, 32, 64, 128, 256, 320, 512)
possible_lora_extra_vocab_size = (256, 512)
if self.max_lora_rank not in possible_max_ranks:
raise ValueError(
f"max_lora_rank ({self.max_lora_rank}) must be one of "
f"{possible_max_ranks}."
)
if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size:
raise ValueError(
f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) "
f"must be one of {possible_lora_extra_vocab_size}."
)
if self.max_loras < 1:
raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.")
@model_validator(mode="after")
def _validate_lora_config(self) -> Self:
if self.max_cpu_loras is None:
self.max_cpu_loras = self.max_loras
elif self.max_cpu_loras < self.max_loras:
@ -125,6 +110,8 @@ class LoRAConfig:
f"max_loras ({self.max_loras})"
)
return self
def verify_with_cache_config(self, cache_config: CacheConfig):
if cache_config.cpu_offload_gb > 0 and not envs.VLLM_USE_V1:
raise ValueError("V0 LoRA does not support CPU offload, please use V1.")

View File

@ -11,6 +11,7 @@ from itertools import pairwise
from typing import TYPE_CHECKING, Any, Protocol, TypeVar
import regex as re
from pydantic.fields import FieldInfo
from typing_extensions import runtime_checkable
if TYPE_CHECKING:
@ -50,7 +51,14 @@ def get_field(cls: ConfigType, name: str) -> Field:
if (default_factory := named_field.default_factory) is not MISSING:
return field(default_factory=default_factory)
if (default := named_field.default) is not MISSING:
if isinstance(default, FieldInfo):
# Handle pydantic.Field defaults
if default.default_factory is not None:
return field(default_factory=default.default_factory)
else:
default = default.default
return field(default=default)
raise ValueError(
f"{cls.__name__}.{name} must have a default value or default factory."
)

View File

@ -452,7 +452,7 @@ class EngineArgs:
num_gpu_blocks_override: Optional[int] = CacheConfig.num_gpu_blocks_override
num_lookahead_slots: int = SchedulerConfig.num_lookahead_slots
model_loader_extra_config: dict = get_field(LoadConfig, "model_loader_extra_config")
ignore_patterns: Optional[Union[str, list[str]]] = LoadConfig.ignore_patterns
ignore_patterns: Union[str, list[str]] = get_field(LoadConfig, "ignore_patterns")
enable_chunked_prefill: Optional[bool] = SchedulerConfig.enable_chunked_prefill
disable_chunked_mm_input: bool = SchedulerConfig.disable_chunked_mm_input