[BugFix] Fix clean shutdown issues (#8492)

This commit is contained in:
Nick Hill
2024-09-16 17:33:46 +01:00
committed by GitHub
parent 837c1968f9
commit acd5511b6d
11 changed files with 213 additions and 134 deletions

View File

@ -26,6 +26,11 @@ class RequestOutput:
finished: bool = False
@dataclass
class MockModelConfig:
use_async_output_proc = True
class MockEngine:
def __init__(self):
@ -35,6 +40,7 @@ class MockEngine:
self.request_id = None
# Ugly, remove dependency when possible
self.parallel_config = ParallelConfig(1, 1, False)
self.model_config = MockModelConfig()
async def step_async(self, virtual_engine):
# PP size is 1, ignore virtual engine
@ -80,7 +86,7 @@ class MockAsyncLLMEngine(AsyncLLMEngine):
@pytest.mark.asyncio
async def test_new_requests_event():
engine = MockAsyncLLMEngine(worker_use_ray=False)
engine = MockAsyncLLMEngine()
engine.start_background_loop()
await asyncio.sleep(0.01)
assert engine.engine.step_calls == 0
@ -113,7 +119,7 @@ async def test_new_requests_event():
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1
engine = MockAsyncLLMEngine(worker_use_ray=True)
engine = MockAsyncLLMEngine()
assert engine.get_model_config() is not None
assert engine.get_tokenizer() is not None
assert engine.get_decoding_config() is not None

View File

@ -1,8 +1,10 @@
import asyncio
import time
import weakref
from functools import partial
from typing import (Any, AsyncGenerator, Callable, Dict, Iterable, List,
Mapping, Optional, Set, Tuple, Type, Union)
from weakref import ReferenceType
import vllm.envs as envs
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
@ -26,6 +28,7 @@ from vllm.sampling_params import SamplingParams
from vllm.sequence import ExecuteModelRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.usage.usage_lib import UsageContext
from vllm.utils import weak_bind
logger = init_logger(__name__)
ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S
@ -450,9 +453,6 @@ class AsyncLLMEngine:
method yields the outputs from the :class:`LLMEngine` to the caller.
Args:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
log_requests: Whether to log the requests.
start_engine_loop: If True, the background task to run the engine
will be automatically started in the generate call.
@ -463,23 +463,22 @@ class AsyncLLMEngine:
_engine_class: Type[_AsyncLLMEngine] = _AsyncLLMEngine
def __init__(self,
worker_use_ray: bool,
*args,
log_requests: bool = True,
start_engine_loop: bool = True,
**kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.log_requests = log_requests
self.engine = self._engine_class(*args, **kwargs)
# This ensures quick processing of request outputs
# so the append to asyncio queues is not delayed,
# especially for multi-step.
#
self.use_process_request_outputs_callback = True
self.use_process_request_outputs_callback = (
self.engine.model_config.use_async_output_proc)
if self.use_process_request_outputs_callback:
self.engine.process_request_outputs_callback = \
self.process_request_outputs
weak_bind(self.process_request_outputs)
self.background_loop: Optional[asyncio.Future] = None
# We need to keep a reference to unshielded
@ -492,6 +491,11 @@ class AsyncLLMEngine:
# Lazy initialized fields
self._request_tracker: RequestTracker
def __del__(self):
if rt := getattr(self, "request_tracker", None):
# Wake up engine loop so that it will exit cleanly
rt.new_requests_event.set()
@classmethod
def _get_executor_cls(
cls, engine_config: EngineConfig) -> Type[ExecutorAsyncBase]:
@ -502,15 +506,12 @@ class AsyncLLMEngine:
raise TypeError(
"distributed_executor_backend must be a subclass of "
f"ExecutorAsyncBase. Got {distributed_executor_backend}.")
if distributed_executor_backend.uses_ray: # type: ignore
initialize_ray_cluster(engine_config.parallel_config)
executor_class = distributed_executor_backend
elif engine_config.device_config.device_type == "neuron":
from vllm.executor.neuron_executor import NeuronExecutorAsync
executor_class = NeuronExecutorAsync
elif engine_config.device_config.device_type == "tpu":
if distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_tpu_executor import RayTPUExecutorAsync
executor_class = RayTPUExecutorAsync
else:
@ -531,11 +532,9 @@ class AsyncLLMEngine:
from vllm.executor.xpu_executor import XPUExecutorAsync
executor_class = XPUExecutorAsync
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync
executor_class = RayXPUExecutorAsync
elif distributed_executor_backend == "mp":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.multiproc_xpu_executor import (
MultiprocessingXPUExecutorAsync)
executor_class = MultiprocessingXPUExecutorAsync
@ -543,7 +542,6 @@ class AsyncLLMEngine:
raise RuntimeError(
"Not supported distributed execution model on XPU device.")
elif distributed_executor_backend == "ray":
initialize_ray_cluster(engine_config.parallel_config)
from vllm.executor.ray_gpu_executor import RayGPUExecutorAsync
executor_class = RayGPUExecutorAsync
elif distributed_executor_backend == "mp":
@ -559,19 +557,23 @@ class AsyncLLMEngine:
def from_engine_args(
cls,
engine_args: AsyncEngineArgs,
engine_config: Optional[EngineConfig] = None,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_config = engine_args.create_engine_config()
if engine_config is None:
engine_config = engine_args.create_engine_config()
executor_class = cls._get_executor_cls(engine_config)
if executor_class.uses_ray:
initialize_ray_cluster(engine_config.parallel_config)
# Create the async LLM engine.
engine = cls(
executor_class.uses_ray,
**engine_config.to_dict(),
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
@ -628,7 +630,7 @@ class AsyncLLMEngine:
self._request_tracker = RequestTracker()
self._background_loop_unshielded = asyncio.get_event_loop(
).create_task(self.run_engine_loop())
).create_task(self.run_engine_loop(weakref.ref(self)))
self._background_loop_unshielded.add_done_callback(
partial(_log_task_completion, error_callback=self._error_callback))
self.background_loop = asyncio.shield(self._background_loop_unshielded)
@ -698,9 +700,16 @@ class AsyncLLMEngine:
async def _engine_abort(self, request_ids: Iterable[str]):
self.engine.abort_request(request_ids)
async def run_engine_loop(self):
@staticmethod
async def run_engine_loop(engine_ref: ReferenceType):
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
engine: Optional["AsyncLLMEngine"] = engine_ref()
if not engine:
return
pipeline_parallel_size = \
self.engine.parallel_config.pipeline_parallel_size
engine.engine.parallel_config.pipeline_parallel_size
has_requests_in_progress = [False] * pipeline_parallel_size
while True:
if not any(has_requests_in_progress):
@ -711,11 +720,21 @@ class AsyncLLMEngine:
# timeout, and unblocks the RPC thread in the workers so that
# they can process any other queued control plane messages,
# such as add/remove lora adapters.
await self.engine.stop_remote_worker_execution_loop_async()
await self._request_tracker.wait_for_new_requests()
await engine.engine.stop_remote_worker_execution_loop_async()
request_tracker = engine._request_tracker
# Allow engine to be garbage collected while
# waiting for new requests
del engine
await asyncio.sleep(0)
if engine_ref() is None:
return
await request_tracker.wait_for_new_requests()
engine = engine_ref()
if not engine:
return
logger.debug("Got new requests!")
requests_in_progress = [
asyncio.create_task(self.engine_step(ve))
asyncio.create_task(engine.engine_step(ve))
for ve in range(pipeline_parallel_size)
]
has_requests_in_progress = [True] * pipeline_parallel_size
@ -733,19 +752,20 @@ class AsyncLLMEngine:
result = task.result()
virtual_engine = requests_in_progress.index(task)
has_unfinished_requests = (
self.engine.has_unfinished_requests_for_virtual_engine(
engine.engine.
has_unfinished_requests_for_virtual_engine(
virtual_engine))
if result or has_unfinished_requests:
requests_in_progress[virtual_engine] = (
asyncio.create_task(
self.engine_step(virtual_engine)))
engine.engine_step(virtual_engine)))
has_requests_in_progress[virtual_engine] = True
else:
has_requests_in_progress[virtual_engine] = False
except asyncio.TimeoutError as exc:
logger.error(
"Engine iteration timed out. This should never happen!")
self.set_errored(exc)
engine.set_errored(exc)
raise
await asyncio.sleep(0)

View File

@ -1,8 +1,8 @@
import functools
import time
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Deque, Dict,
Iterable, List, Mapping, NamedTuple, Optional)
from typing import Sequence as GenericSequence
@ -51,7 +51,7 @@ from vllm.transformers_utils.tokenizer_group import (
BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device
from vllm.utils import Counter, Device, weak_bind
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
@ -382,11 +382,16 @@ class LLMEngine:
for _ in range(self.parallel_config.pipeline_parallel_size)
]
self.async_callbacks = [
functools.partial(self._process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
if model_config.use_async_output_proc:
process_model_outputs = weak_bind(self._process_model_outputs)
self.async_callbacks = [
partial(process_model_outputs,
ctx=self.scheduler_contexts[v_id])
for v_id in range(self.parallel_config.pipeline_parallel_size)
]
else:
self.async_callbacks = []
# Currently used by AsyncLLMEngine to ensure quick append
# of request outputs to asyncio queues
@ -869,8 +874,8 @@ class LLMEngine:
"""
return self.scheduler[virtual_engine].has_unfinished_seqs()
@staticmethod
def _process_sequence_group_outputs(
self,
seq_group: SequenceGroup,
outputs: List[EmbeddingSequenceGroupOutput],
) -> None:

View File

@ -1,21 +1,20 @@
import asyncio
import signal
from http import HTTPStatus
from typing import Any
from typing import Any, Optional
import uvicorn
from fastapi import FastAPI, Response
from fastapi import FastAPI, Request, Response
from vllm import envs
from vllm.engine.async_llm_engine import AsyncEngineDeadError
from vllm.engine.protocol import AsyncEngineClient
from vllm.logger import init_logger
from vllm.utils import find_process_using_port
logger = init_logger(__name__)
async def serve_http(app: FastAPI, engine: AsyncEngineClient,
async def serve_http(app: FastAPI, limit_concurrency: Optional[int],
**uvicorn_kwargs: Any):
logger.info("Available routes are:")
for route in app.routes:
@ -29,16 +28,16 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
# Set concurrency limits in uvicorn if running in multiprocessing mode
# since zmq has maximum socket limit of zmq.constants.SOCKET_LIMIT (65536).
if engine.limit_concurrency is not None:
if limit_concurrency is not None:
logger.info(
"Launching Uvicorn with --limit_concurrency %s. To avoid this "
"limit at the expense of performance run with "
"--disable-frontend-multiprocessing", engine.limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = engine.limit_concurrency
"--disable-frontend-multiprocessing", limit_concurrency)
uvicorn_kwargs["limit_concurrency"] = limit_concurrency
config = uvicorn.Config(app, **uvicorn_kwargs)
server = uvicorn.Server(config)
_add_shutdown_handlers(app, server, engine)
_add_shutdown_handlers(app, server)
loop = asyncio.get_running_loop()
@ -68,15 +67,15 @@ async def serve_http(app: FastAPI, engine: AsyncEngineClient,
return server.shutdown()
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server,
engine: AsyncEngineClient) -> None:
def _add_shutdown_handlers(app: FastAPI, server: uvicorn.Server) -> None:
"""Adds handlers for fatal errors that should crash the server"""
@app.exception_handler(RuntimeError)
async def runtime_error_handler(_, __):
async def runtime_error_handler(request: Request, __):
"""On generic runtime error, check to see if the engine has died.
It probably has, in which case the server will no longer be able to
handle requests. Trigger a graceful shutdown with a SIGTERM."""
engine = request.app.state.engine_client
if (not envs.VLLM_KEEP_ALIVE_ON_ENGINE_DEATH and engine.errored
and not engine.is_running):
logger.fatal("AsyncLLMEngine has failed, terminating server "

View File

@ -4,16 +4,20 @@ import inspect
import multiprocessing
import os
import re
import signal
import tempfile
from argparse import Namespace
from contextlib import asynccontextmanager
from functools import partial
from http import HTTPStatus
from typing import AsyncIterator, Optional, Set
import uvloop
from fastapi import APIRouter, FastAPI, Request
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response, StreamingResponse
from starlette.datastructures import State
from starlette.routing import Mount
from typing_extensions import assert_never
@ -54,12 +58,6 @@ from vllm.version import __version__ as VLLM_VERSION
TIMEOUT_KEEP_ALIVE = 5 # seconds
async_engine_client: AsyncEngineClient
engine_args: AsyncEngineArgs
openai_serving_chat: OpenAIServingChat
openai_serving_completion: OpenAIServingCompletion
openai_serving_embedding: OpenAIServingEmbedding
openai_serving_tokenization: OpenAIServingTokenization
prometheus_multiproc_dir: tempfile.TemporaryDirectory
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
@ -83,18 +81,28 @@ def model_is_embedding(model_name: str, trust_remote_code: bool,
@asynccontextmanager
async def lifespan(app: FastAPI):
try:
if app.state.log_stats:
async_engine_client = app.state.engine_client
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
async def _force_log():
while True:
await asyncio.sleep(10)
await async_engine_client.do_log_stats()
if not engine_args.disable_log_stats:
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
yield
task = asyncio.create_task(_force_log())
_running_tasks.add(task)
task.add_done_callback(_running_tasks.remove)
else:
task = None
try:
yield
finally:
if task is not None:
task.cancel()
finally:
# Ensure app state including engine ref is gc'd
del app.state
@asynccontextmanager
@ -103,16 +111,10 @@ async def build_async_engine_client(
# Context manager to handle async_engine_client lifecycle
# Ensures everything is shutdown and cleaned up on error/exit
global engine_args
engine_args = AsyncEngineArgs.from_cli_args(args)
# Backend itself still global for the silly lil' health handler
global async_engine_client
async with build_async_engine_client_from_engine_args(
engine_args, args.disable_frontend_multiprocessing) as engine:
async_engine_client = engine # type: ignore[assignment]
yield engine
@ -134,12 +136,22 @@ async def build_async_engine_client_from_engine_args(
if (model_is_embedding(engine_args.model, engine_args.trust_remote_code,
engine_args.quantization, engine_args.revision)
or disable_frontend_multiprocessing):
engine_client = AsyncLLMEngine.from_engine_args(
engine_args, usage_context=UsageContext.OPENAI_API_SERVER)
try:
yield engine_client
finally:
engine_client.shutdown_background_loop()
engine_config = engine_args.create_engine_config()
uses_ray = getattr(AsyncLLMEngine._get_executor_cls(engine_config),
"uses_ray", False)
build_engine = partial(AsyncLLMEngine.from_engine_args,
engine_args=engine_args,
engine_config=engine_config,
usage_context=UsageContext.OPENAI_API_SERVER)
if uses_ray:
# Must run in main thread with ray for its signal handlers to work
engine_client = build_engine()
else:
engine_client = await asyncio.get_running_loop().run_in_executor(
None, build_engine)
yield engine_client
return
# Otherwise, use the multiprocessing AsyncLLMEngine.
@ -241,16 +253,36 @@ def mount_metrics(app: FastAPI):
app.routes.append(metrics_route)
def chat(request: Request) -> OpenAIServingChat:
return request.app.state.openai_serving_chat
def completion(request: Request) -> OpenAIServingCompletion:
return request.app.state.openai_serving_completion
def tokenization(request: Request) -> OpenAIServingTokenization:
return request.app.state.openai_serving_tokenization
def embedding(request: Request) -> OpenAIServingEmbedding:
return request.app.state.openai_serving_embedding
def engine_client(request: Request) -> AsyncEngineClient:
return request.app.state.engine_client
@router.get("/health")
async def health() -> Response:
async def health(raw_request: Request) -> Response:
"""Health check."""
await async_engine_client.check_health()
await engine_client(raw_request).check_health()
return Response(status_code=200)
@router.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_tokenization.create_tokenize(request)
async def tokenize(request: TokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@ -261,8 +293,8 @@ async def tokenize(request: TokenizeRequest):
@router.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_tokenization.create_detokenize(request)
async def detokenize(request: DetokenizeRequest, raw_request: Request):
generator = await tokenization(raw_request).create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
@ -273,8 +305,8 @@ async def detokenize(request: DetokenizeRequest):
@router.get("/v1/models")
async def show_available_models():
models = await openai_serving_completion.show_available_models()
async def show_available_models(raw_request: Request):
models = await completion(raw_request).show_available_models()
return JSONResponse(content=models.model_dump())
@ -288,7 +320,7 @@ async def show_version():
async def create_chat_completion(request: ChatCompletionRequest,
raw_request: Request):
generator = await openai_serving_chat.create_chat_completion(
generator = await chat(raw_request).create_chat_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
@ -303,7 +335,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
@router.post("/v1/completions")
async def create_completion(request: CompletionRequest, raw_request: Request):
generator = await openai_serving_completion.create_completion(
generator = await completion(raw_request).create_completion(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
@ -316,7 +348,7 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
@router.post("/v1/embeddings")
async def create_embedding(request: EmbeddingRequest, raw_request: Request):
generator = await openai_serving_embedding.create_embedding(
generator = await embedding(raw_request).create_embedding(
request, raw_request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
@ -333,16 +365,16 @@ if envs.VLLM_TORCH_PROFILER_DIR:
"used for local development!")
@router.post("/start_profile")
async def start_profile():
async def start_profile(raw_request: Request):
logger.info("Starting profiler...")
await async_engine_client.start_profile()
await engine_client(raw_request).start_profile()
logger.info("Profiler started.")
return Response(status_code=200)
@router.post("/stop_profile")
async def stop_profile():
async def stop_profile(raw_request: Request):
logger.info("Stopping profiler...")
await async_engine_client.stop_profile()
await engine_client(raw_request).stop_profile()
logger.info("Profiler stopped.")
return Response(status_code=200)
@ -353,13 +385,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
"This should ONLY be used for local development!")
@router.post("/v1/load_lora_adapter")
async def load_lora_adapter(request: LoadLoraAdapterRequest):
response = await openai_serving_chat.load_lora_adapter(request)
async def load_lora_adapter(request: LoadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await openai_serving_completion.load_lora_adapter(request)
response = await completion(raw_request).load_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
@ -367,13 +400,14 @@ if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
return Response(status_code=200, content=response)
@router.post("/v1/unload_lora_adapter")
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
response = await openai_serving_chat.unload_lora_adapter(request)
async def unload_lora_adapter(request: UnloadLoraAdapterRequest,
raw_request: Request):
response = await chat(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
response = await openai_serving_completion.unload_lora_adapter(request)
response = await completion(raw_request).unload_lora_adapter(request)
if isinstance(response, ErrorResponse):
return JSONResponse(content=response.model_dump(),
status_code=response.code)
@ -398,7 +432,8 @@ def build_app(args: Namespace) -> FastAPI:
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(_, exc):
err = openai_serving_chat.create_error_response(message=str(exc))
chat = app.state.openai_serving_chat
err = chat.create_error_response(message=str(exc))
return JSONResponse(err.model_dump(),
status_code=HTTPStatus.BAD_REQUEST)
@ -430,30 +465,26 @@ def build_app(args: Namespace) -> FastAPI:
return app
async def init_app(
def init_app_state(
async_engine_client: AsyncEngineClient,
model_config: ModelConfig,
state: State,
args: Namespace,
) -> FastAPI:
app = build_app(args)
) -> None:
if args.served_model_name is not None:
served_model_names = args.served_model_name
else:
served_model_names = [args.model]
model_config = await async_engine_client.get_model_config()
if args.disable_log_requests:
request_logger = None
else:
request_logger = RequestLogger(max_log_len=args.max_log_len)
global openai_serving_chat
global openai_serving_completion
global openai_serving_embedding
global openai_serving_tokenization
state.engine_client = async_engine_client
state.log_stats = not args.disable_log_stats
openai_serving_chat = OpenAIServingChat(
state.openai_serving_chat = OpenAIServingChat(
async_engine_client,
model_config,
served_model_names,
@ -465,7 +496,7 @@ async def init_app(
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
enable_auto_tools=args.enable_auto_tool_choice,
tool_parser=args.tool_call_parser)
openai_serving_completion = OpenAIServingCompletion(
state.openai_serving_completion = OpenAIServingCompletion(
async_engine_client,
model_config,
served_model_names,
@ -474,13 +505,13 @@ async def init_app(
request_logger=request_logger,
return_tokens_as_token_ids=args.return_tokens_as_token_ids,
)
openai_serving_embedding = OpenAIServingEmbedding(
state.openai_serving_embedding = OpenAIServingEmbedding(
async_engine_client,
model_config,
served_model_names,
request_logger=request_logger,
)
openai_serving_tokenization = OpenAIServingTokenization(
state.openai_serving_tokenization = OpenAIServingTokenization(
async_engine_client,
model_config,
served_model_names,
@ -488,25 +519,31 @@ async def init_app(
request_logger=request_logger,
chat_template=args.chat_template,
)
app.root_path = args.root_path
return app
async def run_server(args, **uvicorn_kwargs) -> None:
logger.info("vLLM API server version %s", VLLM_VERSION)
logger.info("args: %s", args)
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("terminated")
signal.signal(signal.SIGTERM, signal_handler)
async with build_async_engine_client(args) as async_engine_client:
# If None, creation of the client failed and we exit.
if async_engine_client is None:
return
app = await init_app(async_engine_client, args)
app = build_app(args)
model_config = await async_engine_client.get_model_config()
init_app_state(async_engine_client, model_config, app.state, args)
shutdown_task = await serve_http(
app,
engine=async_engine_client,
limit_concurrency=async_engine_client.limit_concurrency,
host=args.host,
port=args.port,
log_level=args.uvicorn_log_level,
@ -530,4 +567,4 @@ if __name__ == "__main__":
parser = make_arg_parser(parser)
args = parser.parse_args()
asyncio.run(run_server(args))
uvloop.run(run_server(args))

View File

@ -46,7 +46,6 @@ class AsyncEngineRPCServer:
"""Cleanup all resources."""
self.socket.close()
self.context.destroy()
self.engine.shutdown_background_loop()
# Clear the engine reference so that it can be GC'ed.
del self.engine
@ -233,5 +232,12 @@ async def run_server(server: AsyncEngineRPCServer):
def run_rpc_server(async_engine_args: AsyncEngineArgs,
usage_context: UsageContext, rpc_path: str):
def signal_handler(*_) -> None:
# Interrupt server on sigterm while initializing
raise KeyboardInterrupt("AsyncEngineRPCServer terminated")
signal.signal(signal.SIGTERM, signal_handler)
server = AsyncEngineRPCServer(async_engine_args, usage_context, rpc_path)
uvloop.run(run_server(server))

View File

@ -1,8 +1,5 @@
import asyncio
import os
import signal
import threading
import weakref
from functools import partial
from typing import Any, List, Optional
@ -108,17 +105,6 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor):
# Set up signal handlers to shutdown the executor cleanly
# sometimes gc does not work well
# Use weakref to avoid holding a reference to self
ref = weakref.ref(self)
def shutdown(signum, frame):
if executor := ref():
executor.shutdown()
if threading.current_thread() is threading.main_thread():
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
self.driver_worker = self._create_worker(
distributed_init_method=distributed_init_method)
self._run_workers("init_device")

View File

@ -120,7 +120,8 @@ class WorkerMonitor(threading.Thread):
logger.error("Worker %s pid %s died, exit code: %s",
process.name, process.pid, process.exitcode)
# Cleanup any remaining workers
logger.info("Killing local vLLM worker processes")
if logger:
logger.info("Killing local vLLM worker processes")
for worker in self.workers:
worker.kill_worker()
# Must be done after worker task queues are all closed
@ -221,6 +222,8 @@ def _run_worker_process(
try:
executor = getattr(worker, method)
output = executor(*args, **kwargs)
except KeyboardInterrupt:
break
except BaseException as e:
tb = traceback.format_exc()
logger.error(

View File

@ -26,6 +26,8 @@ logger = init_logger(__name__)
class RayTPUExecutor(TPUExecutor):
uses_ray: bool = True
def __init__(self, *args, **kwargs):
# This is non-None when the execute model loop is running
# in the parallel workers. It's a coroutine in the AsyncLLMEngine case.

View File

@ -1,11 +1,11 @@
# The CLI entrypoint to vLLM.
import argparse
import asyncio
import os
import signal
import sys
from typing import List, Optional
import uvloop
from openai import OpenAI
from openai.types.chat import ChatCompletionMessageParam
@ -34,7 +34,7 @@ def serve(args: argparse.Namespace) -> None:
# EngineArgs expects the model name to be passed as --model.
args.model = args.model_tag
asyncio.run(run_server(args))
uvloop.run(run_server(args))
def interactive_cli(args: argparse.Namespace) -> None:

View File

@ -12,6 +12,7 @@ import tempfile
import threading
import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future
from functools import lru_cache, partial, wraps
from platform import uname
@ -1079,6 +1080,20 @@ def cuda_device_count_stateless() -> int:
return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES)
def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]:
"""Make an instance method that weakly references
its associated instance and no-ops once that
instance is collected."""
ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined]
unbound = bound_method.__func__ # type: ignore[attr-defined]
def weak_bound(*args, **kwargs) -> None:
if inst := ref():
unbound(inst, *args, **kwargs)
return weak_bound
#From: https://stackoverflow.com/a/4104188/2749989
def run_once(f: Callable[P, None]) -> Callable[P, None]: