diff --git a/tests/v1/engine/test_engine_core.py b/tests/v1/engine/test_engine_core.py index e6f7ebf259..98265c6349 100644 --- a/tests/v1/engine/test_engine_core.py +++ b/tests/v1/engine/test_engine_core.py @@ -306,17 +306,17 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): # Schedule Batch 1: (10, req0) assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 1 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 10 # num_computed_tokens should have been updated immediately. assert engine_core.scheduler.requests[ req0.request_id].num_computed_tokens == 10 # Schedule Batch 2: (2, req0), (8, req1) - assert engine_core.step_with_batch_queue()[0] is None - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert engine_core.step_with_batch_queue()[0] == {} + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 2 assert scheduler_output.num_scheduled_tokens["1"] == 8 # num_computed_tokens should have been updated immediately. @@ -325,42 +325,32 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): assert engine_core.scheduler.get_num_unfinished_requests() == 2 - # Batch queue is full. Finish Batch 1. - engine_core.step_with_batch_queue() - - # Schedule Batch 3: (4, req1). Note that req0 cannot be scheduled + # Finish Batch 1 and schedule Batch 3: (4, req1). + # Note that req0 cannot be scheduled # because it is in the decoding stage now. engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + assert len(engine_core.batch_queue) == 1 + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["1"] == 4 - # Batch queue is full. Finish Batch 2. Get first token of req0. + # Finish Batch 2. Get first token of req0. + # Schedule Batch 4: (1, req0). output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req0.request_id].num_tokens == 13 - - # Schedule Batch 4: (1, req0). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["0"] == 1 - # Batch queue is full. Finish Batch 3. Get first token of req1. + # Finish Batch 3. Get first token of req1. Schedule Batch 5: (1, req1). output = engine_core.step_with_batch_queue()[0].get(0) assert output is not None assert len(output.outputs) == 1 assert engine_core.scheduler.requests[req1.request_id].num_tokens == 13 - - # Schedule Batch 5: (1, req1). - engine_core.step_with_batch_queue() - assert engine_core.batch_queue.qsize() == 2 - scheduler_output = engine_core.batch_queue.queue[-1][1] + scheduler_output = engine_core.batch_queue[-1][1] assert scheduler_output.num_scheduled_tokens["1"] == 1 # Loop until req0 is finished. - step = 0 req_id = 0 expected_num_tokens = [ engine_core.scheduler.requests["0"].num_tokens + 1, @@ -368,19 +358,14 @@ def test_engine_core_concurrent_batches(monkeypatch: pytest.MonkeyPatch): ] while engine_core.scheduler.get_num_unfinished_requests() == 2: output = engine_core.step_with_batch_queue()[0] - if step % 2 == 0: - # Even steps consumes an output. - assert output is not None - assert len(output[0].outputs) == 1 - if req_id in engine_core.scheduler.requests: - assert engine_core.scheduler.requests[ - req_id].num_tokens == expected_num_tokens[req_id] - expected_num_tokens[req_id] += 1 - req_id = (req_id + 1) % 2 - else: - # Odd steps schedules a new batch. - assert output is None - step += 1 + # Every step consumes an output. + assert output is not None + assert len(output[0].outputs) == 1 + if req_id in engine_core.scheduler.requests: + assert engine_core.scheduler.requests[ + req_id].num_tokens == expected_num_tokens[req_id] + expected_num_tokens[req_id] += 1 + req_id = (req_id + 1) % 2 @multi_gpu_test(num_gpus=2) diff --git a/tests/v1/test_async_llm_dp.py b/tests/v1/test_async_llm_dp.py index c2610a87ac..32da58011b 100644 --- a/tests/v1/test_async_llm_dp.py +++ b/tests/v1/test_async_llm_dp.py @@ -75,9 +75,10 @@ async def generate( ], ) @pytest.mark.parametrize("data_parallel_backend", ["mp", "ray"]) +@pytest.mark.parametrize("async_scheduling", [True, False]) @pytest.mark.asyncio -async def test_load(output_kind: RequestOutputKind, - data_parallel_backend: str): +async def test_load(output_kind: RequestOutputKind, data_parallel_backend: str, + async_scheduling: bool): stats_loggers = {} @@ -105,6 +106,7 @@ async def test_load(output_kind: RequestOutputKind, prompt = "This is a test of data parallel" engine_args.data_parallel_backend = data_parallel_backend + engine_args.async_scheduling = async_scheduling engine = AsyncLLM.from_engine_args(engine_args, stat_loggers=[SimpleStatsLogger]) after.callback(engine.shutdown) diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7abaffa54c..4b2a15afb6 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -10,6 +10,7 @@ import msgspec import vllm.platforms from vllm.config import ParallelConfig +from vllm.distributed import get_pp_group from vllm.executor.msgspec_utils import decode_hook, encode_hook from vllm.logger import init_logger from vllm.platforms import current_platform @@ -136,6 +137,11 @@ try: scheduler_output, intermediate_tensors) if isinstance(output, IntermediateTensors): output = scheduler_output, output + elif not get_pp_group().is_last_rank: + # Case where there are no scheduled requests + # but may still be finished requests. + assert not output or not output.req_ids + output = scheduler_output, None return output def override_env_vars(self, vars: Dict[str, str]): diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 785cbc9d8d..922c06b44b 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -138,12 +138,12 @@ class EngineCore: # schedule and execute batches, and is required by pipeline parallelism # to eliminate pipeline bubbles. self.batch_queue_size = self.model_executor.max_concurrent_batches - self.batch_queue: Optional[queue.Queue[tuple[Future[ModelRunnerOutput], - SchedulerOutput]]] = None + self.batch_queue: Optional[deque[tuple[Future[ModelRunnerOutput], + SchedulerOutput]]] = None if self.batch_queue_size > 1: logger.info("Batch queue is enabled with size %d", self.batch_queue_size) - self.batch_queue = queue.Queue(self.batch_queue_size) + self.batch_queue = deque(maxlen=self.batch_queue_size) self.request_block_hasher: Optional[Callable[[Request], list[BlockHash]]] = None @@ -319,41 +319,43 @@ class EngineCore: batch in the job queue is finished. 3. Update the scheduler from the output. """ - assert self.batch_queue is not None + batch_queue = self.batch_queue + assert batch_queue is not None - engine_core_outputs = None - scheduler_output = None # Try to schedule a new batch if the batch queue is not full, but # the scheduler may return an empty batch if all requests are scheduled. # Note that this is not blocking. - if not self.batch_queue.full(): + assert len(batch_queue) < self.batch_queue_size + + model_executed = False + if self.scheduler.has_requests(): scheduler_output = self.scheduler.schedule() - if scheduler_output.total_num_scheduled_tokens > 0: - future = self.model_executor.execute_model(scheduler_output) - self.batch_queue.put_nowait( - (future, scheduler_output)) # type: ignore + future = self.model_executor.execute_model(scheduler_output) + batch_queue.appendleft( + (future, scheduler_output)) # type: ignore[arg-type] - scheduled_batch = (scheduler_output is not None - and scheduler_output.total_num_scheduled_tokens > 0) + model_executed = scheduler_output.total_num_scheduled_tokens > 0 + if model_executed and len(batch_queue) < self.batch_queue_size \ + and not batch_queue[-1][0].done(): + # Don't block on next worker response unless the queue is full + # or there are no more requests to schedule. + return None, True - # If no more requests can be scheduled and the job queue is not empty, - # block until the first batch in the job queue is finished. - # TODO(comaniac): Ideally we should peek the first batch in the - # job queue to check if it's finished before scheduling a new batch, - # but peeking the first element in a queue is not thread-safe, - # so we need more work. - if not scheduled_batch and not self.batch_queue.empty(): - future, scheduler_output = self.batch_queue.get_nowait() + elif not batch_queue: + # Queue is empty. We should not reach here since this method should + # only be called when the scheduler contains requests or the queue + # is non-empty. + return None, False - # Blocking until the first result is available. - model_output = self.execute_model_with_error_logging( - lambda _: future.result(), scheduler_output) + # Block until the next result is available. + future, scheduler_output = batch_queue.pop() + model_output = self.execute_model_with_error_logging( + lambda _: future.result(), scheduler_output) - self.batch_queue.task_done() - engine_core_outputs = (self.scheduler.update_from_output( - scheduler_output, model_output)) + engine_core_outputs = self.scheduler.update_from_output( + scheduler_output, model_output) - return engine_core_outputs, scheduled_batch + return engine_core_outputs, model_executed def shutdown(self): self.structured_output_manager.clear_backend() @@ -388,7 +390,7 @@ class EngineCore: return self.model_executor.is_sleeping def execute_dummy_batch(self): - self.model_executor.collective_rpc("execute_dummy_batch") + self.model_executor.execute_dummy_batch() def add_lora(self, lora_request: LoRARequest) -> bool: return self.model_executor.add_lora(lora_request) @@ -733,7 +735,8 @@ class EngineCoreProc(EngineCore): """Exits when an engine step needs to be performed.""" waited = False - while not self.engines_running and not self.scheduler.has_requests(): + while not self.engines_running and not self.scheduler.has_requests() \ + and not self.batch_queue: if logger.isEnabledFor(DEBUG) and self.input_queue.empty(): logger.debug("EngineCore waiting for work.") waited = True diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 4be2f74177..68408a0b8a 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -81,12 +81,10 @@ class Executor(ExecutorBase): pass def determine_available_memory(self) -> list[int]: # in bytes - output = self.collective_rpc("determine_available_memory") - return output + return self.collective_rpc("determine_available_memory") def get_kv_cache_specs(self) -> list[dict[str, KVCacheSpec]]: - output = self.collective_rpc("get_kv_cache_spec") - return output + return self.collective_rpc("get_kv_cache_spec") def execute_model( self, @@ -96,6 +94,9 @@ class Executor(ExecutorBase): args=(scheduler_output, )) return output[0] + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch") + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: output = self.collective_rpc("take_draft_token_ids") return output[0] diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index 15b88a2128..12e79ff165 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -191,6 +191,10 @@ class MultiprocExecutor(Executor): outputs, self.output_rank) return self.kv_output_aggregator.aggregate(outputs, self.output_rank) + def execute_dummy_batch(self) -> None: + self.collective_rpc("execute_dummy_batch", + unique_reply_rank=self.output_rank) + def take_draft_token_ids(self) -> Optional[DraftTokenIds]: # OPTIMIZATION: Get output only from a single worker (output_rank) outputs = self.collective_rpc("take_draft_token_ids", @@ -242,12 +246,17 @@ class MultiprocExecutor(Executor): dequeue_timeout = None if deadline is None else ( deadline - time.monotonic()) - if non_block: + if self.io_thread_pool is not None: + # We must consume worker_response_mq from a single thread. result = self.io_thread_pool.submit( # type: ignore get_response, w, dequeue_timeout, self.shutdown_event) - else: + if not non_block: + result = result.result() + elif not non_block: result = get_response(w, dequeue_timeout) - + else: + raise RuntimeError("non_block can only be used when" + " max_concurrent_batches > 1") responses.append(result) return responses diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c252193313..2088bfff5b 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -354,36 +354,37 @@ class Worker(WorkerBase): scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: intermediate_tensors = None - if not get_pp_group().is_first_rank: + forward_pass = scheduler_output.total_num_scheduled_tokens > 0 + if forward_pass and not get_pp_group().is_first_rank: intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( all_gather_group=get_tp_group())) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) - - parallel_config = self.vllm_config.parallel_config - if parallel_config.distributed_executor_backend != "external_launcher" \ - and not get_pp_group().is_last_rank: - assert isinstance(output, IntermediateTensors) - get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) - - kv_connector_output = output.kv_connector_output - if not kv_connector_output: - return None - - # In case of PP with kv transfer, we need to pass through the - # kv_connector_output - if (not kv_connector_output.finished_sending - and not kv_connector_output.finished_recving): - return EMPTY_MODEL_RUNNER_OUTPUT - - output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) - output.kv_connector_output = kv_connector_output + if isinstance(output, ModelRunnerOutput): return output - assert isinstance(output, ModelRunnerOutput) + assert isinstance(output, IntermediateTensors) + parallel_config = self.vllm_config.parallel_config + assert parallel_config.distributed_executor_backend != ( + "external_launcher") and not get_pp_group().is_last_rank + + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + + kv_connector_output = output.kv_connector_output + if not kv_connector_output: + return None + + # In case of PP with kv transfer, we need to pass through the + # kv_connector_output + if (not kv_connector_output.finished_sending + and not kv_connector_output.finished_recving): + return EMPTY_MODEL_RUNNER_OUTPUT + + output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) + output.kv_connector_output = kv_connector_output return output def take_draft_token_ids(self) -> Optional[DraftTokenIds]: