Files
vllm-dev/vllm/v1/engine/core.py
2025-07-25 05:36:45 -07:00

1162 lines
49 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import os
import queue
import signal
import sys
import threading
import time
from collections import deque
from collections.abc import Generator
from concurrent.futures import Future
from contextlib import ExitStack, contextmanager
from inspect import isclass, signature
from logging import DEBUG
from typing import Any, Callable, Optional, TypeVar, Union
import msgspec
import zmq
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import stateless_destroy_torch_distributed_process_group
from vllm.executor.multiproc_worker_utils import _add_prefix
from vllm.logger import init_logger
from vllm.logging_utils.dump_input import dump_engine_exception
from vllm.lora.request import LoRARequest
from vllm.tasks import POOLING_TASKS, SupportedTask
from vllm.transformers_utils.config import (
maybe_register_config_serialize_by_value)
from vllm.utils import (bind_process_name, make_zmq_socket,
resolve_obj_by_qualname)
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
unify_kv_cache_configs)
from vllm.v1.core.sched.interface import SchedulerInterface
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType,
ReconfigureDistributedRequest, ReconfigureRankType,
UtilityOutput)
from vllm.v1.engine.mm_input_cache import MirroredProcessingCache
from vllm.v1.engine.utils import EngineHandshakeMetadata, EngineZmqAddresses
from vllm.v1.executor.abstract import Executor
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.metrics.stats import SchedulerStats
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.structured_output import StructuredOutputManager
from vllm.version import __version__ as VLLM_VERSION
logger = init_logger(__name__)
POLLING_TIMEOUT_S = 2.5
HANDSHAKE_TIMEOUT_MINS = 5
_R = TypeVar('_R') # Return type for collective_rpc
class EngineCore:
"""Inner loop of vLLM's Engine."""
def __init__(self,
vllm_config: VllmConfig,
executor_class: type[Executor],
log_stats: bool,
executor_fail_callback: Optional[Callable] = None):
# plugins need to be loaded at the engine/scheduler level too
from vllm.plugins import load_general_plugins
load_general_plugins()
self.vllm_config = vllm_config
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
VLLM_VERSION, vllm_config)
self.log_stats = log_stats
# Setup Model.
self.model_executor = executor_class(vllm_config)
if executor_fail_callback is not None:
self.model_executor.register_failure_callback(
executor_fail_callback)
self.available_gpu_memory_for_kv_cache = -1
# Setup KV Caches and update CacheConfig after profiling.
num_gpu_blocks, num_cpu_blocks, kv_cache_config = \
self._initialize_kv_caches(vllm_config)
vllm_config.cache_config.num_gpu_blocks = num_gpu_blocks
vllm_config.cache_config.num_cpu_blocks = num_cpu_blocks
self.collective_rpc("initialize_cache",
args=(num_gpu_blocks, num_cpu_blocks))
self.structured_output_manager = StructuredOutputManager(vllm_config)
# Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if Scheduler is not V1Scheduler:
logger.warning(
"Using configured V1 scheduler class %s. "
"This scheduler interface is not public and "
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
self.scheduler: SchedulerInterface = Scheduler(
vllm_config=vllm_config,
kv_cache_config=kv_cache_config,
structured_output_manager=self.structured_output_manager,
include_finished_set=vllm_config.parallel_config.data_parallel_size
> 1,
log_stats=self.log_stats,
)
# Setup MM Input Mapper.
self.mm_input_cache_server = MirroredProcessingCache(
vllm_config.model_config)
# Setup batch queue for pipeline parallelism.
# Batch queue for scheduled batches. This enables us to asynchronously
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d",
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
def _initialize_kv_caches(
self, vllm_config: VllmConfig) -> tuple[int, int, KVCacheConfig]:
start = time.time()
# Get all kv cache needed by the model
kv_cache_specs = self.model_executor.get_kv_cache_specs()
has_kv_cache = any(kv_cache_spec for kv_cache_spec in kv_cache_specs)
if has_kv_cache:
if os.environ.get("VLLM_ELASTIC_EP_SCALE_UP_LAUNCH") == "1":
dp_group = getattr(self, "dp_group", None)
assert dp_group is not None
self.available_gpu_memory_for_kv_cache = \
ParallelConfig.sync_kv_cache_memory_size(dp_group, -1)
available_gpu_memory = [
self.available_gpu_memory_for_kv_cache
] * len(kv_cache_specs)
else:
# Profiles the peak memory usage of the model to determine how
# much memory can be allocated for kv cache.
available_gpu_memory = (
self.model_executor.determine_available_memory())
self.available_gpu_memory_for_kv_cache = \
available_gpu_memory[0]
else:
# Attention free models don't need memory for kv cache
available_gpu_memory = [0] * len(kv_cache_specs)
assert len(kv_cache_specs) == len(available_gpu_memory)
# Get the kv cache tensor size
kv_cache_configs = [
get_kv_cache_config(vllm_config, kv_cache_spec_one_worker,
available_gpu_memory_one_worker)
for kv_cache_spec_one_worker, available_gpu_memory_one_worker in
zip(kv_cache_specs, available_gpu_memory)
]
# Since we use a shared centralized controller, we need the
# `kv_cache_config` to be consistent across all workers to make sure
# all the memory operators can be applied to all workers.
unify_kv_cache_configs(kv_cache_configs)
# All workers have the same kv_cache_config except layer names, so use
# an arbitrary one to initialize the scheduler.
assert all([
cfg.num_blocks == kv_cache_configs[0].num_blocks
for cfg in kv_cache_configs
])
num_gpu_blocks = kv_cache_configs[0].num_blocks
num_cpu_blocks = 0
scheduler_kv_cache_config = kv_cache_configs[0]
# Initialize kv cache and warmup the execution
self.model_executor.initialize_from_config(kv_cache_configs)
elapsed = time.time() - start
logger.info(("init engine (profile, create kv cache, "
"warmup model) took %.2f seconds"), elapsed)
return num_gpu_blocks, num_cpu_blocks, scheduler_kv_cache_config
def get_supported_tasks(self) -> tuple[SupportedTask, ...]:
return self.model_executor.supported_tasks
def add_request(self, request: EngineCoreRequest):
"""Add request to the scheduler."""
if pooling_params := request.pooling_params:
supported_pooling_tasks = [
task for task in self.get_supported_tasks()
if task in POOLING_TASKS
]
if pooling_params.task not in supported_pooling_tasks:
raise ValueError(f"Unsupported task: {pooling_params.task!r} "
f"Supported tasks: {supported_pooling_tasks}")
if request.mm_hashes is not None:
# Here, if hash exists for a multimodal input, then it will be
# fetched from the cache, else it will be added to the cache.
# Note that the cache here is mirrored with the client cache, so
# anything that has a hash must have a HIT cache entry here
# as well.
assert request.mm_inputs is not None
request.mm_inputs = self.mm_input_cache_server.get_and_update_p1(
request.mm_inputs, request.mm_hashes)
req = Request.from_engine_core_request(request)
if req.use_structured_output:
# Start grammar compilation asynchronously
self.structured_output_manager.grammar_init(req)
if req.kv_transfer_params is not None and (
not self.scheduler.get_kv_connector()):
logger.warning("Got kv_transfer_params, but no KVConnector found. "
"Disabling KVTransfer for this request.")
self.scheduler.add_request(req)
def abort_requests(self, request_ids: list[str]):
"""Abort requests from the scheduler."""
# TODO: The scheduler doesn't really need to know the
# specific finish reason, TBD whether we propagate that
# (i.e. client-aborted vs stop criteria met).
self.scheduler.finish_requests(request_ids,
RequestStatus.FINISHED_ABORTED)
def execute_model_with_error_logging(
self,
model_fn: Callable[[SchedulerOutput], ModelRunnerOutput],
scheduler_output: SchedulerOutput,
) -> ModelRunnerOutput:
"""Execute the model and log detailed info on failure."""
try:
return model_fn(scheduler_output)
except Exception as err:
# We do not want to catch BaseException here since we're only
# interested in dumping info when the exception is due to an
# error from execute_model itself.
# NOTE: This method is exception-free
dump_engine_exception(self.vllm_config, scheduler_output,
self.scheduler.make_stats())
raise err
def step(self) -> tuple[dict[int, EngineCoreOutputs], bool]:
"""Schedule, execute, and make output.
Returns tuple of outputs and a flag indicating whether the model
was executed.
"""
# Check for any requests remaining in the scheduler - unfinished,
# or finished and not yet removed from the batch.
if not self.scheduler.has_requests():
return {}, False
scheduler_output = self.scheduler.schedule()
model_output = self.execute_model_with_error_logging(
self.model_executor.execute_model, # type: ignore
scheduler_output)
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output) # type: ignore
return (engine_core_outputs,
scheduler_output.total_num_scheduled_tokens > 0)
def step_with_batch_queue(
self) -> tuple[Optional[dict[int, EngineCoreOutputs]], bool]:
"""Schedule and execute batches with the batch queue.
Note that if nothing to output in this step, None is returned.
The execution flow is as follows:
1. Try to schedule a new batch if the batch queue is not full.
If a new batch is scheduled, directly return an empty engine core
output. In other words, fulfilling the batch queue has a higher priority
than getting model outputs.
2. If there is no new scheduled batch, meaning that the batch queue
is full or no other requests can be scheduled, we block until the first
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
assert self.batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
if not self.batch_queue.full():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
# Blocking until the first result is available.
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
self.batch_queue.task_done()
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
return engine_core_outputs, scheduled_batch
def shutdown(self):
self.structured_output_manager.clear_backend()
if self.model_executor:
self.model_executor.shutdown()
if self.scheduler:
self.scheduler.shutdown()
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)
def reset_mm_cache(self):
# NOTE: Since this is mainly for debugging, we don't attempt to
# re-sync the internal caches (P0 processor, P0 mirror, P1 mirror)
if self.scheduler.has_unfinished_requests():
logger.warning("Resetting the multi-modal cache when requests are "
"in progress may lead to desynced internal caches.")
self.mm_input_cache_server.reset()
def reset_prefix_cache(self):
self.scheduler.reset_prefix_cache()
def sleep(self, level: int = 1):
self.model_executor.sleep(level)
def wake_up(self, tags: Optional[list[str]] = None):
self.model_executor.wake_up(tags)
def is_sleeping(self) -> bool:
return self.model_executor.is_sleeping
def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch")
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
def remove_lora(self, lora_id: int) -> bool:
return self.model_executor.remove_lora(lora_id)
def list_loras(self) -> set[int]:
return self.model_executor.list_loras()
def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)
def save_sharded_state(
self,
path: str,
pattern: Optional[str] = None,
max_size: Optional[int] = None,
) -> None:
self.model_executor.save_sharded_state(path=path,
pattern=pattern,
max_size=max_size)
def collective_rpc(self,
method: Union[str, Callable[..., _R]],
timeout: Optional[float] = None,
args: tuple = (),
kwargs: Optional[dict[str, Any]] = None) -> list[_R]:
return self.model_executor.collective_rpc(method, timeout, args,
kwargs)
def save_tensorized_model(
self,
tensorizer_config,
) -> None:
self.model_executor.save_tensorized_model(
tensorizer_config=tensorizer_config, )
class EngineCoreProc(EngineCore):
"""ZMQ-wrapper for running EngineCore in background process."""
ENGINE_CORE_DEAD = b'ENGINE_CORE_DEAD'
def __init__(
self,
vllm_config: VllmConfig,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
engine_index: int = 0,
):
bind_process_name(self.__class__.__name__, f"{engine_index}")
self.input_queue = queue.Queue[tuple[EngineCoreRequestType, Any]]()
self.output_queue = queue.Queue[Union[tuple[int, EngineCoreOutputs],
bytes]]()
executor_fail_callback = lambda: self.input_queue.put_nowait(
(EngineCoreRequestType.EXECUTOR_FAILED, b''))
self.engine_index = engine_index
identity = self.engine_index.to_bytes(length=2, byteorder="little")
self.engines_running = False
with self._perform_handshakes(handshake_address, identity,
local_client, vllm_config,
client_handshake_address) as addresses:
self.client_count = len(addresses.outputs)
# Set up data parallel environment.
self.has_coordinator = addresses.coordinator_output is not None
self.frontend_stats_publish_address = (
addresses.frontend_stats_publish_address)
# Only publish request queue stats to coordinator for "internal"
# LB mode.
self.publish_dp_lb_stats = (
self.has_coordinator
and not vllm_config.parallel_config.data_parallel_external_lb)
self._init_data_parallel(vllm_config)
super().__init__(vllm_config, executor_class, log_stats,
executor_fail_callback)
self.step_fn = (self.step if self.batch_queue is None else
self.step_with_batch_queue)
# Background Threads and Queues for IO. These enable us to
# overlap ZMQ socket IO with GPU since they release the GIL,
# and to overlap some serialization/deserialization with the
# model forward pass.
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
threading.Thread(target=self.process_input_sockets,
args=(addresses.inputs, addresses.coordinator_input,
identity),
daemon=True).start()
self.output_thread = threading.Thread(
target=self.process_output_sockets,
args=(addresses.outputs, addresses.coordinator_output,
self.engine_index),
daemon=True)
self.output_thread.start()
@contextmanager
def _perform_handshakes(
self,
handshake_address: str,
identity: bytes,
local_client: bool,
vllm_config: VllmConfig,
client_handshake_address: Optional[str],
) -> Generator[EngineZmqAddresses, None, None]:
"""
Perform startup handshakes.
For DP=1 or offline mode, this is with the colocated front-end process.
For DP>1 with internal loadbalancing this is with the shared front-end
process which may reside on a different node.
For DP>1 with external or hybrid loadbalancing, two handshakes are
performed:
- With the rank 0 front-end process which retrieves the
DP Coordinator ZMQ addresses and DP process group address.
- With the colocated front-end process which retrieves the
client input/output socket addresses.
with the exception of the rank 0 and colocated engines themselves which
don't require the second handshake.
Here, "front-end" process can mean the process containing the engine
core client (which is the API server process in the case the API
server is not scaled out), OR the launcher process running the
run_multi_api_server() function in serve.py.
"""
input_ctx = zmq.Context()
is_local = local_client and client_handshake_address is None
headless = not local_client
handshake = self._perform_handshake(input_ctx, handshake_address,
identity, is_local, headless,
vllm_config,
vllm_config.parallel_config)
if client_handshake_address is None:
with handshake as addresses:
yield addresses
else:
assert local_client
local_handshake = self._perform_handshake(
input_ctx, client_handshake_address, identity, True, False,
vllm_config)
with handshake as addresses, local_handshake as client_addresses:
addresses.inputs = client_addresses.inputs
addresses.outputs = client_addresses.outputs
yield addresses
# Update config which may have changed from the handshake
vllm_config.__post_init__()
@contextmanager
def _perform_handshake(
self,
ctx: zmq.Context,
handshake_address: str,
identity: bytes,
local_client: bool,
headless: bool,
vllm_config: VllmConfig,
parallel_config_to_update: Optional[ParallelConfig] = None,
) -> Generator[EngineZmqAddresses, None, None]:
with make_zmq_socket(ctx,
handshake_address,
zmq.DEALER,
identity=identity,
linger=5000,
bind=False) as handshake_socket:
# Register engine with front-end.
addresses = self.startup_handshake(handshake_socket, local_client,
headless,
parallel_config_to_update)
yield addresses
# Send ready message.
num_gpu_blocks = vllm_config.cache_config.num_gpu_blocks
# We pass back the coordinator stats update address here for the
# external LB case for our colocated front-end to use (coordinator
# only runs with rank 0).
dp_stats_address = self.frontend_stats_publish_address
handshake_socket.send(
msgspec.msgpack.encode({
"status": "READY",
"local": local_client,
"headless": headless,
"num_gpu_blocks": num_gpu_blocks,
"dp_stats_address": dp_stats_address,
}))
@staticmethod
def startup_handshake(
handshake_socket: zmq.Socket,
local_client: bool,
headless: bool,
parallel_config: Optional[ParallelConfig] = None,
) -> EngineZmqAddresses:
# Send registration message.
handshake_socket.send(
msgspec.msgpack.encode({
"status": "HELLO",
"local": local_client,
"headless": headless,
}))
# Receive initialization message.
logger.info("Waiting for init message from front-end.")
if not handshake_socket.poll(timeout=HANDSHAKE_TIMEOUT_MINS * 60_000):
raise RuntimeError("Did not receive response from front-end "
f"process within {HANDSHAKE_TIMEOUT_MINS} "
f"minutes")
init_bytes = handshake_socket.recv()
init_message: EngineHandshakeMetadata = msgspec.msgpack.decode(
init_bytes, type=EngineHandshakeMetadata)
logger.debug("Received init message: %s", init_message)
if parallel_config is not None:
for key, value in init_message.parallel_config.items():
setattr(parallel_config, key, value)
return init_message.addresses
@staticmethod
def run_engine_core(*args,
dp_rank: int = 0,
local_dp_rank: int = 0,
**kwargs):
"""Launch EngineCore busy loop in background process."""
# Signal handler used for graceful termination.
# SystemExit exception is only raised once to allow this and worker
# processes to terminate without error
shutdown_requested = False
# Ensure we can serialize transformer config after spawning
maybe_register_config_serialize_by_value()
def signal_handler(signum, frame):
nonlocal shutdown_requested
if not shutdown_requested:
shutdown_requested = True
raise SystemExit()
# Either SIGTERM or SIGINT will terminate the engine_core
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
engine_core: Optional[EngineCoreProc] = None
try:
parallel_config: ParallelConfig = kwargs[
"vllm_config"].parallel_config
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
# Set data parallel rank for this engine process.
parallel_config.data_parallel_rank = dp_rank
parallel_config.data_parallel_rank_local = local_dp_rank
engine_core = DPEngineCoreProc(*args, **kwargs)
else:
engine_core = EngineCoreProc(*args, **kwargs)
engine_core.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception as e:
if engine_core is None:
logger.exception("EngineCore failed to start.")
else:
logger.exception("EngineCore encountered a fatal error.")
engine_core._send_engine_dead()
raise e
finally:
if engine_core is not None:
engine_core.shutdown()
def _init_data_parallel(self, vllm_config: VllmConfig):
pass
def run_busy_loop(self):
"""Core busy loop of the EngineCore."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core and return the outputs.
self._process_engine_step()
def _process_input_queue(self):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not self.scheduler.has_requests():
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True
req = self.input_queue.get()
self._handle_client_request(*req)
if waited:
logger.debug("EngineCore loop active.")
# Handle any more client requests.
while not self.input_queue.empty():
req = self.input_queue.get_nowait()
self._handle_client_request(*req)
def _process_engine_step(self) -> bool:
"""Called only when there are unfinished local requests."""
# Step the engine core.
outputs, model_executed = self.step_fn()
# Put EngineCoreOutputs into the output queue.
for output in (outputs.items() if outputs else ()):
self.output_queue.put_nowait(output)
return model_executed
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
"""Dispatch request from client."""
if request_type == EngineCoreRequestType.ADD:
self.add_request(request)
elif request_type == EngineCoreRequestType.ABORT:
self.abort_requests(request)
elif request_type == EngineCoreRequestType.UTILITY:
client_idx, call_id, method_name, args = request
output = UtilityOutput(call_id)
try:
method = getattr(self, method_name)
output.result = method(
*self._convert_msgspec_args(method, args))
except BaseException as e:
logger.exception("Invocation of %s method failed", method_name)
output.failure_message = (f"Call to {method_name} method"
f" failed: {str(e)}")
self.output_queue.put_nowait(
(client_idx, EngineCoreOutputs(utility_output=output)))
elif request_type == EngineCoreRequestType.EXECUTOR_FAILED:
raise RuntimeError("Executor failed.")
else:
logger.error("Unrecognized input request type encountered: %s",
request_type)
@staticmethod
def _convert_msgspec_args(method, args):
"""If a provided arg type doesn't match corresponding target method
arg type, try converting to msgspec object."""
if not args:
return args
arg_types = signature(method).parameters.values()
assert len(args) <= len(arg_types)
return tuple(
msgspec.convert(v, type=p.annotation) if isclass(p.annotation)
and issubclass(p.annotation, msgspec.Struct)
and not isinstance(v, p.annotation) else v
for v, p in zip(args, arg_types))
def _send_engine_dead(self):
"""Send EngineDead status to the EngineCoreClient."""
# Put ENGINE_CORE_DEAD in the queue.
self.output_queue.put_nowait(EngineCoreProc.ENGINE_CORE_DEAD)
# Wait until msg sent by the daemon before shutdown.
self.output_thread.join(timeout=5.0)
if self.output_thread.is_alive():
logger.fatal("vLLM shutdown signal from EngineCore failed "
"to send. Please report this issue.")
def process_input_sockets(self, input_addresses: list[str],
coord_input_address: Optional[str],
identity: bytes):
"""Input socket IO thread."""
# Msgpack serialization decoding.
add_request_decoder = MsgpackDecoder(EngineCoreRequest)
generic_decoder = MsgpackDecoder()
with ExitStack() as stack, zmq.Context() as ctx:
input_sockets = [
stack.enter_context(
make_zmq_socket(ctx,
input_address,
zmq.DEALER,
identity=identity,
bind=False))
for input_address in input_addresses
]
if coord_input_address is None:
coord_socket = None
else:
coord_socket = stack.enter_context(
make_zmq_socket(ctx,
coord_input_address,
zmq.XSUB,
identity=identity,
bind=False))
# Send subscription message to coordinator.
coord_socket.send(b'\x01')
# Register sockets with poller.
poller = zmq.Poller()
for input_socket in input_sockets:
# Send initial message to each input socket - this is required
# before the front-end ROUTER socket can send input messages
# back to us.
input_socket.send(b'')
poller.register(input_socket, zmq.POLLIN)
if coord_socket is not None:
poller.register(coord_socket, zmq.POLLIN)
while True:
for input_socket, _ in poller.poll():
# (RequestType, RequestData)
type_frame, *data_frames = input_socket.recv_multipart(
copy=False)
request_type = EngineCoreRequestType(
bytes(type_frame.buffer))
# Deserialize the request data.
decoder = add_request_decoder if (
request_type
== EngineCoreRequestType.ADD) else generic_decoder
request = decoder.decode(data_frames)
# Push to input queue for core busy loop.
self.input_queue.put_nowait((request_type, request))
def process_output_sockets(self, output_paths: list[str],
coord_output_path: Optional[str],
engine_index: int):
"""Output socket IO thread."""
# Msgpack serialization encoding.
encoder = MsgpackEncoder()
# Send buffers to reuse.
reuse_buffers: list[bytearray] = []
# Keep references to outputs and buffers until zmq is finished
# with them (outputs may contain tensors/np arrays whose
# backing buffers were extracted for zero-copy send).
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
with ExitStack() as stack, zmq.Context() as ctx:
sockets = [
stack.enter_context(
make_zmq_socket(ctx, output_path, zmq.PUSH, linger=4000))
for output_path in output_paths
]
coord_socket = stack.enter_context(
make_zmq_socket(
ctx, coord_output_path, zmq.PUSH, bind=False,
linger=4000)) if coord_output_path is not None else None
max_reuse_bufs = len(sockets) + 1
while True:
output = self.output_queue.get()
if output == EngineCoreProc.ENGINE_CORE_DEAD:
for socket in sockets:
socket.send(output)
break
assert not isinstance(output, bytes)
client_index, outputs = output
outputs.engine_index = engine_index
if client_index == -1:
# Don't reuse buffer for coordinator message
# which will be very small.
assert coord_socket is not None
coord_socket.send_multipart(encoder.encode(outputs))
continue
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
tracker = sockets[client_index].send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < max_reuse_bufs:
# Limit the number of buffers to reuse.
reuse_buffers.append(buffer)
class DPEngineCoreProc(EngineCoreProc):
"""ZMQ-wrapper for running EngineCore in background process
in a data parallel context."""
def __init__(
self,
vllm_config: VllmConfig,
local_client: bool,
handshake_address: str,
executor_class: type[Executor],
log_stats: bool,
client_handshake_address: Optional[str] = None,
):
self._decorate_logs()
# Counts forward-passes of the model so that we can synchronize
# finished with DP peers every N steps.
self.counter = 0
self.current_wave = 0
self.last_counts = (0, 0)
# Initialize the engine.
dp_rank = vllm_config.parallel_config.data_parallel_rank
super().__init__(vllm_config, local_client, handshake_address,
executor_class, log_stats, client_handshake_address,
dp_rank)
def _decorate_logs(self):
# Add process-specific prefix to stdout and stderr before
# we initialize the engine.
from multiprocessing import current_process
process_name = current_process().name
pid = os.getpid()
_add_prefix(sys.stdout, process_name, pid)
_add_prefix(sys.stderr, process_name, pid)
def _init_data_parallel(self, vllm_config: VllmConfig):
# Configure GPUs and stateless process group for data parallel.
dp_rank = vllm_config.parallel_config.data_parallel_rank
dp_size = vllm_config.parallel_config.data_parallel_size
local_dp_rank = vllm_config.parallel_config.data_parallel_rank_local
assert dp_size > 1
assert 0 <= local_dp_rank <= dp_rank < dp_size
if vllm_config.kv_transfer_config is not None:
# modify the engine_id and append the local_dp_rank to it to ensure
# that the kv_transfer_config is unique for each DP rank.
vllm_config.kv_transfer_config.engine_id = (
f"{vllm_config.kv_transfer_config.engine_id}_dp{local_dp_rank}"
)
logger.debug("Setting kv_transfer_config.engine_id to %s",
vllm_config.kv_transfer_config.engine_id)
self.dp_rank = dp_rank
self.dp_group = vllm_config.parallel_config.stateless_init_dp_group()
def shutdown(self):
super().shutdown()
if dp_group := getattr(self, "dp_group", None):
stateless_destroy_torch_distributed_process_group(dp_group)
def add_request(self, request: EngineCoreRequest):
if self.has_coordinator and request.current_wave != self.current_wave:
if request.current_wave > self.current_wave:
self.current_wave = request.current_wave
elif not self.engines_running:
# Request received for an already-completed wave, notify
# front-end that we need to start the next one.
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(start_wave=self.current_wave)))
super().add_request(request)
def _handle_client_request(self, request_type: EngineCoreRequestType,
request: Any) -> None:
if request_type == EngineCoreRequestType.START_DP_WAVE:
new_wave, exclude_eng_index = request
if exclude_eng_index != self.engine_index and (
new_wave >= self.current_wave):
self.current_wave = new_wave
if not self.engines_running:
logger.debug("EngineCore starting idle loop for wave %d.",
new_wave)
self.engines_running = True
else:
super()._handle_client_request(request_type, request)
def _maybe_publish_request_counts(self):
if not self.publish_dp_lb_stats:
return
# Publish our request counts (if they've changed).
counts = self.scheduler.get_request_counts()
if counts != self.last_counts:
self.last_counts = counts
stats = SchedulerStats(*counts)
self.output_queue.put_nowait(
(-1, EngineCoreOutputs(scheduler_stats=stats)))
def run_busy_loop(self):
"""Core busy loop of the EngineCore for data parallel case."""
# Loop until process is sent a SIGINT or SIGTERM
while True:
# 1) Poll the input queue until there is work to do.
self._process_input_queue()
# 2) Step the engine core.
executed = self._process_engine_step()
self._maybe_publish_request_counts()
local_unfinished_reqs = self.scheduler.has_unfinished_requests()
if not executed:
if not local_unfinished_reqs and not self.engines_running:
# All engines are idle.
continue
# We are in a running state and so must execute a dummy pass
# if the model didn't execute any ready requests.
self.execute_dummy_batch()
# 3) All-reduce operation to determine global unfinished reqs.
self.engines_running = self._has_global_unfinished_reqs(
local_unfinished_reqs)
if not self.engines_running:
if self.dp_rank == 0 or not self.has_coordinator:
# Notify client that we are pausing the loop.
logger.debug("Wave %d finished, pausing engine loop.",
self.current_wave)
# In the coordinator case, dp rank 0 sends updates to the
# coordinator. Otherwise (offline spmd case), each rank
# sends the update to its colocated front-end process.
client_index = -1 if self.has_coordinator else 0
self.output_queue.put_nowait(
(client_index,
EngineCoreOutputs(wave_complete=self.current_wave)))
self.current_wave += 1
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
# Optimization - only perform finish-sync all-reduce every 32 steps.
self.counter += 1
if self.counter != 32:
return True
self.counter = 0
return ParallelConfig.has_unfinished_dp(self.dp_group,
local_unfinished)
def reinitialize_distributed(
self, reconfig_request: ReconfigureDistributedRequest) -> None:
stateless_destroy_torch_distributed_process_group(self.dp_group)
self.shutdown()
parallel_config = self.vllm_config.parallel_config
old_dp_size = parallel_config.data_parallel_size
parallel_config.data_parallel_size = \
reconfig_request.new_data_parallel_size
if reconfig_request.new_data_parallel_rank != -1:
parallel_config.data_parallel_rank = \
reconfig_request.new_data_parallel_rank
# local rank specifies device visibility, it should not be changed
assert reconfig_request.new_data_parallel_rank_local == \
ReconfigureRankType.KEEP_CURRENT_RANK
parallel_config.data_parallel_master_ip = \
reconfig_request.new_data_parallel_master_ip
parallel_config.data_parallel_master_port = \
reconfig_request.new_data_parallel_master_port
if reconfig_request.new_data_parallel_rank != -2:
self.dp_rank = parallel_config.data_parallel_rank
self.dp_group = parallel_config.stateless_init_dp_group()
reconfig_request.new_data_parallel_master_port = \
parallel_config.data_parallel_master_port
self.model_executor.reinitialize_distributed(reconfig_request)
if reconfig_request.new_data_parallel_size > old_dp_size:
assert self.available_gpu_memory_for_kv_cache > 0
# pass available_gpu_memory_for_kv_cache from existing
# engine-cores to new engine-cores so they can directly
# use it in _initialize_kv_caches() rather than profiling.
ParallelConfig.sync_kv_cache_memory_size(
self.dp_group, self.available_gpu_memory_for_kv_cache)
# NOTE(yongji): newly joined workers require dummy_run even
# CUDA graph is not used
self.model_executor.collective_rpc("compile_or_warm_up_model")
if reconfig_request.new_data_parallel_rank == \
ReconfigureRankType.SHUTDOWN_CURRENT_RANK:
self.shutdown()
logger.info("DPEngineCoreProc %s shutdown", self.dp_rank)
else:
logger.info("Distributed environment reinitialized for DP rank %s",
self.dp_rank)
class DPEngineCoreActor(DPEngineCoreProc):
"""
Ray actor for running EngineCore in a data parallel context
"""
def __init__(
self,
vllm_config: VllmConfig,
local_client: bool,
addresses: EngineZmqAddresses,
executor_class: type[Executor],
log_stats: bool,
dp_rank: int = 0,
local_dp_rank: int = 0,
):
self.addresses = addresses
vllm_config.parallel_config.data_parallel_rank = dp_rank
vllm_config.parallel_config.data_parallel_rank_local = \
local_dp_rank
# Set CUDA_VISIBLE_DEVICES as early as possible in actor life cycle
# NOTE: in MP we set CUDA_VISIBLE_DEVICES at process creation time,
# and this cannot be done in the same way for Ray because:
# 1) Ray manages life cycle of all ray workers (including
# DPEngineCoreActor)
# 2) Ray sets CUDA_VISIBLE_DEVICES based on num_gpus configuration
# To bypass 2, we need to also set
# RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES, but vLLM workers created
# thereafter would have CUDA_VISIBLE_DEVICES set, which is sticky:
# https://github.com/ray-project/ray/blob/e752fc319ddedd9779a0989b6d3613909bad75c9/python/ray/_private/worker.py#L456 # noqa: E501
# This is problematic because when the vLLM worker (a Ray actor)
# executes a task, it indexes into the sticky CUDA_VISIBLE_DEVICES
# rather than directly using the GPU ID, potentially resulting in
# index out of bounds error. See:
# https://github.com/ray-project/ray/pull/40461/files#diff-31e8159767361e4bc259b6d9883d9c0d5e5db780fcea4a52ead4ee3ee4a59a78R1860 # noqa: E501
# and get_accelerator_ids_for_accelerator_resource() in worker.py
# of ray.
self._set_cuda_visible_devices(vllm_config, local_dp_rank)
super().__init__(vllm_config, local_client, "", executor_class,
log_stats)
def _set_cuda_visible_devices(self, vllm_config: VllmConfig,
local_dp_rank: int):
from vllm.platforms import current_platform
device_control_env_var = current_platform.device_control_env_var
world_size = vllm_config.parallel_config.world_size
# Set CUDA_VISIBLE_DEVICES or equivalent.
try:
os.environ[device_control_env_var] = ",".join(
str(current_platform.device_id_to_physical_device_id(i))
for i in range(local_dp_rank *
world_size, (local_dp_rank + 1) * world_size))
except IndexError as e:
raise Exception(
f"Error setting {device_control_env_var}: "
f"local range: [{local_dp_rank * world_size}, "
f"{(local_dp_rank + 1) * world_size}) "
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
def _decorate_logs(self):
pass
@contextmanager
def _perform_handshakes(self, handshake_address: str, identity: bytes,
local_client: bool, vllm_config: VllmConfig,
client_handshake_address: Optional[str]):
"""
For Ray, we don't need to actually perform handshake.
All addresses information is known before the actor creation.
Therefore, we simply yield these addresses.
"""
yield self.addresses
def wait_for_init(self):
"""
Wait until the engine core is initialized.
This is just an empty method. When ray.get() on this method
(or any other method of the actor) returns, it is guaranteed
that actor creation (i.e., __init__) is complete.
"""
pass
def run(self):
"""
Run the engine core busy loop.
"""
try:
self.run_busy_loop()
except SystemExit:
logger.debug("EngineCore exiting.")
raise
except Exception:
logger.exception("EngineCore encountered a fatal error.")
raise
finally:
self.shutdown()