mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix][Nixl][PD] Fix heterogenous TP (#22663)
Signed-off-by: NickLucche <nlucches@redhat.com> Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -4,13 +4,17 @@
|
||||
import importlib
|
||||
from typing import TYPE_CHECKING, Callable
|
||||
|
||||
# yapf: disable
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import (
|
||||
KVConnectorBase, KVConnectorBaseType)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.logger import init_logger
|
||||
|
||||
# yapf: enable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -42,17 +46,7 @@ class KVConnectorFactory:
|
||||
f"but found {envs.VLLM_USE_V1=}")
|
||||
|
||||
kv_transfer_config = config.kv_transfer_config
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
assert issubclass(connector_cls, KVConnectorBase)
|
||||
connector_cls = cls.get_connector_class(kv_transfer_config)
|
||||
logger.info("Creating v1 connector with name: %s and engine_id: %s",
|
||||
connector_cls.__name__, kv_transfer_config.engine_id)
|
||||
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
|
||||
@ -65,6 +59,23 @@ class KVConnectorFactory:
|
||||
# We build separately to enforce strict separation
|
||||
return connector_cls(config, role)
|
||||
|
||||
@classmethod
|
||||
def get_connector_class(
|
||||
cls, kv_transfer_config: "KVTransferConfig"
|
||||
) -> type[KVConnectorBaseType]:
|
||||
"""Get the connector class by name."""
|
||||
connector_name = kv_transfer_config.kv_connector
|
||||
if connector_name in cls._registry:
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
else:
|
||||
connector_module_path = kv_transfer_config.kv_connector_module_path
|
||||
if connector_module_path is None:
|
||||
raise ValueError(
|
||||
f"Unsupported connector type: {connector_name}")
|
||||
connector_module = importlib.import_module(connector_module_path)
|
||||
connector_cls = getattr(connector_module, connector_name)
|
||||
return connector_cls
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
# The registration should not be done in each individual file, as we want to
|
||||
|
@ -13,8 +13,8 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, get_current_vllm_config
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1)
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
@ -106,8 +106,9 @@ def get_kv_connector_cache_layout():
|
||||
vllm_config = get_current_vllm_config()
|
||||
kv_config = vllm_config.kv_transfer_config
|
||||
if kv_config is not None:
|
||||
required_kvcache_layout = (
|
||||
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
|
||||
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
|
||||
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
|
||||
vllm_config)
|
||||
if required_kvcache_layout is not None:
|
||||
return required_kvcache_layout
|
||||
logger.info_once("Connectors do not specify a " \
|
||||
@ -143,6 +144,8 @@ class KVOutputAggregator:
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
output = output.kv_connector_output
|
||||
if not output:
|
||||
continue
|
||||
update_finished_set(output.finished_sending,
|
||||
self._send_remaining_count, finished_sending)
|
||||
update_finished_set(output.finished_recving,
|
||||
|
Reference in New Issue
Block a user