[BugFix] Async scheduling and PP compatibility with DP (#23770)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-08-29 08:17:27 -07:00
committed by GitHub
parent 0a2f4c0793
commit d90d8eb674
7 changed files with 105 additions and 98 deletions

View File

@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
# Schedule Batch 1: (10, req0)
assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 1
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 10
# num_computed_tokens should have been updated immediately.
assert engine_core.scheduler.requests[
req0.request_id].num_computed_tokens == 10
# Schedule Batch 2: (2, req0), (8, req1)
assert engine_core.step_with_batch_queue()[0] is None
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert engine_core.step_with_batch_queue()[0] == {}
assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 2
assert scheduler_output.num_scheduled_tokens["1"] == 8
# num_computed_tokens should have been updated immediately.
@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
assert engine_core.scheduler.get_num_unfinished_requests() == 2
# Batch queue is full. Finish Batch 1.
engine_core.step_with_batch_queue()
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
# Finish Batch 1 and schedule Batch 3: (4, req1).
# Note that req0 cannot be scheduled
# because it is in the decoding stage now.
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
assert len(engine_core.batch_queue) == 1
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 4
# Batch queue is full. Finish Batch 2. Get first token of req0.
# Finish Batch 2. Get first token of req0.
# Schedule Batch 4: (1, req0).
output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
# Schedule Batch 4: (1, req0).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["0"] == 1
# Batch queue is full. Finish Batch 3. Get first token of req1.
# Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
output = engine_core.step_with_batch_queue()[0].get(0)
assert output is not None
assert len(output.outputs) == 1
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
# Schedule Batch 5: (1, req1).
engine_core.step_with_batch_queue()
assert engine_core.batch_queue.qsize() == 2
scheduler_output = engine_core.batch_queue.queue[-1][1]
scheduler_output = engine_core.batch_queue[-1][1]
assert scheduler_output.num_scheduled_tokens["1"] == 1
# Loop until req0 is finished.
step = 0
req_id = 0
expected_num_tokens = [
engine_core.scheduler.requests["0"].num_tokens + 1,
@ -368,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
]
while engine_core.scheduler.get_num_unfinished_requests() == 2:
output = engine_core.step_with_batch_queue()[0]
if step % 2 == 0:
# Even steps consumes an output.
assert output is not None
assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
else:
# Odd steps schedules a new batch.
assert output is None
step += 1
# Every step consumes an output.
assert output is not None
assert len(output[0].outputs) == 1
if req_id in engine_core.scheduler.requests:
assert engine_core.scheduler.requests[
req_id].num_tokens == expected_num_tokens[req_id]
expected_num_tokens[req_id] += 1
req_id = (req_id + 1) % 2
@multi_gpu_test(num_gpus=2)

View File

@ -75,9 +75,10 @@ async def generate(
],
)
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
@pytest.mark.parametrize("async_scheduling", [True, False])
@pytest.mark.asyncio
async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str):
async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str,
async_scheduling: bool):
stats_loggers = {}
@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind,
prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend
engine_args.async_scheduling = async_scheduling
engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown)

View File

@ -10,6 +10,7 @@ import msgspec
import vllm.platforms
from vllm.config import ParallelConfig
from vllm.distributed import get_pp_group
from vllm.executor.msgspec_utils import decode_hook, encode_hook
from vllm.logger import init_logger
from vllm.platforms import current_platform
@ -136,6 +137,11 @@ try:
scheduler_output, intermediate_tensors)
if isinstance(output, IntermediateTensors):
output = scheduler_output, output
elif not get_pp_group().is_last_rank:
# Case where there are no scheduled requests
# but may still be finished requests.
assert not output or not output.req_ids
output = scheduler_output, None
return output
def override_env_vars(self, vars: Dict[str, str]):

View File

