1253 lines
50 KiB
Python
1253 lines
50 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
import asyncio
|
|
import contextlib
|
|
import queue
|
|
import sys
|
|
import uuid
|
|
import weakref
|
|
from abc import ABC, abstractmethod
|
|
from collections import defaultdict, deque
|
|
from collections.abc import Awaitable, Sequence
|
|
from concurrent.futures import Future
|
|
from dataclasses import dataclass
|
|
from threading import Thread
|
|
from typing import Any, Callable, Optional, TypeVar, Union
|
|
|
|
import msgspec.msgpack
|
|
import zmq
|
|
import zmq.asyncio
|
|
|
|
from vllm.config import VllmConfig
|
|
from vllm.logger import init_logger
|
|
from vllm.lora.request import LoRARequest
|
|
from vllm.tasks import SupportedTask
|
|
from vllm.utils import get_open_port, get_open_zmq_inproc_path, make_zmq_socket
|
|
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
|
EngineCoreRequestType,
|
|
ReconfigureDistributedRequest, ReconfigureRankType,
|
|
UtilityOutput)
|
|
from vllm.v1.engine.coordinator import DPCoordinator
|
|
from vllm.v1.engine.core import EngineCore, EngineCoreProc
|
|
from vllm.v1.engine.exceptions import EngineDeadError
|
|
from vllm.v1.engine.utils import (CoreEngineActorManager,
|
|
CoreEngineProcManager, launch_core_engines)
|
|
from vllm.v1.executor.abstract import Executor
|
|
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder, bytestr
|
|
|
|
logger = init_logger(__name__)
|
|
|
|
AnyFuture = Union[asyncio.Future[Any], Future[Any]]
|
|
|
|
_R = TypeVar('_R') # Return type for collective_rpc
|
|
|
|
EngineIdentity = bytes
|
|
|
|
|
|
class EngineCoreClient(ABC):
|
|
"""
|
|
EngineCoreClient: subclasses handle different methods for pushing
|
|
and pulling from the EngineCore for asyncio / multiprocessing.
|
|
|
|
Subclasses:
|
|
* InprocClient: In process EngineCore (for V0-style LLMEngine use)
|
|
* SyncMPClient: ZMQ + background proc EngineCore (for LLM)
|
|
* AsyncMPClient: ZMQ + background proc EngineCore w/ asyncio (for AsyncLLM)
|
|
"""
|
|
|
|
@staticmethod
|
|
def make_client(
|
|
multiprocess_mode: bool,
|
|
asyncio_mode: bool,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
) -> "EngineCoreClient":
|
|
|
|
# TODO: support this for debugging purposes.
|
|
if asyncio_mode and not multiprocess_mode:
|
|
raise NotImplementedError(
|
|
"Running EngineCore in asyncio without multiprocessing "
|
|
"is not currently supported.")
|
|
|
|
if multiprocess_mode and asyncio_mode:
|
|
return EngineCoreClient.make_async_mp_client(
|
|
vllm_config, executor_class, log_stats)
|
|
|
|
if multiprocess_mode and not asyncio_mode:
|
|
return SyncMPClient(vllm_config, executor_class, log_stats)
|
|
|
|
return InprocClient(vllm_config, executor_class, log_stats)
|
|
|
|
@staticmethod
|
|
def make_async_mp_client(
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
client_addresses: Optional[dict[str, str]] = None,
|
|
client_index: int = 0,
|
|
) -> "MPClient":
|
|
parallel_config = vllm_config.parallel_config
|
|
client_args = (vllm_config, executor_class, log_stats,
|
|
client_addresses, client_index)
|
|
if parallel_config.data_parallel_size > 1:
|
|
if parallel_config.data_parallel_external_lb:
|
|
# External load balancer - client per DP rank.
|
|
return DPAsyncMPClient(*client_args)
|
|
# Internal load balancer - client balances to all DP ranks.
|
|
return DPLBAsyncMPClient(*client_args)
|
|
return AsyncMPClient(*client_args)
|
|
|
|
@abstractmethod
|
|
def shutdown(self):
|
|
...
|
|
|
|
def get_output(self) -> EngineCoreOutputs:
|
|
raise NotImplementedError
|
|
|
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
raise NotImplementedError
|
|
|
|
def add_request(self, request: EngineCoreRequest) -> None:
|
|
raise NotImplementedError
|
|
|
|
def profile(self, is_start: bool = True) -> None:
|
|
raise NotImplementedError
|
|
|
|
def reset_mm_cache(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def reset_prefix_cache(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
raise NotImplementedError
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
def is_sleeping(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def execute_dummy_batch_async(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def abort_requests(self, request_ids: list[str]) -> None:
|
|
raise NotImplementedError
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def list_loras(self) -> set[int]:
|
|
raise NotImplementedError
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
def save_sharded_state(self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
def collective_rpc(self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
raise NotImplementedError
|
|
|
|
def dp_engines_running(self) -> bool:
|
|
"""Returns True id data parallel engines are collectively in a
|
|
running state."""
|
|
raise NotImplementedError
|
|
|
|
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def get_output_async(self) -> EngineCoreOutputs:
|
|
raise NotImplementedError
|
|
|
|
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
|
|
raise NotImplementedError
|
|
|
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def profile_async(self, is_start: bool = True) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def reset_mm_cache_async(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def reset_prefix_cache_async(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def sleep_async(self, level: int = 1) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def is_sleeping_async(self) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def remove_lora_async(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def list_loras_async(self) -> set[int]:
|
|
raise NotImplementedError
|
|
|
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
|
raise NotImplementedError
|
|
|
|
async def save_sharded_state_async(self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None) -> None:
|
|
raise NotImplementedError
|
|
|
|
async def collective_rpc_async(
|
|
self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
raise NotImplementedError
|
|
|
|
|
|
class InprocClient(EngineCoreClient):
|
|
"""
|
|
InprocClient: client for in-process EngineCore. Intended
|
|
for use in LLMEngine for V0-style add_request() and step()
|
|
EngineCore setup in this process (no busy loop).
|
|
|
|
* pushes EngineCoreRequest directly into the EngineCore
|
|
* pulls EngineCoreOutputs by stepping the EngineCore
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
self.engine_core = EngineCore(*args, **kwargs)
|
|
|
|
def get_output(self) -> EngineCoreOutputs:
|
|
outputs, _ = self.engine_core.step()
|
|
return outputs.get(0) or EngineCoreOutputs()
|
|
|
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
return self.engine_core.get_supported_tasks()
|
|
|
|
def add_request(self, request: EngineCoreRequest) -> None:
|
|
self.engine_core.add_request(request)
|
|
|
|
def abort_requests(self, request_ids: list[str]) -> None:
|
|
if len(request_ids) > 0:
|
|
self.engine_core.abort_requests(request_ids)
|
|
|
|
def shutdown(self) -> None:
|
|
self.engine_core.shutdown()
|
|
|
|
def profile(self, is_start: bool = True) -> None:
|
|
self.engine_core.profile(is_start)
|
|
|
|
def reset_mm_cache(self) -> None:
|
|
self.engine_core.reset_mm_cache()
|
|
|
|
def reset_prefix_cache(self) -> None:
|
|
self.engine_core.reset_prefix_cache()
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
self.engine_core.sleep(level)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
self.engine_core.wake_up(tags)
|
|
|
|
def is_sleeping(self) -> bool:
|
|
return self.engine_core.is_sleeping()
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
self.engine_core.execute_dummy_batch()
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.engine_core.add_lora(lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.engine_core.remove_lora(lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.engine_core.list_loras()
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.engine_core.pin_lora(lora_id)
|
|
|
|
def save_sharded_state(self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None) -> None:
|
|
self.engine_core.save_sharded_state(path, pattern, max_size)
|
|
|
|
def collective_rpc(self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
return self.engine_core.collective_rpc(method, timeout, args, kwargs)
|
|
|
|
def dp_engines_running(self) -> bool:
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class BackgroundResources:
|
|
"""Used as a finalizer for clean shutdown, avoiding
|
|
circular reference back to the client object."""
|
|
|
|
ctx: Union[zmq.Context]
|
|
# If CoreEngineProcManager, it manages local engines;
|
|
# if CoreEngineActorManager, it manages all engines.
|
|
engine_manager: Optional[Union[CoreEngineProcManager,
|
|
CoreEngineActorManager]] = None
|
|
coordinator: Optional[DPCoordinator] = None
|
|
output_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
|
input_socket: Optional[Union[zmq.Socket, zmq.asyncio.Socket]] = None
|
|
first_req_send_socket: Optional[zmq.asyncio.Socket] = None
|
|
output_queue_task: Optional[asyncio.Task] = None
|
|
stats_update_task: Optional[asyncio.Task] = None
|
|
shutdown_path: Optional[str] = None
|
|
|
|
# Set if any of the engines are dead. Here so that the output
|
|
# processing threads can access it without holding a ref to the client.
|
|
engine_dead: bool = False
|
|
|
|
def __call__(self):
|
|
"""Clean up background resources."""
|
|
|
|
self.engine_dead = True
|
|
if self.engine_manager is not None:
|
|
self.engine_manager.close()
|
|
if self.coordinator is not None:
|
|
self.coordinator.close()
|
|
|
|
if self.output_queue_task is not None:
|
|
self.output_queue_task.cancel()
|
|
if self.stats_update_task is not None:
|
|
self.stats_update_task.cancel()
|
|
|
|
# ZMQ context termination can hang if the sockets
|
|
# aren't explicitly closed first.
|
|
for socket in (self.output_socket, self.input_socket,
|
|
self.first_req_send_socket):
|
|
if socket is not None:
|
|
socket.close(linger=0)
|
|
|
|
if self.shutdown_path is not None:
|
|
# We must ensure that the sync output socket is
|
|
# closed cleanly in its own thread.
|
|
with self.ctx.socket(zmq.PAIR) as shutdown_sender:
|
|
shutdown_sender.connect(self.shutdown_path)
|
|
# Send shutdown signal.
|
|
shutdown_sender.send(b'')
|
|
|
|
def validate_alive(self, frames: Sequence[zmq.Frame]):
|
|
if len(frames) == 1 and (frames[0].buffer
|
|
== EngineCoreProc.ENGINE_CORE_DEAD):
|
|
self.engine_dead = True
|
|
raise EngineDeadError()
|
|
|
|
|
|
class MPClient(EngineCoreClient):
|
|
"""
|
|
MPClient: base client for multi-proc EngineCore.
|
|
EngineCore runs in a background process busy loop, getting
|
|
new EngineCoreRequests and returning EngineCoreOutputs
|
|
|
|
* pushes EngineCoreRequests via input_socket
|
|
* pulls EngineCoreOutputs via output_socket
|
|
|
|
* AsyncMPClient subclass for AsyncLLM usage
|
|
* SyncMPClient subclass for LLM usage
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
asyncio_mode: bool,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
client_addresses: Optional[dict[str, str]] = None,
|
|
):
|
|
self.vllm_config = vllm_config
|
|
# Serialization setup.
|
|
self.encoder = MsgpackEncoder()
|
|
self.decoder = MsgpackDecoder(EngineCoreOutputs)
|
|
|
|
# ZMQ setup.
|
|
sync_ctx = zmq.Context(io_threads=2)
|
|
self.ctx = zmq.asyncio.Context(sync_ctx) if asyncio_mode else sync_ctx
|
|
|
|
# This will ensure resources created so far are closed
|
|
# when the client is garbage collected, even if an
|
|
# exception is raised mid-construction.
|
|
self.resources = BackgroundResources(ctx=sync_ctx)
|
|
self._finalizer = weakref.finalize(self, self.resources)
|
|
success = False
|
|
try:
|
|
# State used for data parallel.
|
|
self.engines_running = False
|
|
|
|
self.stats_update_address: Optional[str] = None
|
|
if client_addresses is not None:
|
|
# Engines are managed externally to this client.
|
|
input_address = client_addresses["input_address"]
|
|
output_address = client_addresses["output_address"]
|
|
self.stats_update_address = client_addresses.get(
|
|
"stats_update_address")
|
|
else:
|
|
# Engines are managed by this client.
|
|
with launch_core_engines(vllm_config, executor_class,
|
|
log_stats) as (engine_manager,
|
|
coordinator,
|
|
addresses):
|
|
self.resources.coordinator = coordinator
|
|
self.resources.engine_manager = engine_manager
|
|
|
|
(input_address, ) = addresses.inputs
|
|
(output_address, ) = addresses.outputs
|
|
self.stats_update_address = (
|
|
addresses.frontend_stats_publish_address)
|
|
if coordinator is not None:
|
|
assert self.stats_update_address == (
|
|
coordinator.get_stats_publish_address())
|
|
|
|
# Create input and output sockets.
|
|
self.input_socket = self.resources.input_socket = make_zmq_socket(
|
|
self.ctx, input_address, zmq.ROUTER, bind=True)
|
|
self.resources.output_socket = make_zmq_socket(
|
|
self.ctx, output_address, zmq.PULL)
|
|
|
|
parallel_config = vllm_config.parallel_config
|
|
dp_size = parallel_config.data_parallel_size
|
|
dp_rank = parallel_config.data_parallel_rank
|
|
dp_local_size = parallel_config.data_parallel_size_local
|
|
offline_mode = parallel_config.data_parallel_rank_local is not None
|
|
# Client manages local+remote EngineCores in pure internal LB case.
|
|
# Client manages local EngineCores in hybrid and external LB case.
|
|
local_engines_only = (parallel_config.data_parallel_hybrid_lb
|
|
or parallel_config.data_parallel_external_lb)
|
|
|
|
num_ranks = dp_local_size if local_engines_only else dp_size
|
|
self.engine_ranks_managed = [dp_rank] if offline_mode else list(
|
|
range(dp_rank, dp_rank + num_ranks))
|
|
assert parallel_config.data_parallel_size_local <= len(
|
|
self.engine_ranks_managed)
|
|
|
|
# ZMQ identity of each engine that this client will talk to.
|
|
self.core_engines: list[EngineIdentity] = [
|
|
rank.to_bytes(2, "little")
|
|
for rank in self.engine_ranks_managed
|
|
]
|
|
|
|
# Wait for ready messages from each engine on the input socket.
|
|
identities = set(self.core_engines)
|
|
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
|
while identities:
|
|
if not sync_input_socket.poll(timeout=600_000):
|
|
raise TimeoutError("Timed out waiting for engines to send"
|
|
"initial message on input socket.")
|
|
identity, _ = sync_input_socket.recv_multipart()
|
|
identities.remove(identity)
|
|
|
|
self.core_engine: EngineIdentity = self.core_engines[0]
|
|
self.utility_results: dict[int, AnyFuture] = {}
|
|
|
|
# Request objects which may contain pytorch-allocated tensors
|
|
# that we need to keep references to until zmq is done with the
|
|
# underlying data.
|
|
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
|
|
|
|
success = True
|
|
finally:
|
|
if not success:
|
|
self._finalizer()
|
|
|
|
def shutdown(self):
|
|
# Terminate background resources.
|
|
self._finalizer()
|
|
|
|
def _format_exception(self, e: Exception) -> Exception:
|
|
"""If errored, use EngineDeadError so root cause is clear."""
|
|
return EngineDeadError(
|
|
suppress_context=True) if self.resources.engine_dead else e
|
|
|
|
def ensure_alive(self):
|
|
if self.resources.engine_dead:
|
|
raise EngineDeadError()
|
|
|
|
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
|
|
if not tracker.done:
|
|
self.pending_messages.appendleft((tracker, msg))
|
|
|
|
def free_pending_messages(self):
|
|
while self.pending_messages and self.pending_messages[-1][0].done:
|
|
self.pending_messages.pop()
|
|
|
|
def dp_engines_running(self) -> bool:
|
|
return self.engines_running
|
|
|
|
|
|
def _process_utility_output(output: UtilityOutput,
|
|
utility_results: dict[int, AnyFuture]):
|
|
"""Set the result from a utility method in the waiting future"""
|
|
future = utility_results.pop(output.call_id)
|
|
if output.failure_message is not None:
|
|
future.set_exception(Exception(output.failure_message))
|
|
else:
|
|
future.set_result(output.result)
|
|
|
|
|
|
class SyncMPClient(MPClient):
|
|
"""Synchronous client for multi-proc EngineCore."""
|
|
|
|
def __init__(self, vllm_config: VllmConfig, executor_class: type[Executor],
|
|
log_stats: bool):
|
|
super().__init__(
|
|
asyncio_mode=False,
|
|
vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=log_stats,
|
|
)
|
|
|
|
self.is_dp = self.vllm_config.parallel_config.data_parallel_size > 1
|
|
self.outputs_queue = queue.Queue[Union[EngineCoreOutputs, Exception]]()
|
|
|
|
# Ensure that the outputs socket processing thread does not have
|
|
# a ref to the client which prevents gc.
|
|
ctx = self.ctx
|
|
out_socket = self.resources.output_socket
|
|
decoder = self.decoder
|
|
utility_results = self.utility_results
|
|
outputs_queue = self.outputs_queue
|
|
|
|
shutdown_path = get_open_zmq_inproc_path()
|
|
resources = self.resources
|
|
resources.shutdown_path = shutdown_path
|
|
|
|
def process_outputs_socket():
|
|
assert isinstance(out_socket, zmq.Socket)
|
|
shutdown_socket = ctx.socket(zmq.PAIR)
|
|
try:
|
|
shutdown_socket.bind(shutdown_path)
|
|
poller = zmq.Poller()
|
|
poller.register(shutdown_socket, zmq.POLLIN)
|
|
poller.register(out_socket, zmq.POLLIN)
|
|
while True:
|
|
socks = poller.poll()
|
|
if not socks:
|
|
continue
|
|
if len(socks) == 2 or socks[0][0] == shutdown_socket:
|
|
# shutdown signal, exit thread.
|
|
break
|
|
|
|
frames = out_socket.recv_multipart(copy=False)
|
|
resources.validate_alive(frames)
|
|
outputs: EngineCoreOutputs = decoder.decode(frames)
|
|
if outputs.utility_output:
|
|
_process_utility_output(outputs.utility_output,
|
|
utility_results)
|
|
else:
|
|
outputs_queue.put_nowait(outputs)
|
|
except Exception as e:
|
|
outputs_queue.put_nowait(e)
|
|
finally:
|
|
# Close sockets.
|
|
shutdown_socket.close(linger=0)
|
|
out_socket.close(linger=0)
|
|
|
|
# Process outputs from engine in separate thread.
|
|
self.output_queue_thread = Thread(target=process_outputs_socket,
|
|
name="EngineCoreOutputQueueThread",
|
|
daemon=True)
|
|
self.output_queue_thread.start()
|
|
|
|
# The thread takes on responsibility for closing the socket.
|
|
self.resources.output_socket = None
|
|
|
|
def get_output(self) -> EngineCoreOutputs:
|
|
# If an exception arises in process_outputs_socket task,
|
|
# it is forwarded to the outputs_queue so we can raise it
|
|
# from this (run_output_handler) task to shut down the server.
|
|
outputs = self.outputs_queue.get()
|
|
if isinstance(outputs, Exception):
|
|
raise self._format_exception(outputs) from None
|
|
if outputs.wave_complete is not None:
|
|
self.engines_running = False
|
|
return outputs
|
|
|
|
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
|
self.ensure_alive()
|
|
self.free_pending_messages()
|
|
# (Identity, RequestType, SerializedRequest)
|
|
msg = (self.core_engine, request_type.value,
|
|
*self.encoder.encode(request))
|
|
|
|
if len(msg) <= 3:
|
|
# No auxiliary buffers => no tensor backing buffers in request.
|
|
self.input_socket.send_multipart(msg, copy=False)
|
|
return
|
|
|
|
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
|
|
self.add_pending_message(tracker, request)
|
|
|
|
def call_utility(self, method: str, *args) -> Any:
|
|
call_id = uuid.uuid1().int >> 64
|
|
future: Future[Any] = Future()
|
|
self.utility_results[call_id] = future
|
|
self._send_input(EngineCoreRequestType.UTILITY,
|
|
(0, call_id, method, args))
|
|
|
|
return future.result()
|
|
|
|
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
|
|
return self.call_utility("get_supported_tasks")
|
|
|
|
def add_request(self, request: EngineCoreRequest) -> None:
|
|
if self.is_dp:
|
|
self.engines_running = True
|
|
self._send_input(EngineCoreRequestType.ADD, request)
|
|
|
|
def abort_requests(self, request_ids: list[str]) -> None:
|
|
if request_ids and not self.resources.engine_dead:
|
|
self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
|
|
|
def profile(self, is_start: bool = True) -> None:
|
|
self.call_utility("profile", is_start)
|
|
|
|
def reset_mm_cache(self) -> None:
|
|
self.call_utility("reset_mm_cache")
|
|
|
|
def reset_prefix_cache(self) -> None:
|
|
self.call_utility("reset_prefix_cache")
|
|
|
|
def add_lora(self, lora_request: LoRARequest) -> bool:
|
|
return self.call_utility("add_lora", lora_request)
|
|
|
|
def remove_lora(self, lora_id: int) -> bool:
|
|
return self.call_utility("remove_lora", lora_id)
|
|
|
|
def list_loras(self) -> set[int]:
|
|
return self.call_utility("list_loras")
|
|
|
|
def pin_lora(self, lora_id: int) -> bool:
|
|
return self.call_utility("pin_lora", lora_id)
|
|
|
|
def sleep(self, level: int = 1) -> None:
|
|
self.call_utility("sleep", level)
|
|
|
|
def wake_up(self, tags: Optional[list[str]] = None) -> None:
|
|
self.call_utility("wake_up", tags)
|
|
|
|
def is_sleeping(self) -> bool:
|
|
return self.call_utility("is_sleeping")
|
|
|
|
def execute_dummy_batch(self) -> None:
|
|
self.call_utility("execute_dummy_batch")
|
|
|
|
def collective_rpc(self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
return self.call_utility("collective_rpc", method, timeout, args,
|
|
kwargs)
|
|
|
|
def save_sharded_state(self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None) -> None:
|
|
self.call_utility("save_sharded_state", path, pattern, max_size)
|
|
|
|
|
|
class AsyncMPClient(MPClient):
|
|
"""Asyncio-compatible client for multi-proc EngineCore."""
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
client_addresses: Optional[dict[str, str]] = None,
|
|
client_index: int = 0):
|
|
super().__init__(
|
|
asyncio_mode=True,
|
|
vllm_config=vllm_config,
|
|
executor_class=executor_class,
|
|
log_stats=log_stats,
|
|
client_addresses=client_addresses,
|
|
)
|
|
|
|
self.client_index = client_index
|
|
self.outputs_queue = asyncio.Queue[Union[EngineCoreOutputs,
|
|
Exception]]()
|
|
try:
|
|
# If we are running in an asyncio event loop, start the queue task.
|
|
# Otherwise, it will be started lazily. If it is not started here,
|
|
# we could miss EXECUTOR_FAILED messages from engine core if they
|
|
# occur prior to any requests being sent.
|
|
asyncio.get_running_loop()
|
|
self._ensure_output_queue_task()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
def _ensure_output_queue_task(self):
|
|
resources = self.resources
|
|
if resources.output_queue_task is not None:
|
|
return
|
|
|
|
# Perform IO in separate task to parallelize as much as possible.
|
|
# Avoid task having direct reference back to the client.
|
|
decoder = self.decoder
|
|
utility_results = self.utility_results
|
|
outputs_queue = self.outputs_queue
|
|
output_handler: Optional[Callable[[AsyncMPClient, EngineCoreOutputs],
|
|
Awaitable[None]]] = getattr(
|
|
self.__class__,
|
|
"process_engine_outputs", None)
|
|
_self_ref = weakref.ref(self) if output_handler else None
|
|
output_socket = resources.output_socket
|
|
assert output_socket is not None
|
|
|
|
async def process_outputs_socket():
|
|
try:
|
|
while True:
|
|
frames = await output_socket.recv_multipart(copy=False)
|
|
resources.validate_alive(frames)
|
|
outputs: EngineCoreOutputs = decoder.decode(frames)
|
|
if outputs.utility_output:
|
|
_process_utility_output(outputs.utility_output,
|
|
utility_results)
|
|
continue
|
|
|
|
if output_handler is not None:
|
|
assert _self_ref is not None
|
|
_self = _self_ref()
|
|
if not _self:
|
|
# Client has been garbage collected, abort.
|
|
return
|
|
await output_handler(_self, outputs)
|
|
|
|
if outputs.outputs or outputs.scheduler_stats:
|
|
outputs_queue.put_nowait(outputs)
|
|
except Exception as e:
|
|
outputs_queue.put_nowait(e)
|
|
|
|
resources.output_queue_task = asyncio.create_task(
|
|
process_outputs_socket(), name="EngineCoreOutputQueueTask")
|
|
|
|
async def get_output_async(self) -> EngineCoreOutputs:
|
|
self._ensure_output_queue_task()
|
|
# If an exception arises in process_outputs_socket task,
|
|
# it is forwarded to the outputs_queue so we can raise it
|
|
# from this (run_output_handler) task to shut down the server.
|
|
assert self.outputs_queue is not None
|
|
outputs = await self.outputs_queue.get()
|
|
if isinstance(outputs, Exception):
|
|
raise self._format_exception(outputs) from None
|
|
return outputs
|
|
|
|
def _send_input(self,
|
|
request_type: EngineCoreRequestType,
|
|
request: Any,
|
|
engine: Optional[EngineIdentity] = None) -> Awaitable[Any]:
|
|
if engine is None:
|
|
engine = self.core_engine
|
|
|
|
message = (request_type.value, *self.encoder.encode(request))
|
|
return self._send_input_message(message, engine, request)
|
|
|
|
def _send_input_message(self, message: tuple[bytestr,
|
|
...], engine: EngineIdentity,
|
|
objects: Any) -> Awaitable[Any]:
|
|
"""
|
|
objects is a reference to retain until zmq is finished with the
|
|
buffers, in case they were extracted from tensors in the request.
|
|
"""
|
|
self.ensure_alive()
|
|
self.free_pending_messages()
|
|
|
|
msg = (engine, ) + message
|
|
if not objects or len(msg) <= 3:
|
|
# No auxiliary buffers => no tensor backing buffers in request.
|
|
return self.input_socket.send_multipart(msg, copy=False)
|
|
|
|
future: asyncio.Future[zmq.MessageTracker]
|
|
future = self.input_socket.send_multipart(msg, copy=False, track=True)
|
|
|
|
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
|
|
with contextlib.suppress(BaseException):
|
|
self.add_pending_message(f.result(), objects)
|
|
|
|
future.add_done_callback(add_pending)
|
|
return future
|
|
|
|
async def call_utility_async(self, method: str, *args) -> Any:
|
|
return await self._call_utility_async(method,
|
|
*args,
|
|
engine=self.core_engine)
|
|
|
|
async def _call_utility_async(self, method: str, *args,
|
|
engine: EngineIdentity) -> Any:
|
|
call_id = uuid.uuid1().int >> 64
|
|
future = asyncio.get_running_loop().create_future()
|
|
self.utility_results[call_id] = future
|
|
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
|
(self.client_index, call_id, method, args)))
|
|
await self._send_input_message(message, engine, args)
|
|
self._ensure_output_queue_task()
|
|
return await future
|
|
|
|
async def get_supported_tasks_async(self) -> tuple[SupportedTask, ...]:
|
|
return await self.call_utility_async("get_supported_tasks")
|
|
|
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
|
request.client_index = self.client_index
|
|
await self._send_input(EngineCoreRequestType.ADD, request)
|
|
self._ensure_output_queue_task()
|
|
|
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
|
if request_ids and not self.resources.engine_dead:
|
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids)
|
|
|
|
async def profile_async(self, is_start: bool = True) -> None:
|
|
await self.call_utility_async("profile", is_start)
|
|
|
|
async def reset_mm_cache_async(self) -> None:
|
|
await self.call_utility_async("reset_mm_cache")
|
|
|
|
async def reset_prefix_cache_async(self) -> None:
|
|
await self.call_utility_async("reset_prefix_cache")
|
|
|
|
async def sleep_async(self, level: int = 1) -> None:
|
|
await self.call_utility_async("sleep", level)
|
|
|
|
async def wake_up_async(self, tags: Optional[list[str]] = None) -> None:
|
|
await self.call_utility_async("wake_up", tags)
|
|
|
|
async def is_sleeping_async(self) -> bool:
|
|
return await self.call_utility_async("is_sleeping")
|
|
|
|
async def execute_dummy_batch_async(self) -> None:
|
|
await self.call_utility_async("execute_dummy_batch")
|
|
|
|
async def add_lora_async(self, lora_request: LoRARequest) -> bool:
|
|
return await self.call_utility_async("add_lora", lora_request)
|
|
|
|
async def remove_lora_async(self, lora_id: int) -> bool:
|
|
return await self.call_utility_async("remove_lora", lora_id)
|
|
|
|
async def list_loras_async(self) -> set[int]:
|
|
return await self.call_utility_async("list_loras")
|
|
|
|
async def pin_lora_async(self, lora_id: int) -> bool:
|
|
return await self.call_utility_async("pin_lora", lora_id)
|
|
|
|
async def save_sharded_state_async(self,
|
|
path: str,
|
|
pattern: Optional[str] = None,
|
|
max_size: Optional[int] = None) -> None:
|
|
await self.call_utility_async("save_sharded_state", path, pattern,
|
|
max_size)
|
|
|
|
async def collective_rpc_async(
|
|
self,
|
|
method: Union[str, Callable[..., _R]],
|
|
timeout: Optional[float] = None,
|
|
args: tuple = (),
|
|
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
|
|
return await self.call_utility_async("collective_rpc", method, timeout,
|
|
args, kwargs)
|
|
|
|
|
|
class DPAsyncMPClient(AsyncMPClient):
|
|
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
|
EngineCore. Assumes external load-balancing by default."""
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
client_addresses: Optional[dict[str, str]] = None,
|
|
client_index: int = 0):
|
|
self.current_wave = 0
|
|
|
|
super().__init__(vllm_config, executor_class, log_stats,
|
|
client_addresses, client_index)
|
|
|
|
# List of [waiting, running] pair per engine.
|
|
# Used only by DPLBAsyncMPClient subclass.
|
|
self.lb_engines: list[list[int]] = []
|
|
|
|
self.first_req_sock_addr = get_open_zmq_inproc_path()
|
|
self.first_req_send_socket = self.resources.first_req_send_socket = (
|
|
make_zmq_socket(self.ctx,
|
|
self.first_req_sock_addr,
|
|
zmq.PAIR,
|
|
bind=True))
|
|
try:
|
|
# If we are running in an asyncio event loop, start the stats task.
|
|
# Otherwise, it will be started lazily.
|
|
asyncio.get_running_loop()
|
|
self._ensure_stats_update_task()
|
|
except RuntimeError:
|
|
pass
|
|
|
|
def _ensure_stats_update_task(self):
|
|
resources = self.resources
|
|
if resources.stats_update_task is not None:
|
|
return
|
|
|
|
assert self.stats_update_address is not None
|
|
assert len(self.engine_ranks_managed) > 0
|
|
# NOTE: running and waiting counts are all global from
|
|
# the Coordinator include all global EngineCores. This
|
|
# slice includes just the cores managed by this client.
|
|
count_slice = slice(self.engine_ranks_managed[0],
|
|
self.engine_ranks_managed[-1] + 1)
|
|
|
|
async def run_engine_stats_update_task():
|
|
with make_zmq_socket(self.ctx, self.stats_update_address,
|
|
zmq.XSUB) as socket, make_zmq_socket(
|
|
self.ctx,
|
|
self.first_req_sock_addr,
|
|
zmq.PAIR,
|
|
bind=False) as first_req_rcv_socket:
|
|
assert isinstance(socket, zmq.asyncio.Socket)
|
|
assert isinstance(first_req_rcv_socket, zmq.asyncio.Socket)
|
|
# Send subscription message.
|
|
await socket.send(b'\x01')
|
|
|
|
poller = zmq.asyncio.Poller()
|
|
poller.register(socket, zmq.POLLIN)
|
|
poller.register(first_req_rcv_socket, zmq.POLLIN)
|
|
|
|
while True:
|
|
events = await poller.poll()
|
|
if not self.engines_running and len(events) == 2 or (
|
|
events[0][0] == first_req_rcv_socket):
|
|
# Check if this is a regular request notification or
|
|
# scale up notification
|
|
buf = first_req_rcv_socket.recv(
|
|
flags=zmq.NOBLOCK).result()
|
|
|
|
decoded = msgspec.msgpack.decode(buf)
|
|
if isinstance(
|
|
decoded,
|
|
(list, tuple)) and len(decoded) == 2 and decoded[
|
|
0] == "SCALE_ELASTIC_EP":
|
|
# Extract new engine count from the decoded message
|
|
new_engine_count = decoded[1]
|
|
# Send scale up notification to coordinator
|
|
scale_msg = msgspec.msgpack.encode(
|
|
("SCALE_ELASTIC_EP", new_engine_count))
|
|
await socket.send(scale_msg)
|
|
continue
|
|
|
|
# we're sending a request while the engines are
|
|
# paused, so that it can wake the others up
|
|
# (to run dummy EP loop).
|
|
assert decoded[0] == "FIRST_REQ"
|
|
target_eng_index = decoded[1]
|
|
self.engines_running = True
|
|
msg = msgspec.msgpack.encode(
|
|
(target_eng_index, self.current_wave))
|
|
await socket.send(msg)
|
|
|
|
buf = None
|
|
while True:
|
|
# Drain all stats events (we only care about latest).
|
|
future: asyncio.Future[bytes] = socket.recv(
|
|
flags=zmq.NOBLOCK)
|
|
if isinstance(future.exception(), zmq.Again):
|
|
break
|
|
buf = future.result()
|
|
if buf is None:
|
|
continue
|
|
|
|
# Update local load-balancing state.
|
|
counts, wave, running = msgspec.msgpack.decode(buf)
|
|
self.current_wave = wave
|
|
self.engines_running = running
|
|
self.lb_engines = counts[count_slice]
|
|
|
|
resources.stats_update_task = asyncio.create_task(
|
|
run_engine_stats_update_task())
|
|
|
|
async def add_request_async(self, request: EngineCoreRequest) -> None:
|
|
self._ensure_stats_update_task()
|
|
|
|
request.current_wave = self.current_wave
|
|
request.client_index = self.client_index
|
|
|
|
chosen_engine = self.get_core_engine_for_request(request)
|
|
to_await = self._send_input(EngineCoreRequestType.ADD, request,
|
|
chosen_engine)
|
|
if not self.engines_running:
|
|
# Notify coordinator that we're sending a request
|
|
req_msg = msgspec.msgpack.encode(("FIRST_REQ", chosen_engine))
|
|
await self.first_req_send_socket.send(req_msg)
|
|
|
|
await to_await
|
|
|
|
self._ensure_output_queue_task()
|
|
|
|
def get_core_engine_for_request(self, request: EngineCoreRequest):
|
|
return self.core_engine
|
|
|
|
|
|
class DPLBAsyncMPClient(DPAsyncMPClient):
|
|
"""Asyncio-compatible client for multi-proc, multi-engine (data parallel)
|
|
EngineCore. Load-balances between multiple engine processes."""
|
|
|
|
def __init__(self,
|
|
vllm_config: VllmConfig,
|
|
executor_class: type[Executor],
|
|
log_stats: bool,
|
|
client_addresses: Optional[dict[str, str]] = None,
|
|
client_index: int = 0):
|
|
|
|
# To route aborts to the correct engine.
|
|
self.reqs_in_flight: dict[str, EngineIdentity] = {}
|
|
|
|
super().__init__(vllm_config, executor_class, log_stats,
|
|
client_addresses, client_index)
|
|
|
|
assert len(self.core_engines) > 1
|
|
|
|
def get_core_engine_for_request(
|
|
self, request: EngineCoreRequest) -> EngineIdentity:
|
|
# Engines are in rank order.
|
|
if (eng_index := request.data_parallel_rank) is None:
|
|
if not self.lb_engines:
|
|
return self.core_engine
|
|
# TODO use P2C alg for larger DP sizes
|
|
num_engines = len(self.lb_engines)
|
|
min_counts = [sys.maxsize, sys.maxsize]
|
|
eng_index = 0
|
|
for i in range(num_engines):
|
|
# Start from client_index to help with balancing when engines
|
|
# are empty.
|
|
idx = (self.client_index + i) % num_engines
|
|
counts = self.lb_engines[idx]
|
|
if counts < min_counts:
|
|
min_counts = counts
|
|
eng_index = idx
|
|
# Adjust local counts for better balancing between stats updates
|
|
# from the coordinator (which happen every 100ms).
|
|
if min_counts[0]:
|
|
min_counts[0] += 1
|
|
else:
|
|
min_counts[1] += 1
|
|
|
|
chosen_engine = self.core_engines[eng_index]
|
|
# Record which engine is chosen for this request, to handle aborts.
|
|
self.reqs_in_flight[request.request_id] = chosen_engine
|
|
return chosen_engine
|
|
|
|
async def call_utility_async(self, method: str, *args) -> Any:
|
|
# Only the result from the first engine is returned.
|
|
return (await asyncio.gather(*[
|
|
self._call_utility_async(method, *args, engine=engine)
|
|
for engine in self.core_engines
|
|
]))[0]
|
|
|
|
@staticmethod
|
|
async def process_engine_outputs(self: "DPLBAsyncMPClient",
|
|
outputs: EngineCoreOutputs):
|
|
if outputs.finished_requests and self.reqs_in_flight:
|
|
for req_id in outputs.finished_requests:
|
|
self.reqs_in_flight.pop(req_id, None)
|
|
|
|
async def abort_requests_async(self, request_ids: list[str]) -> None:
|
|
if not request_ids or self.resources.engine_dead:
|
|
return
|
|
|
|
if len(request_ids) == 1:
|
|
# Fast-path common case.
|
|
if engine := self.reqs_in_flight.get(request_ids[0]):
|
|
await self._abort_requests(request_ids, engine)
|
|
return
|
|
|
|
by_engine = defaultdict[EngineIdentity, list[str]](list)
|
|
for req_id in request_ids:
|
|
if engine := self.reqs_in_flight.get(req_id):
|
|
by_engine[engine].append(req_id)
|
|
for engine, req_ids in by_engine.items():
|
|
await self._abort_requests(req_ids, engine)
|
|
|
|
async def _abort_requests(self, request_ids: list[str],
|
|
engine: EngineIdentity) -> None:
|
|
await self._send_input(EngineCoreRequestType.ABORT, request_ids,
|
|
engine)
|
|
|
|
async def _send_reconfig_message(
|
|
self, reconfig_request: ReconfigureDistributedRequest,
|
|
engine: EngineIdentity) -> asyncio.Future:
|
|
"""Send reconfiguration message and return the result future without
|
|
waiting for completion."""
|
|
call_id = uuid.uuid1().int >> 64
|
|
future = asyncio.get_running_loop().create_future()
|
|
self.utility_results[call_id] = future
|
|
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
|
(self.client_index, call_id, "reinitialize_distributed",
|
|
(reconfig_request, ))))
|
|
await self._send_input_message(message, engine, reconfig_request)
|
|
self._ensure_output_queue_task()
|
|
return future
|
|
|
|
async def scale_elastic_ep(self, new_data_parallel_size: int) -> None:
|
|
"""Scale elastic EP data parallel size"""
|
|
cur_data_parallel_size = len(self.core_engines)
|
|
|
|
assert new_data_parallel_size != cur_data_parallel_size, (
|
|
f"new_data_parallel_size {new_data_parallel_size} must be "
|
|
f"different from cur_data_parallel_size {cur_data_parallel_size}")
|
|
|
|
assert self.vllm_config.parallel_config.data_parallel_backend == \
|
|
"ray", ("Only ray DP backend supports scaling elastic EP")
|
|
|
|
scale_up = new_data_parallel_size > cur_data_parallel_size
|
|
|
|
if scale_up:
|
|
await self._scale_up_elastic_ep(cur_data_parallel_size,
|
|
new_data_parallel_size)
|
|
else:
|
|
await self._scale_down_elastic_ep(cur_data_parallel_size,
|
|
new_data_parallel_size)
|
|
|
|
async def _scale_up_elastic_ep(self, cur_data_parallel_size: int,
|
|
new_data_parallel_size: int) -> None:
|
|
"""Scale up the data parallel size by creating new engine cores
|
|
and reconfiguring existing ones."""
|
|
cur_data_parallel_size = len(self.core_engines)
|
|
|
|
# Phase 1: Send reconfigure messages to all existing engines and wait
|
|
# for them to be sent
|
|
reconfig_futures = []
|
|
self.vllm_config.parallel_config.data_parallel_master_port = \
|
|
get_open_port()
|
|
for engine in self.core_engines:
|
|
reconfig_request = ReconfigureDistributedRequest(
|
|
new_data_parallel_size=new_data_parallel_size,
|
|
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
|
new_data_parallel_rank_local=\
|
|
ReconfigureRankType.KEEP_CURRENT_RANK,
|
|
new_data_parallel_master_ip=self.vllm_config.parallel_config.
|
|
data_parallel_master_ip,
|
|
new_data_parallel_master_port=self.vllm_config.parallel_config.
|
|
data_parallel_master_port)
|
|
future = await self._send_reconfig_message(reconfig_request,
|
|
engine)
|
|
reconfig_futures.append(future)
|
|
|
|
logger.info("All reconfigure messages sent, starting engine creation")
|
|
|
|
# Phase 2: Create new engines now that reconfig messages have been sent
|
|
# self.resources.engine_manager is guaranteed to be
|
|
# CoreEngineActorManager for RayDPClient
|
|
assert isinstance(self.resources.engine_manager,
|
|
CoreEngineActorManager)
|
|
self.resources.engine_manager.scale_up_elastic_ep(
|
|
self.vllm_config, new_data_parallel_size)
|
|
|
|
# Create new CoreEngine objects for the new engines
|
|
new_engine_identities = set()
|
|
for i in range(cur_data_parallel_size, new_data_parallel_size):
|
|
new_engine = i.to_bytes(2, "little")
|
|
self.core_engines.append(new_engine)
|
|
new_engine_identities.add(new_engine)
|
|
|
|
# Wait for ready messages from new engines on the input socket
|
|
sync_input_socket = zmq.Socket.shadow(self.input_socket)
|
|
while new_engine_identities:
|
|
if not sync_input_socket.poll(timeout=600_000):
|
|
raise TimeoutError(
|
|
"Timed out waiting for new engines to send initial "
|
|
"message on input socket.")
|
|
identity, _ = sync_input_socket.recv_multipart()
|
|
new_engine_identities.discard(identity)
|
|
|
|
# Phase 3: Wait for all existing engines to complete reconfiguration
|
|
logger.info("Waiting for existing engines to complete reconfiguration")
|
|
await asyncio.gather(*reconfig_futures)
|
|
|
|
# Notify coordinator about scale up through existing
|
|
# stats_update_task connection
|
|
self._ensure_stats_update_task()
|
|
scale_up_marker = msgspec.msgpack.encode(
|
|
("SCALE_ELASTIC_EP", new_data_parallel_size))
|
|
await self.first_req_send_socket.send(scale_up_marker)
|
|
|
|
# Update the parallel config
|
|
self.vllm_config.parallel_config.data_parallel_size = \
|
|
new_data_parallel_size
|
|
logger.info(
|
|
"[Elastic EP] Scale up completed, new data parallel size: %s",
|
|
new_data_parallel_size)
|
|
|
|
async def _scale_down_elastic_ep(self, cur_data_parallel_size: int,
|
|
new_data_parallel_size: int) -> None:
|
|
"""Scale down the data parallel size by shutting down and
|
|
reconfiguring existing engine cores."""
|
|
cur_data_parallel_size = len(self.core_engines)
|
|
|
|
self.vllm_config.parallel_config.data_parallel_master_port = \
|
|
get_open_port()
|
|
|
|
reconfig_futures = []
|
|
for cur_dp_rank, engine in enumerate(self.core_engines):
|
|
reconfig_request = ReconfigureDistributedRequest(
|
|
new_data_parallel_size=new_data_parallel_size,
|
|
new_data_parallel_rank=ReconfigureRankType.KEEP_CURRENT_RANK,
|
|
new_data_parallel_rank_local=\
|
|
ReconfigureRankType.KEEP_CURRENT_RANK,
|
|
new_data_parallel_master_ip=self.vllm_config.parallel_config.
|
|
data_parallel_master_ip,
|
|
new_data_parallel_master_port=self.vllm_config.parallel_config.
|
|
data_parallel_master_port)
|
|
if cur_dp_rank >= new_data_parallel_size:
|
|
reconfig_request.new_data_parallel_rank = \
|
|
ReconfigureRankType.SHUTDOWN_CURRENT_RANK
|
|
future = await self._send_reconfig_message(reconfig_request,
|
|
engine)
|
|
reconfig_futures.append(future)
|
|
|
|
for _ in range(new_data_parallel_size, cur_data_parallel_size):
|
|
self.core_engines.pop()
|
|
|
|
await asyncio.gather(*reconfig_futures)
|
|
|
|
assert isinstance(self.resources.engine_manager,
|
|
CoreEngineActorManager)
|
|
self.resources.engine_manager.scale_down_elastic_ep(
|
|
cur_data_parallel_size, new_data_parallel_size)
|
|
|
|
self._ensure_stats_update_task()
|
|
scale_down_marker = msgspec.msgpack.encode(
|
|
("SCALE_ELASTIC_EP", new_data_parallel_size))
|
|
await self.first_req_send_socket.send(scale_down_marker)
|
|
|
|
self.vllm_config.parallel_config.data_parallel_size = \
|
|
new_data_parallel_size
|
|
logger.info(
|
|
"[Elastic EP] Scale down completed, new data parallel size: %s",
|
|
new_data_parallel_size)
|