[PD] Skip tp_size exchange with rank0 (#19413)

Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
Nicolò Lucchesi
2025-06-26 05:04:39 +02:00
committed by GitHub
parent 754b00edb3
commit 2582683566
2 changed files with 72 additions and 66 deletions

View File

@ -7,6 +7,8 @@ from collections import defaultdict
from typing import Optional
from unittest.mock import patch
import pytest
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
super().__init__(*args, **kwargs)
self._hand_shake_latency = hand_shake_latency
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
time.sleep(self._hand_shake_latency)
# These should've been done in register_kv_caches(), called by
@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
kv_caches_base_addr=[0],
num_blocks=1,
tp_size=1,
block_len=self.block_len,
attn_backend_name=self.backend_name,
))
),
remote_tp_size=remote_tp_size)
return {0: remote_agent_name}
@ -233,6 +236,8 @@ class TestNixlHandshake:
"localhost",
"remote_port":
1234,
"remote_tp_size":
1,
})
connector.bind_connector_metadata(metadata)
@ -259,13 +264,23 @@ class TestNixlHandshake:
@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper)
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
(1, 1),
(2, 1),
(4, 2),
(4, 4),
])
def test_async_load_kv(
self,
# dist_init is a fixture that initializes the distributed environment.
dist_init):
self,
# Fixture that initializes the distributed environment.
dist_init,
# Simulate consumer-producer TP sizes.
decode_tp_size,
prefill_tp_size):
"""Test that NixlConnector's start_load_kv should be non-blocking."""
vllm_config = create_vllm_config()
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
# Test worker role in decode server.
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
@ -280,6 +295,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": prefill_tp_size,
})
connector.bind_connector_metadata(metadata)
@ -329,6 +345,7 @@ class TestNixlHandshake:
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
"remote_host": "localhost",
"remote_port": 1234,
"remote_tp_size": 1,
})
connector.bind_connector_metadata(metadata)

View File

@ -62,7 +62,6 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
tp_size: int
block_len: int
attn_backend_name: str
@ -73,7 +72,8 @@ class ReqMeta:
remote_block_ids: list[int]
remote_host: str
remote_port: int
remote_engine_id: EngineId
remote_engine_id: str
tp_size: int
class NixlConnectorMetadata(KVConnectorMetadata):
@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata):
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
# P workers don't need to receive tp_size from proxy here.
tp_size=kv_transfer_params.get("tp_size", 1),
)
@ -330,7 +332,7 @@ class NixlConnectorScheduler:
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
tp_size=self.vllm_config.parallel_config.tensor_parallel_size)
class NixlConnectorWorker:
@ -473,7 +475,8 @@ class NixlConnectorWorker:
"Connection listener got unexpected message %s", msg)
sock.send_multipart((identity, b"", encoded_data))
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
def _nixl_handshake(self, host: str, port: int,
remote_tp_size: int) -> dict[int, str]:
"""Do a NIXL handshake with a remote instance."""
start_time = time.perf_counter()
@ -482,7 +485,7 @@ class NixlConnectorWorker:
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
def handshake(path: str, rank: int) -> str:
# Send query for the request.
with zmq_ctx(zmq.REQ, path) as sock:
sock.send(GET_META_MSG)
@ -492,33 +495,25 @@ class NixlConnectorWorker:
got_metadata_time = time.perf_counter()
# Register Remote agent.
remote_agent_name = self.add_remote_agent(metadata, rank)
remote_agent_name = self.add_remote_agent(
metadata, rank, remote_tp_size)
setup_agent_time = time.perf_counter()
logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return metadata, remote_agent_name
return remote_agent_name
# Handshake with remote agent-rank0 first to get the tp_size of remote
path = make_zmq_path("tcp", host, port)
logger.debug("Querying master rank metadata on path: %s", path)
rank_to_agent_name: dict[int, str] = {}
metadata, rank_to_agent_name[0] = handshake(path, 0)
# Handshake only with the other TP remote the current local rank will
# Handshake only with the remote TP rank that current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
p_remote_rank = self.tp_rank // tp_ratio
if p_remote_rank > 0:
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
_, rank_to_agent_name[p_remote_rank] = handshake(
path, p_remote_rank)
return rank_to_agent_name
path = make_zmq_path("tcp", host, port + p_remote_rank)
logger.debug("Querying metadata on path: %s at remote rank %s", path,
p_remote_rank)
# Remote rank -> agent name.
return {p_remote_rank: handshake(path, p_remote_rank)}
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
@ -645,7 +640,6 @@ class NixlConnectorWorker:
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
tp_size=self.world_size,
block_len=self.block_len,
attn_backend_name=self.backend_name)
ready_event = threading.Event()
@ -659,7 +653,8 @@ class NixlConnectorWorker:
def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata,
remote_tp_rank: int = 0) -> str:
remote_tp_rank: int = 0,
remote_tp_size: int = 1) -> str:
"""
Add the remote NIXL agent and prepare the descriptors for reading cache
blocks from remote.
@ -704,9 +699,9 @@ class NixlConnectorWorker:
return self._remote_agents[engine_id][remote_tp_rank]
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
assert self._tp_size[engine_id] == remote_tp_size
else:
self._tp_size[engine_id] = nixl_agent_meta.tp_size
self._tp_size[engine_id] = remote_tp_size
# We may eventually enable this after asserting equality in cache
# layout and close outputs.
assert nixl_agent_meta.attn_backend_name == self.backend_name
@ -756,33 +751,31 @@ class NixlConnectorWorker:
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_tp_rank = self.tp_rank // tp_ratio
# Only register the remote's descriptors if current rank pulls from it.
if p_remote_tp_rank == remote_tp_rank:
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.tp_rank % tp_ratio * self.block_len \
if not (self.use_mla or is_kv_replicated) else 0
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
# For each block, grab the heads chunk belonging to rank_i
# of size remote_nheads // tp_ratio, which correspond to
# self.block_len == remote_block_len//tp_ratio bytes.
addr = base_addr + block_offset + rank_offset
# (addr, len, device id)
blocks_data.append((addr, self.block_len, remote_tp_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and "
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
self.tp_rank)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)
# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
remote_agent_name, descs)
return remote_agent_name
@ -917,7 +910,7 @@ class NixlConnectorWorker:
if fut is None:
fut = self._handshake_initiation_executor.submit(
self._nixl_handshake, meta.remote_host,
meta.remote_port)
meta.remote_port, meta.tp_size)
self._handshake_futures[remote_engine_id] = fut
def done_callback(f: Future[dict[int, str]],
@ -957,13 +950,9 @@ class NixlConnectorWorker:
remote_block_ids=meta.remote_block_ids,
)
def _read_blocks(
self,
local_block_ids: list[int],
remote_block_ids: list[int],
dst_engine_id: str,
request_id: str,
):
def _read_blocks(self, local_block_ids: list[int],
remote_block_ids: list[int], dst_engine_id: str,
request_id: str):
# NOTE(rob): having the staging blocks be on the READER side is
# not going to work well (since we will have to call rearrange tensors).
# after we detect the txn is complete (which means we cannot make the