@ -138,12 +138,12 @@ class EngineCore:
# schedule and execute batches, and is required by pipeline parallelism
# to eliminate pipeline bubbles.
self.batch_queue_size = self.model_executor.max_concurrent_batches
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput],
SchedulerOutput]]] = None
if self.batch_queue_size > 1:
logger.info("Batch queue is enabled with size %d",
self.batch_queue_size)
self.batch_queue = queue.Queue(self.batch_queue_size)
self.batch_queue = deque(maxlen=self.batch_queue_size)
self.request_block_hasher: Optional[Callable[[Request],
list[BlockHash]]] = None
@ -319,41 +319,43 @@ class EngineCore:
batch in the job queue is finished.
3. Update the scheduler from the output.
"""
assert self.batch_queue is not None
batch_queue = self.batch_queue
assert batch_queue is not None
engine_core_outputs = None
scheduler_output = None
# Try to schedule a new batch if the batch queue is not full, but
# the scheduler may return an empty batch if all requests are scheduled.
# Note that this is not blocking.
if not self.batch_queue.full():
assert len(batch_queue) < self.batch_queue_size
model_executed = False
if self.scheduler.has_requests():
scheduler_output = self.scheduler.schedule()
if scheduler_output.total_num_scheduled_tokens > 0:
future = self.model_executor.execute_model(scheduler_output)
self.batch_queue.put_nowait(
(future, scheduler_output)) # type: ignore
future = self.model_executor.execute_model(scheduler_output)
batch_queue.appendleft(
(future, scheduler_output)) # type: ignore[arg-type]
scheduled_batch = (scheduler_output is not None
and scheduler_output.total_num_scheduled_tokens > 0)
model_executed = scheduler_output.total_num_scheduled_tokens > 0
if model_executed and len(batch_queue) < self.batch_queue_size \
and not batch_queue[-1][0].done():
# Don't block on next worker response unless the queue is full
# or there are no more requests to schedule.
return None, True
# If no more requests can be scheduled and the job queue is not empty,
# block until the first batch in the job queue is finished.
# TODO(comaniac): Ideally we should peek the first batch in the
# job queue to check if it's finished before scheduling a new batch,
# but peeking the first element in a queue is not thread-safe,
# so we need more work.
if not scheduled_batch and not self.batch_queue.empty():
future, scheduler_output = self.batch_queue.get_nowait()
elif not batch_queue:
# Queue is empty. We should not reach here since this method should
# only be called when the scheduler contains requests or the queue
# is non-empty.
return None, False
# Blocking until the first result is available.
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
# Block until the next result is available.
future, scheduler_output = batch_queue.pop()
model_output = self.execute_model_with_error_logging(
lambda _: future.result(), scheduler_output)
self.batch_queue.task_done()
engine_core_outputs = (self.scheduler.update_from_output(
scheduler_output, model_output))
engine_core_outputs = self.scheduler.update_from_output(
scheduler_output, model_output)
return engine_core_outputs, scheduled_batch
return engine_core_outputs, model_executed
def shutdown(self):
self.structured_output_manager.clear_backend()
@ -388,7 +390,7 @@ class EngineCore:
return self.model_executor.is_sleeping
def execute_dummy_batch(self):
self.model_executor.collective_rpc("execute_dummy_batch")
self.model_executor.execute_dummy_batch()
def add_lora(self, lora_request: LoRARequest) -> bool:
return self.model_executor.add_lora(lora_request)
@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore):
"""Exits when an engine step needs to be performed."""
waited = False
while not self.engines_running and not self.scheduler.has_requests():
while not self.engines_running and not self.scheduler.has_requests() \
and not self.batch_queue:
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
logger.debug("EngineCore waiting for work.")
waited = True

View File

@ -81,12 +81,10 @@ class Executor(ExecutorBase):
pass
def determine_available_memory(self) -> list[int]: # in bytes
output = self.collective_rpc("determine_available_memory")
return output
return self.collective_rpc("determine_available_memory")
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
output = self.collective_rpc("get_kv_cache_spec")
return output
return self.collective_rpc("get_kv_cache_spec")
def execute_model(
self,
@ -96,6 +94,9 @@ class Executor(ExecutorBase):
args=(scheduler_output, ))
return output[0]
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch")
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
output = self.collective_rpc("take_draft_token_ids")
return output[0]

View File

@ -191,6 +191,10 @@ class MultiprocExecutor(Executor):
outputs, self.output_rank)
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
def execute_dummy_batch(self) -> None:
self.collective_rpc("execute_dummy_batch",
unique_reply_rank=self.output_rank)
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
# OPTIMIZATION: Get output only from a single worker (output_rank)
outputs = self.collective_rpc("take_draft_token_ids",
@ -242,12 +246,17 @@ class MultiprocExecutor(Executor):
dequeue_timeout = None if deadline is None else (
deadline - time.monotonic())
if non_block:
if self.io_thread_pool is not None:
# We must consume worker_response_mq from a single thread.
result = self.io_thread_pool.submit( # type: ignore
get_response, w, dequeue_timeout, self.shutdown_event)
else:
if not non_block:
result = result.result()
elif not non_block:
result = get_response(w, dequeue_timeout)
else:
raise RuntimeError("non_block can only be used when"
" max_concurrent_batches > 1")
responses.append(result)
return responses

View File

@ -354,36 +354,37 @@ class Worker(WorkerBase):
scheduler_output: "SchedulerOutput",
) -> Optional[ModelRunnerOutput]:
intermediate_tensors = None
if not get_pp_group().is_first_rank:
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
if forward_pass and not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
output = self.model_runner.execute_model(scheduler_output,
intermediate_tensors)
parallel_config = self.vllm_config.parallel_config
if parallel_config.distributed_executor_backend != "external_launcher" \
and not get_pp_group().is_last_rank:
assert isinstance(output, IntermediateTensors)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
if isinstance(output, ModelRunnerOutput):
return output
assert isinstance(output, ModelRunnerOutput)
assert isinstance(output, IntermediateTensors)
parallel_config = self.vllm_config.parallel_config
assert parallel_config.distributed_executor_backend != (
"external_launcher") and not get_pp_group().is_last_rank
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
kv_connector_output = output.kv_connector_output
if not kv_connector_output:
return None
# In case of PP with kv transfer, we need to pass through the
# kv_connector_output
if (not kv_connector_output.finished_sending
and not kv_connector_output.finished_recving):
return EMPTY_MODEL_RUNNER_OUTPUT
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
output.kv_connector_output = kv_connector_output
return output
def take_draft_token_ids(self) -> Optional[DraftTokenIds]: