diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7534ae5590..e62b623b4e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -115,6 +115,11 @@ repos: entry: python tools/check_spdx_header.py language: python types: [python] + - id: check-root-lazy-imports + name: Check root lazy imports + entry: python tools/check_init_lazy_imports.py + language: python + types: [python] - id: check-filenames name: Check for spaces in all filenames entry: bash diff --git a/tools/check_init_lazy_imports.py b/tools/check_init_lazy_imports.py new file mode 100644 index 0000000000..e8e6f07cc3 --- /dev/null +++ b/tools/check_init_lazy_imports.py @@ -0,0 +1,108 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Ensure we perform lazy loading in vllm/__init__.py. +i.e: appears only within the ``if typing.TYPE_CHECKING:`` guard, +**except** for a short whitelist. +""" + +from __future__ import annotations + +import ast +import pathlib +import sys +from collections.abc import Iterable +from typing import Final + +REPO_ROOT: Final = pathlib.Path(__file__).resolve().parent.parent +INIT_PATH: Final = REPO_ROOT / "vllm" / "__init__.py" + +# If you need to add items to whitelist, do it here. +ALLOWED_IMPORTS: Final[frozenset[str]] = frozenset({ + "vllm.env_override", +}) +ALLOWED_FROM_MODULES: Final[frozenset[str]] = frozenset({ + ".version", +}) + + +def _is_internal(name: str | None, *, level: int = 0) -> bool: + if level > 0: + return True + if name is None: + return False + return name.startswith("vllm.") or name == "vllm" + + +def _fail(violations: Iterable[tuple[int, str]]) -> None: + print("ERROR: Disallowed eager imports in vllm/__init__.py:\n", + file=sys.stderr) + for lineno, msg in violations: + print(f" Line {lineno}: {msg}", file=sys.stderr) + sys.exit(1) + + +def main() -> None: + source = INIT_PATH.read_text(encoding="utf-8") + tree = ast.parse(source, filename=str(INIT_PATH)) + + violations: list[tuple[int, str]] = [] + + class Visitor(ast.NodeVisitor): + + def __init__(self) -> None: + super().__init__() + self._in_type_checking = False + + def visit_If(self, node: ast.If) -> None: + guard_is_type_checking = False + test = node.test + if isinstance(test, ast.Attribute) and isinstance( + test.value, ast.Name): + guard_is_type_checking = (test.value.id == "typing" + and test.attr == "TYPE_CHECKING") + elif isinstance(test, ast.Name): + guard_is_type_checking = test.id == "TYPE_CHECKING" + + if guard_is_type_checking: + prev = self._in_type_checking + self._in_type_checking = True + for child in node.body: + self.visit(child) + self._in_type_checking = prev + for child in node.orelse: + self.visit(child) + else: + self.generic_visit(node) + + def visit_Import(self, node: ast.Import) -> None: + if self._in_type_checking: + return + for alias in node.names: + module_name = alias.name + if _is_internal( + module_name) and module_name not in ALLOWED_IMPORTS: + violations.append(( + node.lineno, + f"import '{module_name}' must be inside typing.TYPE_CHECKING", # noqa: E501 + )) + + def visit_ImportFrom(self, node: ast.ImportFrom) -> None: + if self._in_type_checking: + return + module_as_written = ("." * node.level) + (node.module or "") + if _is_internal( + node.module, level=node.level + ) and module_as_written not in ALLOWED_FROM_MODULES: + violations.append(( + node.lineno, + f"from '{module_as_written}' import ... must be inside typing.TYPE_CHECKING", # noqa: E501 + )) + + Visitor().visit(tree) + + if violations: + _fail(violations) + + +if __name__ == "__main__": + main() diff --git a/vllm/__init__.py b/vllm/__init__.py index 6232b657e8..7b90fd3a24 100644 --- a/vllm/__init__.py +++ b/vllm/__init__.py @@ -1,29 +1,72 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """vLLM: a high-throughput and memory-efficient inference engine for LLMs""" + # The version.py should be independent library, and we always import the # version library first. Such assumption is critical for some customization. from .version import __version__, __version_tuple__ # isort:skip +import typing + # The environment variables override should be imported before any other # modules to ensure that the environment variables are set before any # other modules are imported. -import vllm.env_override # isort:skip # noqa: F401 +import vllm.env_override # noqa: F401 + +MODULE_ATTRS = { + "AsyncEngineArgs": ".engine.arg_utils:AsyncEngineArgs", + "EngineArgs": ".engine.arg_utils:EngineArgs", + "AsyncLLMEngine": ".engine.async_llm_engine:AsyncLLMEngine", + "LLMEngine": ".engine.llm_engine:LLMEngine", + "LLM": ".entrypoints.llm:LLM", + "initialize_ray_cluster": ".executor.ray_utils:initialize_ray_cluster", + "PromptType": ".inputs:PromptType", + "TextPrompt": ".inputs:TextPrompt", + "TokensPrompt": ".inputs:TokensPrompt", + "ModelRegistry": ".model_executor.models:ModelRegistry", + "SamplingParams": ".sampling_params:SamplingParams", + "PoolingParams": ".pooling_params:PoolingParams", + "ClassificationOutput": ".outputs:ClassificationOutput", + "ClassificationRequestOutput": ".outputs:ClassificationRequestOutput", + "CompletionOutput": ".outputs:CompletionOutput", + "EmbeddingOutput": ".outputs:EmbeddingOutput", + "EmbeddingRequestOutput": ".outputs:EmbeddingRequestOutput", + "PoolingOutput": ".outputs:PoolingOutput", + "PoolingRequestOutput": ".outputs:PoolingRequestOutput", + "RequestOutput": ".outputs:RequestOutput", + "ScoringOutput": ".outputs:ScoringOutput", + "ScoringRequestOutput": ".outputs:ScoringRequestOutput", +} + +if typing.TYPE_CHECKING: + from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs + from vllm.engine.async_llm_engine import AsyncLLMEngine + from vllm.engine.llm_engine import LLMEngine + from vllm.entrypoints.llm import LLM + from vllm.executor.ray_utils import initialize_ray_cluster + from vllm.inputs import PromptType, TextPrompt, TokensPrompt + from vllm.model_executor.models import ModelRegistry + from vllm.outputs import (ClassificationOutput, + ClassificationRequestOutput, CompletionOutput, + EmbeddingOutput, EmbeddingRequestOutput, + PoolingOutput, PoolingRequestOutput, + RequestOutput, ScoringOutput, + ScoringRequestOutput) + from vllm.pooling_params import PoolingParams + from vllm.sampling_params import SamplingParams +else: + + def __getattr__(name: str) -> typing.Any: + from importlib import import_module + + if name in MODULE_ATTRS: + module_name, attr_name = MODULE_ATTRS[name].split(":") + module = import_module(module_name, __package__) + return getattr(module, attr_name) + else: + raise AttributeError( + f'module {__package__} has no attribute {name}') -from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.engine.llm_engine import LLMEngine -from vllm.entrypoints.llm import LLM -from vllm.executor.ray_utils import initialize_ray_cluster -from vllm.inputs import PromptType, TextPrompt, TokensPrompt -from vllm.model_executor.models import ModelRegistry -from vllm.outputs import (ClassificationOutput, ClassificationRequestOutput, - CompletionOutput, EmbeddingOutput, - EmbeddingRequestOutput, PoolingOutput, - PoolingRequestOutput, RequestOutput, ScoringOutput, - ScoringRequestOutput) -from vllm.pooling_params import PoolingParams -from vllm.sampling_params import SamplingParams __all__ = [ "__version__", diff --git a/vllm/config.py b/vllm/config.py index b8232aae70..7549c97b4f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -28,7 +28,7 @@ from pydantic.dataclasses import dataclass from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.distributed import ProcessGroup, ReduceOp from transformers import PretrainedConfig -from typing_extensions import deprecated, runtime_checkable +from typing_extensions import Self, deprecated, runtime_checkable import vllm.envs as envs from vllm import version @@ -1537,7 +1537,6 @@ class CacheConfig: def __post_init__(self) -> None: self.swap_space_bytes = self.swap_space * GiB_bytes - self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() @@ -1546,7 +1545,8 @@ class CacheConfig: # metrics info return {key: str(value) for key, value in self.__dict__.items()} - def _verify_args(self) -> None: + @model_validator(mode='after') + def _verify_args(self) -> Self: if self.cpu_offload_gb < 0: raise ValueError("CPU offload space must be non-negative" f", but got {self.cpu_offload_gb}") @@ -1556,6 +1556,8 @@ class CacheConfig: "GPU memory utilization must be less than 1.0. Got " f"{self.gpu_memory_utilization}.") + return self + def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass @@ -1942,15 +1944,14 @@ class ParallelConfig: if self.distributed_executor_backend is None and self.world_size == 1: self.distributed_executor_backend = "uni" - self._verify_args() - @property def use_ray(self) -> bool: return self.distributed_executor_backend == "ray" or ( isinstance(self.distributed_executor_backend, type) and self.distributed_executor_backend.uses_ray) - def _verify_args(self) -> None: + @model_validator(mode='after') + def _verify_args(self) -> Self: # Lazy import to avoid circular import from vllm.executor.executor_base import ExecutorBase from vllm.platforms import current_platform @@ -1977,8 +1978,7 @@ class ParallelConfig: raise ValueError("Unable to use nsight profiling unless workers " "run with Ray.") - assert isinstance(self.worker_extension_cls, str), ( - "worker_extension_cls must be a string (qualified class name).") + return self PreemptionMode = Literal["swap", "recompute"] @@ -2202,9 +2202,8 @@ class SchedulerConfig: self.max_num_partial_prefills, self.max_long_partial_prefills, self.long_prefill_token_threshold) - self._verify_args() - - def _verify_args(self) -> None: + @model_validator(mode='after') + def _verify_args(self) -> Self: if (self.max_num_batched_tokens < self.max_model_len and not self.chunked_prefill_enabled): raise ValueError( @@ -2263,6 +2262,8 @@ class SchedulerConfig: "must be greater than or equal to 1 and less than or equal to " f"max_num_partial_prefills ({self.max_num_partial_prefills}).") + return self + @property def is_multi_step(self) -> bool: return self.num_scheduler_steps > 1 @@ -2669,8 +2670,6 @@ class SpeculativeConfig: if self.posterior_alpha is None: self.posterior_alpha = 0.3 - self._verify_args() - @staticmethod def _maybe_override_draft_max_model_len( speculative_max_model_len: Optional[int], @@ -2761,7 +2760,8 @@ class SpeculativeConfig: return draft_parallel_config - def _verify_args(self) -> None: + @model_validator(mode='after') + def _verify_args(self) -> Self: if self.num_speculative_tokens is None: raise ValueError( "num_speculative_tokens must be provided with " @@ -2812,6 +2812,8 @@ class SpeculativeConfig: "Eagle3 is only supported for Llama models. " f"Got {self.target_model_config.hf_text_config.model_type=}") + return self + @property def num_lookahead_slots(self) -> int: """The number of additional slots the scheduler should allocate per diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index bffc8ba8c9..dd09f51490 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -3,7 +3,9 @@ # yapf: disable import argparse +import copy import dataclasses +import functools import json import sys import threading @@ -168,7 +170,8 @@ def get_type_hints(type_hint: TypeHint) -> set[TypeHint]: return type_hints -def get_kwargs(cls: ConfigType) -> dict[str, Any]: +@functools.lru_cache(maxsize=30) +def _compute_kwargs(cls: ConfigType) -> dict[str, Any]: cls_docs = get_attr_docs(cls) kwargs = {} for field in fields(cls): @@ -269,6 +272,16 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]: return kwargs +def get_kwargs(cls: ConfigType) -> dict[str, Any]: + """Return argparse kwargs for the given Config dataclass. + + The heavy computation is cached via functools.lru_cache, and a deep copy + is returned so callers can mutate the dictionary without affecting the + cached version. + """ + return copy.deepcopy(_compute_kwargs(cls)) + + @dataclass class EngineArgs: """Arguments for vLLM engine.""" diff --git a/vllm/entrypoints/cli/benchmark/main.py b/vllm/entrypoints/cli/benchmark/main.py index fdc5a047f6..8904a2468b 100644 --- a/vllm/entrypoints/cli/benchmark/main.py +++ b/vllm/entrypoints/cli/benchmark/main.py @@ -1,10 +1,16 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from __future__ import annotations + import argparse +import typing from vllm.entrypoints.cli.benchmark.base import BenchmarkSubcommandBase from vllm.entrypoints.cli.types import CLISubcommand -from vllm.utils import FlexibleArgumentParser + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser class BenchmarkSubcommand(CLISubcommand): @@ -23,7 +29,6 @@ class BenchmarkSubcommand(CLISubcommand): def subparser_init( self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - bench_parser = subparsers.add_parser( self.name, help=self.help, diff --git a/vllm/entrypoints/cli/collect_env.py b/vllm/entrypoints/cli/collect_env.py index 141aafdb1a..785c18812a 100644 --- a/vllm/entrypoints/cli/collect_env.py +++ b/vllm/entrypoints/cli/collect_env.py @@ -1,19 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import argparse +import typing from vllm.collect_env import main as collect_env_main from vllm.entrypoints.cli.types import CLISubcommand -from vllm.utils import FlexibleArgumentParser + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser class CollectEnvSubcommand(CLISubcommand): """The `collect-env` subcommand for the vLLM CLI. """ - - def __init__(self): - self.name = "collect-env" - super().__init__() + name = "collect-env" @staticmethod def cmd(args: argparse.Namespace) -> None: @@ -23,12 +25,11 @@ class CollectEnvSubcommand(CLISubcommand): def subparser_init( self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: - collect_env_parser = subparsers.add_parser( + return subparsers.add_parser( "collect-env", help="Start collecting environment information.", description="Start collecting environment information.", usage="vllm collect-env") - return collect_env_parser def cmd_init() -> list[CLISubcommand]: diff --git a/vllm/entrypoints/cli/main.py b/vllm/entrypoints/cli/main.py index 9bb1162e38..3e09d45b2e 100644 --- a/vllm/entrypoints/cli/main.py +++ b/vllm/entrypoints/cli/main.py @@ -1,27 +1,15 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +'''The CLI entrypoints of vLLM -# The CLI entrypoint to vLLM. +Note that all future modules must be lazily loaded within main +to avoid certain eager import breakage.''' +from __future__ import annotations + +import importlib.metadata import signal import sys -import vllm.entrypoints.cli.benchmark.main -import vllm.entrypoints.cli.collect_env -import vllm.entrypoints.cli.openai -import vllm.entrypoints.cli.run_batch -import vllm.entrypoints.cli.serve -import vllm.version -from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup -from vllm.utils import FlexibleArgumentParser - -CMD_MODULES = [ - vllm.entrypoints.cli.openai, - vllm.entrypoints.cli.serve, - vllm.entrypoints.cli.benchmark.main, - vllm.entrypoints.cli.collect_env, - vllm.entrypoints.cli.run_batch, -] - def register_signal_handlers(): @@ -33,16 +21,34 @@ def register_signal_handlers(): def main(): + import vllm.entrypoints.cli.benchmark.main + import vllm.entrypoints.cli.collect_env + import vllm.entrypoints.cli.openai + import vllm.entrypoints.cli.run_batch + import vllm.entrypoints.cli.serve + from vllm.entrypoints.utils import VLLM_SUBCMD_PARSER_EPILOG, cli_env_setup + from vllm.utils import FlexibleArgumentParser + + CMD_MODULES = [ + vllm.entrypoints.cli.openai, + vllm.entrypoints.cli.serve, + vllm.entrypoints.cli.benchmark.main, + vllm.entrypoints.cli.collect_env, + vllm.entrypoints.cli.run_batch, + ] + cli_env_setup() parser = FlexibleArgumentParser( description="vLLM CLI", epilog=VLLM_SUBCMD_PARSER_EPILOG, ) - parser.add_argument('-v', - '--version', - action='version', - version=vllm.version.__version__) + parser.add_argument( + '-v', + '--version', + action='version', + version=importlib.metadata.version('vllm'), + ) subparsers = parser.add_subparsers(required=False, dest="subparser") cmds = {} for cmd_module in CMD_MODULES: diff --git a/vllm/entrypoints/cli/openai.py b/vllm/entrypoints/cli/openai.py index 58dcdfe217..5ddaee5b52 100644 --- a/vllm/entrypoints/cli/openai.py +++ b/vllm/entrypoints/cli/openai.py @@ -1,18 +1,21 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# Commands that act as an interactive OpenAI API client + +from __future__ import annotations import argparse import os import signal import sys -from typing import Optional +from typing import TYPE_CHECKING from openai import OpenAI from openai.types.chat import ChatCompletionMessageParam from vllm.entrypoints.cli.types import CLISubcommand -from vllm.utils import FlexibleArgumentParser + +if TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser def _register_signal_handlers(): @@ -42,8 +45,7 @@ def _interactive_cli(args: argparse.Namespace) -> tuple[str, OpenAI]: return model_name, openai_client -def chat(system_prompt: Optional[str], model_name: str, - client: OpenAI) -> None: +def chat(system_prompt: str | None, model_name: str, client: OpenAI) -> None: conversation: list[ChatCompletionMessageParam] = [] if system_prompt is not None: conversation.append({"role": "system", "content": system_prompt}) @@ -92,10 +94,7 @@ def _add_query_options( class ChatCommand(CLISubcommand): """The `chat` subcommand for the vLLM CLI. """ - - def __init__(self): - self.name = "chat" - super().__init__() + name = "chat" @staticmethod def cmd(args: argparse.Namespace) -> None: @@ -157,10 +156,7 @@ class ChatCommand(CLISubcommand): class CompleteCommand(CLISubcommand): """The `complete` subcommand for the vLLM CLI. """ - - def __init__(self): - self.name = "complete" - super().__init__() + name = 'complete' @staticmethod def cmd(args: argparse.Namespace) -> None: diff --git a/vllm/entrypoints/cli/run_batch.py b/vllm/entrypoints/cli/run_batch.py index 6bdd3b63c2..61a34cbc39 100644 --- a/vllm/entrypoints/cli/run_batch.py +++ b/vllm/entrypoints/cli/run_batch.py @@ -1,37 +1,42 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from __future__ import annotations + import argparse import asyncio - -from prometheus_client import start_http_server +import importlib.metadata +import typing from vllm.entrypoints.cli.types import CLISubcommand -from vllm.entrypoints.logger import logger -from vllm.entrypoints.openai.run_batch import main as run_batch_main -from vllm.entrypoints.openai.run_batch import make_arg_parser from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG, show_filtered_argument_or_group_from_help) -from vllm.utils import FlexibleArgumentParser -from vllm.version import __version__ as VLLM_VERSION +from vllm.logger import init_logger + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser + +logger = init_logger(__name__) class RunBatchSubcommand(CLISubcommand): """The `run-batch` subcommand for vLLM CLI.""" - - def __init__(self): - self.name = "run-batch" - super().__init__() + name = "run-batch" @staticmethod def cmd(args: argparse.Namespace) -> None: - logger.info("vLLM batch processing API version %s", VLLM_VERSION) + from vllm.entrypoints.openai.run_batch import main as run_batch_main + + logger.info("vLLM batch processing API version %s", + importlib.metadata.version("vllm")) logger.info("args: %s", args) # Start the Prometheus metrics server. # LLMEngine uses the Prometheus client # to publish metrics at the /metrics endpoint. if args.enable_metrics: + from prometheus_client import start_http_server + logger.info("Prometheus metrics enabled") start_http_server(port=args.port, addr=args.url) else: @@ -42,6 +47,8 @@ class RunBatchSubcommand(CLISubcommand): def subparser_init( self, subparsers: argparse._SubParsersAction) -> FlexibleArgumentParser: + from vllm.entrypoints.openai.run_batch import make_arg_parser + run_batch_parser = subparsers.add_parser( "run-batch", help="Run batch prompts and write results to file.", diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 9040877a42..897c222a3f 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -9,8 +9,8 @@ import sys import uvloop import zmq +import vllm import vllm.envs as envs -from vllm import AsyncEngineArgs from vllm.entrypoints.cli.types import CLISubcommand from vllm.entrypoints.openai.api_server import (run_server, run_server_worker, setup_server) @@ -38,10 +38,7 @@ logger = init_logger(__name__) class ServeSubcommand(CLISubcommand): """The `serve` subcommand for the vLLM CLI. """ - - def __init__(self): - self.name = "serve" - super().__init__() + name = "serve" @staticmethod def cmd(args: argparse.Namespace) -> None: @@ -115,7 +112,7 @@ def run_headless(args: argparse.Namespace): raise ValueError("api_server_count can't be set in headless mode") # Create the EngineConfig. - engine_args = AsyncEngineArgs.from_cli_args(args) + engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) @@ -175,7 +172,7 @@ def run_multi_api_server(args: argparse.Namespace): listen_address, sock = setup_server(args) - engine_args = AsyncEngineArgs.from_cli_args(args) + engine_args = vllm.AsyncEngineArgs.from_cli_args(args) usage_context = UsageContext.OPENAI_API_SERVER vllm_config = engine_args.create_engine_config(usage_context=usage_context) model_config = vllm_config.model_config diff --git a/vllm/entrypoints/cli/types.py b/vllm/entrypoints/cli/types.py index 0a72443129..b88f094b30 100644 --- a/vllm/entrypoints/cli/types.py +++ b/vllm/entrypoints/cli/types.py @@ -1,9 +1,13 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse +from __future__ import annotations -from vllm.utils import FlexibleArgumentParser +import argparse +import typing + +if typing.TYPE_CHECKING: + from vllm.utils import FlexibleArgumentParser class CLISubcommand: diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 9994b3cae8..29740fc7e6 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -15,7 +15,7 @@ from tqdm import tqdm from vllm.engine.arg_utils import AsyncEngineArgs, optional_type from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.logger import RequestLogger, logger +from vllm.entrypoints.logger import RequestLogger # yapf: disable from vllm.entrypoints.openai.protocol import (BatchRequestInput, BatchRequestOutput, @@ -29,10 +29,13 @@ from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_models import (BaseModelPath, OpenAIServingModels) from vllm.entrypoints.openai.serving_score import ServingScores +from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, random_uuid from vllm.version import __version__ as VLLM_VERSION +logger = init_logger(__name__) + def make_arg_parser(parser: FlexibleArgumentParser): parser.add_argument( @@ -201,13 +204,16 @@ async def upload_data(output_url: str, data_or_file: str, except Exception as e: if attempt < max_retries: logger.error( - f"Failed to upload data (attempt {attempt}). " - f"Error message: {str(e)}.\nRetrying in {delay} seconds..." + "Failed to upload data (attempt %d). Error message: %s.\nRetrying in %d seconds...", # noqa: E501 + attempt, + e, + delay, ) await asyncio.sleep(delay) else: - raise Exception(f"Failed to upload data (attempt {attempt}). " - f"Error message: {str(e)}.") from e + raise Exception( + f"Failed to upload data (attempt {attempt}). Error message: {str(e)}." # noqa: E501 + ) from e async def write_file(path_or_url: str, batch_outputs: list[BatchRequestOutput], diff --git a/vllm/utils.py b/vllm/utils.py index dc408e1676..34be4d52c4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -67,9 +67,6 @@ from torch.library import Library from typing_extensions import Never, ParamSpec, TypeIs, assert_never import vllm.envs as envs -# NOTE: import triton_utils to make TritonPlaceholderModule work -# if triton is unavailable -import vllm.triton_utils # noqa: F401 from vllm.logger import enable_trace_function_call, init_logger if TYPE_CHECKING: