Enable Pydantic mypy checks and convert configs to Pydantic dataclasses (#17599)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -58,7 +58,7 @@ repos:
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
language: python
|
||||
types: [python]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests]
|
||||
additional_dependencies: &mypy_deps [mypy==1.11.1, types-cachetools, types-setuptools, types-PyYAML, types-requests, pydantic]
|
||||
stages: [pre-commit] # Don't run in CI
|
||||
- id: mypy-3.9 # TODO: Use https://github.com/pre-commit/mirrors-mypy when mypy setup is less awkward
|
||||
name: Run mypy for Python 3.9
|
||||
|
@ -110,6 +110,7 @@ ignore = [
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
ignore_missing_imports = true
|
||||
check_untyped_defs = true
|
||||
follow_imports = "silent"
|
||||
|
@ -24,16 +24,16 @@ if current_platform.is_rocm():
|
||||
MODELS = [
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
quantization="gptq"),
|
||||
]
|
||||
else:
|
||||
MODELS = [
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ",
|
||||
quantization="AWQ"),
|
||||
quantization="awq"),
|
||||
ModelWithQuantization(
|
||||
model_path="TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ",
|
||||
quantization="GPTQ"),
|
||||
quantization="gptq"),
|
||||
]
|
||||
|
||||
|
||||
@ -100,7 +100,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
"#ff8050",
|
||||
"#ff8080",
|
||||
]
|
||||
elif model.quantization == "AWQ":
|
||||
elif model.quantization == "awq":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't understand",
|
||||
"I'm sorry, I don't understand",
|
||||
@ -109,7 +109,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
"#f07700: A v",
|
||||
"#f00000: A v",
|
||||
]
|
||||
elif model.quantization == "GPTQ":
|
||||
elif model.quantization == "gptq":
|
||||
expected_no_lora_output = [
|
||||
"I'm sorry, I don't have",
|
||||
"I'm sorry, I don't have",
|
||||
@ -122,7 +122,7 @@ def test_quant_model_lora(tinyllama_lora_files, model):
|
||||
def expect_match(output, expected_output):
|
||||
# HACK: GPTQ lora outputs are just incredibly unstable.
|
||||
# Assert that the outputs changed.
|
||||
if (model.quantization == "GPTQ"
|
||||
if (model.quantization == "gptq"
|
||||
and expected_output is expected_lora_output):
|
||||
assert output != expected_no_lora_output
|
||||
for i, o in enumerate(output):
|
||||
@ -172,7 +172,7 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
|
||||
model):
|
||||
if num_gpus_available < 2:
|
||||
pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||
if model.quantization == "GPTQ":
|
||||
if model.quantization == "gptq":
|
||||
pytest.skip("GPTQ lora outputs are just incredibly unstable")
|
||||
llm_tp1 = vllm.LLM(
|
||||
model=model.model_path,
|
||||
|
@ -173,7 +173,7 @@ def test_traces_with_detailed_steps(
|
||||
llm = LLM(
|
||||
model=model,
|
||||
otlp_traces_endpoint=FAKE_TRACE_SERVER_ADDRESS,
|
||||
collect_detailed_traces="all",
|
||||
collect_detailed_traces=["all"],
|
||||
)
|
||||
prompts = ["This is a short prompt"]
|
||||
outputs = llm.generate(prompts, sampling_params=sampling_params)
|
||||
|
108
vllm/config.py
108
vllm/config.py
@ -11,8 +11,8 @@ import uuid
|
||||
import warnings
|
||||
from collections import Counter
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import (MISSING, Field, asdict, dataclass, field, fields,
|
||||
is_dataclass, replace)
|
||||
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
|
||||
replace)
|
||||
from functools import cached_property
|
||||
from importlib.util import find_spec
|
||||
from pathlib import Path
|
||||
@ -21,9 +21,12 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import (ConfigDict, SkipValidation, TypeAdapter, field_validator,
|
||||
model_validator)
|
||||
from pydantic.dataclasses import dataclass
|
||||
from torch.distributed import ProcessGroup, ReduceOp
|
||||
from transformers import PretrainedConfig
|
||||
from typing_extensions import deprecated
|
||||
from typing_extensions import deprecated, runtime_checkable
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm import version
|
||||
@ -57,10 +60,15 @@ if TYPE_CHECKING:
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
from vllm.model_executor.model_loader import BaseModelLoader
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
|
||||
ConfigType = type[DataclassInstance]
|
||||
else:
|
||||
PlacementGroup = Any
|
||||
ExecutorBase = Any
|
||||
QuantizationConfig = Any
|
||||
BaseModelLoader = Any
|
||||
TensorizerConfig = Any
|
||||
ConfigType = type
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -92,6 +100,7 @@ HfOverrides = Union[dict[str, Any], Callable[[PretrainedConfig],
|
||||
PretrainedConfig]]
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class SupportsHash(Protocol):
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -223,7 +232,7 @@ ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class ModelConfig:
|
||||
"""Configuration for the model."""
|
||||
|
||||
@ -236,7 +245,7 @@ class ModelConfig:
|
||||
task, even if the same model can be used for multiple tasks. When the model
|
||||
only supports one task, "auto" can be used to select it; otherwise, you
|
||||
must specify explicitly which task to use."""
|
||||
tokenizer: str = None # type: ignore
|
||||
tokenizer: SkipValidation[str] = None # type: ignore
|
||||
"""Name or path of the Hugging Face tokenizer to use. If unspecified, model
|
||||
name or path will be used."""
|
||||
tokenizer_mode: TokenizerMode = "auto"
|
||||
@ -284,7 +293,7 @@ class ModelConfig:
|
||||
"""The specific revision to use for the tokenizer on the Hugging Face Hub.
|
||||
It can be a branch name, a tag name, or a commit id. If unspecified, will
|
||||
use the default version."""
|
||||
max_model_len: int = None # type: ignore
|
||||
max_model_len: SkipValidation[int] = None # type: ignore
|
||||
"""Model context length (prompt and output). If unspecified, will be
|
||||
automatically derived from the model config.
|
||||
|
||||
@ -602,6 +611,22 @@ class ModelConfig:
|
||||
self._verify_cuda_graph()
|
||||
self._verify_bnb_config()
|
||||
|
||||
@field_validator("quantization", mode="before")
|
||||
@classmethod
|
||||
def validate_quantization_before(cls, value: Any) -> Any:
|
||||
if isinstance(value, str):
|
||||
return value.lower()
|
||||
return value
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_model_config_after(self: "ModelConfig") -> "ModelConfig":
|
||||
if not isinstance(self.tokenizer, str):
|
||||
raise ValueError("tokenizer must be a string after __post_init__.")
|
||||
if not isinstance(self.max_model_len, int):
|
||||
raise ValueError(
|
||||
"max_model_len must be an integer after __post_init__.")
|
||||
return self
|
||||
|
||||
@property
|
||||
def registry(self):
|
||||
return ModelRegistry
|
||||
@ -823,8 +848,7 @@ class ModelConfig:
|
||||
"quark", "modelopt_fp4", "bitblas", "gptq_bitblas"
|
||||
]
|
||||
if self.quantization is not None:
|
||||
self.quantization = cast(QuantizationMethods,
|
||||
self.quantization.lower())
|
||||
self.quantization = cast(QuantizationMethods, self.quantization)
|
||||
|
||||
# Parse quantization method from the HF model config, if available.
|
||||
quant_cfg = self._parse_quant_hf_config()
|
||||
@ -1397,7 +1421,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
|
||||
class CacheConfig:
|
||||
"""Configuration for the KV cache."""
|
||||
|
||||
block_size: BlockSize = None # type: ignore
|
||||
block_size: SkipValidation[BlockSize] = None # type: ignore
|
||||
"""Size of a contiguous cache block in number of tokens. This is ignored on
|
||||
neuron devices and set to `--max-model-len`. On CUDA devices, only block
|
||||
sizes up to 32 are supported. On HPU devices, block size defaults to 128.
|
||||
@ -1619,7 +1643,8 @@ class LoadConfig:
|
||||
download_dir: Optional[str] = None
|
||||
"""Directory to download and load the weights, default to the default
|
||||
cache directory of Hugging Face."""
|
||||
model_loader_extra_config: dict = field(default_factory=dict)
|
||||
model_loader_extra_config: Union[dict, TensorizerConfig] = field(
|
||||
default_factory=dict)
|
||||
"""Extra config for model loader. This will be passed to the model loader
|
||||
corresponding to the chosen load_format."""
|
||||
ignore_patterns: Optional[Union[list[str], str]] = None
|
||||
@ -1929,19 +1954,19 @@ class SchedulerConfig:
|
||||
runner_type: RunnerType = "generate"
|
||||
"""The runner type to launch for the model."""
|
||||
|
||||
max_num_batched_tokens: int = None # type: ignore
|
||||
max_num_batched_tokens: SkipValidation[int] = None # type: ignore
|
||||
"""Maximum number of tokens to be processed in a single iteration.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||
|
||||
max_num_seqs: int = None # type: ignore
|
||||
max_num_seqs: SkipValidation[int] = None # type: ignore
|
||||
"""Maximum number of sequences to be processed in a single iteration.
|
||||
|
||||
This config has no static default. If left unspecified by the user, it will
|
||||
be set in `EngineArgs.create_engine_config` based on the usage context."""
|
||||
|
||||
max_model_len: int = None # type: ignore
|
||||
max_model_len: SkipValidation[int] = None # type: ignore
|
||||
"""Maximum length of a sequence (including prompt and generated text). This
|
||||
is primarily set in `ModelConfig` and that value should be manually
|
||||
duplicated here."""
|
||||
@ -1980,7 +2005,7 @@ class SchedulerConfig:
|
||||
"""Apply a delay (of delay factor multiplied by previous
|
||||
prompt latency) before scheduling next prompt."""
|
||||
|
||||
enable_chunked_prefill: bool = None # type: ignore
|
||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
||||
"""If True, prefill requests can be chunked based
|
||||
on the remaining max_num_batched_tokens."""
|
||||
|
||||
@ -2202,7 +2227,7 @@ Device = Literal["auto", "cuda", "neuron", "cpu", "tpu", "xpu", "hpu"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class DeviceConfig:
|
||||
"""Configuration for the device to use for vLLM execution."""
|
||||
|
||||
@ -2260,8 +2285,8 @@ class DeviceConfig:
|
||||
self.device = torch.device(self.device_type)
|
||||
|
||||
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
|
||||
"draft_model", "deepseek_mtp"]
|
||||
SpeculativeMethod = Literal["ngram", "eagle", "eagle3", "medusa",
|
||||
"mlp_speculator", "draft_model", "deepseek_mtp"]
|
||||
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
|
||||
"typical_acceptance_sampler"]
|
||||
|
||||
@ -2272,8 +2297,7 @@ class SpeculativeConfig:
|
||||
"""Configuration for speculative decoding."""
|
||||
|
||||
# General speculative decoding control
|
||||
num_speculative_tokens: int = field(default=None,
|
||||
init=True) # type: ignore
|
||||
num_speculative_tokens: SkipValidation[int] = None # type: ignore
|
||||
"""The number of speculative tokens, if provided. It will default to the
|
||||
number in the draft model config if present, otherwise, it is required."""
|
||||
model: Optional[str] = None
|
||||
@ -2349,26 +2373,23 @@ class SpeculativeConfig:
|
||||
"""Specifies the tree structure for speculative token generation.
|
||||
"""
|
||||
# required configuration params passed from engine
|
||||
target_model_config: ModelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
target_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the target model."""
|
||||
target_parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
target_parallel_config: SkipValidation[
|
||||
ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the target model."""
|
||||
enable_chunked_prefill: bool = field(default=None,
|
||||
init=True) # type: ignore
|
||||
enable_chunked_prefill: SkipValidation[bool] = None # type: ignore
|
||||
"""Whether vLLM is configured to use chunked prefill or not. Used for
|
||||
raising an error since it's not yet compatible with speculative decode."""
|
||||
disable_log_stats: bool = field(default=None, init=True) # type: ignore
|
||||
disable_log_stats: SkipValidation[bool] = None # type: ignore
|
||||
"""Whether to disable the periodic printing of stage times in speculative
|
||||
decoding."""
|
||||
|
||||
# params generated in the post-init stage
|
||||
draft_model_config: ModelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
draft_model_config: SkipValidation[ModelConfig] = None # type: ignore
|
||||
"""The configuration of the draft model initialized internal."""
|
||||
draft_parallel_config: ParallelConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
draft_parallel_config: SkipValidation[
|
||||
ParallelConfig] = None # type: ignore
|
||||
"""The parallel configuration for the draft model initialized internal."""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
@ -2766,7 +2787,7 @@ LoRADType = Literal["auto", "float16", "bfloat16"]
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class LoRAConfig:
|
||||
"""Configuration for LoRA."""
|
||||
|
||||
@ -2863,7 +2884,7 @@ class LoRAConfig:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class PromptAdapterConfig:
|
||||
"""Configuration for PromptAdapters."""
|
||||
|
||||
@ -3892,17 +3913,11 @@ class CompilationConfig:
|
||||
"pass_config",
|
||||
"traced_files",
|
||||
}
|
||||
include = dict()
|
||||
for k, v in asdict(self).items():
|
||||
if k in exclude:
|
||||
continue
|
||||
f = get_field(CompilationConfig, k)
|
||||
if (d := f.default) is not MISSING and d == v:
|
||||
continue
|
||||
if (df := f.default_factory) is not MISSING and df() == v:
|
||||
continue
|
||||
include[k] = v
|
||||
return json.dumps(include)
|
||||
# The cast to string is necessary because Pydantic is mocked in docs
|
||||
# builds and sphinx-argparse doesn't know the return type of decode()
|
||||
return str(
|
||||
TypeAdapter(CompilationConfig).dump_json(
|
||||
self, exclude=exclude, exclude_unset=True).decode())
|
||||
|
||||
__str__ = __repr__
|
||||
|
||||
@ -3911,7 +3926,7 @@ class CompilationConfig:
|
||||
"""Parse the CLI value for the compilation config."""
|
||||
if cli_value in ["0", "1", "2", "3"]:
|
||||
return cls(level=int(cli_value))
|
||||
return cls(**json.loads(cli_value))
|
||||
return TypeAdapter(CompilationConfig).validate_json(cli_value)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
count_none = self.custom_ops.count("none")
|
||||
@ -4037,7 +4052,7 @@ class CompilationConfig:
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
@dataclass(config=ConfigDict(arbitrary_types_allowed=True))
|
||||
class VllmConfig:
|
||||
"""Dataclass which contains all vllm-related configuration. This
|
||||
simplifies passing around the distinct configurations in the codebase.
|
||||
@ -4294,9 +4309,6 @@ class VllmConfig:
|
||||
"To workaround this limitation, vLLM will set 'ieee' input "
|
||||
"precision for chunked prefill triton kernels.")
|
||||
|
||||
if self.compilation_config is None:
|
||||
self.compilation_config = CompilationConfig()
|
||||
|
||||
# async tp is built on top of sequence parallelism
|
||||
# and requires it to be enabled.
|
||||
if self.compilation_config.pass_config.enable_async_tp:
|
||||
|
@ -14,6 +14,7 @@ from typing import (Annotated, Any, Callable, Dict, List, Literal, Optional,
|
||||
|
||||
import regex as re
|
||||
import torch
|
||||
from pydantic import SkipValidation, TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeIs, deprecated
|
||||
|
||||
import vllm.envs as envs
|
||||
@ -38,7 +39,7 @@ from vllm.test_utils import MODEL_WEIGHTS_S3_BUCKET, MODELS_ON_S3
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (STR_DUAL_CHUNK_FLASH_ATTN_VAL, FlexibleArgumentParser,
|
||||
GiB_bytes, is_in_doc_build, is_in_ray_actor)
|
||||
GiB_bytes, is_in_ray_actor)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@ -156,7 +157,8 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
# Get the set of possible types for the field
|
||||
type_hints: set[TypeHint] = set()
|
||||
if get_origin(field.type) in {Union, Annotated}:
|
||||
type_hints.update(get_args(field.type))
|
||||
predicate = lambda arg: not isinstance(arg, SkipValidation)
|
||||
type_hints.update(filter(predicate, get_args(field.type)))
|
||||
else:
|
||||
type_hints.add(field.type)
|
||||
|
||||
@ -168,10 +170,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
if field.default is not MISSING:
|
||||
default = field.default
|
||||
elif field.default_factory is not MISSING:
|
||||
if is_dataclass(field.default_factory) and is_in_doc_build():
|
||||
default = {}
|
||||
else:
|
||||
default = field.default_factory()
|
||||
default = field.default_factory()
|
||||
|
||||
# Get the help text for the field
|
||||
name = field.name
|
||||
@ -189,12 +188,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
- `--json-arg '{"key1": "value1", "key2": {"key3": "value2"}}'`\n
|
||||
- `--json-arg.key1 value1 --json-arg.key2.key3 value2`\n\n"""
|
||||
if dataclass_cls is not None:
|
||||
dataclass_init = lambda x, f=dataclass_cls: f(**json.loads(x))
|
||||
# Special case for configs with a from_cli method
|
||||
if hasattr(dataclass_cls, "from_cli"):
|
||||
from_cli = dataclass_cls.from_cli
|
||||
dataclass_init = lambda x, f=from_cli: f(x)
|
||||
kwargs[name]["type"] = dataclass_init
|
||||
|
||||
def parse_dataclass(val: str, cls=dataclass_cls) -> Any:
|
||||
try:
|
||||
if hasattr(cls, "from_cli"):
|
||||
return cls.from_cli(val)
|
||||
return TypeAdapter(cls).validate_json(val)
|
||||
except ValidationError as e:
|
||||
raise argparse.ArgumentTypeError(repr(e)) from e
|
||||
|
||||
kwargs[name]["type"] = parse_dataclass
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif contains_type(type_hints, bool):
|
||||
# Creates --no-<name> and --<name> flags
|
||||
@ -225,12 +228,11 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
|
||||
kwargs[name]["type"] = human_readable_int
|
||||
elif contains_type(type_hints, float):
|
||||
kwargs[name]["type"] = float
|
||||
elif contains_type(type_hints,
|
||||
dict) and (contains_type(type_hints, str) or any(
|
||||
is_not_builtin(th) for th in type_hints)):
|
||||
elif (contains_type(type_hints, dict)
|
||||
and (contains_type(type_hints, str)
|
||||
or any(is_not_builtin(th) for th in type_hints))):
|
||||
kwargs[name]["type"] = union_dict_and_str
|
||||
elif contains_type(type_hints, dict):
|
||||
# Dict arguments will always be optional
|
||||
kwargs[name]["type"] = parse_type(json.loads)
|
||||
kwargs[name]["help"] += json_tip
|
||||
elif (contains_type(type_hints, str)
|
||||
@ -317,8 +319,7 @@ class EngineArgs:
|
||||
rope_scaling: dict[str, Any] = get_field(ModelConfig, "rope_scaling")
|
||||
rope_theta: Optional[float] = ModelConfig.rope_theta
|
||||
hf_token: Optional[Union[bool, str]] = ModelConfig.hf_token
|
||||
hf_overrides: Optional[HfOverrides] = \
|
||||
get_field(ModelConfig, "hf_overrides")
|
||||
hf_overrides: HfOverrides = get_field(ModelConfig, "hf_overrides")
|
||||
tokenizer_revision: Optional[str] = ModelConfig.tokenizer_revision
|
||||
quantization: Optional[QuantizationMethods] = ModelConfig.quantization
|
||||
enforce_eager: bool = ModelConfig.enforce_eager
|
||||
@ -398,7 +399,8 @@ class EngineArgs:
|
||||
get_field(ModelConfig, "override_neuron_config")
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||
ModelConfig.override_pooler_config
|
||||
compilation_config: Optional[CompilationConfig] = None
|
||||
compilation_config: CompilationConfig = \
|
||||
get_field(VllmConfig, "compilation_config")
|
||||
worker_cls: str = ParallelConfig.worker_cls
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
@ -413,7 +415,8 @@ class EngineArgs:
|
||||
|
||||
calculate_kv_scales: bool = CacheConfig.calculate_kv_scales
|
||||
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
additional_config: dict[str, Any] = \
|
||||
get_field(VllmConfig, "additional_config")
|
||||
enable_reasoning: Optional[bool] = None # DEPRECATED
|
||||
reasoning_parser: str = DecodingConfig.reasoning_backend
|
||||
|
||||
|
@ -207,6 +207,9 @@ class LLM:
|
||||
if isinstance(worker_cls, type):
|
||||
kwargs["worker_cls"] = cloudpickle.dumps(worker_cls)
|
||||
|
||||
if hf_overrides is None:
|
||||
hf_overrides = {}
|
||||
|
||||
if compilation_config is not None:
|
||||
if isinstance(compilation_config, int):
|
||||
compilation_config_instance = CompilationConfig(
|
||||
@ -218,7 +221,7 @@ class LLM:
|
||||
else:
|
||||
compilation_config_instance = compilation_config
|
||||
else:
|
||||
compilation_config_instance = None
|
||||
compilation_config_instance = CompilationConfig()
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
|
@ -175,11 +175,15 @@ class ChatCompletionNamedToolChoiceParam(OpenAIBaseModel):
|
||||
type: Literal["function"] = "function"
|
||||
|
||||
|
||||
# extra="forbid" is a workaround to have kwargs as a field,
|
||||
# see https://github.com/pydantic/pydantic/issues/3125
|
||||
class LogitsProcessorConstructor(BaseModel):
|
||||
qualname: str
|
||||
args: Optional[list[Any]] = None
|
||||
kwargs: Optional[dict[str, Any]] = None
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
LogitsProcessors = list[Union[str, LogitsProcessorConstructor]]
|
||||
|
||||
@ -234,7 +238,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
response_format: Optional[AnyResponseFormat] = None
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
|
||||
stop: Optional[Union[str, list[str]]] = []
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
temperature: Optional[float] = None
|
||||
@ -258,7 +262,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
|
||||
min_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[list[int]] = Field(default_factory=list)
|
||||
stop_token_ids: Optional[list[int]] = []
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
@ -756,7 +760,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
n: int = 1
|
||||
presence_penalty: Optional[float] = 0.0
|
||||
seed: Optional[int] = Field(None, ge=_LONG_INFO.min, le=_LONG_INFO.max)
|
||||
stop: Optional[Union[str, list[str]]] = Field(default_factory=list)
|
||||
stop: Optional[Union[str, list[str]]] = []
|
||||
stream: Optional[bool] = False
|
||||
stream_options: Optional[StreamOptions] = None
|
||||
suffix: Optional[str] = None
|
||||
@ -770,7 +774,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
min_p: Optional[float] = None
|
||||
repetition_penalty: Optional[float] = None
|
||||
length_penalty: float = 1.0
|
||||
stop_token_ids: Optional[list[int]] = Field(default_factory=list)
|
||||
stop_token_ids: Optional[list[int]] = []
|
||||
include_stop_str_in_output: bool = False
|
||||
ignore_eos: bool = False
|
||||
min_tokens: int = 0
|
||||
|
@ -134,11 +134,9 @@ class RequestProcessingMixin(BaseModel):
|
||||
Mixin for request processing,
|
||||
handling prompt preparation and engine input.
|
||||
"""
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = \
|
||||
Field(default_factory=list)
|
||||
request_prompts: Optional[Sequence[RequestPrompt]] = []
|
||||
engine_prompts: Optional[Union[list[EngineTokensPrompt],
|
||||
list[EngineEmbedsPrompt]]] = Field(
|
||||
default_factory=list)
|
||||
list[EngineEmbedsPrompt]]] = []
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@ -528,12 +526,14 @@ class OpenAIServing:
|
||||
if isinstance(request,
|
||||
(EmbeddingChatRequest, EmbeddingCompletionRequest,
|
||||
ScoreRequest, RerankRequest, ClassificationRequest)):
|
||||
operation = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification"
|
||||
}.get(type(request), "embedding generation")
|
||||
|
||||
if token_num > self.max_model_len:
|
||||
operations: dict[type[AnyRequest], str] = {
|
||||
ScoreRequest: "score",
|
||||
ClassificationRequest: "classification"
|
||||
}
|
||||
operation = operations.get(type(request),
|
||||
"embedding generation")
|
||||
raise ValueError(
|
||||
f"This model's maximum context length is "
|
||||
f"{self.max_model_len} tokens. However, you requested "
|
||||
|
@ -3,12 +3,10 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, TypedDict, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# These classes are deprecated, see SamplingParams
|
||||
class LLMGuidedOptions(TypedDict, total=False):
|
||||
guided_json: Union[dict, BaseModel, str]
|
||||
guided_json: Union[dict, str]
|
||||
guided_regex: str
|
||||
guided_choice: list[str]
|
||||
guided_grammar: str
|
||||
@ -20,7 +18,7 @@ class LLMGuidedOptions(TypedDict, total=False):
|
||||
@dataclass
|
||||
class GuidedDecodingRequest:
|
||||
"""One of the fields will be used to retrieve the logit processor."""
|
||||
guided_json: Optional[Union[dict, BaseModel, str]] = None
|
||||
guided_json: Optional[Union[dict, str]] = None
|
||||
guided_regex: Optional[str] = None
|
||||
guided_choice: Optional[list[str]] = None
|
||||
guided_grammar: Optional[str] = None
|
||||
|
@ -1878,14 +1878,6 @@ def get_cuda_view_from_cpu_tensor(cpu_tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.ops._C.get_cuda_view_from_cpu_tensor(cpu_tensor)
|
||||
|
||||
|
||||
def is_in_doc_build() -> bool:
|
||||
try:
|
||||
from sphinx.ext.autodoc.mock import _MockModule
|
||||
return isinstance(zmq, _MockModule)
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
def import_from_path(module_name: str, file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Import a Python file according to its file path.
|
||||
|
Reference in New Issue
Block a user