mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[PD] Skip tp_size
exchange with rank0 (#19413)
Signed-off-by: NickLucche <nlucches@redhat.com>
This commit is contained in:
@ -7,6 +7,8 @@ from collections import defaultdict
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
@ -161,7 +163,8 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._hand_shake_latency = hand_shake_latency
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
|
||||
def _nixl_handshake(self, host: str, port: int,
|
||||
remote_tp_size: int) -> dict[int, str]:
|
||||
# Mimic slow _nixl_handshake, as well as bypass zmq communication.
|
||||
time.sleep(self._hand_shake_latency)
|
||||
# These should've been done in register_kv_caches(), called by
|
||||
@ -177,10 +180,10 @@ class FakeNixlConnectorWorker(NixlConnectorWorker):
|
||||
agent_metadata=FakeNixlWrapper.AGENT_METADATA,
|
||||
kv_caches_base_addr=[0],
|
||||
num_blocks=1,
|
||||
tp_size=1,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name,
|
||||
))
|
||||
),
|
||||
remote_tp_size=remote_tp_size)
|
||||
return {0: remote_agent_name}
|
||||
|
||||
|
||||
@ -233,6 +236,8 @@ class TestNixlHandshake:
|
||||
"localhost",
|
||||
"remote_port":
|
||||
1234,
|
||||
"remote_tp_size":
|
||||
1,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
@ -259,13 +264,23 @@ class TestNixlHandshake:
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
@pytest.mark.parametrize("decode_tp_size, prefill_tp_size", [
|
||||
(1, 1),
|
||||
(2, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
])
|
||||
def test_async_load_kv(
|
||||
self,
|
||||
# dist_init is a fixture that initializes the distributed environment.
|
||||
dist_init):
|
||||
self,
|
||||
# Fixture that initializes the distributed environment.
|
||||
dist_init,
|
||||
# Simulate consumer-producer TP sizes.
|
||||
decode_tp_size,
|
||||
prefill_tp_size):
|
||||
"""Test that NixlConnector's start_load_kv should be non-blocking."""
|
||||
|
||||
vllm_config = create_vllm_config()
|
||||
vllm_config.parallel_config.tensor_parallel_size = decode_tp_size
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
@ -280,6 +295,7 @@ class TestNixlHandshake:
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": prefill_tp_size,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
@ -329,6 +345,7 @@ class TestNixlHandshake:
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
|
@ -62,7 +62,6 @@ class NixlAgentMetadata(
|
||||
agent_metadata: bytes
|
||||
kv_caches_base_addr: list[int]
|
||||
num_blocks: int
|
||||
tp_size: int
|
||||
block_len: int
|
||||
attn_backend_name: str
|
||||
|
||||
@ -73,7 +72,8 @@ class ReqMeta:
|
||||
remote_block_ids: list[int]
|
||||
remote_host: str
|
||||
remote_port: int
|
||||
remote_engine_id: EngineId
|
||||
remote_engine_id: str
|
||||
tp_size: int
|
||||
|
||||
|
||||
class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
@ -93,6 +93,8 @@ class NixlConnectorMetadata(KVConnectorMetadata):
|
||||
remote_engine_id=kv_transfer_params["remote_engine_id"],
|
||||
remote_host=kv_transfer_params["remote_host"],
|
||||
remote_port=kv_transfer_params["remote_port"],
|
||||
# P workers don't need to receive tp_size from proxy here.
|
||||
tp_size=kv_transfer_params.get("tp_size", 1),
|
||||
)
|
||||
|
||||
|
||||
@ -330,7 +332,7 @@ class NixlConnectorScheduler:
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
tp_size=self.vllm_config.parallel_config.tensor_parallel_size)
|
||||
|
||||
|
||||
class NixlConnectorWorker:
|
||||
@ -473,7 +475,8 @@ class NixlConnectorWorker:
|
||||
"Connection listener got unexpected message %s", msg)
|
||||
sock.send_multipart((identity, b"", encoded_data))
|
||||
|
||||
def _nixl_handshake(self, host: str, port: int) -> dict[int, str]:
|
||||
def _nixl_handshake(self, host: str, port: int,
|
||||
remote_tp_size: int) -> dict[int, str]:
|
||||
"""Do a NIXL handshake with a remote instance."""
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@ -482,7 +485,7 @@ class NixlConnectorWorker:
|
||||
# a hack to keep us moving. We will switch when moving to etcd
|
||||
# or where we have a single ZMQ socket in the scheduler.
|
||||
|
||||
def handshake(path: str, rank: int) -> tuple[NixlAgentMetadata, str]:
|
||||
def handshake(path: str, rank: int) -> str:
|
||||
# Send query for the request.
|
||||
with zmq_ctx(zmq.REQ, path) as sock:
|
||||
sock.send(GET_META_MSG)
|
||||
@ -492,33 +495,25 @@ class NixlConnectorWorker:
|
||||
got_metadata_time = time.perf_counter()
|
||||
|
||||
# Register Remote agent.
|
||||
remote_agent_name = self.add_remote_agent(metadata, rank)
|
||||
remote_agent_name = self.add_remote_agent(
|
||||
metadata, rank, remote_tp_size)
|
||||
setup_agent_time = time.perf_counter()
|
||||
|
||||
logger.debug("NIXL handshake: get metadata took: %s",
|
||||
got_metadata_time - start_time)
|
||||
logger.debug("NIXL handshake: add agent took: %s",
|
||||
setup_agent_time - got_metadata_time)
|
||||
return metadata, remote_agent_name
|
||||
return remote_agent_name
|
||||
|
||||
# Handshake with remote agent-rank0 first to get the tp_size of remote
|
||||
path = make_zmq_path("tcp", host, port)
|
||||
logger.debug("Querying master rank metadata on path: %s", path)
|
||||
rank_to_agent_name: dict[int, str] = {}
|
||||
metadata, rank_to_agent_name[0] = handshake(path, 0)
|
||||
|
||||
# Handshake only with the other TP remote the current local rank will
|
||||
# Handshake only with the remote TP rank that current local rank will
|
||||
# pull from. With homogeneous TP it happens to be the same rank_i.
|
||||
tp_ratio = self._tp_size[self.engine_id] // metadata.tp_size
|
||||
tp_ratio = self._tp_size[self.engine_id] // remote_tp_size
|
||||
p_remote_rank = self.tp_rank // tp_ratio
|
||||
if p_remote_rank > 0:
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
logger.debug("Querying metadata on path: %s at remote rank %s",
|
||||
path, p_remote_rank)
|
||||
_, rank_to_agent_name[p_remote_rank] = handshake(
|
||||
path, p_remote_rank)
|
||||
|
||||
return rank_to_agent_name
|
||||
path = make_zmq_path("tcp", host, port + p_remote_rank)
|
||||
logger.debug("Querying metadata on path: %s at remote rank %s", path,
|
||||
p_remote_rank)
|
||||
# Remote rank -> agent name.
|
||||
return {p_remote_rank: handshake(path, p_remote_rank)}
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data in nixl."""
|
||||
@ -645,7 +640,6 @@ class NixlConnectorWorker:
|
||||
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
|
||||
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
|
||||
num_blocks=self.num_blocks,
|
||||
tp_size=self.world_size,
|
||||
block_len=self.block_len,
|
||||
attn_backend_name=self.backend_name)
|
||||
ready_event = threading.Event()
|
||||
@ -659,7 +653,8 @@ class NixlConnectorWorker:
|
||||
|
||||
def add_remote_agent(self,
|
||||
nixl_agent_meta: NixlAgentMetadata,
|
||||
remote_tp_rank: int = 0) -> str:
|
||||
remote_tp_rank: int = 0,
|
||||
remote_tp_size: int = 1) -> str:
|
||||
"""
|
||||
Add the remote NIXL agent and prepare the descriptors for reading cache
|
||||
blocks from remote.
|
||||
@ -704,9 +699,9 @@ class NixlConnectorWorker:
|
||||
return self._remote_agents[engine_id][remote_tp_rank]
|
||||
|
||||
if engine_id in self._tp_size:
|
||||
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
|
||||
assert self._tp_size[engine_id] == remote_tp_size
|
||||
else:
|
||||
self._tp_size[engine_id] = nixl_agent_meta.tp_size
|
||||
self._tp_size[engine_id] = remote_tp_size
|
||||
# We may eventually enable this after asserting equality in cache
|
||||
# layout and close outputs.
|
||||
assert nixl_agent_meta.attn_backend_name == self.backend_name
|
||||
@ -756,33 +751,31 @@ class NixlConnectorWorker:
|
||||
# rank. With heterogeneous TP, prepare the descriptors by splitting the
|
||||
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
|
||||
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
|
||||
p_remote_tp_rank = self.tp_rank // tp_ratio
|
||||
# Only register the remote's descriptors if current rank pulls from it.
|
||||
if p_remote_tp_rank == remote_tp_rank:
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and "
|
||||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
||||
self.tp_rank)
|
||||
self.kv_caches_base_addr[
|
||||
engine_id] = nixl_agent_meta.kv_caches_base_addr
|
||||
rank_offset = self.tp_rank % tp_ratio * self.block_len \
|
||||
if not (self.use_mla or is_kv_replicated) else 0
|
||||
# Register all remote blocks, but only the corresponding kv heads.
|
||||
for base_addr in nixl_agent_meta.kv_caches_base_addr:
|
||||
for block_id in range(nixl_agent_meta.num_blocks):
|
||||
block_offset = block_id * nixl_agent_meta.block_len
|
||||
# For each block, grab the heads chunk belonging to rank_i
|
||||
# of size remote_nheads // tp_ratio, which correspond to
|
||||
# self.block_len == remote_block_len//tp_ratio bytes.
|
||||
addr = base_addr + block_offset + rank_offset
|
||||
# (addr, len, device id)
|
||||
blocks_data.append((addr, self.block_len, remote_tp_rank))
|
||||
logger.debug(
|
||||
"Created %s blocks for dst engine %s with remote rank %s and "
|
||||
"local rank %s", len(blocks_data), engine_id, remote_tp_rank,
|
||||
self.tp_rank)
|
||||
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
remote_agent_name, descs)
|
||||
# Register with NIXL.
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[
|
||||
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
|
||||
remote_agent_name, descs)
|
||||
|
||||
return remote_agent_name
|
||||
|
||||
@ -917,7 +910,7 @@ class NixlConnectorWorker:
|
||||
if fut is None:
|
||||
fut = self._handshake_initiation_executor.submit(
|
||||
self._nixl_handshake, meta.remote_host,
|
||||
meta.remote_port)
|
||||
meta.remote_port, meta.tp_size)
|
||||
self._handshake_futures[remote_engine_id] = fut
|
||||
|
||||
def done_callback(f: Future[dict[int, str]],
|
||||
@ -957,13 +950,9 @@ class NixlConnectorWorker:
|
||||
remote_block_ids=meta.remote_block_ids,
|
||||
)
|
||||
|
||||
def _read_blocks(
|
||||
self,
|
||||
local_block_ids: list[int],
|
||||
remote_block_ids: list[int],
|
||||
dst_engine_id: str,
|
||||
request_id: str,
|
||||
):
|
||||
def _read_blocks(self, local_block_ids: list[int],
|
||||
remote_block_ids: list[int], dst_engine_id: str,
|
||||
request_id: str):
|
||||
# NOTE(rob): having the staging blocks be on the READER side is
|
||||
# not going to work well (since we will have to call rearrange tensors).
|
||||
# after we detect the txn is complete (which means we cannot make the
|
||||
|
Reference in New Issue
Block a user