mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Core] Generic mechanism for handling engine utility (#13060)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -41,7 +41,7 @@ def download_and_prepare_lora_module():
|
||||
]
|
||||
for tokenizer_file in tokenizer_files:
|
||||
del_path = Path(LORA_MODULE_DOWNLOAD_PATH) / tokenizer_file
|
||||
del_path.unlink()
|
||||
del_path.unlink(missing_ok=True)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
|
@ -3,7 +3,8 @@
|
||||
import asyncio
|
||||
import time
|
||||
import uuid
|
||||
from typing import Dict, List
|
||||
from contextlib import ExitStack
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
from transformers import AutoTokenizer
|
||||
@ -14,7 +15,9 @@ from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.core_client import EngineCoreClient
|
||||
from vllm.v1.engine.core import EngineCore
|
||||
from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
@ -63,7 +66,7 @@ def loop_until_done(client: EngineCoreClient, outputs: Dict):
|
||||
async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
||||
|
||||
while True:
|
||||
engine_core_outputs = await client.get_output_async().outputs
|
||||
engine_core_outputs = (await client.get_output_async()).outputs
|
||||
|
||||
if len(engine_core_outputs) == 0:
|
||||
break
|
||||
@ -78,6 +81,14 @@ async def loop_until_done_async(client: EngineCoreClient, outputs: Dict):
|
||||
break
|
||||
|
||||
|
||||
# Dummy utility function to monkey-patch into engine core.
|
||||
def echo(self, msg: str, err_msg: Optional[str] = None) -> str:
|
||||
print(f"echo util function called: {msg}, {err_msg}")
|
||||
if err_msg is not None:
|
||||
raise ValueError(err_msg)
|
||||
return msg
|
||||
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.parametrize("multiprocessing_mode", [True, False])
|
||||
def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
@ -85,7 +96,10 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3)
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
@ -147,15 +161,30 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
|
||||
client.abort_requests([request.request_id])
|
||||
|
||||
if multiprocessing_mode:
|
||||
"""Utility method invocation"""
|
||||
|
||||
@fork_new_process_for_each_test
|
||||
@pytest.mark.asyncio
|
||||
core_client: SyncMPClient = client
|
||||
|
||||
result = core_client._call_utility("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
core_client._call_utility("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_engine_core_client_asyncio(monkeypatch):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME)
|
||||
# Monkey-patch core engine utility function to test.
|
||||
m.setattr(EngineCore, "echo", echo, raising=False)
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME, enforce_eager=True)
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
usage_context=UsageContext.UNKNOWN_CONTEXT)
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
@ -166,6 +195,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
after.callback(client.shutdown)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
@ -204,3 +234,14 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client._call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client._call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
import enum
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import msgspec
|
||||
|
||||
@ -106,6 +106,18 @@ class EngineCoreOutput(
|
||||
return self.finish_reason is not None
|
||||
|
||||
|
||||
class UtilityOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
gc=False): # type: ignore[call-arg]
|
||||
|
||||
call_id: int
|
||||
|
||||
# Non-None implies the call failed, result should be None.
|
||||
failure_message: Optional[str] = None
|
||||
result: Any = None
|
||||
|
||||
|
||||
class EngineCoreOutputs(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
@ -116,10 +128,12 @@ class EngineCoreOutputs(
|
||||
# e.g. columnwise layout
|
||||
|
||||
# [num_reqs]
|
||||
outputs: List[EngineCoreOutput]
|
||||
scheduler_stats: Optional[SchedulerStats]
|
||||
outputs: List[EngineCoreOutput] = []
|
||||
scheduler_stats: Optional[SchedulerStats] = None
|
||||
timestamp: float = 0.0
|
||||
|
||||
utility_output: Optional[UtilityOutput] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
@ -132,6 +146,4 @@ class EngineCoreRequestType(enum.Enum):
|
||||
"""
|
||||
ADD = b'\x00'
|
||||
ABORT = b'\x01'
|
||||
PROFILE = b'\x02'
|
||||
RESET_PREFIX_CACHE = b'\x03'
|
||||
ADD_LORA = b'\x04'
|
||||
UTILITY = b'\x02'
|
||||
|
@ -5,9 +5,11 @@ import signal
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from multiprocessing.connection import Connection
|
||||
from typing import Any, List, Optional, Tuple, Type
|
||||
|
||||
import msgspec
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -21,7 +23,7 @@ from vllm.utils import get_exception_traceback, zmq_socket_ctx
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
@ -330,19 +332,39 @@ class EngineCoreProc(EngineCore):
|
||||
self.add_request(request)
|
||||
elif request_type == EngineCoreRequestType.ABORT:
|
||||
self.abort_requests(request)
|
||||
elif request_type == EngineCoreRequestType.RESET_PREFIX_CACHE:
|
||||
self.reset_prefix_cache()
|
||||
elif request_type == EngineCoreRequestType.PROFILE:
|
||||
self.model_executor.profile(request)
|
||||
elif request_type == EngineCoreRequestType.ADD_LORA:
|
||||
self.model_executor.add_lora(request)
|
||||
elif request_type == EngineCoreRequestType.UTILITY:
|
||||
call_id, method_name, args = request
|
||||
output = UtilityOutput(call_id)
|
||||
try:
|
||||
method = getattr(self, method_name)
|
||||
output.result = method(
|
||||
*self._convert_msgspec_args(method, args))
|
||||
except BaseException as e:
|
||||
logger.exception("Invocation of %s method failed", method_name)
|
||||
output.failure_message = (f"Call to {method_name} method"
|
||||
f" failed: {str(e)}")
|
||||
self.output_queue.put_nowait(
|
||||
EngineCoreOutputs(utility_output=output))
|
||||
|
||||
@staticmethod
|
||||
def _convert_msgspec_args(method, args):
|
||||
"""If a provided arg type doesn't match corresponding target method
|
||||
arg type, try converting to msgspec object."""
|
||||
if not args:
|
||||
return args
|
||||
arg_types = signature(method).parameters.values()
|
||||
assert len(args) <= len(arg_types)
|
||||
return tuple(
|
||||
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
|
||||
and issubclass(p.annotation, msgspec.Struct)
|
||||
and not isinstance(v, p.annotation) else v
|
||||
for v, p in zip(args, arg_types))
|
||||
|
||||
def process_input_socket(self, input_path: str):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
|
||||
add_lora_decoder = MsgpackDecoder(LoRARequest)
|
||||
generic_decoder = MsgpackDecoder()
|
||||
|
||||
with zmq_socket_ctx(input_path, zmq.constants.PULL) as socket:
|
||||
@ -352,14 +374,9 @@ class EngineCoreProc(EngineCore):
|
||||
request_type = EngineCoreRequestType(bytes(type_frame.buffer))
|
||||
|
||||
# Deserialize the request data.
|
||||
decoder = None
|
||||
if request_type == EngineCoreRequestType.ADD:
|
||||
decoder = add_request_decoder
|
||||
elif request_type == EngineCoreRequestType.ADD_LORA:
|
||||
decoder = add_lora_decoder
|
||||
else:
|
||||
decoder = generic_decoder
|
||||
|
||||
decoder = add_request_decoder if (
|
||||
request_type
|
||||
== EngineCoreRequestType.ADD) else generic_decoder
|
||||
request = decoder.decode(data_frame.buffer)
|
||||
|
||||
# Push to input queue for core busy loop.
|
||||
|
@ -2,10 +2,14 @@
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import queue
|
||||
import signal
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, List, Optional, Type
|
||||
from concurrent.futures import Future
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -16,7 +20,7 @@ from vllm.lora.request import LoRARequest
|
||||
from vllm.utils import (get_open_zmq_ipc_path, kill_process_tree,
|
||||
make_zmq_socket)
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType)
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
@ -24,6 +28,8 @@ from vllm.v1.utils import BackgroundProcHandle
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
||||
|
||||
|
||||
class EngineCoreClient(ABC):
|
||||
"""
|
||||
@ -204,6 +210,8 @@ class MPClient(EngineCoreClient):
|
||||
"log_stats": log_stats,
|
||||
})
|
||||
|
||||
self.utility_results: Dict[int, AnyFuture] = {}
|
||||
|
||||
def shutdown(self):
|
||||
"""Clean up background resources."""
|
||||
if hasattr(self, "proc_handle"):
|
||||
@ -212,6 +220,16 @@ class MPClient(EngineCoreClient):
|
||||
self._finalizer()
|
||||
|
||||
|
||||
def _process_utility_output(output: UtilityOutput,
|
||||
utility_results: Dict[int, AnyFuture]):
|
||||
"""Set the result from a utility method in the waiting future"""
|
||||
future = utility_results.pop(output.call_id)
|
||||
if output.failure_message is not None:
|
||||
future.set_exception(Exception(output.failure_message))
|
||||
else:
|
||||
future.set_result(output.result)
|
||||
|
||||
|
||||
class SyncMPClient(MPClient):
|
||||
"""Synchronous client for multi-proc EngineCore."""
|
||||
|
||||
@ -224,10 +242,30 @@ class SyncMPClient(MPClient):
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
self.outputs_queue: queue.Queue[EngineCoreOutputs] = queue.Queue()
|
||||
|
||||
(frame, ) = self.output_socket.recv_multipart(copy=False)
|
||||
return self.decoder.decode(frame.buffer)
|
||||
# Ensure that the outputs socket processing thread does not have
|
||||
# a ref to the client which prevents gc.
|
||||
output_socket = self.output_socket
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
|
||||
def process_outputs_socket():
|
||||
while True:
|
||||
(frame, ) = output_socket.recv_multipart(copy=False)
|
||||
outputs = decoder.decode(frame.buffer)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
else:
|
||||
outputs_queue.put_nowait(outputs)
|
||||
|
||||
# Process outputs from engine in separate thread.
|
||||
Thread(target=process_outputs_socket, daemon=True).start()
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
return self.outputs_queue.get()
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
@ -236,6 +274,16 @@ class SyncMPClient(MPClient):
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
def _call_utility(self, method: str, *args) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future: Future[Any] = Future()
|
||||
self.utility_results[call_id] = future
|
||||
|
||||
self._send_input(EngineCoreRequestType.UTILITY,
|
||||
(call_id, method, args))
|
||||
|
||||
return future.result()
|
||||
|
||||
def add_request(self, request: EngineCoreRequest) -> None:
|
||||
# NOTE: text prompt is not needed in the core engine as it has been
|
||||
# tokenized.
|
||||
@ -247,13 +295,13 @@ class SyncMPClient(MPClient):
|
||||
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
def profile(self, is_start: bool = True) -> None:
|
||||
self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
||||
self._call_utility("profile", is_start)
|
||||
|
||||
def reset_prefix_cache(self) -> None:
|
||||
self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||
self._call_utility("reset_prefix_cache")
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> None:
|
||||
self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
|
||||
self._call_utility("add_lora", lora_request)
|
||||
|
||||
|
||||
class AsyncMPClient(MPClient):
|
||||
@ -268,24 +316,35 @@ class AsyncMPClient(MPClient):
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
self.outputs_queue: Optional[asyncio.Queue[EngineCoreOutputs]] = None
|
||||
self.queue_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def _start_output_queue_task(self):
|
||||
# Perform IO in separate task to parallelize as much as possible.
|
||||
# Avoid task having direct reference back to the client.
|
||||
self.outputs_queue = asyncio.Queue()
|
||||
output_socket = self.output_socket
|
||||
decoder = self.decoder
|
||||
utility_results = self.utility_results
|
||||
outputs_queue = self.outputs_queue
|
||||
|
||||
async def process_outputs_socket():
|
||||
while True:
|
||||
(frame, ) = await output_socket.recv_multipart(copy=False)
|
||||
outputs: EngineCoreOutputs = decoder.decode(frame.buffer)
|
||||
if outputs.utility_output:
|
||||
_process_utility_output(outputs.utility_output,
|
||||
utility_results)
|
||||
else:
|
||||
outputs_queue.put_nowait(outputs)
|
||||
|
||||
self.queue_task = asyncio.create_task(process_outputs_socket())
|
||||
|
||||
async def get_output_async(self) -> EngineCoreOutputs:
|
||||
if self.outputs_queue is None:
|
||||
# Perform IO in separate task to parallelize as much as possible
|
||||
self.outputs_queue = asyncio.Queue()
|
||||
|
||||
async def process_outputs_socket():
|
||||
assert self.outputs_queue is not None
|
||||
while True:
|
||||
(frame, ) = await self.output_socket.recv_multipart(
|
||||
copy=False)
|
||||
self.outputs_queue.put_nowait(frame.buffer)
|
||||
|
||||
self.queue_task = asyncio.create_task(process_outputs_socket())
|
||||
|
||||
return self.decoder.decode(await self.outputs_queue.get())
|
||||
await self._start_output_queue_task()
|
||||
assert self.outputs_queue is not None
|
||||
return await self.outputs_queue.get()
|
||||
|
||||
async def _send_input(self, request_type: EngineCoreRequestType,
|
||||
request: Any) -> None:
|
||||
@ -293,6 +352,18 @@ class AsyncMPClient(MPClient):
|
||||
msg = (request_type.value, self.encoder.encode(request))
|
||||
await self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
if self.outputs_queue is None:
|
||||
await self._start_output_queue_task()
|
||||
|
||||
async def _call_utility_async(self, method: str, *args) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
future = asyncio.get_running_loop().create_future()
|
||||
self.utility_results[call_id] = future
|
||||
await self._send_input(EngineCoreRequestType.UTILITY,
|
||||
(call_id, method, args))
|
||||
|
||||
return await future
|
||||
|
||||
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
||||
# NOTE: text prompt is not needed in the core engine as it has been
|
||||
# tokenized.
|
||||
@ -304,10 +375,10 @@ class AsyncMPClient(MPClient):
|
||||
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
||||
|
||||
async def profile_async(self, is_start: bool = True) -> None:
|
||||
await self._send_input(EngineCoreRequestType.PROFILE, is_start)
|
||||
await self._call_utility_async("profile", is_start)
|
||||
|
||||
async def reset_prefix_cache_async(self) -> None:
|
||||
await self._send_input(EngineCoreRequestType.RESET_PREFIX_CACHE, None)
|
||||
await self._call_utility_async("reset_prefix_cache")
|
||||
|
||||
async def add_lora_async(self, lora_request: LoRARequest) -> None:
|
||||
await self._send_input(EngineCoreRequestType.ADD_LORA, lora_request)
|
||||
await self._call_utility_async("add_lora", lora_request)
|
||||
|
Reference in New Issue
Block a user