mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Fix clean shutdown issues (#8492)
This commit is contained in:
@ -26,6 +26,11 @@ class RequestOutput:
|
||||
finished: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockModelConfig:
|
||||
use_async_output_proc = True
|
||||
|
||||
|
||||
class MockEngine:
|
||||
|
||||
def __init__(self):
|
||||
@ -35,6 +40,7 @@ class MockEngine:
|
||||
self.request_id = None
|
||||
# Ugly, remove dependency when possible
|
||||
self.parallel_config = ParallelConfig(1, 1, False)
|
||||
self.model_config = MockModelConfig()
|
||||
|
||||
async def step_async(self, virtual_engine):
|
||||
# PP size is 1, ignore virtual engine
|
||||
@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_new_requests_event():
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=False)
|
||||
engine = MockAsyncLLMEngine()
|
||||
engine.start_background_loop()
|
||||
await asyncio.sleep(0.01)
|
||||
assert engine.engine.step_calls == 0
|
||||
@ -113,7 +119,7 @@ async def test_new_requests_event():
|
||||
assert engine.engine.add_request_calls == 3
|
||||
assert engine.engine.step_calls == old_step_calls + 1
|
||||
|
||||
engine = MockAsyncLLMEngine(worker_use_ray=True)
|
||||
engine = MockAsyncLLMEngine()
|
||||
assert engine.get_model_config() is not None
|
||||
assert engine.get_tokenizer() is not None
|
||||
assert engine.get_decoding_config() is not None
|
||||
|
@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
import weakref
|
||||
from functools import partial
|
||||
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
|
||||
Mapping, Optional, Set, Tuple, Type, Union)
|
||||
from weakref import ReferenceType
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
|
||||
@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import weak_bind
|
||||
|
||||
logger = init_logger(__name__)
|
||||
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
|
||||
@ -450,9 +453,6 @@ class AsyncLLMEngine:
|
||||
method yields the outputs from the :class:`LLMEngine` to the caller.
|
||||
|
||||
Args:
|
||||
worker_use_ray: Whether to use Ray for model workers. Required for
|
||||
distributed execution. Should be the same as
|
||||
`parallel_config.worker_use_ray`.
|
||||
log_requests: Whether to log the requests.
|
||||
start_engine_loop: If True, the background task to run the engine
|
||||
will be automatically started in the generate call.
|
||||
@ -463,23 +463,22 @@ class AsyncLLMEngine:
|
||||
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
|
||||
|
||||
def __init__(self,
|
||||
worker_use_ray: bool,
|
||||
*args,
|
||||
log_requests: bool = True,
|
||||
start_engine_loop: bool = True,
|
||||
**kwargs) -> None:
|
||||
self.worker_use_ray = worker_use_ray
|
||||
self.log_requests = log_requests
|
||||
self.engine = self._engine_class(*args, **kwargs)
|
||||
|
||||
# This ensures quick processing of request outputs
|
||||
# so the append to asyncio queues is not delayed,
|
||||
# especially for multi-step.
|
||||
#
|
||||
self.use_process_request_outputs_callback = True
|
||||
self.use_process_request_outputs_callback = (
|
||||
self.engine.model_config.use_async_output_proc)
|
||||
|
||||
if self.use_process_request_outputs_callback:
|
||||
self.engine.process_request_outputs_callback = \
|
||||
self.process_request_outputs
|
||||
weak_bind(self.process_request_outputs)
|
||||
|
||||
self.background_loop: Optional[asyncio.Future] = None
|
||||
# We need to keep a reference to unshielded
|
||||
@ -492,6 +491,11 @@ class AsyncLLMEngine:
|
||||
# Lazy initialized fields
|
||||
self._request_tracker: RequestTracker
|
||||
|
||||
def __del__(self):
|
||||
if rt := getattr(self, "request_tracker", None):
|
||||
# Wake up engine loop so that it will exit cleanly
|
||||
rt.new_requests_event.set()
|
||||
|
||||
@classmethod
|
||||
def _get_executor_cls(
|
||||
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
|
||||
@ -502,15 +506,12 @@ class AsyncLLMEngine:
|
||||
raise TypeError(
|
||||
"distributed_executor_backend must be a subclass of "
|
||||
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
|
||||
if distributed_executor_backend.uses_ray: # type: ignore
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
executor_class = distributed_executor_backend
|
||||
elif engine_config.device_config.device_type == "neuron":
|
||||
from vllm.executor.neuron_executor import NeuronExecutorAsync
|
||||
executor_class = NeuronExecutorAsync
|
||||
elif engine_config.device_config.device_type == "tpu":
|
||||
if distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
|
||||
executor_class = RayTPUExecutorAsync
|
||||
else:
|
||||
@ -531,11 +532,9 @@ class AsyncLLMEngine:
|
||||
from vllm.executor.xpu_executor import XPUExecutorAsync
|
||||
executor_class = XPUExecutorAsync
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
|
||||
executor_class = RayXPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.multiproc_xpu_executor import (
|
||||
MultiprocessingXPUExecutorAsync)
|
||||
executor_class = MultiprocessingXPUExecutorAsync
|
||||
@ -543,7 +542,6 @@ class AsyncLLMEngine:
|
||||
raise RuntimeError(
|
||||
"Not supported distributed execution model on XPU device.")
|
||||
elif distributed_executor_backend == "ray":
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
|
||||
executor_class = RayGPUExecutorAsync
|
||||
elif distributed_executor_backend == "mp":
|
||||
@ -559,19 +557,23 @@ class AsyncLLMEngine:
|
||||
def from_engine_args(
|
||||
cls,
|
||||
engine_args: AsyncEngineArgs,
|
||||
engine_config: Optional[EngineConfig] = None,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
engine_config = engine_args.create_engine_config()
|
||||
if engine_config is None:
|
||||
engine_config = engine_args.create_engine_config()
|
||||
|
||||
executor_class = cls._get_executor_cls(engine_config)
|
||||
|
||||
if executor_class.uses_ray:
|
||||
initialize_ray_cluster(engine_config.parallel_config)
|
||||
|
||||
# Create the async LLM engine.
|
||||
engine = cls(
|
||||
executor_class.uses_ray,
|
||||
**engine_config.to_dict(),
|
||||
executor_class=executor_class,
|
||||
log_requests=not engine_args.disable_log_requests,
|
||||
@ -628,7 +630,7 @@ class AsyncLLMEngine:
|
||||
self._request_tracker = RequestTracker()
|
||||
|
||||
self._background_loop_unshielded = asyncio.get_event_loop(
|
||||
).create_task(self.run_engine_loop())
|
||||
).create_task(self.run_engine_loop(weakref.ref(self)))
|
||||
self._background_loop_unshielded.add_done_callback(
|
||||
partial(_log_task_completion, error_callback=self._error_callback))
|
||||
self.background_loop = asyncio.shield(self._background_loop_unshielded)
|
||||
@ -698,9 +700,16 @@ class AsyncLLMEngine:
|
||||
async def _engine_abort(self, request_ids: Iterable[str]):
|
||||
self.engine.abort_request(request_ids)
|
||||
|
||||
async def run_engine_loop(self):
|
||||
@staticmethod
|
||||
async def run_engine_loop(engine_ref: ReferenceType):
|
||||
"""We use a weakref to the engine so that the running loop
|
||||
doesn't prevent the engine being garbage collected."""
|
||||
engine: Optional["AsyncLLMEngine"] = engine_ref()
|
||||
if not engine:
|
||||
return
|
||||
|
||||
pipeline_parallel_size = \
|
||||
self.engine.parallel_config.pipeline_parallel_size
|
||||
engine.engine.parallel_config.pipeline_parallel_size
|
||||
has_requests_in_progress = [False] * pipeline_parallel_size
|
||||
while True:
|
||||
if not any(has_requests_in_progress):
|
||||
@ -711,11 +720,21 @@ class AsyncLLMEngine:
|
||||
# timeout, and unblocks the RPC thread in the workers so that
|
||||
# they can process any other queued control plane messages,
|
||||
# such as add/remove lora adapters.
|
||||
await self.engine.stop_remote_worker_execution_loop_async()
|
||||
await self._request_tracker.wait_for_new_requests()
|
||||
await engine.engine.stop_remote_worker_execution_loop_async()
|
||||
request_tracker = engine._request_tracker
|
||||
# Allow engine to be garbage collected while
|
||||
# waiting for new requests
|
||||
del engine
|
||||
await asyncio.sleep(0)
|
||||
if engine_ref() is None:
|
||||
return
|
||||
await request_tracker.wait_for_new_requests()
|
||||
engine = engine_ref()
|
||||
if not engine:
|
||||
return
|
||||
logger.debug("Got new requests!")
|
||||
requests_in_progress = [
|
||||
asyncio.create_task(self.engine_step(ve))
|
||||
asyncio.create_task(engine.engine_step(ve))
|
||||
for ve in range(pipeline_parallel_size)
|
||||
]
|
||||
has_requests_in_progress = [True] * pipeline_parallel_size
|
||||
@ -733,19 +752,20 @@ class AsyncLLMEngine:
|
||||
result = task.result()
|
||||
virtual_engine = requests_in_progress.index(task)
|
||||
has_unfinished_requests = (
|
||||
self.engine.has_unfinished_requests_for_virtual_engine(
|
||||
engine.engine.
|
||||
has_unfinished_requests_for_virtual_engine(
|
||||
virtual_engine))
|
||||
if result or has_unfinished_requests:
|
||||
requests_in_progress[virtual_engine] = (
|
||||
asyncio.create_task(
|
||||
self.engine_step(virtual_engine)))
|
||||
engine.engine_step(virtual_engine)))
|
||||
has_requests_in_progress[virtual_engine] = True
|
||||
else:
|
||||
has_requests_in_progress[virtual_engine] = False
|
||||
except asyncio.TimeoutError as exc:
|
||||
logger.error(
|
||||
"Engine iteration timed out. This should never happen!")
|
||||
self.set_errored(exc)
|
||||
engine.set_errored(exc)
|
||||
raise
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
@ -1,8 +1,8 @@
|
||||
import functools
|
||||
import time
|
||||
from collections import deque
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
|
||||
Iterable, List, Mapping, NamedTuple, Optional)
|
||||
from typing import Sequence as GenericSequence
|
||||
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
|
||||
BaseTokenizerGroup, init_tokenizer_from_configs)
|
||||
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter, Device
|
||||
from vllm.utils import Counter, Device, weak_bind
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -382,11 +382,16 @@ class LLMEngine:
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
self.async_callbacks = [
|
||||
functools.partial(self._process_model_outputs,
|
||||
ctx=self.scheduler_contexts[v_id])
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
if model_config.use_async_output_proc:
|
||||
process_model_outputs = weak_bind(self._process_model_outputs)
|
||||
|
||||
self.async_callbacks = [
|
||||
partial(process_model_outputs,
|
||||
ctx=self.scheduler_contexts[v_id])
|
||||
for v_id in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
else:
|
||||
self.async_callbacks = []
|
||||
|
||||
# Currently used by AsyncLLMEngine to ensure quick append
|
||||
# of request outputs to asyncio queues
|
||||
@ -869,8 +874,8 @@ class LLMEngine:
|
||||
"""
|
||||
return self.scheduler[virtual_engine].has_unfinished_seqs()
|
||||
|
||||
@staticmethod
|
||||
def _process_sequence_group_outputs(
|
||||
self,
|
||||
seq_group: SequenceGroup,
|
||||
outputs: List[EmbeddingSequenceGroupOutput],
|
||||
) -> None:
|
||||
|
@ -1,21 +1,20 @@
|
||||
import asyncio
|
||||
import signal
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, Response
|
||||
from fastapi import FastAPI, Request, Response
|
||||
|
||||
from vllm import envs
|
||||
from vllm.engine.async_llm_engine import AsyncEngineDeadError
|
||||
from vllm.engine.protocol import AsyncEngineClient
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
||||
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
|
||||
**uvicorn_kwargs: Any):
|
||||
logger.info("Available routes are:")
|
||||
for route in app.routes:
|
||||
@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
||||
|
||||
# Set concurrency limits in uvicorn if running in multiprocessing mode
|
||||
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
|
||||
if engine.limit_concurrency is not None:
|
||||
if limit_concurrency is not None:
|
||||
logger.info(
|
||||
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
|
||||
"limit at the expense of performance run with "
|
||||
"--disable-frontend-multiprocessing", engine.limit_concurrency)
|
||||
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
|
||||
"--disable-frontend-multiprocessing", limit_concurrency)
|
||||
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
|
||||
|
||||
config = uvicorn.Config(app, **uvicorn_kwargs)
|
||||
server = uvicorn.Server(config)
|
||||
_add_shutdown_handlers(app, server, engine)
|
||||
_add_shutdown_handlers(app, server)
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
@ -68,15 +67,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
|
||||
return server.shutdown()
|
||||
|
||||
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
|
||||
engine: AsyncEngineClient) -> None:
|
||||
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
|
||||
"""Adds handlers for fatal errors that should crash the server"""
|
||||
|
||||
@app.exception_handler(RuntimeError)
|
||||
async def runtime_error_handler(_, __):
|
||||
async def runtime_error_handler(request: Request, __):
|
||||
"""On generic runtime error, check to see if the engine has died.
|
||||
It probably has, in which case the server will no longer be able to
|
||||
handle requests. Trigger a graceful shutdown with a SIGTERM."""
|
||||
engine = request.app.state.engine_client
|
||||
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
|
||||
and not engine.is_running):
|
||||
logger.fatal("AsyncLLMEngine has failed, terminating server "
|
||||
|
@ -4,16 +4,20 @@ import inspect
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import signal
|
||||
import tempfile
|
||||
from argparse import Namespace
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import AsyncIterator, Optional, Set
|
||||
|
||||
import uvloop
|
||||
from fastapi import APIRouter, FastAPI, Request
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
from starlette.datastructures import State
|
||||
from starlette.routing import Mount
|
||||
from typing_extensions import assert_never
|
||||
|
||||
@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
async_engine_client: AsyncEngineClient
|
||||
engine_args: AsyncEngineArgs
|
||||
openai_serving_chat: OpenAIServingChat
|
||||
openai_serving_completion: OpenAIServingCompletion
|
||||
openai_serving_embedding: OpenAIServingEmbedding
|
||||
openai_serving_tokenization: OpenAIServingTokenization
|
||||
prometheus_multiproc_dir: tempfile.TemporaryDirectory
|
||||
|
||||
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
|
||||
@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
try:
|
||||
if app.state.log_stats:
|
||||
async_engine_client = app.state.engine_client
|
||||
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
async def _force_log():
|
||||
while True:
|
||||
await asyncio.sleep(10)
|
||||
await async_engine_client.do_log_stats()
|
||||
|
||||
if not engine_args.disable_log_stats:
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
|
||||
yield
|
||||
task = asyncio.create_task(_force_log())
|
||||
_running_tasks.add(task)
|
||||
task.add_done_callback(_running_tasks.remove)
|
||||
else:
|
||||
task = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if task is not None:
|
||||
task.cancel()
|
||||
finally:
|
||||
# Ensure app state including engine ref is gc'd
|
||||
del app.state
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
@ -103,16 +111,10 @@ async def build_async_engine_client(
|
||||
|
||||
# Context manager to handle async_engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
global engine_args
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
|
||||
# Backend itself still global for the silly lil' health handler
|
||||
global async_engine_client
|
||||
|
||||
async with build_async_engine_client_from_engine_args(
|
||||
engine_args, args.disable_frontend_multiprocessing) as engine:
|
||||
|
||||
async_engine_client = engine # type: ignore[assignment]
|
||||
yield engine
|
||||
|
||||
|
||||
@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
|
||||
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
|
||||
engine_args.quantization, engine_args.revision)
|
||||
or disable_frontend_multiprocessing):
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
try:
|
||||
yield engine_client
|
||||
finally:
|
||||
engine_client.shutdown_background_loop()
|
||||
engine_config = engine_args.create_engine_config()
|
||||
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
|
||||
"uses_ray", False)
|
||||
|
||||
build_engine = partial(AsyncLLMEngine.from_engine_args,
|
||||
engine_args=engine_args,
|
||||
engine_config=engine_config,
|
||||
usage_context=UsageContext.OPENAI_API_SERVER)
|
||||
if uses_ray:
|
||||
# Must run in main thread with ray for its signal handlers to work
|
||||
engine_client = build_engine()
|
||||
else:
|
||||
engine_client = await asyncio.get_running_loop().run_in_executor(
|
||||
None, build_engine)
|
||||
|
||||
yield engine_client
|
||||
return
|
||||
|
||||
# Otherwise, use the multiprocessing AsyncLLMEngine.
|
||||
@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
|
||||
app.routes.append(metrics_route)
|
||||
|
||||
|
||||
def chat(request: Request) -> OpenAIServingChat:
|
||||
return request.app.state.openai_serving_chat
|
||||
|
||||
|
||||
def completion(request: Request) -> OpenAIServingCompletion:
|
||||
return request.app.state.openai_serving_completion
|
||||
|
||||
|
||||
def tokenization(request: Request) -> OpenAIServingTokenization:
|
||||
return request.app.state.openai_serving_tokenization
|
||||
|
||||
|
||||
def embedding(request: Request) -> OpenAIServingEmbedding:
|
||||
return request.app.state.openai_serving_embedding
|
||||
|
||||
|
||||
def engine_client(request: Request) -> AsyncEngineClient:
|
||||
return request.app.state.engine_client
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def health() -> Response:
|
||||
async def health(raw_request: Request) -> Response:
|
||||
"""Health check."""
|
||||
await async_engine_client.check_health()
|
||||
await engine_client(raw_request).check_health()
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.post("/tokenize")
|
||||
async def tokenize(request: TokenizeRequest):
|
||||
generator = await openai_serving_tokenization.create_tokenize(request)
|
||||
async def tokenize(request: TokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_tokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
|
||||
|
||||
|
||||
@router.post("/detokenize")
|
||||
async def detokenize(request: DetokenizeRequest):
|
||||
generator = await openai_serving_tokenization.create_detokenize(request)
|
||||
async def detokenize(request: DetokenizeRequest, raw_request: Request):
|
||||
generator = await tokenization(raw_request).create_detokenize(request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
status_code=generator.code)
|
||||
@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
|
||||
|
||||
|
||||
@router.get("/v1/models")
|
||||
async def show_available_models():
|
||||
models = await openai_serving_completion.show_available_models()
|
||||
async def show_available_models(raw_request: Request):
|
||||
models = await completion(raw_request).show_available_models()
|
||||
return JSONResponse(content=models.model_dump())
|
||||
|
||||
|
||||
@ -288,7 +320,7 @@ async def show_version():
|
||||
async def create_chat_completion(request: ChatCompletionRequest,
|
||||
raw_request: Request):
|
||||
|
||||
generator = await openai_serving_chat.create_chat_completion(
|
||||
generator = await chat(raw_request).create_chat_completion(
|
||||
request, raw_request)
|
||||
|
||||
if isinstance(generator, ErrorResponse):
|
||||
@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
|
||||
@router.post("/v1/completions")
|
||||
async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
generator = await openai_serving_completion.create_completion(
|
||||
generator = await completion(raw_request).create_completion(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
|
||||
@router.post("/v1/embeddings")
|
||||
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
|
||||
generator = await openai_serving_embedding.create_embedding(
|
||||
generator = await embedding(raw_request).create_embedding(
|
||||
request, raw_request)
|
||||
if isinstance(generator, ErrorResponse):
|
||||
return JSONResponse(content=generator.model_dump(),
|
||||
@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
|
||||
"used for local development!")
|
||||
|
||||
@router.post("/start_profile")
|
||||
async def start_profile():
|
||||
async def start_profile(raw_request: Request):
|
||||
logger.info("Starting profiler...")
|
||||
await async_engine_client.start_profile()
|
||||
await engine_client(raw_request).start_profile()
|
||||
logger.info("Profiler started.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@router.post("/stop_profile")
|
||||
async def stop_profile():
|
||||
async def stop_profile(raw_request: Request):
|
||||
logger.info("Stopping profiler...")
|
||||
await async_engine_client.stop_profile()
|
||||
await engine_client(raw_request).stop_profile()
|
||||
logger.info("Profiler stopped.")
|
||||
return Response(status_code=200)
|
||||
|
||||
@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
"This should ONLY be used for local development!")
|
||||
|
||||
@router.post("/v1/load_lora_adapter")
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest):
|
||||
response = await openai_serving_chat.load_lora_adapter(request)
|
||||
async def load_lora_adapter(request: LoadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await openai_serving_completion.load_lora_adapter(request)
|
||||
response = await completion(raw_request).load_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
|
||||
return Response(status_code=200, content=response)
|
||||
|
||||
@router.post("/v1/unload_lora_adapter")
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
|
||||
response = await openai_serving_chat.unload_lora_adapter(request)
|
||||
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
|
||||
raw_request: Request):
|
||||
response = await chat(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
|
||||
response = await openai_serving_completion.unload_lora_adapter(request)
|
||||
response = await completion(raw_request).unload_lora_adapter(request)
|
||||
if isinstance(response, ErrorResponse):
|
||||
return JSONResponse(content=response.model_dump(),
|
||||
status_code=response.code)
|
||||
@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
|
||||
@app.exception_handler(RequestValidationError)
|
||||
async def validation_exception_handler(_, exc):
|
||||
err = openai_serving_chat.create_error_response(message=str(exc))
|
||||
chat = app.state.openai_serving_chat
|
||||
err = chat.create_error_response(message=str(exc))
|
||||
return JSONResponse(err.model_dump(),
|
||||
status_code=HTTPStatus.BAD_REQUEST)
|
||||
|
||||
@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
|
||||
return app
|
||||
|
||||
|
||||
async def init_app(
|
||||
def init_app_state(
|
||||
async_engine_client: AsyncEngineClient,
|
||||
model_config: ModelConfig,
|
||||
state: State,
|
||||
args: Namespace,
|
||||
) -> FastAPI:
|
||||
app = build_app(args)
|
||||
|
||||
) -> None:
|
||||
if args.served_model_name is not None:
|
||||
served_model_names = args.served_model_name
|
||||
else:
|
||||
served_model_names = [args.model]
|
||||
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
|
||||
if args.disable_log_requests:
|
||||
request_logger = None
|
||||
else:
|
||||
request_logger = RequestLogger(max_log_len=args.max_log_len)
|
||||
|
||||
global openai_serving_chat
|
||||
global openai_serving_completion
|
||||
global openai_serving_embedding
|
||||
global openai_serving_tokenization
|
||||
state.engine_client = async_engine_client
|
||||
state.log_stats = not args.disable_log_stats
|
||||
|
||||
openai_serving_chat = OpenAIServingChat(
|
||||
state.openai_serving_chat = OpenAIServingChat(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
@ -465,7 +496,7 @@ async def init_app(
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
enable_auto_tools=args.enable_auto_tool_choice,
|
||||
tool_parser=args.tool_call_parser)
|
||||
openai_serving_completion = OpenAIServingCompletion(
|
||||
state.openai_serving_completion = OpenAIServingCompletion(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
@ -474,13 +505,13 @@ async def init_app(
|
||||
request_logger=request_logger,
|
||||
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
|
||||
)
|
||||
openai_serving_embedding = OpenAIServingEmbedding(
|
||||
state.openai_serving_embedding = OpenAIServingEmbedding(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
request_logger=request_logger,
|
||||
)
|
||||
openai_serving_tokenization = OpenAIServingTokenization(
|
||||
state.openai_serving_tokenization = OpenAIServingTokenization(
|
||||
async_engine_client,
|
||||
model_config,
|
||||
served_model_names,
|
||||
@ -488,25 +519,31 @@ async def init_app(
|
||||
request_logger=request_logger,
|
||||
chat_template=args.chat_template,
|
||||
)
|
||||
app.root_path = args.root_path
|
||||
|
||||
return app
|
||||
|
||||
|
||||
async def run_server(args, **uvicorn_kwargs) -> None:
|
||||
logger.info("vLLM API server version %s", VLLM_VERSION)
|
||||
logger.info("args: %s", args)
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
async with build_async_engine_client(args) as async_engine_client:
|
||||
# If None, creation of the client failed and we exit.
|
||||
if async_engine_client is None:
|
||||
return
|
||||
|
||||
app = await init_app(async_engine_client, args)
|
||||
app = build_app(args)
|
||||
|
||||
model_config = await async_engine_client.get_model_config()
|
||||
init_app_state(async_engine_client, model_config, app.state, args)
|
||||
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
engine=async_engine_client,
|
||||
limit_concurrency=async_engine_client.limit_concurrency,
|
||||
host=args.host,
|
||||
port=args.port,
|
||||
log_level=args.uvicorn_log_level,
|
||||
@ -530,4 +567,4 @@ if __name__ == "__main__":
|
||||
parser = make_arg_parser(parser)
|
||||
args = parser.parse_args()
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
uvloop.run(run_server(args))
|
||||
|
@ -46,7 +46,6 @@ class AsyncEngineRPCServer:
|
||||
"""Cleanup all resources."""
|
||||
self.socket.close()
|
||||
self.context.destroy()
|
||||
self.engine.shutdown_background_loop()
|
||||
# Clear the engine reference so that it can be GC'ed.
|
||||
del self.engine
|
||||
|
||||
@ -233,5 +232,12 @@ async def run_server(server: AsyncEngineRPCServer):
|
||||
|
||||
def run_rpc_server(async_engine_args: AsyncEngineArgs,
|
||||
usage_context: UsageContext, rpc_path: str):
|
||||
|
||||
def signal_handler(*_) -> None:
|
||||
# Interrupt server on sigterm while initializing
|
||||
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
|
||||
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
|
||||
uvloop.run(run_server(server))
|
||||
|
@ -1,8 +1,5 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import threading
|
||||
import weakref
|
||||
from functools import partial
|
||||
from typing import Any, List, Optional
|
||||
|
||||
@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
|
||||
# Set up signal handlers to shutdown the executor cleanly
|
||||
# sometimes gc does not work well
|
||||
|
||||
# Use weakref to avoid holding a reference to self
|
||||
ref = weakref.ref(self)
|
||||
|
||||
def shutdown(signum, frame):
|
||||
if executor := ref():
|
||||
executor.shutdown()
|
||||
|
||||
if threading.current_thread() is threading.main_thread():
|
||||
signal.signal(signal.SIGINT, shutdown)
|
||||
signal.signal(signal.SIGTERM, shutdown)
|
||||
|
||||
self.driver_worker = self._create_worker(
|
||||
distributed_init_method=distributed_init_method)
|
||||
self._run_workers("init_device")
|
||||
|
@ -120,7 +120,8 @@ class WorkerMonitor(threading.Thread):
|
||||
logger.error("Worker %s pid %s died, exit code: %s",
|
||||
process.name, process.pid, process.exitcode)
|
||||
# Cleanup any remaining workers
|
||||
logger.info("Killing local vLLM worker processes")
|
||||
if logger:
|
||||
logger.info("Killing local vLLM worker processes")
|
||||
for worker in self.workers:
|
||||
worker.kill_worker()
|
||||
# Must be done after worker task queues are all closed
|
||||
@ -221,6 +222,8 @@ def _run_worker_process(
|
||||
try:
|
||||
executor = getattr(worker, method)
|
||||
output = executor(*args, **kwargs)
|
||||
except KeyboardInterrupt:
|
||||
break
|
||||
except BaseException as e:
|
||||
tb = traceback.format_exc()
|
||||
logger.error(
|
||||
|
@ -26,6 +26,8 @@ logger = init_logger(__name__)
|
||||
|
||||
class RayTPUExecutor(TPUExecutor):
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# This is non-None when the execute model loop is running
|
||||
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.
|
||||
|
@ -1,11 +1,11 @@
|
||||
# The CLI entrypoint to vLLM.
|
||||
import argparse
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import uvloop
|
||||
from openai import OpenAI
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
|
||||
@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
|
||||
# EngineArgs expects the model name to be passed as --model.
|
||||
args.model = args.model_tag
|
||||
|
||||
asyncio.run(run_server(args))
|
||||
uvloop.run(run_server(args))
|
||||
|
||||
|
||||
def interactive_cli(args: argparse.Namespace) -> None:
|
||||
|
@ -12,6 +12,7 @@ import tempfile
|
||||
import threading
|
||||
import uuid
|
||||
import warnings
|
||||
import weakref
|
||||
from asyncio import FIRST_COMPLETED, ensure_future
|
||||
from functools import lru_cache, partial, wraps
|
||||
from platform import uname
|
||||
@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int:
|
||||
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
|
||||
|
||||
|
||||
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
|
||||
"""Make an instance method that weakly references
|
||||
its associated instance and no-ops once that
|
||||
instance is collected."""
|
||||
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
|
||||
unbound = bound_method.__func__ # type: ignore[attr-defined]
|
||||
|
||||
def weak_bound(*args, **kwargs) -> None:
|
||||
if inst := ref():
|
||||
unbound(inst, *args, **kwargs)
|
||||
|
||||
return weak_bound
|
||||
|
||||
|
||||
#From: https://stackoverflow.com/a/4104188/2749989
|
||||
def run_once(f: Callable[P, None]) -> Callable[P, None]:
|
||||
|
||||
|
Reference in New Issue
Block a user