[Perf][CLI] Improve overall startup time (#19941)

This commit is contained in:
Aaron Pham
2025-06-22 19:11:22 -04:00
committed by GitHub
parent 33d51f599e
commit c4cf260677
14 changed files with 293 additions and 103 deletions

View File

@ -115,6 +115,11 @@ repos:
entry: python tools/check_spdx_header.py
language: python
types: [python]
- id: check-root-lazy-imports
name: Check root lazy imports
entry: python tools/check_init_lazy_imports.py
language: python
types: [python]
- id: check-filenames
name: Check for spaces in all filenames
entry: bash

View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Ensure we perform lazy loading in vllm/__init__.py.
i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard,
**except** for a short whitelist.
"""
from __future__ import annotations
import ast
import pathlib
import sys
from collections.abc import Iterable
from typing import Final
REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent
INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py"
# If you need to add items to whitelist, do it here.
ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({
"vllm.env_override",
})
ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({
".version",
})
def _is_internal(name: str | None, *, level: int = 0) -> bool:
if level > 0:
return True
if name is None:
return False
return name.startswith("vllm.") or name == "vllm"
def _fail(violations: Iterable[tuple[int, str]]) -> None:
print("ERROR: Disallowed eager imports in vllm/__init__.py:\n",
file=sys.stderr)
for lineno, msg in violations:
print(f" Line {lineno}: {msg}", file=sys.stderr)
sys.exit(1)
def main() -> None:
source = INIT_PATH.read_text(encoding="utf-8")
tree = ast.parse(source, filename=str(INIT_PATH))
violations: list[tuple[int, str]] = []
class Visitor(ast.NodeVisitor):
def __init__(self) -> None:
super().__init__()
self._in_type_checking = False
def visit_If(self, node: ast.If) -> None:
guard_is_type_checking = False
test = node.test
if isinstance(test, ast.Attribute) and isinstance(
test.value, ast.Name):
guard_is_type_checking = (test.value.id == "typing"
and test.attr == "TYPE_CHECKING")
elif isinstance(test, ast.Name):
guard_is_type_checking = test.id == "TYPE_CHECKING"
if guard_is_type_checking:
prev = self._in_type_checking
self._in_type_checking = True
for child in node.body:
self.visit(child)
self._in_type_checking = prev
for child in node.orelse:
self.visit(child)
else:
self.generic_visit(node)
def visit_Import(self, node: ast.Import) -> None:
if self._in_type_checking:
return
for alias in node.names:
module_name = alias.name
if _is_internal(
module_name) and module_name not in ALLOWED_IMPORTS:
violations.append((
node.lineno,
f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501
))
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if self._in_type_checking:
return
module_as_written = ("." * node.level) + (node.module or "")
if _is_internal(
node.module, level=node.level
) and module_as_written not in ALLOWED_FROM_MODULES:
violations.append((
node.lineno,
f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501
))
Visitor().visit(tree)
if violations:
_fail(violations)
if __name__ == "__main__":
main()

View File

@ -1,29 +1,72 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""vLLM: a high-throughput and memory-efficient inference engine for LLMs"""
# The version.py should be independent library, and we always import the
# version library first. Such assumption is critical for some customization.
from .version import __version__, __version_tuple__ # isort:skip
import typing
# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import vllm.env_override # isort:skip # noqa: F401
import vllm.env_override # noqa: F401
MODULE_ATTRS = {
"AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs",
"EngineArgs": ".engine.arg_utils:EngineArgs",
"AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine",
"LLMEngine": ".engine.llm_engine:LLMEngine",
"LLM": ".entrypoints.llm:LLM",
"initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster",
"PromptType": ".inputs:PromptType",
"TextPrompt": ".inputs:TextPrompt",
"TokensPrompt": ".inputs:TokensPrompt",
"ModelRegistry": ".model_executor.models:ModelRegistry",
"SamplingParams": ".sampling_params:SamplingParams",
"PoolingParams": ".pooling_params:PoolingParams",
"ClassificationOutput": ".outputs:ClassificationOutput",
"ClassificationRequestOutput": ".outputs:ClassificationRequestOutput",
"CompletionOutput": ".outputs:CompletionOutput",
"EmbeddingOutput": ".outputs:EmbeddingOutput",
"EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput",
"PoolingOutput": ".outputs:PoolingOutput",
"PoolingRequestOutput": ".outputs:PoolingRequestOutput",
"RequestOutput": ".outputs:RequestOutput",
"ScoringOutput": ".outputs:ScoringOutput",
"ScoringRequestOutput": ".outputs:ScoringRequestOutput",
}
if typing.TYPE_CHECKING:
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput,
ClassificationRequestOutput, CompletionOutput,
EmbeddingOutput, EmbeddingRequestOutput,
PoolingOutput, PoolingRequestOutput,
RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
else:
def __getattr__(name: str) -> typing.Any:
from importlib import import_module
if name in MODULE_ATTRS:
module_name, attr_name = MODULE_ATTRS[name].split(":")
module = import_module(module_name, __package__)
return getattr(module, attr_name)
else:
raise AttributeError(
f'module {__package__} has no attribute {name}')
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.llm_engine import LLMEngine
from vllm.entrypoints.llm import LLM
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import PromptType, TextPrompt, TokensPrompt
from vllm.model_executor.models import ModelRegistry
from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput,
CompletionOutput, EmbeddingOutput,
EmbeddingRequestOutput, PoolingOutput,
PoolingRequestOutput, RequestOutput, ScoringOutput,
ScoringRequestOutput)
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import SamplingParams
__all__ = [
"__version__",

View File

@ -28,7 +28,7 @@ from pydantic.dataclasses import dataclass
from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.distributed import ProcessGroup, ReduceOp
from transformers import PretrainedConfig
from typing_extensions import deprecated, runtime_checkable
from typing_extensions import Self, deprecated, runtime_checkable
import vllm.envs as envs
from vllm import version
@ -1537,7 +1537,6 @@ class CacheConfig:
def __post_init__(self) -> None:
self.swap_space_bytes = self.swap_space * GiB_bytes
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
@ -1546,7 +1545,8 @@ class CacheConfig:
# metrics info
return {key: str(value) for key, value in self.__dict__.items()}
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.cpu_offload_gb < 0:
raise ValueError("CPU offload space must be non-negative"
f", but got {self.cpu_offload_gb}")
@ -1556,6 +1556,8 @@ class CacheConfig:
"GPU memory utilization must be less than 1.0. Got "
f"{self.gpu_memory_utilization}.")
return self
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
@ -1942,15 +1944,14 @@ class ParallelConfig:
if self.distributed_executor_backend is None and self.world_size == 1:
self.distributed_executor_backend = "uni"
self._verify_args()
@property
def use_ray(self) -> bool:
return self.distributed_executor_backend == "ray" or (
isinstance(self.distributed_executor_backend, type)
and self.distributed_executor_backend.uses_ray)
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
# Lazy import to avoid circular import
from vllm.executor.executor_base import ExecutorBase
from vllm.platforms import current_platform
@ -1977,8 +1978,7 @@ class ParallelConfig:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")
assert isinstance(self.worker_extension_cls, str), (
"worker_extension_cls must be a string (qualified class name).")
return self
PreemptionMode = Literal["swap", "recompute"]
@ -2202,9 +2202,8 @@ class SchedulerConfig:
self.max_num_partial_prefills, self.max_long_partial_prefills,
self.long_prefill_token_threshold)
self._verify_args()
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if (self.max_num_batched_tokens < self.max_model_len
and not self.chunked_prefill_enabled):
raise ValueError(
@ -2263,6 +2262,8 @@ class SchedulerConfig:
"must be greater than or equal to 1 and less than or equal to "
f"max_num_partial_prefills ({self.max_num_partial_prefills}).")
return self
@property
def is_multi_step(self) -> bool:
return self.num_scheduler_steps > 1
@ -2669,8 +2670,6 @@ class SpeculativeConfig:
if self.posterior_alpha is None:
self.posterior_alpha = 0.3
self._verify_args()
@staticmethod
def _maybe_override_draft_max_model_len(
speculative_max_model_len: Optional[int],
@ -2761,7 +2760,8 @@ class SpeculativeConfig:
return draft_parallel_config
def _verify_args(self) -> None:
@model_validator(mode='after')
def _verify_args(self) -> Self:
if self.num_speculative_tokens is None:
raise ValueError(
"num_speculative_tokens must be provided with "
@ -2812,6 +2812,8 @@ class SpeculativeConfig:
"Eagle3 is only supported for Llama models. "
f"Got {self.target_model_config.hf_text_config.model_type=}")
return self
@property
def num_lookahead_slots(self) -> int:
"""The number of additional slots the scheduler should allocate per

View File

@ -3,7 +3,9 @@
# yapf: disable
import argparse
import copy
import dataclasses
import functools
import json
import sys
import threading
@ -168,7 +170,8 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]:
return type_hints
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
@functools.lru_cache(maxsize=30)
def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
cls_docs = get_attr_docs(cls)
kwargs = {}
for field in fields(cls):
@ -269,6 +272,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
return kwargs
def get_kwargs(cls: ConfigType) -> dict[str, Any]:
"""Return argparse kwargs for the given Config dataclass.
The heavy computation is cached via functools.lru_cache, and a deep copy
is returned so callers can mutate the dictionary without affecting the
cached version.
"""
return copy.deepcopy(_compute_kwargs(cls))
@dataclass
class EngineArgs:
"""Arguments for vLLM engine."""

View File

@ -1,10 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import typing
from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class BenchmarkSubcommand(CLISubcommand):
@ -23,7 +29,6 @@ class BenchmarkSubcommand(CLISubcommand):
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
bench_parser = subparsers.add_parser(
self.name,
help=self.help,

View File

@ -1,19 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import typing
from vllm.collect_env import main as collect_env_main
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class CollectEnvSubcommand(CLISubcommand):
"""The `collect-env` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "collect-env"
super().__init__()
name = "collect-env"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
@ -23,12 +25,11 @@ class CollectEnvSubcommand(CLISubcommand):
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
collect_env_parser = subparsers.add_parser(
return subparsers.add_parser(
"collect-env",
help="Start collecting environment information.",
description="Start collecting environment information.",
usage="vllm collect-env")
return collect_env_parser
def cmd_init() -> list[CLISubcommand]:

View File

@ -1,27 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
'''The CLI entrypoints of vLLM
# The CLI entrypoint to vLLM.
Note that all future modules must be lazily loaded within main
to avoid certain eager import breakage.'''
from __future__ import annotations
import importlib.metadata
import signal
import sys
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
import vllm.version
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
from vllm.utils import FlexibleArgumentParser
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,
]
def register_signal_handlers():
@ -33,16 +21,34 @@ def register_signal_handlers():
def main():
import vllm.entrypoints.cli.benchmark.main
import vllm.entrypoints.cli.collect_env
import vllm.entrypoints.cli.openai
import vllm.entrypoints.cli.run_batch
import vllm.entrypoints.cli.serve
from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup
from vllm.utils import FlexibleArgumentParser
CMD_MODULES = [
vllm.entrypoints.cli.openai,
vllm.entrypoints.cli.serve,
vllm.entrypoints.cli.benchmark.main,
vllm.entrypoints.cli.collect_env,
vllm.entrypoints.cli.run_batch,
]
cli_env_setup()
parser = FlexibleArgumentParser(
description="vLLM CLI",
epilog=VLLM_SUBCMD_PARSER_EPILOG,
)
parser.add_argument('-v',
'--version',
action='version',
version=vllm.version.__version__)
parser.add_argument(
'-v',
'--version',
action='version',
version=importlib.metadata.version('vllm'),
)
subparsers = parser.add_subparsers(required=False, dest="subparser")
cmds = {}
for cmd_module in CMD_MODULES:

View File

@ -1,18 +1,21 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Commands that act as an interactive OpenAI API client
from __future__ import annotations
import argparse
import os
import signal
import sys
from typing import Optional
from typing import TYPE_CHECKING
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.utils import FlexibleArgumentParser
if TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
def _register_signal_handlers():
@ -42,8 +45,7 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]:
return model_name, openai_client
def chat(system_prompt: Optional[str], model_name: str,
client: OpenAI) -> None:
def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None:
conversation: list[ChatCompletionMessageParam] = []
if system_prompt is not None:
conversation.append({"role": "system", "content": system_prompt})
@ -92,10 +94,7 @@ def _add_query_options(
class ChatCommand(CLISubcommand):
"""The `chat` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "chat"
super().__init__()
name = "chat"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
@ -157,10 +156,7 @@ class ChatCommand(CLISubcommand):
class CompleteCommand(CLISubcommand):
"""The `complete` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "complete"
super().__init__()
name = 'complete'
@staticmethod
def cmd(args: argparse.Namespace) -> None:

View File

@ -1,37 +1,42 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from __future__ import annotations
import argparse
import asyncio
from prometheus_client import start_http_server
import importlib.metadata
import typing
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.logger import logger
from vllm.entrypoints.openai.run_batch import main as run_batch_main
from vllm.entrypoints.openai.run_batch import make_arg_parser
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
show_filtered_argument_or_group_from_help)
from vllm.utils import FlexibleArgumentParser
from vllm.version import __version__ as VLLM_VERSION
from vllm.logger import init_logger
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
logger = init_logger(__name__)
class RunBatchSubcommand(CLISubcommand):
"""The `run-batch` subcommand for vLLM CLI."""
def __init__(self):
self.name = "run-batch"
super().__init__()
name = "run-batch"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
logger.info("vLLM batch processing API version %s", VLLM_VERSION)
from vllm.entrypoints.openai.run_batch import main as run_batch_main
logger.info("vLLM batch processing API version %s",
importlib.metadata.version("vllm"))
logger.info("args: %s", args)
# Start the Prometheus metrics server.
# LLMEngine uses the Prometheus client
# to publish metrics at the /metrics endpoint.
if args.enable_metrics:
from prometheus_client import start_http_server
logger.info("Prometheus metrics enabled")
start_http_server(port=args.port, addr=args.url)
else:
@ -42,6 +47,8 @@ class RunBatchSubcommand(CLISubcommand):
def subparser_init(
self,
subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser:
from vllm.entrypoints.openai.run_batch import make_arg_parser
run_batch_parser = subparsers.add_parser(
"run-batch",
help="Run batch prompts and write results to file.",

View File

@ -9,8 +9,8 @@ import sys
import uvloop
import zmq
import vllm
import vllm.envs as envs
from vllm import AsyncEngineArgs
from vllm.entrypoints.cli.types import CLISubcommand
from vllm.entrypoints.openai.api_server import (run_server, run_server_worker,
setup_server)
@ -38,10 +38,7 @@ logger = init_logger(__name__)
class ServeSubcommand(CLISubcommand):
"""The `serve` subcommand for the vLLM CLI. """
def __init__(self):
self.name = "serve"
super().__init__()
name = "serve"
@staticmethod
def cmd(args: argparse.Namespace) -> None:
@ -115,7 +112,7 @@ def run_headless(args: argparse.Namespace):
raise ValueError("api_server_count can't be set in headless mode")
# Create the EngineConfig.
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
@ -175,7 +172,7 @@ def run_multi_api_server(args: argparse.Namespace):
listen_address, sock = setup_server(args)
engine_args = AsyncEngineArgs.from_cli_args(args)
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
usage_context = UsageContext.OPENAI_API_SERVER
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
model_config = vllm_config.model_config

View File

@ -1,9 +1,13 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import argparse
from __future__ import annotations
from vllm.utils import FlexibleArgumentParser
import argparse
import typing
if typing.TYPE_CHECKING:
from vllm.utils import FlexibleArgumentParser
class CLISubcommand:

View File

@ -15,7 +15,7 @@ from tqdm import tqdm
from vllm.engine.arg_utils import AsyncEngineArgs, optional_type
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.logger import RequestLogger, logger
from vllm.entrypoints.logger import RequestLogger
# yapf: disable
from vllm.entrypoints.openai.protocol import (BatchRequestInput,
BatchRequestOutput,
@ -29,10 +29,13 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
from vllm.entrypoints.openai.serving_models import (BaseModelPath,
OpenAIServingModels)
from vllm.entrypoints.openai.serving_score import ServingScores
from vllm.logger import init_logger
from vllm.usage.usage_lib import UsageContext
from vllm.utils import FlexibleArgumentParser, random_uuid
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
def make_arg_parser(parser: FlexibleArgumentParser):
parser.add_argument(
@ -201,13 +204,16 @@ async def upload_data(output_url: str, data_or_file: str,
except Exception as e:
if attempt < max_retries:
logger.error(
f"Failed to upload data (attempt {attempt}). "
f"Error message: {str(e)}.\nRetrying in {delay} seconds..."
"Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501
attempt,
e,
delay,
)
await asyncio.sleep(delay)
else:
raise Exception(f"Failed to upload data (attempt {attempt}). "
f"Error message: {str(e)}.") from e
raise Exception(
f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501
) from e
async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput],

View File

@ -67,9 +67,6 @@ from torch.library import Library
from typing_extensions import Never, ParamSpec, TypeIs, assert_never
import vllm.envs as envs
# NOTE: import triton_utils to make TritonPlaceholderModule work
# if triton is unavailable
import vllm.triton_utils # noqa: F401
from vllm.logger import enable_trace_function_call, init_logger
if TYPE_CHECKING: