Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
This commit is contained in:
rshaw@neuralmagic.com
2025-04-06 14:07:43 +00:00
parent 45fa7f9b8e
commit 6de0982dd0
6 changed files with 1138 additions and 0 deletions

108
vllm/core/event_manager.py Normal file
View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
import ctypes
import logging
import uuid
from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64
from typing import Optional
from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
logger = logging.getLogger(__name__)
class DynamoResult:
OK = 0
ERR = 1
class KVCacheEventManager:
def __init__(self, namespace: str, component: str, worker_id: int,
lib_path: str, kv_block_size: int):
self.lib = None
try:
self.lib = ctypes.CDLL(lib_path)
self.lib.dynamo_llm_init.argtypes = [
c_char_p,
c_char_p,
c_int64,
c_uint32,
]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id, kv_block_size
)
if result == DynamoResult.OK:
logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
logger.info("KVCacheEventManager initialization failed!")
except Exception as e:
print(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t
self.lib.dynamo_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t
self.event_id_counter = 0
def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock],
block: PrefixCachingBlock):
token_ids_arr = (ctypes.c_uint32 *
len(block.token_ids))(*block.token_ids)
num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids))
block_hash = (ctypes.c_uint64 * 1)(block.content_hash)
parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash)
if parent is not None else None)
# Publish the event
result = self.lib.dynamo_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
token_ids_arr, # const uint32_t *token_ids
num_block_tokens, # const uintptr_t *num_block_tokens
block_hash, # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
)
if result == DynamoResult.OK:
logger.debug(f"Store - Published KV Event: {block.content_hash}")
else:
logger.debug(
f"Store - Failed to Publish KV Event: {block.content_hash}")
self.event_id_counter += 1
def enqueue_removed_event(self, block_hash: PrefixHash):
result = self.lib.dynamo_kv_event_publish_removed(
self.event_id_counter,
(ctypes.c_uint64 * 1)(block_hash),
1,
)
if result == DynamoResult.OK:
logger.debug(f"Remove - Published KV Event: {block_hash}")
else:
logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
self.event_id_counter += 1

View File

@ -0,0 +1,110 @@
import torch
import triton
import triton.language as tl
@triton.jit
def rearrange_kernel_read(
t1_ptr,
t2_ptr,
N,
B,
H,
C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
curr_n = offsets // block_size
curr_b = offsets // token_size % B
curr_h = offsets // C % H
curr_c = offsets % C
src_pos = offsets
tp_group = curr_h * d // H
dst_h = curr_h % (H // d)
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
dst_pos = tensor_subset_size * tp_group + tp_group_offset
tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos))
@triton.jit
def rearrange_kernel_write(
t1_ptr,
t2_ptr,
N,
B,
H,
C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
curr_n = offsets // block_size
curr_b = offsets // token_size % B
curr_h = offsets // C % H
curr_c = offsets % C
src_pos = offsets
tp_group = curr_h * d // H
dst_h = curr_h % (H // d)
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
dst_pos = tensor_subset_size * tp_group + tp_group_offset
tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str):
N, B, H, C = t1.shape
assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
assert H % d == 0, "H must be divisible by d"
block_size = B * H * C
token_size = H * C
tensor_size = N * block_size
tensor_subset_size = tensor_size // d
BLOCK_SIZE = 1024
grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
if direction == "read":
rearrange_kernel_read[grid](
t1, t2,
N, B, H, C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE=BLOCK_SIZE
)
elif direction == "write":
rearrange_kernel_write[grid](
t1, t2,
N, B, H, C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE=BLOCK_SIZE
)
else:
raise ValueError(f"Invalid direction: {direction}")

View File

@ -0,0 +1,379 @@
import torch
from typing import List, Tuple
from vllm.config import VllmConfig
from vllm.logger import init_logger
import msgspec
import time
import uuid
from collections import defaultdict
from .kv_rearrange import rearrange_tensors
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
class NixlMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
engine_id: str
agent_metadata: List[bytes]
kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
num_blocks: int
class DynamoNixlConnector:
def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
self.vllm_config = vllm_config
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer
self.num_layers = None
self.num_blocks = None
self.num_heads = None
self.block_len = None
self.kv_caches = None
self.kv_caches_base_addr = {}
self.kv_cache_shape = {}
self._registered_descs = []
self._remote_agents = {}
self.engine_id = engine_id
self.rank = rank
self._tp_size = {}
self.src_xfer_side_handles = {}
self.dst_xfer_side_handles = defaultdict(dict)
self.dst_num_blocks = {}
self._transfers = defaultdict(list)
self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
@property
def agent_name(self):
return self.nixl_wrapper.name
def register_kv_caches(self, kv_caches: List[torch.Tensor]):
_, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
self.num_layers = len(kv_caches)
self.num_blocks = num_blocks
self.num_heads = num_heads
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
for key_cache, value_cache in kv_caches:
base_addr = key_cache.data_ptr()
region_len = 2 * num_blocks * self.block_len
caches_data.append((base_addr, region_len, self.rank, ""))
kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
self._registered_descs.append(descs)
def get_agent_metadata(self):
return self.nixl_wrapper.get_agent_metadata()
def shutdown(self):
for descs_list in self._registered_descs:
self.nixl_wrapper.deregister_memory(descs_list)
for agent_names in self._remote_agents.values():
for agent_name in agent_names:
self.nixl_wrapper.remove_remote_agent(agent_name)
for src_xfer_side_handle in self.src_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
def _get_ranges(self, block_ids):
# This function should return a list of ranges of block ids that are contiguous
# For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]]
# The ranges are sorted by the starting block id
# The function should also make sure that the block ids are contiguous
# If the block ids are not contiguous, the function should raise an error
ranges = []
for i in range(len(block_ids)):
if i == 0 or block_ids[i] != block_ids[i-1] + 1:
ranges.append([block_ids[i], block_ids[i]])
else:
ranges[-1][1] = block_ids[i]
return ranges
def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None):
if layer_ids == "all":
layer_ids = list(range(self.num_layers))
if block_ids == "all":
block_ids = list(range(self.num_blocks))
descs_ids = []
if i is not None:
num_blocks = self.num_blocks
for layer_id in layer_ids:
for is_value in [0, 1]:
staging_range_idx = 0
for block_id in block_ids:
if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
staging_range_idx += 1
start_offset = staging_ranges[staging_range_idx][0]
i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
else:
num_blocks = self.dst_num_blocks[engine_id]
for layer_id in layer_ids:
for is_value in [0, 1]:
for block_id in block_ids:
descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
return descs_ids
def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
# This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
# For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
# The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]])
src_overlapping_ranges, dst_overlapping_ranges = [], []
original_src_ranges = []
org_src_range = tuple(src_ranges[0])
src_idx, dst_idx = 0, 0
while src_idx < len(src_ranges) and dst_idx < len(dst_ranges):
src_range = src_ranges[src_idx]
dst_range = dst_ranges[dst_idx]
# Calculate the length of each range
src_len = src_range[-1] - src_range[0] + 1
dst_len = dst_range[-1] - dst_range[0] + 1
# If ranges have the same length, add them directly
if src_len == dst_len:
src_overlapping_ranges.append([src_range[0], src_range[-1]])
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
original_src_ranges.append(org_src_range)
src_idx += 1
dst_idx += 1
if src_idx < len(src_ranges):
org_src_range = tuple(src_ranges[src_idx])
# If source range is longer, split it
elif src_len > dst_len:
src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1])
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
original_src_ranges.append(org_src_range)
# Update source range for next iteration
src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]]
dst_idx += 1
# If destination range is longer, split it
else: # src_len < dst_len
src_overlapping_ranges.append([src_range[0], src_range[-1]])
dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1])
original_src_ranges.append(org_src_range)
# Update destination range for next iteration
dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]]
src_idx += 1
if src_idx < len(src_ranges):
org_src_range = tuple(src_ranges[src_idx])
if return_original_src_ranges:
return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges
return src_overlapping_ranges, dst_overlapping_ranges
def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id):
logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id)
assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids)
if len(local_block_ids) == 0:
logger.debug("No blocks to read")
return
start_time = time.perf_counter()
local_ranges = self._get_ranges(local_block_ids)
staging_ranges = self._get_ranges(staging_block_ids)
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
handles = []
logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
create_xfer_start_time = time.perf_counter()
for i in range(tp_multiplier):
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids,
remote_xfer_side_handle, remote_block_descs_ids,
"")
handles.append(handle)
status = self.nixl_wrapper.transfer(handle)
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
transfer_start_time = time.perf_counter()
for handle in handles:
while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE":
if status == "PROC":
time.sleep(0.001)
else:
raise RuntimeError("Read transfer failed with state %s", status)
# self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000)
rearrange_start_time = time.perf_counter()
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
for kv_cache in self.kv_caches:
for cache in kv_cache:
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg):
logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg)
# hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
# isl[-1] token is calculated in the first iteration in decode.
# If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
# one less block due to the missing last token.
remote_block_ids = remote_block_ids[:len(local_block_ids)]
assert len(staging_block_ids) == len(local_block_ids)
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
if len(local_block_ids) == 0:
logger.debug("No blocks to write")
for i in range(tp_multiplier):
self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg)
return
start_time = time.perf_counter()
local_ranges = self._get_ranges(local_block_ids)
staging_ranges = self._get_ranges(staging_block_ids)
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
for kv_cache in self.kv_caches:
for cache in kv_cache:
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
create_xfer_start_time = time.perf_counter()
# getting block descs ids
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
for i in range(tp_multiplier):
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids,
remote_xfer_side_handle, remote_block_descs_ids,
notify_msg)
self._transfers[notify_msg].append(handle)
status = self.nixl_wrapper.transfer(handle)
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
transfer_start_time = time.perf_counter()
logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000)
def get_notifs(self):
return self.nixl_wrapper.update_notifs()
def get_new_notifs(self):
return self.nixl_wrapper.get_new_notifs()
def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
self._tp_size[engine_id] = agent_tp
agent_names = []
for agent_meta in agent_metadata:
agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
agent_names.append(agent_name)
self._remote_agents[engine_id] = agent_names
self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id]
assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
dst_block_len = self.block_len // tp_multiplier
if tp_multiplier not in self.src_xfer_side_handles:
# create descs and xfer side handles
blocks_data = []
for layer_id in range(self.num_layers):
for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
for i in range(tp_multiplier):
tp_multiplier_offset = i * dst_block_len
blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs)
# create dst xfer side handles
self.dst_num_blocks[engine_id] = num_blocks
for i in range(tp_multiplier):
blocks_data = []
for layer_id in range(self.num_layers):
for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]:
for block_id in range(num_blocks):
block_offset = block_id * dst_block_len
blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
return agent_names
def get_done_tranfers(self) -> List[str]:
done_req_ids = []
for req_id, handles in self._transfers.items():
running_reqs = []
for handle in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors?
continue
if xfer_state == "PROC":
running_reqs.append(handle)
else:
raise RuntimeError("Transfer failed with state %s", xfer_state)
if len(running_reqs) == 0:
done_req_ids.append(req_id)
else:
self._transfers[req_id] = running_reqs
return done_req_ids

View File

@ -0,0 +1,350 @@
# SPDX-License-Identifier: Apache-2.0
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class DynamoConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
world_group,
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.rank = rank
if self.config.kv_connector != "DynamoNcclConnector":
raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class")
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
PyNcclPipe)
from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import (
DynamoNcclDataPlane)
logger.info(
"Initializing DynamoNcclConnector under kv_transfer_config %s",
self.config)
self.lookup_buffer_size = self.config.kv_buffer_size
self.producer_data_pipe: PyNcclPipe
self.consumer_data_pipe: PyNcclPipe
self.producer_signal_pipe: PyNcclPipe
self.consumer_signal_pipe: PyNcclPipe
self._broadcast_and_enhance_kv_config(rank, config, world_group)
self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
self.tp_size = config.parallel_config.tensor_parallel_size
# 2 pipes for every rank in the world
if self.config.is_kv_producer:
port_offset_base = rank + 1
else:
port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
self.data_pipe = PyNcclPipe(
kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.data_plane = DynamoNcclDataPlane(
data_pipe=self.data_pipe,
port=self._get_data_plane_port(self.global_kv_rank),
)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
model_config = model_executable.model.config
is_deepseek = "deepseek" in model_config.architectures[0].lower()
if not is_deepseek:
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
else:
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(4.5 * hidden_size / num_attention_heads)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
current_request_id = request_ids[idx]
decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
for target_rank in range(self.config.tensor_parallel_multiplier):
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
head_start = target_rank * num_heads_per_rank
head_end = head_start + num_heads_per_rank
if not is_deepseek:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
else:
key_cache = kv_cache
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(torch.empty(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
decode_global_rank = decode_first_global_rank + target_rank
decode_port = self._get_data_plane_port(decode_global_rank)
partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
self._send(decode_hostname, decode_port, current_request_id, keys, values,
partial_hidden_or_intermediate_states)
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
start_pos_list = []
model_config = model_executable.model.config
is_deepseek = "deepseek" in model_config.architectures[0].lower()
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
current_request_id = request_ids[idx]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self._recv(current_request_id)
keys: torch.Tensor = ret[0]
values: torch.Tensor = ret[1]
hidden: torch.Tensor = ret[2]
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
if not is_deepseek:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
else:
key_cache = kv_cache
copy_from =keys[i - model_executable.model.start_layer].to(
key_cache.device)
kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.data_pipe.close()
# self.data_plane.close()
@staticmethod
def parse_request_id(request_id: str) -> Tuple[str, int]:
# Regular expression to match the string hostname and integer decode_kv_rank
pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
decode_hostname = match.group(1)
decode_rank = int(match.group(2))
return decode_hostname, decode_rank
raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
remote_address = f"{hostname}:{port}"
self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
keys = self.data_plane.recv_tensor(f"{request_id}_keys")
values = self.data_plane.recv_tensor(f"{request_id}_values")
hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
return keys, values, hidden
def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
if kv_rank < config.kv_producers_parallel_size:
return kv_rank
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
if kv_rank <= config.kv_producers_parallel_size:
return kv_rank * config.kv_producers_tensor_parallel_size + rank
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
def _get_data_plane_port(self, global_kv_rank: int) -> int:
return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
if rank == 0:
config_group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port,
rank=self.config.kv_rank,
world_size=self.config.kv_parallel_size,
)
parallel_configs = config_group.all_gather_obj({
"kv_role": self.config.kv_role,
"tensor_parallel_size": config.parallel_config.tensor_parallel_size,
"pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
})
logger.debug("parallel_configs: %s", parallel_configs)
kv_config_enhanced = {
"kv_producers_tensor_parallel_size": None,
"kv_consumers_tensor_parallel_size": None,
"kv_producers_pipeline_parallel_size": None,
"kv_consumers_pipeline_parallel_size": None,
"kv_producers_parallel_size": 0,
}
for parallel_config in parallel_configs:
kv_role = parallel_config["kv_role"]
assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
if kv_role == "kv_producer":
kv_config_enhanced["kv_producers_parallel_size"] += 1
if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
else:
assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
world_group.broadcast_object(kv_config_enhanced)
else:
kv_config_enhanced = world_group.broadcast_object()
logger.info("kv_config_enhanced: %s", kv_config_enhanced)
self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]

View File

@ -0,0 +1,124 @@
import logging
import threading
import typing
import zmq
import socket
import time
import torch
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
logger = logging.getLogger(__name__)
class DynamoNcclDataPlane:
def __init__(
self,
data_pipe: PyNcclPipe,
hostname: str = "",
port: int = 0,
) -> None:
self.data_pipe = data_pipe
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
self.store = {}
self.context = zmq.Context()
self.rep_socket = self.context.socket(zmq.REP)
logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self.req_sockets = {}
logger.info(f"Rank {self.rank} connected to the server")
@property
def rank(self):
return self.data_pipe.kv_group_rank
def send_tensor(
self,
tensor: torch.Tensor,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
return self._send_tensor(tensor, tensor_id, remote_address)
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
ret = self._recv_tensor(tensor_id, remote_address)
return ret
def _send_tensor(
self,
tensor: torch.Tensor,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
if remote_address is None:
self.store[tensor_id] = tensor
else:
# tensor_shape = "_".join(str(dim) for dim in tensor.shape)
# tensor_dtype = str(tensor.dtype)
if remote_address not in self.req_sockets:
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
req_socket = self.req_sockets[remote_address]
# req_socket.connect(f"tcp://{remote_address}")
req_socket.send_string(f"PUT {self.rank} {tensor_id}")
dst_rank = req_socket.recv_string()
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
self.data_pipe.send_tensor(tensor, int(dst_rank))
def _recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
logger.debug(f"Rank {self.rank} receiving tensor")
if remote_address is not None:
raise NotImplementedError("Getting tensor from remote rank not implemented")
if tensor_id in self.store:
logger.debug(f"Popping tensor {tensor_id} from store")
future = self.store.pop(tensor_id)
tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
logger.debug(f"Rank {self.rank} received tensor")
return tensor
logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
time.sleep(0.001)
return self._recv_tensor(tensor_id, remote_address)
# raise NotImplementedError("Tensor not found in store")
def _receive_tensor(
self,
tensor_id: str,
rank: int,
):
future = self.data_pipe.recv_tensor(rank)
logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
self.store[tensor_id] = future
def listen_for_requests(self):
while True:
cmd, rank, tensor_id = self.rep_socket.recv_string().split()
logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
self.rep_socket.send_string(f"{self.rank}")
if cmd == "GET":
raise NotImplementedError("Getting tensor from remote rank not implemented")
elif cmd == "PUT":
rank = int(rank)
# shape = [int(dim) for dim in shape.split("_")]
# dtype = getattr(torch, dtype)
self._receive_tensor(tensor_id, rank)

67
vllm/remote_prefill.py Normal file
View File

@ -0,0 +1,67 @@
from dataclasses import dataclass
from typing import Callable, Optional, List
from enum import Enum
import msgspec
from vllm.sampling_params import SamplingParams
class RemotePrefillRequest(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
"""The request data of one remote prefill output of a request.
Args:
engine_id: The unique ID of the engine.
request_id: The unique ID of the request.
prompt_token_ids: The token IDs of the prompt.
sampling_params: The sampling parameters.
block_ids: The block IDs of the request.
computed_block_ids: The computed block IDs of the request.
"""
engine_id: str
request_id: str
prompt_token_ids: List[int]
sampling_params: SamplingParams
block_ids: List[int]
computed_block_ids: List[int]
class MemoryOpType(str, Enum):
WRITE = "WRITE"
READ = "READ"
class MemoryTransferRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""The request data of one memory transfer output of a request.
Args:
request_id: The unique ID of the request.
"""
request_id: str
local_block_ids: List[int]
staging_block_ids: List[int]
remote_block_ids: List[int]
remote_engine_id: str
notify_msg: str
op_type: MemoryOpType
RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None]
@dataclass
class RemotePrefillParams:
"""Remote prefill parameters for text generation."""
is_remote_prefill: bool = False
is_remote_decode: bool = False
decode_block_ids: Optional[List[int]] = None
decode_computed_block_ids: Optional[List[int]] = None
decode_engine_id: Optional[str] = None
remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None