diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 03de8d9b92b..9c90fe381bb 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -100,9 +100,8 @@ class PPTestSettings: eager_mode=True, chunked_prefill=False), ], - # only ray is supported for V1 - distributed_backends=["mp", "ray", "ray"], - vllm_major_versions=["0", "0", "1"], + distributed_backends=["mp", "mp", "ray", "ray"], + vllm_major_versions=["0", "1", "0", "1"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, load_format=load_format), @@ -350,6 +349,11 @@ def _compare_tp( # Temporary. Currently when zeromq + SPMD is used, it does not properly # terminate because of a Ray Compiled Graph issue. common_args.append("--disable-frontend-multiprocessing") + elif distributed_backend == "mp": + # Both V0/V1 of multiprocessing executor support PP + pp_env = { + "VLLM_USE_V1": vllm_major_version, + } else: pp_env = None diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d20ef68434c..3a10ed9d763 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1338,11 +1338,10 @@ class EngineArgs: and _warn_or_fallback("Engine in background thread")): return False - # PP is supported on V1 with Ray distributed executor, - # but off for MP distributed executor for now. if (self.pipeline_parallel_size > 1 - and self.distributed_executor_backend != "ray"): - name = "Pipeline Parallelism without Ray distributed executor" + and self.distributed_executor_backend not in ["ray", "mp"]): + name = "Pipeline Parallelism without Ray distributed executor " \ + "or multiprocessing executor" _raise_or_fallback(feature_name=name, recommend_to_remove=False) return False diff --git a/vllm/v1/executor/multiproc_executor.py b/vllm/v1/executor/multiproc_executor.py index cb125bf4bf1..ff449901030 100644 --- a/vllm/v1/executor/multiproc_executor.py +++ b/vllm/v1/executor/multiproc_executor.py @@ -8,7 +8,7 @@ import threading import time import traceback import weakref -from concurrent.futures import Future +from concurrent.futures import Future, ThreadPoolExecutor from dataclasses import dataclass from enum import Enum, auto from functools import partial @@ -53,10 +53,11 @@ class MultiprocExecutor(Executor): self.world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - assert self.world_size == tensor_parallel_size, ( + pp_parallel_size = self.parallel_config.pipeline_parallel_size + assert self.world_size == tensor_parallel_size * pp_parallel_size, ( f"world_size ({self.world_size}) must be equal to the " - f"tensor_parallel_size ({tensor_parallel_size}). " - f"Pipeline parallelism is not yet implemented in v1") + f"tensor_parallel_size ({tensor_parallel_size}) x pipeline" + f"_parallel_size ({pp_parallel_size}). ") # Set multiprocessing envs that are common to V0 and V1 set_multiprocessing_worker_envs(self.parallel_config) @@ -104,6 +105,17 @@ class MultiprocExecutor(Executor): self._ensure_worker_termination( [w.proc for w in unready_workers]) + # For pipeline parallel, we use a thread pool for asynchronous + # execute_model. + self.io_thread_pool: Optional[ThreadPoolExecutor] = None + if self.max_concurrent_batches > 1: + # Note: must use only 1 IO thread to keep dequeue sequence + # from the response queue + self.io_thread_pool = ThreadPoolExecutor( + max_workers=1, thread_name_prefix="mp_exec_io") + + self.output_rank = self._get_output_rank() + def start_worker_monitor(self): workers = self.workers self_ref = weakref.ref(self) @@ -145,7 +157,9 @@ class MultiprocExecutor(Executor): ) -> Union[ModelRunnerOutput, Future[ModelRunnerOutput]]: (output, ) = self.collective_rpc("execute_model", args=(scheduler_output, ), - rank0_reply_only=True, + unique_reply_rank=self.output_rank, + non_block=self.max_concurrent_batches + > 1, timeout=EXECUTE_MODEL_TIMEOUT_S) return output @@ -154,7 +168,8 @@ class MultiprocExecutor(Executor): timeout: Optional[float] = None, args: tuple = (), kwargs: Optional[dict] = None, - rank0_reply_only: bool = False) -> list[Any]: + non_block: bool = False, + unique_reply_rank: Optional[int] = None) -> list[Any]: if self.is_failed: raise RuntimeError("Executor failed.") @@ -171,22 +186,35 @@ class MultiprocExecutor(Executor): send_method = cloudpickle.dumps( method, protocol=pickle.HIGHEST_PROTOCOL) self.rpc_broadcast_mq.enqueue( - (send_method, args, kwargs, rank0_reply_only)) + (send_method, args, kwargs, unique_reply_rank)) - workers = (self.workers[0], ) if rank0_reply_only else self.workers - responses = [None] * len(workers) - for w in workers: - dequeue_timeout = None if deadline is None else ( - deadline - time.monotonic()) + workers = (self.workers[unique_reply_rank], + ) if unique_reply_rank is not None else self.workers + responses = [] + + def get_response(w: WorkerProcHandle, + dequeue_timeout: Optional[float] = None, + cancel_event: Optional[threading.Event] = None): status, result = w.worker_response_mq.dequeue( - timeout=dequeue_timeout, cancel=self.shutdown_event) + timeout=dequeue_timeout, cancel=cancel_event) if status != WorkerProc.ResponseStatus.SUCCESS: raise RuntimeError( f"Worker failed with error '{result}', please check the" " stack trace above for the root cause") + return result - responses[w.rank] = result + for w in workers: + dequeue_timeout = None if deadline is None else ( + deadline - time.monotonic()) + + if non_block: + result = self.io_thread_pool.submit( # type: ignore + get_response, w, dequeue_timeout, self.shutdown_event) + else: + result = get_response(w, dequeue_timeout) + + responses.append(result) return responses except TimeoutError as e: @@ -225,6 +253,11 @@ class MultiprocExecutor(Executor): if not getattr(self, 'shutting_down', False): self.shutting_down = True self.shutdown_event.set() + + if self.io_thread_pool is not None: + self.io_thread_pool.shutdown(wait=False, cancel_futures=True) + self.io_thread_pool = None + for w in self.workers: w.worker_response_mq = None self._ensure_worker_termination([w.proc for w in self.workers]) @@ -235,6 +268,22 @@ class MultiprocExecutor(Executor): self.collective_rpc("check_health", timeout=10) return + @property + def max_concurrent_batches(self) -> int: + return self.parallel_config.pipeline_parallel_size + + def _get_output_rank(self) -> int: + # Only returns ModelRunnerOutput from TP rank=0 and PP rank=-1 + # (the first TP worker of the last PP stage). + # Example: + # Assuming TP=8, PP=4, then the world_size=32 + # 0-7, PP rank 0 + # 8-15, PP rank 1 + # 16-23, PP rank 2 + # 24-31, PP rank 3 + # so world_size - tp_size = 32 - 8 = 24 should be PP rank = -1 (i.e. 3) + return self.world_size - self.parallel_config.tensor_parallel_size + @dataclass class UnreadyWorkerProcHandle: @@ -280,12 +329,14 @@ class WorkerProc: all_kwargs: list[dict] = [ {} for _ in range(vllm_config.parallel_config.world_size) ] + is_driver_worker = ( + rank % vllm_config.parallel_config.tensor_parallel_size == 0) all_kwargs[rank] = { "vllm_config": vllm_config, "local_rank": local_rank, "rank": rank, "distributed_init_method": distributed_init_method, - "is_driver_worker": rank == 0, + "is_driver_worker": is_driver_worker, } wrapper.init_worker(all_kwargs) self.worker = wrapper @@ -455,7 +506,7 @@ class WorkerProc: def worker_busy_loop(self): """Main busy loop for Multiprocessing Workers""" while True: - method, args, kwargs, rank0_only = self.rpc_broadcast_mq.dequeue() + method, args, kwargs, output_rank = self.rpc_broadcast_mq.dequeue() try: if isinstance(method, str): @@ -470,11 +521,11 @@ class WorkerProc: logger.exception("WorkerProc hit an exception.") # exception might not be serializable, so we convert it to # string, only for logging purpose. - if not rank0_only or self.rank == 0: + if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.FAILURE, str(e))) continue - if not rank0_only or self.rank == 0: + if output_rank is None or self.rank == output_rank: self.worker_response_mq.enqueue( (WorkerProc.ResponseStatus.SUCCESS, output)) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 97d8c91b465..8137cb6b9b6 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1016,7 +1016,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> Union[ModelRunnerOutput, torch.Tensor]: + ) -> Union[ModelRunnerOutput, IntermediateTensors]: # Update KVConnector with the KVConnector metadata forward(). if has_kv_transfer_group(): get_kv_transfer_group().bind_connector_metadata( diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index ac6861f93a8..da2ecfc4bcc 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -15,11 +15,12 @@ from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.distributed.kv_transfer import ensure_kv_transfer_initialized -from vllm.distributed.parallel_state import get_pp_group +from vllm.distributed.parallel_state import get_pp_group, get_tp_group from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor import set_random_seed from vllm.platforms import current_platform +from vllm.sequence import IntermediateTensors from vllm.utils import GiB_bytes from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput @@ -266,7 +267,22 @@ class Worker(WorkerBase): self, scheduler_output: "SchedulerOutput", ) -> Optional[ModelRunnerOutput]: - output = self.model_runner.execute_model(scheduler_output) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict( + all_gather_group=get_tp_group())) + + output = self.model_runner.execute_model(scheduler_output, + intermediate_tensors) + + if not get_pp_group().is_last_rank: + assert isinstance(output, IntermediateTensors) + get_pp_group().send_tensor_dict(output.tensors, + all_gather_group=get_tp_group()) + return None + + assert isinstance(output, ModelRunnerOutput) return output if self.is_driver_worker else None def profile(self, is_start: bool = True):