mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
6 Commits
efcb786d52
...
wide_ep_wo
Author | SHA1 | Date | |
---|---|---|---|
5c2a80c37d | |||
844022b188 | |||
ff6c72db76 | |||
8f731cf3ab | |||
a46c1de7b6 | |||
ca11caf59e |
@ -1025,6 +1025,11 @@ class NixlConnectorWorker:
|
||||
# Sorted dict, oldest requests are put first so we can exit early.
|
||||
if now < expires:
|
||||
break
|
||||
count = self.consumer_notification_counts_by_req.pop(req_id, 0)
|
||||
logger.warning(
|
||||
"Releasing expired KV blocks for request %s which were "
|
||||
"retrieved by %d decode worker(s) within %d seconds.", req_id,
|
||||
count, envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT)
|
||||
del self._reqs_to_send[req_id]
|
||||
done_sending.add(req_id)
|
||||
|
||||
@ -1040,6 +1045,12 @@ class NixlConnectorWorker:
|
||||
for notifs in self.nixl_wrapper.get_new_notifs().values():
|
||||
for notif in notifs:
|
||||
req_id, tp_ratio = notif.decode("utf-8").rsplit(":", 1)
|
||||
if req_id not in self._reqs_to_send:
|
||||
logger.warning(
|
||||
"Potentially invalid KV blocks for "
|
||||
"unrecognized request %s were retrieved by "
|
||||
"a decode worker. They may have expired.", req_id)
|
||||
|
||||
self.consumer_notification_counts_by_req[req_id] += 1
|
||||
# Wait all consumers (D) to be done reading before freeing.
|
||||
if self.consumer_notification_counts_by_req[req_id] == int(
|
||||
|
@ -199,6 +199,8 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_count = client_config.pop(
|
||||
"client_count") if client_config else 1
|
||||
client_index = client_config.pop(
|
||||
"client_index") if client_config else 0
|
||||
try:
|
||||
@ -208,6 +210,7 @@ async def build_async_engine_client_from_engine_args(
|
||||
disable_log_requests=engine_args.disable_log_requests,
|
||||
disable_log_stats=engine_args.disable_log_stats,
|
||||
client_addresses=client_config,
|
||||
client_count=client_count,
|
||||
client_index=client_index)
|
||||
|
||||
# Don't keep the dummy data in memory
|
||||
|
@ -2827,6 +2827,9 @@ def make_zmq_socket(
|
||||
if linger is not None:
|
||||
socket.setsockopt(zmq.LINGER, linger)
|
||||
|
||||
if socket_type == zmq.XPUB:
|
||||
socket.setsockopt(zmq.XPUB_VERBOSE, True)
|
||||
|
||||
# Determine if the path is a TCP socket with an IPv6 address.
|
||||
# Enable IPv6 on the zmq socket if so.
|
||||
scheme, host, _ = split_zmq_path(path)
|
||||
|
@ -67,6 +67,22 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
self.cg_buf_tile_scheduler_metadata = None
|
||||
self.cg_buf_num_splits = None
|
||||
|
||||
device_properties = torch.cuda.get_device_properties(self.device)
|
||||
num_sms = device_properties.multi_processor_count
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
self.cg_buf_tile_scheduler_metadata = torch.empty(
|
||||
# Upper bound on size (<= #SMs, TileSchedulerMetaDataSize)
|
||||
# TileSchedulerMetaDataSize = 8
|
||||
(num_sms, 8),
|
||||
device=self.device,
|
||||
dtype=torch.int32,
|
||||
)
|
||||
self.cg_buf_num_splits = torch.empty(
|
||||
(vllm_config.scheduler_config.max_num_seqs + 1),
|
||||
device=self.device,
|
||||
dtype=torch.int32)
|
||||
|
||||
def _build_decode(self, block_table_tensor: torch.Tensor,
|
||||
seq_lens: torch.Tensor) -> FlashMLADecodeMetadata:
|
||||
tile_scheduler_metadata, num_splits = \
|
||||
@ -77,28 +93,25 @@ class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
|
||||
)
|
||||
|
||||
if self.compilation_config.full_cuda_graph:
|
||||
# First time around (CUDAGraph capture), allocate the static buffer
|
||||
if self.cg_buf_tile_scheduler_metadata is None:
|
||||
self.cg_buf_tile_scheduler_metadata = tile_scheduler_metadata
|
||||
self.cg_buf_num_splits = num_splits
|
||||
else:
|
||||
assert self.cg_buf_num_splits is not None
|
||||
assert self.cg_buf_tile_scheduler_metadata is not None
|
||||
assert self.cg_buf_num_splits is not None
|
||||
|
||||
# Metadata per-SM, fixed size (#SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size() ==
|
||||
tile_scheduler_metadata.size())
|
||||
self.cg_buf_tile_scheduler_metadata.\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
# Metadata per-SM, upper bound on size (<= #SMs, TileMetadataSize)
|
||||
assert (self.cg_buf_tile_scheduler_metadata.size(0)
|
||||
>= tile_scheduler_metadata.size(0))
|
||||
sm_parts = tile_scheduler_metadata.size(0)
|
||||
self.cg_buf_tile_scheduler_metadata[:sm_parts].\
|
||||
copy_(tile_scheduler_metadata)
|
||||
tile_scheduler_metadata = self.cg_buf_tile_scheduler_metadata
|
||||
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
# Num splits is per-batch, varying size (batch_size,)
|
||||
n = num_splits.size(0)
|
||||
# make sure static buffer is large enough
|
||||
assert n <= self.cg_buf_num_splits.size(0)
|
||||
num_splits_view = self.cg_buf_num_splits[:n]
|
||||
num_splits_view.copy_(num_splits)
|
||||
self.cg_buf_num_splits[n:].fill_(0) # fill the rest with 0s
|
||||
num_splits = num_splits_view
|
||||
|
||||
return FlashMLADecodeMetadata(
|
||||
block_table=block_table_tensor,
|
||||
|
@ -57,6 +57,7 @@ class AsyncLLM(EngineClient):
|
||||
start_engine_loop: bool = True,
|
||||
stat_loggers: Optional[list[StatLoggerFactory]] = None,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> None:
|
||||
"""
|
||||
@ -120,6 +121,7 @@ class AsyncLLM(EngineClient):
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
client_addresses=client_addresses,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
@ -151,6 +153,7 @@ class AsyncLLM(EngineClient):
|
||||
disable_log_requests: bool = False,
|
||||
disable_log_stats: bool = False,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> "AsyncLLM":
|
||||
if not envs.VLLM_USE_V1:
|
||||
@ -170,6 +173,7 @@ class AsyncLLM(EngineClient):
|
||||
log_stats=not disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
client_addresses=client_addresses,
|
||||
client_count=client_count,
|
||||
client_index=client_index,
|
||||
)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import multiprocessing
|
||||
import time
|
||||
import weakref
|
||||
@ -66,18 +67,14 @@ class DPCoordinator:
|
||||
|
||||
# Assume coordinator is colocated with front-end procs when not in
|
||||
# either external or hybrid DP LB mode.
|
||||
local_only = not (external_lb or hybrid_lb)
|
||||
front_publish_address = get_engine_client_zmq_addr(
|
||||
local_only=not external_lb and not hybrid_lb, host=host)
|
||||
local_only=local_only, host=host)
|
||||
|
||||
local_only_eng = dp_size == parallel_config.data_parallel_size_local
|
||||
back_publish_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
back_output_address = get_engine_client_zmq_addr(local_only_eng, host)
|
||||
|
||||
# When in external LB mode, load stats aren't published, only changes
|
||||
# to request wave / running state, so we don't need to rate-limit the
|
||||
# updates to the front-end proc(s).
|
||||
min_stats_update_interval_ms = 0 if external_lb else 100
|
||||
|
||||
context = get_mp_context()
|
||||
self.proc: multiprocessing.Process = context.Process(
|
||||
target=DPCoordinatorProc.run_coordinator,
|
||||
@ -87,7 +84,6 @@ class DPCoordinator:
|
||||
"front_publish_address": front_publish_address,
|
||||
"back_output_address": back_output_address,
|
||||
"back_publish_address": back_publish_address,
|
||||
"min_stats_update_interval_ms": min_stats_update_interval_ms,
|
||||
},
|
||||
daemon=True)
|
||||
self.proc.start()
|
||||
@ -126,10 +122,6 @@ class DPCoordinatorProc:
|
||||
|
||||
self.stats_update_interval_ms = min_stats_update_interval_ms
|
||||
|
||||
self.current_wave = 0
|
||||
self.engines_running = False
|
||||
self.stats_changed = False
|
||||
|
||||
@staticmethod
|
||||
def run_coordinator(
|
||||
engine_count: int,
|
||||
@ -156,6 +148,14 @@ class DPCoordinatorProc:
|
||||
|
||||
decoder = MsgpackDecoder(EngineCoreOutputs)
|
||||
|
||||
current_wave = 0
|
||||
engines_running = False
|
||||
|
||||
stats_changed = False
|
||||
last_stats_step = -1
|
||||
last_stats_wave = -1
|
||||
last_step_counts: Optional[list[list[int]]] = None
|
||||
|
||||
with make_zmq_socket(
|
||||
path=front_publish_address, # IPC
|
||||
ctx=self.ctx,
|
||||
@ -173,6 +173,18 @@ class DPCoordinatorProc:
|
||||
bind=True,
|
||||
) as publish_back:
|
||||
|
||||
# Wait until all engines subscribe.
|
||||
for _ in self.engines:
|
||||
if publish_back.recv() != b'\x01':
|
||||
logger.error(
|
||||
"DP Coordinator received unexpected message while"
|
||||
"waiting for engines to subscribe")
|
||||
return
|
||||
# Send ready message to engines.
|
||||
publish_back.send(b"READY")
|
||||
|
||||
logger.info("All engine subscriptions received by DP coordinator")
|
||||
|
||||
poller = zmq.Poller()
|
||||
poller.register(publish_front, zmq.POLLIN)
|
||||
poller.register(output_back, zmq.POLLIN)
|
||||
@ -180,21 +192,33 @@ class DPCoordinatorProc:
|
||||
while True:
|
||||
elapsed = int(time.time() * 1000) - last_publish_time
|
||||
# Send at stats_update_interval_ms interval if the stats have
|
||||
# changed, or otherwise every 4 seconds.
|
||||
# changed, or otherwise every 5 seconds.
|
||||
wait_for = (self.stats_update_interval_ms
|
||||
if self.stats_changed else 4000)
|
||||
events = poller.poll(timeout=max(0, wait_for - elapsed))
|
||||
if stats_changed else 5000)
|
||||
|
||||
# Wait at least 50ms to ensure we've received all stats for
|
||||
# the current step.
|
||||
min_timeout = 50 if last_step_counts is None else 0
|
||||
|
||||
events = poller.poll(timeout=max(min_timeout, wait_for -
|
||||
elapsed))
|
||||
if not events:
|
||||
# Poller timeout - publish current stats to front-ends.
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
to_publish = (engine_req_counts_list, self.current_wave,
|
||||
self.engines_running)
|
||||
if last_step_counts is not None:
|
||||
engine_req_counts_list = last_step_counts
|
||||
last_step_counts = None
|
||||
else:
|
||||
engine_req_counts_list = self._get_engine_counts()
|
||||
stats_changed = False
|
||||
|
||||
to_publish = (engine_req_counts_list, current_wave,
|
||||
engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(to_publish))
|
||||
last_publish_time = int(time.time() * 1000)
|
||||
self.stats_changed = False
|
||||
continue
|
||||
|
||||
events = dict(events)
|
||||
wave_state_changed = False
|
||||
|
||||
if publish_front in events:
|
||||
buffer = publish_front.recv()
|
||||
@ -221,7 +245,7 @@ class DPCoordinatorProc:
|
||||
# current_wave
|
||||
# we note that 0 is the wave number for the new
|
||||
# engine
|
||||
self.engines_running = False
|
||||
engines_running = False
|
||||
logger.info(
|
||||
"DPCoordinator scaled up from %s to %s "
|
||||
"engines", current_count, new_engine_count)
|
||||
@ -237,15 +261,15 @@ class DPCoordinatorProc:
|
||||
# engines are paused, so that we can wake the other
|
||||
# engines.
|
||||
engine_to_exclude, wave = decoded
|
||||
if not self.engines_running:
|
||||
if wave < self.current_wave:
|
||||
if not engines_running:
|
||||
if wave < current_wave:
|
||||
# If the wave number is stale, ensure the message
|
||||
# is handled by all the engines.
|
||||
engine_to_exclude = None
|
||||
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
self._send_start_wave(publish_back, self.current_wave,
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(publish_back, current_wave,
|
||||
engine_to_exclude)
|
||||
|
||||
if output_back in events:
|
||||
@ -263,36 +287,56 @@ class DPCoordinatorProc:
|
||||
# 1. Updated request load stats - update our local
|
||||
# state with these.
|
||||
stats = self.engines[eng_index].request_counts
|
||||
stats_step = scheduler_stats.step_counter
|
||||
stats_wave = scheduler_stats.current_wave
|
||||
if (stats_wave > last_stats_wave
|
||||
or stats_wave == last_stats_wave
|
||||
and stats_step > last_stats_step):
|
||||
if stats_changed:
|
||||
last_step_counts = self._get_engine_counts(
|
||||
do_copy=True)
|
||||
last_stats_step = stats_step
|
||||
last_stats_wave = stats_wave
|
||||
elif stats_wave != last_stats_wave or (
|
||||
stats_step != last_stats_step):
|
||||
logger.warning(
|
||||
"Received stats for out-of-order "
|
||||
"step (%d, %d) from engine %d (expected "
|
||||
"> (%d, %d))", stats_wave, stats_step,
|
||||
eng_index, last_stats_wave, last_stats_step)
|
||||
stats[0] = scheduler_stats.num_waiting_reqs
|
||||
stats[1] = scheduler_stats.num_running_reqs
|
||||
self.stats_changed = True
|
||||
stats_changed = True
|
||||
|
||||
if (wave := outputs.wave_complete) is not None:
|
||||
# 2. Notification from rank 0 engine that we've
|
||||
# moved into the global paused state
|
||||
# (engines_running==False).
|
||||
if self.current_wave <= wave:
|
||||
if current_wave <= wave:
|
||||
new_wave = wave + 1
|
||||
logger.debug("Moving DP wave from %d to %d.",
|
||||
self.current_wave, new_wave)
|
||||
self.current_wave = new_wave
|
||||
self.engines_running = False
|
||||
self.stats_changed = True
|
||||
current_wave, new_wave)
|
||||
current_wave = new_wave
|
||||
engines_running = False
|
||||
wave_state_changed = True
|
||||
elif (wave := outputs.start_wave) is not None and (
|
||||
wave > self.current_wave or
|
||||
(wave == self.current_wave
|
||||
and not self.engines_running)):
|
||||
wave > current_wave or
|
||||
(wave == current_wave and not engines_running)):
|
||||
# 3. The engine received request for a non-current wave
|
||||
# so we must ensure that other engines progress to the
|
||||
# next wave (race condition handling).
|
||||
logger.debug(
|
||||
"Starting wave %d after notification of "
|
||||
"stale wave request from engine.", wave)
|
||||
self.current_wave = wave
|
||||
self.engines_running = True
|
||||
self.stats_changed = True
|
||||
current_wave = wave
|
||||
engines_running = True
|
||||
wave_state_changed = True
|
||||
self._send_start_wave(publish_back, wave, eng_index)
|
||||
|
||||
if wave_state_changed:
|
||||
message = (None, current_wave, engines_running)
|
||||
publish_front.send(msgspec.msgpack.encode(message))
|
||||
|
||||
@staticmethod
|
||||
def _send_start_wave(socket: zmq.Socket, wave: int,
|
||||
exclude_engine_index: Optional[int]):
|
||||
@ -305,6 +349,8 @@ class DPCoordinatorProc:
|
||||
socket.send_multipart(
|
||||
(EngineCoreRequestType.START_DP_WAVE.value, wave_encoded))
|
||||
|
||||
def _get_engine_counts(self) -> list[list[int]]:
|
||||
def _get_engine_counts(self, do_copy=False) -> list[list[int]]:
|
||||
"""Return list of [waiting, running] count lists for each engine."""
|
||||
if do_copy:
|
||||
return [copy.copy(e.request_counts) for e in self.engines]
|
||||
return [e.request_counts for e in self.engines]
|
||||
|
@ -331,10 +331,11 @@ class EngineCore:
|
||||
# Blocking until the first result is available.
|
||||
model_output = self.execute_model_with_error_logging(
|
||||
lambda _: future.result(), scheduler_output)
|
||||
assert model_output is not None
|
||||
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
|
||||
return engine_core_outputs, scheduled_batch
|
||||
|
||||
@ -445,8 +446,11 @@ class EngineCoreProc(EngineCore):
|
||||
self.has_coordinator = addresses.coordinator_output is not None
|
||||
self.frontend_stats_publish_address = (
|
||||
addresses.frontend_stats_publish_address)
|
||||
logger.debug("Has DP Coordinator: %s, stats publish address: %s",
|
||||
self.has_coordinator,
|
||||
self.frontend_stats_publish_address)
|
||||
# Only publish request queue stats to coordinator for "internal"
|
||||
# LB mode.
|
||||
# and "hybrid" LB modes .
|
||||
self.publish_dp_lb_stats = (
|
||||
self.has_coordinator
|
||||
and not vllm_config.parallel_config.data_parallel_external_lb)
|
||||
@ -456,25 +460,38 @@ class EngineCoreProc(EngineCore):
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
executor_fail_callback)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
ready_event = threading.Event()
|
||||
input_thread = threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs,
|
||||
addresses.coordinator_input,
|
||||
identity, ready_event),
|
||||
daemon=True)
|
||||
input_thread.start()
|
||||
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
self.engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
# Don't complete handshake until DP coordinator ready message is
|
||||
# received.
|
||||
while not ready_event.wait(10):
|
||||
if not input_thread.is_alive():
|
||||
raise RuntimeError(
|
||||
"Input socket thread died during startup")
|
||||
assert addresses.coordinator_input is not None
|
||||
logger.info("Waiting for READY message from DP Coordinator")
|
||||
|
||||
self.step_fn = (self.step if self.batch_queue is None else
|
||||
self.step_with_batch_queue)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
# and to overlap some serialization/deserialization with the
|
||||
# model forward pass.
|
||||
# Threads handle Socket <-> Queues and core_busy_loop uses Queue.
|
||||
threading.Thread(target=self.process_input_sockets,
|
||||
args=(addresses.inputs, addresses.coordinator_input,
|
||||
identity),
|
||||
daemon=True).start()
|
||||
self.output_thread = threading.Thread(
|
||||
target=self.process_output_sockets,
|
||||
args=(addresses.outputs, addresses.coordinator_output,
|
||||
self.engine_index),
|
||||
daemon=True)
|
||||
self.output_thread.start()
|
||||
|
||||
@contextmanager
|
||||
def _perform_handshakes(
|
||||
self,
|
||||
@ -489,10 +506,10 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
For DP=1 or offline mode, this is with the colocated front-end process.
|
||||
|
||||
For DP>1 with internal loadbalancing this is with the shared front-end
|
||||
For DP>1 with internal load-balancing this is with the shared front-end
|
||||
process which may reside on a different node.
|
||||
|
||||
For DP>1 with external or hybrid loadbalancing, two handshakes are
|
||||
For DP>1 with external or hybrid load-balancing, two handshakes are
|
||||
performed:
|
||||
- With the rank 0 front-end process which retrieves the
|
||||
DP Coordinator ZMQ addresses and DP process group address.
|
||||
@ -751,7 +768,7 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
def process_input_sockets(self, input_addresses: list[str],
|
||||
coord_input_address: Optional[str],
|
||||
identity: bytes):
|
||||
identity: bytes, ready_event: threading.Event):
|
||||
"""Input socket IO thread."""
|
||||
|
||||
# Msgpack serialization decoding.
|
||||
@ -788,9 +805,14 @@ class EngineCoreProc(EngineCore):
|
||||
# back to us.
|
||||
input_socket.send(b'')
|
||||
poller.register(input_socket, zmq.POLLIN)
|
||||
|
||||
if coord_socket is not None:
|
||||
# Wait for ready message from coordinator.
|
||||
assert coord_socket.recv() == b"READY"
|
||||
poller.register(coord_socket, zmq.POLLIN)
|
||||
|
||||
ready_event.set()
|
||||
del ready_event
|
||||
while True:
|
||||
for input_socket, _ in poller.poll():
|
||||
# (RequestType, RequestData)
|
||||
@ -887,7 +909,7 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
|
||||
# Counts forward-passes of the model so that we can synchronize
|
||||
# finished with DP peers every N steps.
|
||||
self.counter = 0
|
||||
self.step_counter = 0
|
||||
self.current_wave = 0
|
||||
self.last_counts = (0, 0)
|
||||
|
||||
@ -959,6 +981,25 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
else:
|
||||
super()._handle_client_request(request_type, request)
|
||||
|
||||
def _process_engine_step(self) -> bool:
|
||||
# Step the engine core.
|
||||
outputs, model_executed = self.step_fn()
|
||||
# Put EngineCoreOutputs into the output queue.
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
|
||||
if outputs and not model_executed:
|
||||
# NOTE(woosuk): This branch is taken when the previous step_fn call
|
||||
# updated the scheduler or worker states without actually executing
|
||||
# the model. With asynchronous scheduling, this typically occurs
|
||||
# every other step. To avoid unnecessary dummy runs, we give
|
||||
# step_fn a second chance to execute the model if possible.
|
||||
outputs, model_executed = self.step_fn()
|
||||
for output in (outputs.items() if outputs else ()):
|
||||
self.output_queue.put_nowait(output)
|
||||
|
||||
return model_executed
|
||||
|
||||
def _maybe_publish_request_counts(self):
|
||||
if not self.publish_dp_lb_stats:
|
||||
return
|
||||
@ -967,7 +1008,9 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
counts = self.scheduler.get_request_counts()
|
||||
if counts != self.last_counts:
|
||||
self.last_counts = counts
|
||||
stats = SchedulerStats(*counts)
|
||||
stats = SchedulerStats(*counts,
|
||||
step_counter=self.step_counter,
|
||||
current_wave=self.current_wave)
|
||||
self.output_queue.put_nowait(
|
||||
(-1, EngineCoreOutputs(scheduler_stats=stats)))
|
||||
|
||||
@ -1009,15 +1052,16 @@ class DPEngineCoreProc(EngineCoreProc):
|
||||
self.output_queue.put_nowait(
|
||||
(client_index,
|
||||
EngineCoreOutputs(wave_complete=self.current_wave)))
|
||||
# Increment wave count and reset step counter.
|
||||
self.current_wave += 1
|
||||
self.step_counter = 0
|
||||
|
||||
def _has_global_unfinished_reqs(self, local_unfinished: bool) -> bool:
|
||||
|
||||
# Optimization - only perform finish-sync all-reduce every 32 steps.
|
||||
self.counter += 1
|
||||
if self.counter != 32:
|
||||
self.step_counter += 1
|
||||
if self.step_counter % 32 != 0:
|
||||
return True
|
||||
self.counter = 0
|
||||
|
||||
return ParallelConfig.has_unfinished_dp(self.dp_group,
|
||||
local_unfinished)
|
||||
|
@ -85,11 +85,12 @@ class EngineCoreClient(ABC):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0,
|
||||
) -> "MPClient":
|
||||
parallel_config = vllm_config.parallel_config
|
||||
client_args = (vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
client_addresses, client_count, client_index)
|
||||
if parallel_config.data_parallel_size > 1:
|
||||
if parallel_config.data_parallel_external_lb:
|
||||
# External load balancer - client per DP rank.
|
||||
@ -686,6 +687,7 @@ class AsyncMPClient(MPClient):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0):
|
||||
super().__init__(
|
||||
asyncio_mode=True,
|
||||
@ -886,11 +888,12 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0):
|
||||
self.current_wave = 0
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
client_addresses, client_count, client_index)
|
||||
|
||||
# List of [waiting, running] pair per engine.
|
||||
# Used only by DPLBAsyncMPClient subclass.
|
||||
@ -986,7 +989,11 @@ class DPAsyncMPClient(AsyncMPClient):
|
||||
counts, wave, running = msgspec.msgpack.decode(buf)
|
||||
self.current_wave = wave
|
||||
self.engines_running = running
|
||||
self.lb_engines = counts[count_slice]
|
||||
if counts is not None:
|
||||
sliced_counts = counts[count_slice]
|
||||
self.lb_engines = sliced_counts
|
||||
logger.debug("Received counts: %s (%s)", sliced_counts,
|
||||
count_slice)
|
||||
|
||||
resources.stats_update_task = asyncio.create_task(
|
||||
run_engine_stats_update_task())
|
||||
@ -1022,40 +1029,45 @@ class DPLBAsyncMPClient(DPAsyncMPClient):
|
||||
executor_class: type[Executor],
|
||||
log_stats: bool,
|
||||
client_addresses: Optional[dict[str, str]] = None,
|
||||
client_count: int = 1,
|
||||
client_index: int = 0):
|
||||
|
||||
self.client_count = client_count
|
||||
|
||||
# To route aborts to the correct engine.
|
||||
self.reqs_in_flight: dict[str, EngineIdentity] = {}
|
||||
|
||||
super().__init__(vllm_config, executor_class, log_stats,
|
||||
client_addresses, client_index)
|
||||
client_addresses, client_count, client_index)
|
||||
|
||||
assert len(self.core_engines) > 1
|
||||
|
||||
self.eng_start_index = (len(self.core_engines) *
|
||||
self.client_index) // client_count
|
||||
|
||||
def get_core_engine_for_request(
|
||||
self, request: EngineCoreRequest) -> EngineIdentity:
|
||||
# Engines are in rank order.
|
||||
current_counts = self.lb_engines
|
||||
if (eng_index := request.data_parallel_rank) is None:
|
||||
if not self.lb_engines:
|
||||
if not current_counts:
|
||||
return self.core_engine
|
||||
# TODO use P2C alg for larger DP sizes
|
||||
num_engines = len(self.lb_engines)
|
||||
min_counts = [sys.maxsize, sys.maxsize]
|
||||
num_engines = len(current_counts)
|
||||
min_score = sys.maxsize
|
||||
eng_index = 0
|
||||
for i in range(num_engines):
|
||||
# Start from client_index to help with balancing when engines
|
||||
# are empty.
|
||||
idx = (self.client_index + i) % num_engines
|
||||
counts = self.lb_engines[idx]
|
||||
if counts < min_counts:
|
||||
min_counts = counts
|
||||
idx = (self.eng_start_index + i) % num_engines
|
||||
waiting, running = current_counts[idx]
|
||||
score = waiting * 4 + running
|
||||
if score < min_score:
|
||||
min_score = score
|
||||
eng_index = idx
|
||||
# Adjust local counts for better balancing between stats updates
|
||||
# from the coordinator (which happen every 100ms).
|
||||
if min_counts[0]:
|
||||
min_counts[0] += 1
|
||||
else:
|
||||
min_counts[1] += 1
|
||||
# Increment local waiting count for better balancing between stats
|
||||
# updates from the coordinator (which happen every 100ms).
|
||||
current_counts[eng_index][0] += self.client_count
|
||||
|
||||
chosen_engine = self.core_engines[eng_index]
|
||||
# Record which engine is chosen for this request, to handle aborts.
|
||||
|
@ -33,6 +33,9 @@ class SchedulerStats:
|
||||
num_running_reqs: int = 0
|
||||
num_waiting_reqs: int = 0
|
||||
|
||||
step_counter: int = 0
|
||||
current_wave: int = 0
|
||||
|
||||
kv_cache_usage: float = 0.0
|
||||
|
||||
prefix_cache_stats: PrefixCacheStats = field(
|
||||
|
@ -154,6 +154,7 @@ class APIServerProcessManager:
|
||||
client_config = {
|
||||
"input_address": in_addr,
|
||||
"output_address": out_addr,
|
||||
"client_count": num_servers,
|
||||
"client_index": i
|
||||
}
|
||||
if stats_update_address is not None:
|
||||
|
Reference in New Issue
Block a user