[BugFix] Async scheduling and PP compatibility with DP (#23770)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
# Schedule Batch 1: (10, req0)
|
||||
assert engine_core.step_with_batch_queue()[0] is None
|
||||
assert engine_core.batch_queue.qsize() == 1
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert len(engine_core.batch_queue) == 1
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 10
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
assert engine_core.scheduler.requests[
|
||||
req0.request_id].num_computed_tokens == 10
|
||||
|
||||
# Schedule Batch 2: (2, req0), (8, req1)
|
||||
assert engine_core.step_with_batch_queue()[0] is None
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert engine_core.step_with_batch_queue()[0] == {}
|
||||
assert len(engine_core.batch_queue) == 1
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 2
|
||||
assert scheduler_output.num_scheduled_tokens["1"] == 8
|
||||
# num_computed_tokens should have been updated immediately.
|
||||
@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
|
||||
assert engine_core.scheduler.get_num_unfinished_requests() == 2
|
||||
|
||||
# Batch queue is full. Finish Batch 1.
|
||||
engine_core.step_with_batch_queue()
|
||||
|
||||
# Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled
|
||||
# Finish Batch 1 and schedule Batch 3: (4, req1).
|
||||
# Note that req0 cannot be scheduled
|
||||
# because it is in the decoding stage now.
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
assert len(engine_core.batch_queue) == 1
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["1"] == 4
|
||||
|
||||
# Batch queue is full. Finish Batch 2. Get first token of req0.
|
||||
# Finish Batch 2. Get first token of req0.
|
||||
# Schedule Batch 4: (1, req0).
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13
|
||||
|
||||
# Schedule Batch 4: (1, req0).
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["0"] == 1
|
||||
|
||||
# Batch queue is full. Finish Batch 3. Get first token of req1.
|
||||
# Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1).
|
||||
output = engine_core.step_with_batch_queue()[0].get(0)
|
||||
assert output is not None
|
||||
assert len(output.outputs) == 1
|
||||
assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13
|
||||
|
||||
# Schedule Batch 5: (1, req1).
|
||||
engine_core.step_with_batch_queue()
|
||||
assert engine_core.batch_queue.qsize() == 2
|
||||
scheduler_output = engine_core.batch_queue.queue[-1][1]
|
||||
scheduler_output = engine_core.batch_queue[-1][1]
|
||||
assert scheduler_output.num_scheduled_tokens["1"] == 1
|
||||
|
||||
# Loop until req0 is finished.
|
||||
step = 0
|
||||
req_id = 0
|
||||
expected_num_tokens = [
|
||||
engine_core.scheduler.requests["0"].num_tokens + 1,
|
||||
@ -368,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch):
|
||||
]
|
||||
while engine_core.scheduler.get_num_unfinished_requests() == 2:
|
||||
output = engine_core.step_with_batch_queue()[0]
|
||||
if step % 2 == 0:
|
||||
# Even steps consumes an output.
|
||||
assert output is not None
|
||||
assert len(output[0].outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
expected_num_tokens[req_id] += 1
|
||||
req_id = (req_id + 1) % 2
|
||||
else:
|
||||
# Odd steps schedules a new batch.
|
||||
assert output is None
|
||||
step += 1
|
||||
# Every step consumes an output.
|
||||
assert output is not None
|
||||
assert len(output[0].outputs) == 1
|
||||
if req_id in engine_core.scheduler.requests:
|
||||
assert engine_core.scheduler.requests[
|
||||
req_id].num_tokens == expected_num_tokens[req_id]
|
||||
expected_num_tokens[req_id] += 1
|
||||
req_id = (req_id + 1) % 2
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
|
@ -75,9 +75,10 @@ async def generate(
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"])
|
||||
@pytest.mark.parametrize("async_scheduling", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_load(output_kind: RequestOutputKind,
|
||||
data_parallel_backend: str):
|
||||
async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str,
|
||||
async_scheduling: bool):
|
||||
|
||||
stats_loggers = {}
|
||||
|
||||
@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind,
|
||||
prompt = "This is a test of data parallel"
|
||||
|
||||
engine_args.data_parallel_backend = data_parallel_backend
|
||||
engine_args.async_scheduling = async_scheduling
|
||||
engine = AsyncLLM.from_engine_args(engine_args,
|
||||
stat_loggers=[SimpleStatsLogger])
|
||||
after.callback(engine.shutdown)
|
||||
|
@ -10,6 +10,7 @@ import msgspec
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed import get_pp_group
|
||||
from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
@ -136,6 +137,11 @@ try:
|
||||
scheduler_output, intermediate_tensors)
|
||||
if isinstance(output, IntermediateTensors):
|
||||
output = scheduler_output, output
|
||||
elif not get_pp_group().is_last_rank:
|
||||
# Case where there are no scheduled requests
|
||||
# but may still be finished requests.
|
||||
assert not output or not output.req_ids
|
||||
output = scheduler_output, None
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: Dict[str, str]):
|
||||
|
@ -138,12 +138,12 @@ class EngineCore:
|
||||
# schedule and execute batches, and is required by pipeline parallelism
|
||||
# to eliminate pipeline bubbles.
|
||||
self.batch_queue_size = self.model_executor.max_concurrent_batches
|
||||
self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput],
|
||||
SchedulerOutput]]] = None
|
||||
self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput],
|
||||
SchedulerOutput]]] = None
|
||||
if self.batch_queue_size > 1:
|
||||
logger.info("Batch queue is enabled with size %d",
|
||||
self.batch_queue_size)
|
||||
self.batch_queue = queue.Queue(self.batch_queue_size)
|
||||
self.batch_queue = deque(maxlen=self.batch_queue_size)
|
||||
|
||||
self.request_block_hasher: Optional[Callable[[Request],
|
||||
list[BlockHash]]] = None
|
||||
@ -319,41 +319,43 @@ class EngineCore:
|
||||
batch in the job queue is finished.
|
||||
3. Update the scheduler from the output.
|
||||
"""
|
||||
assert self.batch_queue is not None
|
||||
batch_queue = self.batch_queue
|
||||
assert batch_queue is not None
|
||||
|
||||
engine_core_outputs = None
|
||||
scheduler_output = None
|
||||
# Try to schedule a new batch if the batch queue is not full, but
|
||||
# the scheduler may return an empty batch if all requests are scheduled.
|
||||
# Note that this is not blocking.
|
||||
if not self.batch_queue.full():
|
||||
assert len(batch_queue) < self.batch_queue_size
|
||||
|
||||
model_executed = False
|
||||
if self.scheduler.has_requests():
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
self.batch_queue.put_nowait(
|
||||
(future, scheduler_output)) # type: ignore
|
||||
future = self.model_executor.execute_model(scheduler_output)
|
||||
batch_queue.appendleft(
|
||||
(future, scheduler_output)) # type: ignore[arg-type]
|
||||
|
||||
scheduled_batch = (scheduler_output is not None
|
||||
and scheduler_output.total_num_scheduled_tokens > 0)
|
||||
model_executed = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if model_executed and len(batch_queue) < self.batch_queue_size \
|
||||
and not batch_queue[-1][0].done():
|
||||
# Don't block on next worker response unless the queue is full
|
||||
# or there are no more requests to schedule.
|
||||
return None, True
|
||||
|
||||
# If no more requests can be scheduled and the job queue is not empty,
|
||||
# block until the first batch in the job queue is finished.
|
||||
# TODO(comaniac): Ideally we should peek the first batch in the
|
||||
# job queue to check if it's finished before scheduling a new batch,
|
||||
# but peeking the first element in a queue is not thread-safe,
|
||||
# so we need more work.
|
||||
if not scheduled_batch and not self.batch_queue.empty():
|
||||
future, scheduler_output = self.batch_queue.get_nowait()
|
||||
elif not batch_queue:
|
||||
# Queue is empty. We should not reach here since this method should
|
||||
# only be called when the scheduler contains requests or the queue
|
||||
# is non-empty.
|
||||
return None, False
|
||||
|
||||
# Blocking until the first result is available.
|
||||
model_output = self.execute_model_with_error_logging(
|
||||
lambda _: future.result(), scheduler_output)
|
||||
# Block until the next result is available.
|
||||
future, scheduler_output = batch_queue.pop()
|
||||
model_output = self.execute_model_with_error_logging(
|
||||
lambda _: future.result(), scheduler_output)
|
||||
|
||||
self.batch_queue.task_done()
|
||||
engine_core_outputs = (self.scheduler.update_from_output(
|
||||
scheduler_output, model_output))
|
||||
engine_core_outputs = self.scheduler.update_from_output(
|
||||
scheduler_output, model_output)
|
||||
|
||||
return engine_core_outputs, scheduled_batch
|
||||
return engine_core_outputs, model_executed
|
||||
|
||||
def shutdown(self):
|
||||
self.structured_output_manager.clear_backend()
|
||||
@ -388,7 +390,7 @@ class EngineCore:
|
||||
return self.model_executor.is_sleeping
|
||||
|
||||
def execute_dummy_batch(self):
|
||||
self.model_executor.collective_rpc("execute_dummy_batch")
|
||||
self.model_executor.execute_dummy_batch()
|
||||
|
||||
def add_lora(self, lora_request: LoRARequest) -> bool:
|
||||
return self.model_executor.add_lora(lora_request)
|
||||
@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore):
|
||||
"""Exits when an engine step needs to be performed."""
|
||||
|
||||
waited = False
|
||||
while not self.engines_running and not self.scheduler.has_requests():
|
||||
while not self.engines_running and not self.scheduler.has_requests() \
|
||||
and not self.batch_queue:
|
||||
if logger.isEnabledFor(DEBUG) and self.input_queue.empty():
|
||||
logger.debug("EngineCore waiting for work.")
|
||||
waited = True
|
||||
|
@ -81,12 +81,10 @@ class Executor(ExecutorBase):
|
||||
pass
|
||||
|
||||
def determine_available_memory(self) -> list[int]: # in bytes
|
||||
output = self.collective_rpc("determine_available_memory")
|
||||
return output
|
||||
return self.collective_rpc("determine_available_memory")
|
||||
|
||||
def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]:
|
||||
output = self.collective_rpc("get_kv_cache_spec")
|
||||
return output
|
||||
return self.collective_rpc("get_kv_cache_spec")
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
@ -96,6 +94,9 @@ class Executor(ExecutorBase):
|
||||
args=(scheduler_output, ))
|
||||
return output[0]
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch")
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
output = self.collective_rpc("take_draft_token_ids")
|
||||
return output[0]
|
||||
|
@ -191,6 +191,10 @@ class MultiprocExecutor(Executor):
|
||||
outputs, self.output_rank)
|
||||
return self.kv_output_aggregator.aggregate(outputs, self.output_rank)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.collective_rpc("execute_dummy_batch",
|
||||
unique_reply_rank=self.output_rank)
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
# OPTIMIZATION: Get output only from a single worker (output_rank)
|
||||
outputs = self.collective_rpc("take_draft_token_ids",
|
||||
@ -242,12 +246,17 @@ class MultiprocExecutor(Executor):
|
||||
dequeue_timeout = None if deadline is None else (
|
||||
deadline - time.monotonic())
|
||||
|
||||
if non_block:
|
||||
if self.io_thread_pool is not None:
|
||||
# We must consume worker_response_mq from a single thread.
|
||||
result = self.io_thread_pool.submit( # type: ignore
|
||||
get_response, w, dequeue_timeout, self.shutdown_event)
|
||||
else:
|
||||
if not non_block:
|
||||
result = result.result()
|
||||
elif not non_block:
|
||||
result = get_response(w, dequeue_timeout)
|
||||
|
||||
else:
|
||||
raise RuntimeError("non_block can only be used when"
|
||||
" max_concurrent_batches > 1")
|
||||
responses.append(result)
|
||||
|
||||
return responses
|
||||
|
@ -354,36 +354,37 @@ class Worker(WorkerBase):
|
||||
scheduler_output: "SchedulerOutput",
|
||||
) -> Optional[ModelRunnerOutput]:
|
||||
intermediate_tensors = None
|
||||
if not get_pp_group().is_first_rank:
|
||||
forward_pass = scheduler_output.total_num_scheduled_tokens > 0
|
||||
if forward_pass and not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
|
||||
output = self.model_runner.execute_model(scheduler_output,
|
||||
intermediate_tensors)
|
||||
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
if parallel_config.distributed_executor_backend != "external_launcher" \
|
||||
and not get_pp_group().is_last_rank:
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
if isinstance(output, ModelRunnerOutput):
|
||||
return output
|
||||
|
||||
assert isinstance(output, ModelRunnerOutput)
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
parallel_config = self.vllm_config.parallel_config
|
||||
assert parallel_config.distributed_executor_backend != (
|
||||
"external_launcher") and not get_pp_group().is_last_rank
|
||||
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
|
||||
kv_connector_output = output.kv_connector_output
|
||||
if not kv_connector_output:
|
||||
return None
|
||||
|
||||
# In case of PP with kv transfer, we need to pass through the
|
||||
# kv_connector_output
|
||||
if (not kv_connector_output.finished_sending
|
||||
and not kv_connector_output.finished_recving):
|
||||
return EMPTY_MODEL_RUNNER_OUTPUT
|
||||
|
||||
output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
output.kv_connector_output = kv_connector_output
|
||||
return output
|
||||
|
||||
def take_draft_token_ids(self) -> Optional[DraftTokenIds]:
|
||||
|
Reference in New Issue
Block a user