mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[P/D]kv_output_aggregator
support P TP > D TP (#23917)
Signed-off-by: LCAIZJ <leichao139636@163.com> Co-authored-by: leichao.lc <leichao.lc@antgroup.com>
This commit is contained in:
@ -355,3 +355,14 @@ class KVConnectorBase_V1(ABC):
|
||||
raise TypeError("get_required_kvcache_layout should not be called "
|
||||
"on the abstract base class")
|
||||
return None
|
||||
|
||||
def get_finished_count(self) -> Optional[int]:
|
||||
"""
|
||||
Get the count of requests expected to complete send/receive operations
|
||||
via this connector.
|
||||
|
||||
Returns:
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
@ -13,6 +13,7 @@ from typing_extensions import TypeVar
|
||||
|
||||
import vllm.platforms
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
@ -54,6 +55,7 @@ class ExecutorBase(ABC):
|
||||
self._init_executor()
|
||||
self.is_sleeping = False
|
||||
self.sleeping_tags: set[str] = set()
|
||||
self.kv_output_aggregator = None
|
||||
|
||||
@abstractmethod
|
||||
def _init_executor(self) -> None:
|
||||
@ -252,6 +254,11 @@ class ExecutorBase(ABC):
|
||||
exception."""
|
||||
self.check_health()
|
||||
|
||||
def init_kv_output_aggregator(self, finished_count: Optional[int]) -> None:
|
||||
"""Init KVOutputAggregator"""
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
finished_count or self.parallel_config.world_size)
|
||||
|
||||
|
||||
class DistributedExecutorBase(ExecutorBase):
|
||||
"""Abstract superclass of distributed executor implementations."""
|
||||
|
@ -128,6 +128,9 @@ class EngineCore:
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
if self.scheduler.connector is not None: # type: ignore
|
||||
self.model_executor.init_kv_output_aggregator(
|
||||
self.scheduler.connector.get_finished_count()) # type: ignore
|
||||
|
||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
||||
|
@ -26,7 +26,6 @@ from vllm.distributed import (destroy_distributed_environment,
|
||||
destroy_model_parallel)
|
||||
from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
||||
MessageQueue)
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
get_pp_group, get_tp_group)
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
@ -135,8 +134,6 @@ class MultiprocExecutor(Executor):
|
||||
|
||||
self.output_rank = self._get_output_rank()
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
self.parallel_config.world_size)
|
||||
|
||||
def start_worker_monitor(self):
|
||||
workers = self.workers
|
||||
|
@ -51,8 +51,6 @@ class RayDistributedExecutor(RayDistributedExecutorV0, Executor):
|
||||
|
||||
# KV connector setup
|
||||
self.has_connector = self.vllm_config.kv_transfer_config is not None
|
||||
self.kv_output_aggregator = KVOutputAggregator(
|
||||
self.parallel_config.world_size)
|
||||
|
||||
@property
|
||||
def max_concurrent_batches(self) -> int:
|
||||
|
Reference in New Issue
Block a user