[BugFix][Nixl][PD] Fix heterogenous TP (#22663)

Signed-off-by: NickLucche <nlucches@redhat.com>
Co-authored-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-08-12 14:37:30 +02:00
committed by GitHub
parent 767e63b860
commit d030b01548
2 changed files with 31 additions and 17 deletions

View File

@ -4,13 +4,17 @@
import importlib
from typing import TYPE_CHECKING, Callable
# yapf: disable
import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.kv_transfer.kv_connector.base import (
KVConnectorBase, KVConnectorBaseType)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
from vllm.logger import init_logger
# yapf: enable
if TYPE_CHECKING:
from vllm.config import VllmConfig
from vllm.config import KVTransferConfig, VllmConfig
logger = init_logger(__name__)
@ -42,17 +46,7 @@ class KVConnectorFactory:
f"but found {envs.VLLM_USE_V1=}")
kv_transfer_config = config.kv_transfer_config
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
assert issubclass(connector_cls, KVConnectorBase)
connector_cls = cls.get_connector_class(kv_transfer_config)
logger.info("Creating v1 connector with name: %s and engine_id: %s",
connector_cls.__name__, kv_transfer_config.engine_id)
# NOTE(Kuntai): v1 connector is explicitly separated into two roles.
@ -65,6 +59,23 @@ class KVConnectorFactory:
# We build separately to enforce strict separation
return connector_cls(config, role)
@classmethod
def get_connector_class(
cls, kv_transfer_config: "KVTransferConfig"
) -> type[KVConnectorBaseType]:
"""Get the connector class by name."""
connector_name = kv_transfer_config.kv_connector
if connector_name in cls._registry:
connector_cls = cls._registry[connector_name]()
else:
connector_module_path = kv_transfer_config.kv_connector_module_path
if connector_module_path is None:
raise ValueError(
f"Unsupported connector type: {connector_name}")
connector_module = importlib.import_module(connector_module_path)
connector_cls = getattr(connector_module, connector_name)
return connector_cls
# Register various connectors here.
# The registration should not be done in each individual file, as we want to

View File

@ -13,8 +13,8 @@ import torch
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1)
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.logger import init_logger
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
@ -106,8 +106,9 @@ def get_kv_connector_cache_layout():
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if kv_config is not None:
required_kvcache_layout = (
KVConnectorBase_V1.get_required_kvcache_layout(vllm_config))
connector_cls = KVConnectorFactory.get_connector_class(kv_config)
required_kvcache_layout = connector_cls.get_required_kvcache_layout(
vllm_config)
if required_kvcache_layout is not None:
return required_kvcache_layout
logger.info_once("Connectors do not specify a " \
@ -143,6 +144,8 @@ class KVOutputAggregator:
finished_recving = set[str]()
for output in outputs:
output = output.kv_connector_output
if not output:
continue
update_finished_set(output.finished_sending,
self._send_remaining_count, finished_sending)
update_finished_set(output.finished_recving,