Enable Pydantic mypy checks and convert configs to Pydantic dataclasses (#17599)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-05-28 13:46:04 +01:00
committed by GitHub
parent d781930f90
commit 4c2b38ce9e
11 changed files with 115 additions and 102 deletions

View File

@ -58,7 +58,7 @@ repos:
entry: tools/mypy.sh 0 "local" entry: tools/mypy.sh 0 "local"
language: python language: python
types: [python] types: [python]
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests] additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
stages: [pre-commit] # Don't run in CI stages: [pre-commit] # Don't run in CI
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward - id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
name: Run mypy for Python 3.9 name: Run mypy for Python 3.9

View File

@ -110,6 +110,7 @@ ignore = [
] ]
[tool.mypy] [tool.mypy]
plugins = ['pydantic.mypy']
ignore_missing_imports = true ignore_missing_imports = true
check_untyped_defs = true check_untyped_defs = true
follow_imports = "silent" follow_imports = "silent"

View File

@ -24,16 +24,16 @@ if current_platform.is_rocm():
MODELS = [ MODELS = [
ModelWithQuantization( ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"), quantization="gptq"),
] ]
else: else:
MODELS = [ MODELS = [
ModelWithQuantization( ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
quantization="AWQ"), quantization="awq"),
ModelWithQuantization( ModelWithQuantization(
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ", model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
quantization="GPTQ"), quantization="gptq"),
] ]
@ -100,7 +100,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
"#ff8050", "#ff8050",
"#ff8080", "#ff8080",
] ]
elif model.quantization == "AWQ": elif model.quantization == "awq":
expected_no_lora_output = [ expected_no_lora_output = [
"I'm sorry, I don't understand", "I'm sorry, I don't understand",
"I'm sorry, I don't understand", "I'm sorry, I don't understand",
@ -109,7 +109,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
"#f07700: A v", "#f07700: A v",
"#f00000: A v", "#f00000: A v",
] ]
elif model.quantization == "GPTQ": elif model.quantization == "gptq":
expected_no_lora_output = [ expected_no_lora_output = [
"I'm sorry, I don't have", "I'm sorry, I don't have",
"I'm sorry, I don't have", "I'm sorry, I don't have",
@ -122,7 +122,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
def expect_match(output, expected_output): def expect_match(output, expected_output):
# HACK: GPTQ lora outputs are just incredibly unstable. # HACK: GPTQ lora outputs are just incredibly unstable.
# Assert that the outputs changed. # Assert that the outputs changed.
if (model.quantization == "GPTQ" if (model.quantization == "gptq"
and expected_output is expected_lora_output): and expected_output is expected_lora_output):
assert output != expected_no_lora_output assert output != expected_no_lora_output
for i, o in enumerate(output): for i, o in enumerate(output):
@ -172,7 +172,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
model): model):
if num_gpus_available < 2: if num_gpus_available < 2:
pytest.skip(f"Not enough GPUs for tensor parallelism {2}") pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
if model.quantization == "GPTQ": if model.quantization == "gptq":
pytest.skip("GPTQ lora outputs are just incredibly unstable") pytest.skip("GPTQ lora outputs are just incredibly unstable")
llm_tp1 = vllm.LLM( llm_tp1 = vllm.LLM(
model=model.model_path, model=model.model_path,

View File

@ -173,7 +173,7 @@ def test_traces_with_detailed_steps(
llm = LLM( llm = LLM(
model=model, model=model,
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS, otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
collect_detailed_traces="all", collect_detailed_traces=["all"],
) )
prompts = ["This is a short prompt"] prompts = ["This is a short prompt"]
outputs = llm.generate(prompts, sampling_params=sampling_params) outputs = llm.generate(prompts, sampling_params=sampling_params)

View File

@ -11,8 +11,8 @@ import uuid
import warnings import warnings
from collections import Counter from collections import Counter
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields, from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
is_dataclass, replace) replace)
from functools import cached_property from functools import cached_property
from importlib.util import find_spec from importlib.util import find_spec
from pathlib import Path from pathlib import Path
@ -21,9 +21,12 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
import regex as re import regex as re
import torch import torch
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
model_validator)
from pydantic.dataclasses import dataclass
from torch.distributed import ProcessGroup, ReduceOp from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig from transformers import PretrainedConfig
from typing_extensions import deprecated from typing_extensions import deprecated, runtime_checkable
import vllm.envs as envs import vllm.envs as envs
from vllm import version from vllm import version
@ -57,10 +60,15 @@ if TYPE_CHECKING:
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.model_loader import BaseModelLoader from vllm.model_executor.model_loader import BaseModelLoader
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
ConfigType = type[DataclassInstance] ConfigType = type[DataclassInstance]
else: else:
PlacementGroup = Any
ExecutorBase = Any
QuantizationConfig = Any QuantizationConfig = Any
BaseModelLoader = Any
TensorizerConfig = Any
ConfigType = type ConfigType = type
logger = init_logger(__name__) logger = init_logger(__name__)
@ -92,6 +100,7 @@ HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
PretrainedConfig]] PretrainedConfig]]
@runtime_checkable
class SupportsHash(Protocol): class SupportsHash(Protocol):
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -223,7 +232,7 @@ ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
@config @config
@dataclass @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class ModelConfig: class ModelConfig:
"""Configuration for the model.""" """Configuration for the model."""
@ -236,7 +245,7 @@ class ModelConfig:
task, even if the same model can be used for multiple tasks. When the model task, even if the same model can be used for multiple tasks. When the model
only supports one task, "auto" can be used to select it; otherwise, you only supports one task, "auto" can be used to select it; otherwise, you
must specify explicitly which task to use.""" must specify explicitly which task to use."""
tokenizer: str = None # type: ignore tokenizer: SkipValidation[str] = None # type: ignore
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model """Name or path of the Hugging Face tokenizer to use. If unspecified, model
name or path will be used.""" name or path will be used."""
tokenizer_mode: TokenizerMode = "auto" tokenizer_mode: TokenizerMode = "auto"
@ -284,7 +293,7 @@ class ModelConfig:
"""The specific revision to use for the tokenizer on the Hugging Face Hub. """The specific revision to use for the tokenizer on the Hugging Face Hub.
It can be a branch name, a tag name, or a commit id. If unspecified, will It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version.""" use the default version."""
max_model_len: int = None # type: ignore max_model_len: SkipValidation[int] = None # type: ignore
"""Model context length (prompt and output). If unspecified, will be """Model context length (prompt and output). If unspecified, will be
automatically derived from the model config. automatically derived from the model config.
@ -602,6 +611,22 @@ class ModelConfig:
self._verify_cuda_graph() self._verify_cuda_graph()
self._verify_bnb_config() self._verify_bnb_config()
@field_validator("quantization", mode="before")
@classmethod
def validate_quantization_before(cls, value: Any) -> Any:
if isinstance(value, str):
return value.lower()
return value
@model_validator(mode="after")
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
if not isinstance(self.tokenizer, str):
raise ValueError("tokenizer must be a string after __post_init__.")
if not isinstance(self.max_model_len, int):
raise ValueError(
"max_model_len must be an integer after __post_init__.")
return self
@property @property
def registry(self): def registry(self):
return ModelRegistry return ModelRegistry
@ -823,8 +848,7 @@ class ModelConfig:
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas" "quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
] ]
if self.quantization is not None: if self.quantization is not None:
self.quantization = cast(QuantizationMethods, self.quantization = cast(QuantizationMethods, self.quantization)
self.quantization.lower())
# Parse quantization method from the HF model config, if available. # Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config() quant_cfg = self._parse_quant_hf_config()
@ -1397,7 +1421,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: BlockSize = None # type: ignore block_size: SkipValidation[BlockSize] = None # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on """Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128. sizes up to 32 are supported. On HPU devices, block size defaults to 128.
@ -1619,7 +1643,8 @@ class LoadConfig:
download_dir: Optional[str] = None download_dir: Optional[str] = None
"""Directory to download and load the weights, default to the default """Directory to download and load the weights, default to the default
cache directory of Hugging Face.""" cache directory of Hugging Face."""
model_loader_extra_config: dict = field(default_factory=dict) model_loader_extra_config: Union[dict, TensorizerConfig] = field(
default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader """Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format.""" corresponding to the chosen load_format."""
ignore_patterns: Optional[Union[list[str], str]] = None ignore_patterns: Optional[Union[list[str], str]] = None
@ -1929,19 +1954,19 @@ class SchedulerConfig:
runner_type: RunnerType = "generate" runner_type: RunnerType = "generate"
"""The runner type to launch for the model.""" """The runner type to launch for the model."""
max_num_batched_tokens: int = None # type: ignore max_num_batched_tokens: SkipValidation[int] = None # type: ignore
"""Maximum number of tokens to be processed in a single iteration. """Maximum number of tokens to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context.""" be set in `EngineArgs.create_engine_config` based on the usage context."""
max_num_seqs: int = None # type: ignore max_num_seqs: SkipValidation[int] = None # type: ignore
"""Maximum number of sequences to be processed in a single iteration. """Maximum number of sequences to be processed in a single iteration.
This config has no static default. If left unspecified by the user, it will This config has no static default. If left unspecified by the user, it will
be set in `EngineArgs.create_engine_config` based on the usage context.""" be set in `EngineArgs.create_engine_config` based on the usage context."""
max_model_len: int = None # type: ignore max_model_len: SkipValidation[int] = None # type: ignore
"""Maximum length of a sequence (including prompt and generated text). This """Maximum length of a sequence (including prompt and generated text). This
is primarily set in `ModelConfig` and that value should be manually is primarily set in `ModelConfig` and that value should be manually
duplicated here.""" duplicated here."""
@ -1980,7 +2005,7 @@ class SchedulerConfig:
"""Apply a delay (of delay factor multiplied by previous """Apply a delay (of delay factor multiplied by previous
prompt latency) before scheduling next prompt.""" prompt latency) before scheduling next prompt."""
enable_chunked_prefill: bool = None # type: ignore enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
"""If True, prefill requests can be chunked based """If True, prefill requests can be chunked based
on the remaining max_num_batched_tokens.""" on the remaining max_num_batched_tokens."""
@ -2202,7 +2227,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
@config @config
@dataclass @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class DeviceConfig: class DeviceConfig:
"""Configuration for the device to use for vLLM execution.""" """Configuration for the device to use for vLLM execution."""
@ -2260,8 +2285,8 @@ class DeviceConfig:
self.device = torch.device(self.device_type) self.device = torch.device(self.device_type)
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator", SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
"draft_model", "deepseek_mtp"] "mlp_speculator", "draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler", SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"] "typical_acceptance_sampler"]
@ -2272,8 +2297,7 @@ class SpeculativeConfig:
"""Configuration for speculative decoding.""" """Configuration for speculative decoding."""
# General speculative decoding control # General speculative decoding control
num_speculative_tokens: int = field(default=None, num_speculative_tokens: SkipValidation[int] = None # type: ignore
init=True) # type: ignore
"""The number of speculative tokens, if provided. It will default to the """The number of speculative tokens, if provided. It will default to the
number in the draft model config if present, otherwise, it is required.""" number in the draft model config if present, otherwise, it is required."""
model: Optional[str] = None model: Optional[str] = None
@ -2349,26 +2373,23 @@ class SpeculativeConfig:
"""Specifies the tree structure for speculative token generation. """Specifies the tree structure for speculative token generation.
""" """
# required configuration params passed from engine # required configuration params passed from engine
target_model_config: ModelConfig = field(default=None, target_model_config: SkipValidation[ModelConfig] = None # type: ignore
init=True) # type: ignore
"""The configuration of the target model.""" """The configuration of the target model."""
target_parallel_config: ParallelConfig = field(default=None, target_parallel_config: SkipValidation[
init=True) # type: ignore ParallelConfig] = None # type: ignore
"""The parallel configuration for the target model.""" """The parallel configuration for the target model."""
enable_chunked_prefill: bool = field(default=None, enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
init=True) # type: ignore
"""Whether vLLM is configured to use chunked prefill or not. Used for """Whether vLLM is configured to use chunked prefill or not. Used for
raising an error since it's not yet compatible with speculative decode.""" raising an error since it's not yet compatible with speculative decode."""
disable_log_stats: bool = field(default=None, init=True) # type: ignore disable_log_stats: SkipValidation[bool] = None # type: ignore
"""Whether to disable the periodic printing of stage times in speculative """Whether to disable the periodic printing of stage times in speculative
decoding.""" decoding."""
# params generated in the post-init stage # params generated in the post-init stage
draft_model_config: ModelConfig = field(default=None, draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
init=True) # type: ignore
"""The configuration of the draft model initialized internal.""" """The configuration of the draft model initialized internal."""
draft_parallel_config: ParallelConfig = field(default=None, draft_parallel_config: SkipValidation[
init=True) # type: ignore ParallelConfig] = None # type: ignore
"""The parallel configuration for the draft model initialized internal.""" """The parallel configuration for the draft model initialized internal."""
def compute_hash(self) -> str: def compute_hash(self) -> str:
@ -2766,7 +2787,7 @@ LoRADType = Literal["auto", "float16", "bfloat16"]
@config @config
@dataclass @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class LoRAConfig: class LoRAConfig:
"""Configuration for LoRA.""" """Configuration for LoRA."""
@ -2863,7 +2884,7 @@ class LoRAConfig:
@config @config
@dataclass @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class PromptAdapterConfig: class PromptAdapterConfig:
"""Configuration for PromptAdapters.""" """Configuration for PromptAdapters."""
@ -3892,17 +3913,11 @@ class CompilationConfig:
"pass_config", "pass_config",
"traced_files", "traced_files",
} }
include = dict() # The cast to string is necessary because Pydantic is mocked in docs
for k, v in asdict(self).items(): # builds and sphinx-argparse doesn't know the return type of decode()
if k in exclude: return str(
continue TypeAdapter(CompilationConfig).dump_json(
f = get_field(CompilationConfig, k) self, exclude=exclude, exclude_unset=True).decode())
if (d := f.default) is not MISSING and d == v:
continue
if (df := f.default_factory) is not MISSING and df() == v:
continue
include[k] = v
return json.dumps(include)
__str__ = __repr__ __str__ = __repr__
@ -3911,7 +3926,7 @@ class CompilationConfig:
"""Parse the CLI value for the compilation config.""" """Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]: if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value)) return cls(level=int(cli_value))
return cls(**json.loads(cli_value)) return TypeAdapter(CompilationConfig).validate_json(cli_value)
def __post_init__(self) -> None: def __post_init__(self) -> None:
count_none = self.custom_ops.count("none") count_none = self.custom_ops.count("none")
@ -4037,7 +4052,7 @@ class CompilationConfig:
@config @config
@dataclass @dataclass(config=ConfigDict(arbitrary_types_allowed=True))
class VllmConfig: class VllmConfig:
"""Dataclass which contains all vllm-related configuration. This """Dataclass which contains all vllm-related configuration. This
simplifies passing around the distinct configurations in the codebase. simplifies passing around the distinct configurations in the codebase.
@ -4294,9 +4309,6 @@ class VllmConfig:
"To workaround this limitation, vLLM will set 'ieee' input " "To workaround this limitation, vLLM will set 'ieee' input "
"precision for chunked prefill triton kernels.") "precision for chunked prefill triton kernels.")
if self.compilation_config is None:
self.compilation_config = CompilationConfig()
# async tp is built on top of sequence parallelism # async tp is built on top of sequence parallelism
# and requires it to be enabled. # and requires it to be enabled.
if self.compilation_config.pass_config.enable_async_tp: if self.compilation_config.pass_config.enable_async_tp:

View File

@ -14,6 +14,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
import regex as re import regex as re
import torch import torch
from pydantic import SkipValidation, TypeAdapter, ValidationError
from typing_extensions import TypeIs, deprecated from typing_extensions import TypeIs, deprecated
import vllm.envs as envs import vllm.envs as envs
@ -38,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
from vllm.transformers_utils.utils import check_gguf_file from vllm.transformers_utils.utils import check_gguf_file
from vllm.usage.usage_lib import UsageContext from vllm.usage.usage_lib import UsageContext
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser, from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
GiB_bytes, is_in_doc_build, is_in_ray_actor) GiB_bytes, is_in_ray_actor)
# yapf: enable # yapf: enable
@ -156,7 +157,8 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Get the set of possible types for the field # Get the set of possible types for the field
type_hints: set[TypeHint] = set() type_hints: set[TypeHint] = set()
if get_origin(field.type) in {Union, Annotated}: if get_origin(field.type) in {Union, Annotated}:
type_hints.update(get_args(field.type)) predicate = lambda arg: not isinstance(arg, SkipValidation)
type_hints.update(filter(predicate, get_args(field.type)))
else: else:
type_hints.add(field.type) type_hints.add(field.type)
@ -168,10 +170,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
if field.default is not MISSING: if field.default is not MISSING:
default = field.default default = field.default
elif field.default_factory is not MISSING: elif field.default_factory is not MISSING:
if is_dataclass(field.default_factory) and is_in_doc_build(): default = field.default_factory()
default = {}
else:
default = field.default_factory()
# Get the help text for the field # Get the help text for the field
name = field.name name = field.name
@ -189,12 +188,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n - `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n""" - `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
if dataclass_cls is not None: if dataclass_cls is not None:
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
# Special case for configs with a from_cli method def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
if hasattr(dataclass_cls, "from_cli"): try:
from_cli = dataclass_cls.from_cli if hasattr(cls, "from_cli"):
dataclass_init = lambda x, f=from_cli: f(x) return cls.from_cli(val)
kwargs[name]["type"] = dataclass_init return TypeAdapter(cls).validate_json(val)
except ValidationError as e:
raise argparse.ArgumentTypeError(repr(e)) from e
kwargs[name]["type"] = parse_dataclass
kwargs[name]["help"] += json_tip kwargs[name]["help"] += json_tip
elif contains_type(type_hints, bool): elif contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags # Creates --no-<name> and --<name> flags
@ -225,12 +228,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = human_readable_int kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float): elif contains_type(type_hints, float):
kwargs[name]["type"] = float kwargs[name]["type"] = float
elif contains_type(type_hints, elif (contains_type(type_hints, dict)
dict) and (contains_type(type_hints, str) or any( and (contains_type(type_hints, str)
is_not_builtin(th) for th in type_hints)): or any(is_not_builtin(th) for th in type_hints))):
kwargs[name]["type"] = union_dict_and_str kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict): elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = parse_type(json.loads) kwargs[name]["type"] = parse_type(json.loads)
kwargs[name]["help"] += json_tip kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str) elif (contains_type(type_hints, str)
@ -317,8 +319,7 @@ class EngineArgs:
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling") rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
rope_theta: Optional[float] = ModelConfig.rope_theta rope_theta: Optional[float] = ModelConfig.rope_theta
hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
hf_overrides: Optional[HfOverrides] = \ hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
get_field(ModelConfig, "hf_overrides")
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
quantization: Optional[QuantizationMethods] = ModelConfig.quantization quantization: Optional[QuantizationMethods] = ModelConfig.quantization
enforce_eager: bool = ModelConfig.enforce_eager enforce_eager: bool = ModelConfig.enforce_eager
@ -398,7 +399,8 @@ class EngineArgs:
get_field(ModelConfig, "override_neuron_config") get_field(ModelConfig, "override_neuron_config")
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \ override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
ModelConfig.override_pooler_config ModelConfig.override_pooler_config
compilation_config: Optional[CompilationConfig] = None compilation_config: CompilationConfig = \
get_field(VllmConfig, "compilation_config")
worker_cls: str = ParallelConfig.worker_cls worker_cls: str = ParallelConfig.worker_cls
worker_extension_cls: str = ParallelConfig.worker_extension_cls worker_extension_cls: str = ParallelConfig.worker_extension_cls
@ -413,7 +415,8 @@ class EngineArgs:
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
additional_config: Optional[Dict[str, Any]] = None additional_config: dict[str, Any] = \
get_field(VllmConfig, "additional_config")
enable_reasoning: Optional[bool] = None # DEPRECATED enable_reasoning: Optional[bool] = None # DEPRECATED
reasoning_parser: str = DecodingConfig.reasoning_backend reasoning_parser: str = DecodingConfig.reasoning_backend

View File

@ -207,6 +207,9 @@ class LLM:
if isinstance(worker_cls, type): if isinstance(worker_cls, type):
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls) kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
if hf_overrides is None:
hf_overrides = {}
if compilation_config is not None: if compilation_config is not None:
if isinstance(compilation_config, int): if isinstance(compilation_config, int):
compilation_config_instance = CompilationConfig( compilation_config_instance = CompilationConfig(
@ -218,7 +221,7 @@ class LLM:
else: else:
compilation_config_instance = compilation_config compilation_config_instance = compilation_config
else: else:
compilation_config_instance = None compilation_config_instance = CompilationConfig()
engine_args = EngineArgs( engine_args = EngineArgs(
model=model, model=model,

View File

@ -175,11 +175,15 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
type: Literal["function"] = "function" type: Literal["function"] = "function"
# extra="forbid" is a workaround to have kwargs as a field,
# see https://github.com/pydantic/pydantic/issues/3125
class LogitsProcessorConstructor(BaseModel): class LogitsProcessorConstructor(BaseModel):
qualname: str qualname: str
args: Optional[list[Any]] = None args: Optional[list[Any]] = None
kwargs: Optional[dict[str, Any]] = None kwargs: Optional[dict[str, Any]] = None
model_config = ConfigDict(extra="forbid")
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]] LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
@ -234,7 +238,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
response_format: Optional[AnyResponseFormat] = None response_format: Optional[AnyResponseFormat] = None
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stop: Optional[Union[str, list[str]]] = []
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
temperature: Optional[float] = None temperature: Optional[float] = None
@ -258,7 +262,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
min_p: Optional[float] = None min_p: Optional[float] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
length_penalty: float = 1.0 length_penalty: float = 1.0
stop_token_ids: Optional[list[int]] = Field(default_factory=list) stop_token_ids: Optional[list[int]] = []
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
min_tokens: int = 0 min_tokens: int = 0
@ -756,7 +760,7 @@ class CompletionRequest(OpenAIBaseModel):
n: int = 1 n: int = 1
presence_penalty: Optional[float] = 0.0 presence_penalty: Optional[float] = 0.0
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max) seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
stop: Optional[Union[str, list[str]]] = Field(default_factory=list) stop: Optional[Union[str, list[str]]] = []
stream: Optional[bool] = False stream: Optional[bool] = False
stream_options: Optional[StreamOptions] = None stream_options: Optional[StreamOptions] = None
suffix: Optional[str] = None suffix: Optional[str] = None
@ -770,7 +774,7 @@ class CompletionRequest(OpenAIBaseModel):
min_p: Optional[float] = None min_p: Optional[float] = None
repetition_penalty: Optional[float] = None repetition_penalty: Optional[float] = None
length_penalty: float = 1.0 length_penalty: float = 1.0
stop_token_ids: Optional[list[int]] = Field(default_factory=list) stop_token_ids: Optional[list[int]] = []
include_stop_str_in_output: bool = False include_stop_str_in_output: bool = False
ignore_eos: bool = False ignore_eos: bool = False
min_tokens: int = 0 min_tokens: int = 0

View File

@ -134,11 +134,9 @@ class RequestProcessingMixin(BaseModel):
Mixin for request processing, Mixin for request processing,
handling prompt preparation and engine input. handling prompt preparation and engine input.
""" """
request_prompts: Optional[Sequence[RequestPrompt]] = \ request_prompts: Optional[Sequence[RequestPrompt]] = []
Field(default_factory=list)
engine_prompts: Optional[Union[list[EngineTokensPrompt], engine_prompts: Optional[Union[list[EngineTokensPrompt],
list[EngineEmbedsPrompt]]] = Field( list[EngineEmbedsPrompt]]] = []
default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True) model_config = ConfigDict(arbitrary_types_allowed=True)
@ -528,12 +526,14 @@ class OpenAIServing:
if isinstance(request, if isinstance(request,
(EmbeddingChatRequest, EmbeddingCompletionRequest, (EmbeddingChatRequest, EmbeddingCompletionRequest,
ScoreRequest, RerankRequest, ClassificationRequest)): ScoreRequest, RerankRequest, ClassificationRequest)):
operation = {
ScoreRequest: "score",
ClassificationRequest: "classification"
}.get(type(request), "embedding generation")
if token_num > self.max_model_len: if token_num > self.max_model_len:
operations: dict[type[AnyRequest], str] = {
ScoreRequest: "score",
ClassificationRequest: "classification"
}
operation = operations.get(type(request),
"embedding generation")
raise ValueError( raise ValueError(
f"This model's maximum context length is " f"This model's maximum context length is "
f"{self.max_model_len} tokens. However, you requested " f"{self.max_model_len} tokens. However, you requested "

View File

@ -3,12 +3,10 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, TypedDict, Union from typing import Optional, TypedDict, Union
from pydantic import BaseModel
# These classes are deprecated, see SamplingParams # These classes are deprecated, see SamplingParams
class LLMGuidedOptions(TypedDict, total=False): class LLMGuidedOptions(TypedDict, total=False):
guided_json: Union[dict, BaseModel, str] guided_json: Union[dict, str]
guided_regex: str guided_regex: str
guided_choice: list[str] guided_choice: list[str]
guided_grammar: str guided_grammar: str
@ -20,7 +18,7 @@ class LLMGuidedOptions(TypedDict, total=False):
@dataclass @dataclass
class GuidedDecodingRequest: class GuidedDecodingRequest:
"""One of the fields will be used to retrieve the logit processor.""" """One of the fields will be used to retrieve the logit processor."""
guided_json: Optional[Union[dict, BaseModel, str]] = None guided_json: Optional[Union[dict, str]] = None
guided_regex: Optional[str] = None guided_regex: Optional[str] = None
guided_choice: Optional[list[str]] = None guided_choice: Optional[list[str]] = None
guided_grammar: Optional[str] = None guided_grammar: Optional[str] = None

View File

@ -1878,14 +1878,6 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor) return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
def is_in_doc_build() -> bool:
try:
from sphinx.ext.autodoc.mock import _MockModule
return isinstance(zmq, _MockModule)
except ModuleNotFoundError:
return False
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]): def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
""" """
Import a Python file according to its file path. Import a Python file according to its file path.