108
vllm/core/event_manager.py
Normal file
108
vllm/core/event_manager.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import ctypes
|
||||
import logging
|
||||
import uuid
|
||||
from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64
|
||||
from typing import Optional
|
||||
|
||||
from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamoResult:
|
||||
OK = 0
|
||||
ERR = 1
|
||||
|
||||
|
||||
class KVCacheEventManager:
|
||||
|
||||
def __init__(self, namespace: str, component: str, worker_id: int,
|
||||
lib_path: str, kv_block_size: int):
|
||||
self.lib = None
|
||||
|
||||
try:
|
||||
self.lib = ctypes.CDLL(lib_path)
|
||||
self.lib.dynamo_llm_init.argtypes = [
|
||||
c_char_p,
|
||||
c_char_p,
|
||||
c_int64,
|
||||
c_uint32,
|
||||
]
|
||||
self.lib.dynamo_llm_init.restype = c_uint32
|
||||
|
||||
result = self.lib.dynamo_llm_init(
|
||||
namespace.encode(), component.encode(), worker_id, kv_block_size
|
||||
)
|
||||
if result == DynamoResult.OK:
|
||||
logger.info(
|
||||
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
|
||||
)
|
||||
else:
|
||||
logger.info("KVCacheEventManager initialization failed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load {lib_path}")
|
||||
raise e
|
||||
|
||||
self.lib.dynamo_kv_event_publish_stored.argtypes = [
|
||||
ctypes.c_uint64, # event_id
|
||||
ctypes.POINTER(ctypes.c_uint32), # token_ids
|
||||
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
|
||||
ctypes.POINTER(ctypes.c_uint64), # block_ids
|
||||
ctypes.c_size_t, # num_blocks
|
||||
ctypes.POINTER(ctypes.c_uint64), # parent_hash
|
||||
ctypes.c_uint64, # lora_id
|
||||
]
|
||||
self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t
|
||||
|
||||
self.lib.dynamo_kv_event_publish_removed.argtypes = [
|
||||
ctypes.c_uint64, # event_id
|
||||
ctypes.POINTER(ctypes.c_uint64), # block_ids
|
||||
ctypes.c_size_t, # num_blocks
|
||||
]
|
||||
self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t
|
||||
|
||||
self.event_id_counter = 0
|
||||
|
||||
def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock],
|
||||
block: PrefixCachingBlock):
|
||||
token_ids_arr = (ctypes.c_uint32 *
|
||||
len(block.token_ids))(*block.token_ids)
|
||||
num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids))
|
||||
block_hash = (ctypes.c_uint64 * 1)(block.content_hash)
|
||||
parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash)
|
||||
if parent is not None else None)
|
||||
|
||||
# Publish the event
|
||||
result = self.lib.dynamo_kv_event_publish_stored(
|
||||
self.event_id_counter, # uint64_t event_id
|
||||
token_ids_arr, # const uint32_t *token_ids
|
||||
num_block_tokens, # const uintptr_t *num_block_tokens
|
||||
block_hash, # const uint64_t *block_ids
|
||||
1, # uintptr_t num_blocks
|
||||
parent_hash, # const uint64_t *parent_hash
|
||||
0, # uint64_t lora_id
|
||||
)
|
||||
|
||||
if result == DynamoResult.OK:
|
||||
logger.debug(f"Store - Published KV Event: {block.content_hash}")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Store - Failed to Publish KV Event: {block.content_hash}")
|
||||
|
||||
self.event_id_counter += 1
|
||||
|
||||
def enqueue_removed_event(self, block_hash: PrefixHash):
|
||||
result = self.lib.dynamo_kv_event_publish_removed(
|
||||
self.event_id_counter,
|
||||
(ctypes.c_uint64 * 1)(block_hash),
|
||||
1,
|
||||
)
|
||||
|
||||
if result == DynamoResult.OK:
|
||||
logger.debug(f"Remove - Published KV Event: {block_hash}")
|
||||
else:
|
||||
logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
|
||||
|
||||
self.event_id_counter += 1
|
110
vllm/distributed/device_communicators/kv_rearrange.py
Normal file
110
vllm/distributed/device_communicators/kv_rearrange.py
Normal file
@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def rearrange_kernel_read(
|
||||
t1_ptr,
|
||||
t2_ptr,
|
||||
N,
|
||||
B,
|
||||
H,
|
||||
C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
curr_n = offsets // block_size
|
||||
curr_b = offsets // token_size % B
|
||||
curr_h = offsets // C % H
|
||||
curr_c = offsets % C
|
||||
|
||||
src_pos = offsets
|
||||
|
||||
tp_group = curr_h * d // H
|
||||
dst_h = curr_h % (H // d)
|
||||
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
|
||||
|
||||
dst_pos = tensor_subset_size * tp_group + tp_group_offset
|
||||
|
||||
tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos))
|
||||
|
||||
@triton.jit
|
||||
def rearrange_kernel_write(
|
||||
t1_ptr,
|
||||
t2_ptr,
|
||||
N,
|
||||
B,
|
||||
H,
|
||||
C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
curr_n = offsets // block_size
|
||||
curr_b = offsets // token_size % B
|
||||
curr_h = offsets // C % H
|
||||
curr_c = offsets % C
|
||||
|
||||
src_pos = offsets
|
||||
|
||||
tp_group = curr_h * d // H
|
||||
dst_h = curr_h % (H // d)
|
||||
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
|
||||
|
||||
dst_pos = tensor_subset_size * tp_group + tp_group_offset
|
||||
|
||||
tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
|
||||
|
||||
|
||||
|
||||
def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str):
|
||||
N, B, H, C = t1.shape
|
||||
|
||||
assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
|
||||
assert H % d == 0, "H must be divisible by d"
|
||||
|
||||
block_size = B * H * C
|
||||
token_size = H * C
|
||||
tensor_size = N * block_size
|
||||
tensor_subset_size = tensor_size // d
|
||||
|
||||
BLOCK_SIZE = 1024
|
||||
grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
|
||||
|
||||
if direction == "read":
|
||||
rearrange_kernel_read[grid](
|
||||
t1, t2,
|
||||
N, B, H, C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
elif direction == "write":
|
||||
rearrange_kernel_write[grid](
|
||||
t1, t2,
|
||||
N, B, H, C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid direction: {direction}")
|
379
vllm/distributed/device_communicators/nixl.py
Normal file
379
vllm/distributed/device_communicators/nixl.py
Normal file
@ -0,0 +1,379 @@
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
import msgspec
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from .kv_rearrange import rearrange_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
logger.info("NIXL is available")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
class NixlMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True):
|
||||
engine_id: str
|
||||
agent_metadata: List[bytes]
|
||||
kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
|
||||
num_blocks: int
|
||||
|
||||
|
||||
class DynamoNixlConnector:
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
|
||||
self.vllm_config = vllm_config
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
raise RuntimeError("NIXL is not available")
|
||||
logger.info("Initializing NIXL wrapper")
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
|
||||
self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer
|
||||
|
||||
self.num_layers = None
|
||||
self.num_blocks = None
|
||||
self.num_heads = None
|
||||
self.block_len = None
|
||||
self.kv_caches = None
|
||||
self.kv_caches_base_addr = {}
|
||||
self.kv_cache_shape = {}
|
||||
|
||||
self._registered_descs = []
|
||||
self._remote_agents = {}
|
||||
self.engine_id = engine_id
|
||||
self.rank = rank
|
||||
self._tp_size = {}
|
||||
self.src_xfer_side_handles = {}
|
||||
self.dst_xfer_side_handles = defaultdict(dict)
|
||||
self.dst_num_blocks = {}
|
||||
|
||||
self._transfers = defaultdict(list)
|
||||
|
||||
|
||||
self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
|
||||
@property
|
||||
def agent_name(self):
|
||||
return self.nixl_wrapper.name
|
||||
|
||||
def register_kv_caches(self, kv_caches: List[torch.Tensor]):
|
||||
_, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
|
||||
self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
|
||||
logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
|
||||
self.num_layers = len(kv_caches)
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
for key_cache, value_cache in kv_caches:
|
||||
base_addr = key_cache.data_ptr()
|
||||
region_len = 2 * num_blocks * self.block_len
|
||||
caches_data.append((base_addr, region_len, self.rank, ""))
|
||||
kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
def get_agent_metadata(self):
|
||||
return self.nixl_wrapper.get_agent_metadata()
|
||||
|
||||
def shutdown(self):
|
||||
for descs_list in self._registered_descs:
|
||||
self.nixl_wrapper.deregister_memory(descs_list)
|
||||
for agent_names in self._remote_agents.values():
|
||||
for agent_name in agent_names:
|
||||
self.nixl_wrapper.remove_remote_agent(agent_name)
|
||||
for src_xfer_side_handle in self.src_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
|
||||
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
|
||||
|
||||
def _get_ranges(self, block_ids):
|
||||
# This function should return a list of ranges of block ids that are contiguous
|
||||
# For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]]
|
||||
# The ranges are sorted by the starting block id
|
||||
# The function should also make sure that the block ids are contiguous
|
||||
# If the block ids are not contiguous, the function should raise an error
|
||||
ranges = []
|
||||
for i in range(len(block_ids)):
|
||||
if i == 0 or block_ids[i] != block_ids[i-1] + 1:
|
||||
ranges.append([block_ids[i], block_ids[i]])
|
||||
else:
|
||||
ranges[-1][1] = block_ids[i]
|
||||
return ranges
|
||||
|
||||
def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None):
|
||||
|
||||
if layer_ids == "all":
|
||||
layer_ids = list(range(self.num_layers))
|
||||
if block_ids == "all":
|
||||
block_ids = list(range(self.num_blocks))
|
||||
|
||||
descs_ids = []
|
||||
|
||||
|
||||
if i is not None:
|
||||
num_blocks = self.num_blocks
|
||||
for layer_id in layer_ids:
|
||||
for is_value in [0, 1]:
|
||||
staging_range_idx = 0
|
||||
for block_id in block_ids:
|
||||
if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
|
||||
staging_range_idx += 1
|
||||
start_offset = staging_ranges[staging_range_idx][0]
|
||||
i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
|
||||
descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
|
||||
else:
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
for layer_id in layer_ids:
|
||||
for is_value in [0, 1]:
|
||||
for block_id in block_ids:
|
||||
descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
|
||||
return descs_ids
|
||||
|
||||
def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
|
||||
# This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
|
||||
# For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
|
||||
# The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]])
|
||||
src_overlapping_ranges, dst_overlapping_ranges = [], []
|
||||
|
||||
original_src_ranges = []
|
||||
org_src_range = tuple(src_ranges[0])
|
||||
|
||||
src_idx, dst_idx = 0, 0
|
||||
while src_idx < len(src_ranges) and dst_idx < len(dst_ranges):
|
||||
src_range = src_ranges[src_idx]
|
||||
dst_range = dst_ranges[dst_idx]
|
||||
|
||||
# Calculate the length of each range
|
||||
src_len = src_range[-1] - src_range[0] + 1
|
||||
dst_len = dst_range[-1] - dst_range[0] + 1
|
||||
|
||||
# If ranges have the same length, add them directly
|
||||
if src_len == dst_len:
|
||||
src_overlapping_ranges.append([src_range[0], src_range[-1]])
|
||||
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
|
||||
original_src_ranges.append(org_src_range)
|
||||
src_idx += 1
|
||||
dst_idx += 1
|
||||
if src_idx < len(src_ranges):
|
||||
org_src_range = tuple(src_ranges[src_idx])
|
||||
# If source range is longer, split it
|
||||
elif src_len > dst_len:
|
||||
src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1])
|
||||
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
|
||||
original_src_ranges.append(org_src_range)
|
||||
# Update source range for next iteration
|
||||
src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]]
|
||||
dst_idx += 1
|
||||
# If destination range is longer, split it
|
||||
else: # src_len < dst_len
|
||||
src_overlapping_ranges.append([src_range[0], src_range[-1]])
|
||||
dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1])
|
||||
original_src_ranges.append(org_src_range)
|
||||
# Update destination range for next iteration
|
||||
dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]]
|
||||
src_idx += 1
|
||||
if src_idx < len(src_ranges):
|
||||
org_src_range = tuple(src_ranges[src_idx])
|
||||
if return_original_src_ranges:
|
||||
return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges
|
||||
return src_overlapping_ranges, dst_overlapping_ranges
|
||||
|
||||
def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id):
|
||||
logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id)
|
||||
|
||||
assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids)
|
||||
|
||||
if len(local_block_ids) == 0:
|
||||
logger.debug("No blocks to read")
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
local_ranges = self._get_ranges(local_block_ids)
|
||||
staging_ranges = self._get_ranges(staging_block_ids)
|
||||
|
||||
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
|
||||
|
||||
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
|
||||
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
|
||||
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
|
||||
handles = []
|
||||
|
||||
logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
create_xfer_start_time = time.perf_counter()
|
||||
|
||||
for i in range(tp_multiplier):
|
||||
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
|
||||
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
|
||||
handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids,
|
||||
remote_xfer_side_handle, remote_block_descs_ids,
|
||||
"")
|
||||
handles.append(handle)
|
||||
status = self.nixl_wrapper.transfer(handle)
|
||||
|
||||
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
|
||||
|
||||
transfer_start_time = time.perf_counter()
|
||||
|
||||
for handle in handles:
|
||||
while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE":
|
||||
if status == "PROC":
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
raise RuntimeError("Read transfer failed with state %s", status)
|
||||
# self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
|
||||
|
||||
logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000)
|
||||
|
||||
rearrange_start_time = time.perf_counter()
|
||||
|
||||
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
|
||||
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
|
||||
for kv_cache in self.kv_caches:
|
||||
for cache in kv_cache:
|
||||
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
|
||||
|
||||
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
|
||||
logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
|
||||
def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg):
|
||||
logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg)
|
||||
|
||||
# hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
|
||||
# isl[-1] token is calculated in the first iteration in decode.
|
||||
# If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
|
||||
# one less block due to the missing last token.
|
||||
remote_block_ids = remote_block_ids[:len(local_block_ids)]
|
||||
|
||||
assert len(staging_block_ids) == len(local_block_ids)
|
||||
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
|
||||
|
||||
if len(local_block_ids) == 0:
|
||||
logger.debug("No blocks to write")
|
||||
for i in range(tp_multiplier):
|
||||
self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg)
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
local_ranges = self._get_ranges(local_block_ids)
|
||||
staging_ranges = self._get_ranges(staging_block_ids)
|
||||
|
||||
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
|
||||
|
||||
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
|
||||
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
|
||||
for kv_cache in self.kv_caches:
|
||||
for cache in kv_cache:
|
||||
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
|
||||
|
||||
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
|
||||
create_xfer_start_time = time.perf_counter()
|
||||
|
||||
# getting block descs ids
|
||||
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
|
||||
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
|
||||
|
||||
for i in range(tp_multiplier):
|
||||
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
|
||||
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
|
||||
handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids,
|
||||
remote_xfer_side_handle, remote_block_descs_ids,
|
||||
notify_msg)
|
||||
self._transfers[notify_msg].append(handle)
|
||||
status = self.nixl_wrapper.transfer(handle)
|
||||
|
||||
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
|
||||
|
||||
transfer_start_time = time.perf_counter()
|
||||
logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
|
||||
def get_notifs(self):
|
||||
return self.nixl_wrapper.update_notifs()
|
||||
|
||||
def get_new_notifs(self):
|
||||
return self.nixl_wrapper.get_new_notifs()
|
||||
|
||||
def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
|
||||
self._tp_size[engine_id] = agent_tp
|
||||
agent_names = []
|
||||
for agent_meta in agent_metadata:
|
||||
agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
|
||||
agent_names.append(agent_name)
|
||||
self._remote_agents[engine_id] = agent_names
|
||||
self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
|
||||
|
||||
tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id]
|
||||
assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
|
||||
|
||||
logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
|
||||
dst_block_len = self.block_len // tp_multiplier
|
||||
if tp_multiplier not in self.src_xfer_side_handles:
|
||||
# create descs and xfer side handles
|
||||
blocks_data = []
|
||||
for layer_id in range(self.num_layers):
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
for i in range(tp_multiplier):
|
||||
tp_multiplier_offset = i * dst_block_len
|
||||
blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs)
|
||||
|
||||
# create dst xfer side handles
|
||||
self.dst_num_blocks[engine_id] = num_blocks
|
||||
for i in range(tp_multiplier):
|
||||
blocks_data = []
|
||||
for layer_id in range(self.num_layers):
|
||||
for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]:
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * dst_block_len
|
||||
blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
|
||||
logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
|
||||
|
||||
return agent_names
|
||||
|
||||
def get_done_tranfers(self) -> List[str]:
|
||||
done_req_ids = []
|
||||
for req_id, handles in self._transfers.items():
|
||||
running_reqs = []
|
||||
for handle in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
# self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors?
|
||||
continue
|
||||
if xfer_state == "PROC":
|
||||
running_reqs.append(handle)
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s", xfer_state)
|
||||
if len(running_reqs) == 0:
|
||||
done_req_ids.append(req_id)
|
||||
else:
|
||||
self._transfers[req_id] = running_reqs
|
||||
return done_req_ids
|
350
vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py
Normal file
350
vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py
Normal file
@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
||||
MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DynamoConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
world_group,
|
||||
):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
|
||||
if self.config.kv_connector != "DynamoNcclConnector":
|
||||
raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class")
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||
PyNcclPipe)
|
||||
from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import (
|
||||
DynamoNcclDataPlane)
|
||||
|
||||
logger.info(
|
||||
"Initializing DynamoNcclConnector under kv_transfer_config %s",
|
||||
self.config)
|
||||
|
||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
||||
|
||||
self.producer_data_pipe: PyNcclPipe
|
||||
self.consumer_data_pipe: PyNcclPipe
|
||||
self.producer_signal_pipe: PyNcclPipe
|
||||
self.consumer_signal_pipe: PyNcclPipe
|
||||
|
||||
self._broadcast_and_enhance_kv_config(rank, config, world_group)
|
||||
|
||||
self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
if self.config.is_kv_producer:
|
||||
port_offset_base = rank + 1
|
||||
else:
|
||||
port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
|
||||
|
||||
|
||||
self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
|
||||
self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
|
||||
|
||||
self.data_pipe = PyNcclPipe(
|
||||
kv_group_rank=self.kv_group_rank,
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
|
||||
self.data_plane = DynamoNcclDataPlane(
|
||||
data_pipe=self.data_pipe,
|
||||
port=self._get_data_plane_port(self.global_kv_rank),
|
||||
)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
model_config = model_executable.model.config
|
||||
is_deepseek = "deepseek" in model_config.architectures[0].lower()
|
||||
if not is_deepseek:
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
else:
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(4.5 * hidden_size / num_attention_heads)
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
current_request_id = request_ids[idx]
|
||||
decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
|
||||
decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
|
||||
|
||||
for target_rank in range(self.config.tensor_parallel_multiplier):
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
|
||||
head_start = target_rank * num_heads_per_rank
|
||||
head_end = head_start + num_heads_per_rank
|
||||
|
||||
if not is_deepseek:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
|
||||
else:
|
||||
key_cache = kv_cache
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(torch.empty(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
decode_global_rank = decode_first_global_rank + target_rank
|
||||
decode_port = self._get_data_plane_port(decode_global_rank)
|
||||
partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
|
||||
self._send(decode_hostname, decode_port, current_request_id, keys, values,
|
||||
partial_hidden_or_intermediate_states)
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
# request its corresponding KV cache or hidden state is missing.
|
||||
# In this case we need to do prefilling to recompute missing KV cache
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
model_config = model_executable.model.config
|
||||
is_deepseek = "deepseek" in model_config.architectures[0].lower()
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
current_request_id = request_ids[idx]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
ret = self._recv(current_request_id)
|
||||
keys: torch.Tensor = ret[0]
|
||||
values: torch.Tensor = ret[1]
|
||||
hidden: torch.Tensor = ret[2]
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(model_executable.model.start_layer,
|
||||
model_executable.model.end_layer):
|
||||
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
if not is_deepseek:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
else:
|
||||
key_cache = kv_cache
|
||||
copy_from =keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device)
|
||||
kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.debug(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.data_pipe.close()
|
||||
# self.data_plane.close()
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str) -> Tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer decode_kv_rank
|
||||
pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
decode_hostname = match.group(1)
|
||||
decode_rank = int(match.group(2))
|
||||
|
||||
return decode_hostname, decode_rank
|
||||
raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
|
||||
|
||||
def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
|
||||
remote_address = f"{hostname}:{port}"
|
||||
self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
|
||||
self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
|
||||
self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
|
||||
|
||||
def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
keys = self.data_plane.recv_tensor(f"{request_id}_keys")
|
||||
values = self.data_plane.recv_tensor(f"{request_id}_values")
|
||||
hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
|
||||
return keys, values, hidden
|
||||
|
||||
def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
|
||||
if kv_rank < config.kv_producers_parallel_size:
|
||||
return kv_rank
|
||||
|
||||
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
|
||||
return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
|
||||
|
||||
|
||||
def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
|
||||
if kv_rank <= config.kv_producers_parallel_size:
|
||||
return kv_rank * config.kv_producers_tensor_parallel_size + rank
|
||||
|
||||
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
|
||||
return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
|
||||
|
||||
|
||||
def _get_data_plane_port(self, global_kv_rank: int) -> int:
|
||||
return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
|
||||
|
||||
def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
|
||||
if rank == 0:
|
||||
config_group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port,
|
||||
rank=self.config.kv_rank,
|
||||
world_size=self.config.kv_parallel_size,
|
||||
)
|
||||
parallel_configs = config_group.all_gather_obj({
|
||||
"kv_role": self.config.kv_role,
|
||||
"tensor_parallel_size": config.parallel_config.tensor_parallel_size,
|
||||
"pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
|
||||
})
|
||||
logger.debug("parallel_configs: %s", parallel_configs)
|
||||
kv_config_enhanced = {
|
||||
"kv_producers_tensor_parallel_size": None,
|
||||
"kv_consumers_tensor_parallel_size": None,
|
||||
"kv_producers_pipeline_parallel_size": None,
|
||||
"kv_consumers_pipeline_parallel_size": None,
|
||||
"kv_producers_parallel_size": 0,
|
||||
}
|
||||
for parallel_config in parallel_configs:
|
||||
kv_role = parallel_config["kv_role"]
|
||||
assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
|
||||
|
||||
if kv_role == "kv_producer":
|
||||
kv_config_enhanced["kv_producers_parallel_size"] += 1
|
||||
if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
|
||||
kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
|
||||
kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
|
||||
else:
|
||||
assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
|
||||
assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
|
||||
world_group.broadcast_object(kv_config_enhanced)
|
||||
else:
|
||||
kv_config_enhanced = world_group.broadcast_object()
|
||||
logger.info("kv_config_enhanced: %s", kv_config_enhanced)
|
||||
|
||||
self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
|
||||
self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
|
||||
self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
|
||||
self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
|
||||
self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
|
124
vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py
Normal file
124
vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py
Normal file
@ -0,0 +1,124 @@
|
||||
import logging
|
||||
import threading
|
||||
import typing
|
||||
import zmq
|
||||
import socket
|
||||
import time
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamoNcclDataPlane:
|
||||
def __init__(
|
||||
self,
|
||||
data_pipe: PyNcclPipe,
|
||||
hostname: str = "",
|
||||
port: int = 0,
|
||||
) -> None:
|
||||
|
||||
self.data_pipe = data_pipe
|
||||
if not hostname:
|
||||
hostname = socket.gethostname()
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
self.store = {}
|
||||
self.context = zmq.Context()
|
||||
self.rep_socket = self.context.socket(zmq.REP)
|
||||
logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
|
||||
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
|
||||
self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
self.req_sockets = {}
|
||||
logger.info(f"Rank {self.rank} connected to the server")
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self.data_pipe.kv_group_rank
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
):
|
||||
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
|
||||
return self._send_tensor(tensor, tensor_id, remote_address)
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
ret = self._recv_tensor(tensor_id, remote_address)
|
||||
return ret
|
||||
|
||||
def _send_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
):
|
||||
logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
|
||||
if remote_address is None:
|
||||
self.store[tensor_id] = tensor
|
||||
else:
|
||||
# tensor_shape = "_".join(str(dim) for dim in tensor.shape)
|
||||
# tensor_dtype = str(tensor.dtype)
|
||||
if remote_address not in self.req_sockets:
|
||||
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
|
||||
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
|
||||
|
||||
req_socket = self.req_sockets[remote_address]
|
||||
# req_socket.connect(f"tcp://{remote_address}")
|
||||
req_socket.send_string(f"PUT {self.rank} {tensor_id}")
|
||||
dst_rank = req_socket.recv_string()
|
||||
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
|
||||
self.data_pipe.send_tensor(tensor, int(dst_rank))
|
||||
|
||||
def _recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
logger.debug(f"Rank {self.rank} receiving tensor")
|
||||
if remote_address is not None:
|
||||
raise NotImplementedError("Getting tensor from remote rank not implemented")
|
||||
if tensor_id in self.store:
|
||||
logger.debug(f"Popping tensor {tensor_id} from store")
|
||||
future = self.store.pop(tensor_id)
|
||||
tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
|
||||
logger.debug(f"Rank {self.rank} received tensor")
|
||||
return tensor
|
||||
|
||||
logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
|
||||
time.sleep(0.001)
|
||||
return self._recv_tensor(tensor_id, remote_address)
|
||||
# raise NotImplementedError("Tensor not found in store")
|
||||
|
||||
def _receive_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
rank: int,
|
||||
):
|
||||
future = self.data_pipe.recv_tensor(rank)
|
||||
logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
|
||||
self.store[tensor_id] = future
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
cmd, rank, tensor_id = self.rep_socket.recv_string().split()
|
||||
logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
|
||||
self.rep_socket.send_string(f"{self.rank}")
|
||||
if cmd == "GET":
|
||||
raise NotImplementedError("Getting tensor from remote rank not implemented")
|
||||
elif cmd == "PUT":
|
||||
rank = int(rank)
|
||||
# shape = [int(dim) for dim in shape.split("_")]
|
||||
# dtype = getattr(torch, dtype)
|
||||
self._receive_tensor(tensor_id, rank)
|
67
vllm/remote_prefill.py
Normal file
67
vllm/remote_prefill.py
Normal file
@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, List
|
||||
from enum import Enum
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class RemotePrefillRequest(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True):
|
||||
"""The request data of one remote prefill output of a request.
|
||||
|
||||
Args:
|
||||
engine_id: The unique ID of the engine.
|
||||
request_id: The unique ID of the request.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
sampling_params: The sampling parameters.
|
||||
block_ids: The block IDs of the request.
|
||||
computed_block_ids: The computed block IDs of the request.
|
||||
"""
|
||||
engine_id: str
|
||||
request_id: str
|
||||
prompt_token_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: List[int]
|
||||
computed_block_ids: List[int]
|
||||
|
||||
|
||||
class MemoryOpType(str, Enum):
|
||||
WRITE = "WRITE"
|
||||
READ = "READ"
|
||||
|
||||
|
||||
class MemoryTransferRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""The request data of one memory transfer output of a request.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
"""
|
||||
request_id: str
|
||||
local_block_ids: List[int]
|
||||
staging_block_ids: List[int]
|
||||
remote_block_ids: List[int]
|
||||
remote_engine_id: str
|
||||
notify_msg: str
|
||||
op_type: MemoryOpType
|
||||
|
||||
|
||||
RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePrefillParams:
|
||||
"""Remote prefill parameters for text generation."""
|
||||
is_remote_prefill: bool = False
|
||||
is_remote_decode: bool = False
|
||||
decode_block_ids: Optional[List[int]] = None
|
||||
decode_computed_block_ids: Optional[List[int]] = None
|
||||
decode_engine_id: Optional[str] = None
|
||||
remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None
|
Reference in New Issue
Block a user