[V1][Core] Generic mechanism for handling engine utility (#13060)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-02-19 01:09:22 -08:00
committed by GitHub
parent f525c0be8b
commit caf7ff4456
5 changed files with 197 additions and 56 deletions

View File

@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
]
for tokenizer_file in tokenizer_files:
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
del_path.unlink()
del_path.unlink(missing_ok=True)
@pytest.fixture(autouse=True)

View File

@ -3,7 +3,8 @@
import asyncio
import time
import uuid
from typing import Dict, List
from contextlib import ExitStack
from typing import Dict, List, Optional
import pytest
from transformers import AutoTokenizer
@ -14,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
from vllm.v1.engine import EngineCoreRequest
from vllm.v1.engine.core_client import EngineCoreClient
from vllm.v1.engine.core import EngineCore
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor
if not current_platform.is_cuda():
@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
while True:
engine_core_outputs = await client.get_output_async().outputs
engine_core_outputs = (await client.get_output_async()).outputs
if len(engine_core_outputs) == 0:
break
@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
break
# Dummy utility function to monkey-patch into engine core.
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
print(f"echo util function called: {msg}, {err_msg}")
if err_msg is not None:
raise ValueError(err_msg)
return msg
@fork_new_process_for_each_test
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo", echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
client.abort_requests([request.request_id])
if multiprocessing_mode:
"""Utility method invocation"""
@fork_new_process_for_each_test
@pytest.mark.asyncio
core_client: SyncMPClient = client
result = core_client._call_utility("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
core_client._call_utility("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"
@pytest.mark.asyncio(loop_scope="function")
async def test_engine_core_client_asyncio(monkeypatch):
with monkeypatch.context() as m:
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
engine_args = EngineArgs(model=MODEL_NAME)
# Monkey-patch core engine utility function to test.
m.setattr(EngineCore, "echo", echo, raising=False)
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
vllm_config = engine_args.create_engine_config(
usage_context=UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
executor_class=executor_class,
log_stats=True,
)
after.callback(client.shutdown)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
core_client: AsyncMPClient = client
result = await core_client._call_utility_async("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
await core_client._call_utility_async("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"

View File

@ -2,7 +2,7 @@
import enum
import time
from typing import List, Optional, Union
from typing import Any, List, Optional, Union
import msgspec
@ -106,6 +106,18 @@ class EngineCoreOutput(
return self.finish_reason is not None
class UtilityOutput(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
gc=False): # type: ignore[call-arg]
call_id: int
# Non-None implies the call failed, result should be None.
failure_message: Optional[str] = None
result: Any = None
class EngineCoreOutputs(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
@ -116,10 +128,12 @@ class EngineCoreOutputs(
# e.g. columnwise layout
# [num_reqs]
outputs: List[EngineCoreOutput]
scheduler_stats: Optional[SchedulerStats]
outputs: List[EngineCoreOutput] = []
scheduler_stats: Optional[SchedulerStats] = None
timestamp: float = 0.0
utility_output: Optional[UtilityOutput] = None
def __post_init__(self):
if self.timestamp == 0.0:
self.timestamp = time.monotonic()
@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
"""
ADD = b'\x00'
ABORT = b'\x01'
PROFILE = b'\x02'
RESET_PREFIX_CACHE = b'\x03'
ADD_LORA = b'\x04'
UTILITY = b'\x02'

View File

@ -5,9 +5,11 @@ import signal
import threading
import time
from concurrent.futures import Future
from inspect import isclass, signature
from multiprocessing.connection import Connection
from typing import Any, List, Optional, Tuple, Type
import msgspec
import psutil
import zmq
import zmq.asyncio
@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
from vllm.v1.executor.abstract import Executor
from vllm.v1.outputs import ModelRunnerOutput
@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
self.reset_prefix_cache()
elif request_type == EngineCoreRequestType.PROFILE:
self.model_executor.profile(request)
elif request_type == EngineCoreRequestType.ADD_LORA:
self.model_executor.add_lora(request)
elif request_type == EngineCoreRequestType.UTILITY:
call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
EngineCoreOutputs(utility_output=output))
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def process_input_socket(self, input_path: str):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
add_lora_decoder = MsgpackDecoder(LoRARequest)
generic_decoder = MsgpackDecoder()
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
# Deserialize the request data.
decoder = None
if request_type == EngineCoreRequestType.ADD:
decoder = add_request_decoder
elif request_type == EngineCoreRequestType.ADD_LORA:
decoder = add_lora_decoder
else:
decoder = generic_decoder
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frame.buffer)
# Push to input queue for core busy loop.

View File

@ -2,10 +2,14 @@
import asyncio
import os
import queue
import signal
import uuid
import weakref
from abc import ABC, abstractmethod
from typing import Any, List, Optional, Type
from concurrent.futures import Future
from threading import Thread
from typing import Any, Dict, List, Optional, Type, Union
import zmq
import zmq.asyncio
@ -16,7 +20,7 @@ from vllm.lora.request import LoRARequest
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
make_zmq_socket)
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType)
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.core import EngineCore, EngineCoreProc
from vllm.v1.executor.abstract import Executor
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
@ -24,6 +28,8 @@ from vllm.v1.utils import BackgroundProcHandle
logger = init_logger(__name__)
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
class EngineCoreClient(ABC):
"""
@ -204,6 +210,8 @@ class MPClient(EngineCoreClient):
"log_stats": log_stats,
})
self.utility_results: Dict[int, AnyFuture] = {}
def shutdown(self):
"""Clean up background resources."""
if hasattr(self, "proc_handle"):
@ -212,6 +220,16 @@ class MPClient(EngineCoreClient):
self._finalizer()
def _process_utility_output(output: UtilityOutput,
utility_results: Dict[int, AnyFuture]):
"""Set the result from a utility method in the waiting future"""
future = utility_results.pop(output.call_id)
if output.failure_message is not None:
future.set_exception(Exception(output.failure_message))
else:
future.set_result(output.result)
class SyncMPClient(MPClient):
"""Synchronous client for multi-proc EngineCore."""
@ -224,10 +242,30 @@ class SyncMPClient(MPClient):
log_stats=log_stats,
)
def get_output(self) -> EngineCoreOutputs:
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
(frame, ) = self.output_socket.recv_multipart(copy=False)
return self.decoder.decode(frame.buffer)
# Ensure that the outputs socket processing thread does not have
# a ref to the client which prevents gc.
output_socket = self.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
def process_outputs_socket():
while True:
(frame, ) = output_socket.recv_multipart(copy=False)
outputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
# Process outputs from engine in separate thread.
Thread(target=process_outputs_socket, daemon=True).start()
def get_output(self) -> EngineCoreOutputs:
return self.outputs_queue.get()
def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
@ -236,6 +274,16 @@ class SyncMPClient(MPClient):
msg = (request_type.value, self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
def _call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future: Future[Any] = Future()
self.utility_results[call_id] = future
self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
return future.result()
def add_request(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
@ -247,13 +295,13 @@ class SyncMPClient(MPClient):
self._send_input(EngineCoreRequestType.ABORT, request_ids)
def profile(self, is_start: bool = True) -> None:
self._send_input(EngineCoreRequestType.PROFILE, is_start)
self._call_utility("profile", is_start)
def reset_prefix_cache(self) -> None:
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
self._call_utility("reset_prefix_cache")
def add_lora(self, lora_request: LoRARequest) -> None:
self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
self._call_utility("add_lora", lora_request)
class AsyncMPClient(MPClient):
@ -268,24 +316,35 @@ class AsyncMPClient(MPClient):
log_stats=log_stats,
)
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
self.queue_task: Optional[asyncio.Task] = None
async def _start_output_queue_task(self):
# Perform IO in separate task to parallelize as much as possible.
# Avoid task having direct reference back to the client.
self.outputs_queue = asyncio.Queue()
output_socket = self.output_socket
decoder = self.decoder
utility_results = self.utility_results
outputs_queue = self.outputs_queue
async def process_outputs_socket():
while True:
(frame, ) = await output_socket.recv_multipart(copy=False)
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
if outputs.utility_output:
_process_utility_output(outputs.utility_output,
utility_results)
else:
outputs_queue.put_nowait(outputs)
self.queue_task = asyncio.create_task(process_outputs_socket())
async def get_output_async(self) -> EngineCoreOutputs:
if self.outputs_queue is None:
# Perform IO in separate task to parallelize as much as possible
self.outputs_queue = asyncio.Queue()
async def process_outputs_socket():
assert self.outputs_queue is not None
while True:
(frame, ) = await self.output_socket.recv_multipart(
copy=False)
self.outputs_queue.put_nowait(frame.buffer)
self.queue_task = asyncio.create_task(process_outputs_socket())
return self.decoder.decode(await self.outputs_queue.get())
await self._start_output_queue_task()
assert self.outputs_queue is not None
return await self.outputs_queue.get()
async def _send_input(self, request_type: EngineCoreRequestType,
request: Any) -> None:
@ -293,6 +352,18 @@ class AsyncMPClient(MPClient):
msg = (request_type.value, self.encoder.encode(request))
await self.input_socket.send_multipart(msg, copy=False)
if self.outputs_queue is None:
await self._start_output_queue_task()
async def _call_utility_async(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
future = asyncio.get_running_loop().create_future()
self.utility_results[call_id] = future
await self._send_input(EngineCoreRequestType.UTILITY,
(call_id, method, args))
return await future
async def add_request_async(self, request: EngineCoreRequest) -> None:
# NOTE: text prompt is not needed in the core engine as it has been
# tokenized.
@ -304,10 +375,10 @@ class AsyncMPClient(MPClient):
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
async def profile_async(self, is_start: bool = True) -> None:
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
await self._call_utility_async("profile", is_start)
async def reset_prefix_cache_async(self) -> None:
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
await self._call_utility_async("reset_prefix_cache")
async def add_lora_async(self, lora_request: LoRARequest) -> None:
await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
await self._call_utility_async("add_lora", lora_request)