[P/D]kv_output_aggregator support P TP > D TP (#23917)

Signed-off-by: LCAIZJ <leichao139636@163.com>
Co-authored-by: leichao.lc <leichao.lc@antgroup.com>
This commit is contained in:
Chao Lei
2025-09-15 17:36:06 +08:00
committed by GitHub
parent a0d8b9738d
commit 8de261b04a
5 changed files with 21 additions and 5 deletions

View File

@ -355,3 +355,14 @@ class KVConnectorBase_V1(ABC):
raise TypeError("get_required_kvcache_layout should not be called "
"on the abstract base class")
return None
def get_finished_count(self) -> Optional[int]:
"""
Get the count of requests expected to complete send/receive operations
via this connector.
Returns:
int: expected sending or receiving completion count.
"""
return None

View File

@ -13,6 +13,7 @@ from typing_extensions import TypeVar
import vllm.platforms
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.layers.sampler import SamplerOutput
@ -54,6 +55,7 @@ class ExecutorBase(ABC):
self._init_executor()
self.is_sleeping = False
self.sleeping_tags: set[str] = set()
self.kv_output_aggregator = None
@abstractmethod
def _init_executor(self) -> None:
@ -252,6 +254,11 @@ class ExecutorBase(ABC):
exception."""
self.check_health()
def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None:
"""Init KVOutputAggregator"""
self.kv_output_aggregator = KVOutputAggregator(
finished_count or self.parallel_config.world_size)
class DistributedExecutorBase(ExecutorBase):
"""Abstract superclass of distributed executor implementations."""

View File

@ -128,6 +128,9 @@ class EngineCore:
log_stats=self.log_stats,
)
self.use_spec_decode = vllm_config.speculative_config is not None
if self.scheduler.connector is not None: # type: ignore
self.model_executor.init_kv_output_aggregator(
self.scheduler.connector.get_finished_count()) # type: ignore
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
self.mm_receiver_cache = engine_receiver_cache_from_config(

View File

@ -26,7 +26,6 @@ from vllm.distributed import (destroy_distributed_environment,
destroy_model_parallel)
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
MessageQueue)
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
get_pp_group, get_tp_group)
from vllm.executor.multiproc_worker_utils import (
@ -135,8 +134,6 @@ class MultiprocExecutor(Executor):
self.output_rank = self._get_output_rank()
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
def start_worker_monitor(self):
workers = self.workers

View File

@ -51,8 +51,6 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
# KV connector setup
self.has_connector = self.vllm_config.kv_transfer_config is not None
self.kv_output_aggregator = KVOutputAggregator(
self.parallel_config.world_size)
@property
def max_concurrent_batches(self) -> int: