Compare commits

...

2 Commits

Author SHA1 Message Date
6de0982dd0 added
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-04-06 14:07:43 +00:00
45fa7f9b8e updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-04-05 13:43:03 +00:00
32 changed files with 2332 additions and 267 deletions

View File

@ -2620,6 +2620,9 @@ class KVTransferConfig(BaseModel):
# The KV connector for vLLM to transmit KV caches between vLLM instances.
kv_connector: Optional[str] = None
# Whether to use NIXL prepped xfer for KV cache transfer.
use_prepped_xfer: bool = True
# The device used by kv connector to buffer the KV cache.
# Currently only support 'cuda'.
kv_buffer_device: Optional[str] = "cuda"
@ -2629,7 +2632,7 @@ class KVTransferConfig(BaseModel):
kv_buffer_size: float = 1e9
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
# are 'kv_producer', 'kv_consumer', and 'both'.
# are 'kv_producer', 'kv_consumer', and 'kv_both'.
kv_role: Optional[str] = None
# The rank of this vLLM instance in the KV cache transfer. Typical value:
@ -2647,6 +2650,14 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection
kv_port: int = 14579
# This does not need to be set by the user. It is set by the connector.
kv_producers_parallel_size: Optional[int] = None
kv_producers_tensor_parallel_size: Optional[int] = None
kv_producers_pipeline_parallel_size: Optional[int] = None
kv_consumers_tensor_parallel_size: Optional[int] = None
kv_consumers_pipeline_parallel_size: Optional[int] = None
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -2680,11 +2691,16 @@ class KVTransferConfig(BaseModel):
f"Supported roles are `kv_producer`, `kv_consumer`, "
f"and `kv_both`")
if self.kv_connector is not None and self.kv_role is None:
if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None:
raise ValueError("Please specify kv_disagg_role when kv_connector "
"is set, supported roles are `kv_producer`, "
"`kv_consumer`, and `kv_both`")
if self.use_prepped_xfer is False:
logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.")
self.use_prepped_xfer = True
@property
def is_kv_transfer_instance(self) -> bool:
return self.kv_connector is not None and \
@ -2694,6 +2710,8 @@ class KVTransferConfig(BaseModel):
def need_kv_parallel_group(self) -> bool:
# for those database-based connector, vLLM does not need to create
# parallel group, and in that case the kv parallel size will be 1.
if self.kv_connector == "DynamoNixlConnector":
return False
return self.kv_connector is not None and self.kv_parallel_size > 1
@property
@ -2706,6 +2724,18 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
@property
def tensor_parallel_multiplier(self) -> int:
return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size
@property
def kv_consumers_parallel_size(self) -> int:
return self.kv_parallel_size - self.kv_producers_parallel_size
@property
def kv_world_size(self) -> int:
return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier
class CompilationLevel:
# constants for the levels of the compilation process

View File

@ -6,6 +6,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
DeviceAwareBlockAllocator)
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
from vllm.core.event_manager import KVCacheEventManager
from vllm.platforms import current_platform
from vllm.utils import Device
@ -28,6 +29,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
num_gpu_blocks: int,
num_cpu_blocks: int,
block_size: int,
event_manager: Optional[KVCacheEventManager] = None,
) -> DeviceAwareBlockAllocator:
"""Creates a CpuGpuBlockAllocator instance with the specified
configuration.
@ -64,6 +66,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
cpu_block_ids = block_ids[num_gpu_blocks:]
if allocator_type == "naive":
assert event_manager is None, "Event API not supported with naive allocator."
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
create_block=NaiveBlock, # type: ignore
num_blocks=num_gpu_blocks,
@ -82,12 +85,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
num_blocks=num_gpu_blocks,
block_size=block_size,
block_ids=gpu_block_ids,
event_manager=event_manager,
)
cpu_allocator = PrefixCachingBlockAllocator(
num_blocks=num_cpu_blocks,
block_size=block_size,
block_ids=cpu_block_ids,
event_manager=event_manager,
)
else:
raise ValueError(f"Unknown allocator type {allocator_type=}")
@ -95,10 +100,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
return CpuGpuBlockAllocator(
cpu_block_allocator=cpu_allocator,
gpu_block_allocator=gpu_allocator,
event_manager=event_manager,
)
def __init__(self, cpu_block_allocator: BlockAllocator,
gpu_block_allocator: BlockAllocator):
gpu_block_allocator: BlockAllocator,
event_manager: Optional[KVCacheEventManager] = None,):
assert not (
cpu_block_allocator.all_block_ids
& gpu_block_allocator.all_block_ids
@ -108,6 +115,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
Device.CPU: cpu_block_allocator,
Device.GPU: gpu_block_allocator,
}
self.event_manager = event_manager
self._swap_mapping: Dict[int, int] = {}
self._null_block: Optional[Block] = None

View File

@ -2,7 +2,7 @@
from collections import deque
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
import heapq
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
get_all_blocks_recursively)
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
@ -38,7 +38,7 @@ class NaiveBlockAllocator(BlockAllocator):
if block_ids is None:
block_ids = range(num_blocks)
self._free_block_indices: Deque[BlockId] = deque(block_ids)
self._free_block_indices: List[BlockId] = list(block_ids)
self._all_block_indices = frozenset(block_ids)
assert len(self._all_block_indices) == num_blocks
@ -134,7 +134,8 @@ class NaiveBlockAllocator(BlockAllocator):
if not self._free_block_indices:
raise BlockAllocator.NoFreeBlocksError()
block_id = self._free_block_indices.popleft()
block_id = heapq.heappop(self._free_block_indices)
# TODO: figure out why sometime block_id is None
self._refcounter.incr(block_id)
return block_id
@ -148,7 +149,7 @@ class NaiveBlockAllocator(BlockAllocator):
refcount = self._refcounter.decr(block_id)
if refcount == 0:
self._free_block_indices.appendleft(block_id)
heapq.heappush(self._free_block_indices, block_id)
def free(self, block: Block, keep_block_object: bool = False) -> None:
# Release the physical block id

View File

@ -4,7 +4,7 @@ import sys
from bisect import bisect_left
from os.path import commonprefix
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
Tuple)
Tuple, TYPE_CHECKING)
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
get_all_blocks_recursively)
@ -23,6 +23,9 @@ PrefixHash = int
# then we know this block hasn't been accessed yet.
_DEFAULT_LAST_ACCESSED_TIME = -1
if TYPE_CHECKING:
from vllm.core.event_manager import KVCacheEventManager
logger = init_logger(__name__)
@ -80,6 +83,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
block_size: int,
block_ids: Optional[Iterable[int]] = None,
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
event_manager: Optional["KVCacheEventManager"] = None,
):
if block_ids is None:
block_ids = range(num_blocks)
@ -131,6 +135,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
self.metric_data = CacheMetricData()
self.event_manager = event_manager
# Implements Block.Factory.
def _create_block(
self,
prev_block: Optional[Block],
@ -337,6 +344,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
assert self._refcounter.get(_block_id) == 0
assert _block_id == block_id
if self.event_manager:
self.event_manager.enqueue_removed_event(content_hash_to_evict)
self._cached_blocks.pop(content_hash_to_evict)
self._refcounter.incr(block_id)
@ -513,6 +523,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
# Mark this block as touched so that it can be marked as
# computed after the entire batch of sequences are scheduled.
self._touched_blocks.add(block.block_id)
if self.event_manager:
self.event_manager.enqueue_stored_event(block.prev_block, block)
return block.block_id
# Reuse the cached content hash
@ -579,9 +593,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
# Mark all touched blocks as computed.
for block_id in self._touched_blocks:
self._block_tracker[block_id].computed = True
self._touched_blocks.clear()
for block_id in block_ids:
if block_id in self._touched_blocks:
logger.debug("Mark block as computed: %s", block_id)
self._block_tracker[block_id].computed = True
self._touched_blocks.remove(block_id)
def _track_block_id(self, block_id: Optional[BlockId],
computed: bool) -> None:

View File

@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
LastAccessBlocksTracker)
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
from vllm.core.event_manager import KVCacheEventManager
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE,
VLLM_WORKER_ID)
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
from vllm.utils import Device
@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def __init__(
self,
model_name: str,
block_size: int,
num_gpu_blocks: int,
num_cpu_blocks: int,
@ -91,11 +95,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
self.watermark_blocks = int(watermark * num_gpu_blocks)
kv_event_manager_params = [
VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE,
VLLM_KV_COMPONENT
]
set_kv_event_manager_params = len(
[param for param in kv_event_manager_params if param is not None])
if set_kv_event_manager_params == len(kv_event_manager_params):
self.event_manager = KVCacheEventManager(
namespace=VLLM_KV_NAMESPACE,
component=VLLM_KV_COMPONENT,
worker_id=VLLM_WORKER_ID,
lib_path=VLLM_KV_CAPI_PATH,
kv_block_size=block_size)
else:
self.event_manager = None
self.block_allocator = CpuGpuBlockAllocator.create(
allocator_type="prefix_caching" if enable_caching else "naive",
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
block_size=block_size,
event_manager=self.event_manager,
)
self.block_tables: Dict[SeqId, BlockTable] = {}
@ -108,7 +130,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
def can_allocate(self,
seq_group: SequenceGroup,
num_lookahead_slots: int = 0) -> AllocStatus:
num_lookahead_slots: int = 0,
is_remote_decode: bool = False) -> AllocStatus:
# FIXME(woosuk): Here we assume that all sequences in the group share
# the same prompt. This may not be true for preempted sequences.
@ -121,6 +144,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
num_lookahead_slots=num_lookahead_slots,
)
# if remote decode, we need to allocate twice as many blocks for staging
if is_remote_decode:
num_required_blocks *= 2
if seq_group.is_encoder_decoder():
encoder_seq = seq_group.get_encoder_seq()
assert encoder_seq is not None

108
vllm/core/event_manager.py Normal file
View File

@ -0,0 +1,108 @@
# SPDX-License-Identifier: Apache-2.0
import ctypes
import logging
import uuid
from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64
from typing import Optional
from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
logger = logging.getLogger(__name__)
class DynamoResult:
OK = 0
ERR = 1
class KVCacheEventManager:
def __init__(self, namespace: str, component: str, worker_id: int,
lib_path: str, kv_block_size: int):
self.lib = None
try:
self.lib = ctypes.CDLL(lib_path)
self.lib.dynamo_llm_init.argtypes = [
c_char_p,
c_char_p,
c_int64,
c_uint32,
]
self.lib.dynamo_llm_init.restype = c_uint32
result = self.lib.dynamo_llm_init(
namespace.encode(), component.encode(), worker_id, kv_block_size
)
if result == DynamoResult.OK:
logger.info(
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
)
else:
logger.info("KVCacheEventManager initialization failed!")
except Exception as e:
print(f"Failed to load {lib_path}")
raise e
self.lib.dynamo_kv_event_publish_stored.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint32), # token_ids
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
ctypes.POINTER(ctypes.c_uint64), # parent_hash
ctypes.c_uint64, # lora_id
]
self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t
self.lib.dynamo_kv_event_publish_removed.argtypes = [
ctypes.c_uint64, # event_id
ctypes.POINTER(ctypes.c_uint64), # block_ids
ctypes.c_size_t, # num_blocks
]
self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t
self.event_id_counter = 0
def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock],
block: PrefixCachingBlock):
token_ids_arr = (ctypes.c_uint32 *
len(block.token_ids))(*block.token_ids)
num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids))
block_hash = (ctypes.c_uint64 * 1)(block.content_hash)
parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash)
if parent is not None else None)
# Publish the event
result = self.lib.dynamo_kv_event_publish_stored(
self.event_id_counter, # uint64_t event_id
token_ids_arr, # const uint32_t *token_ids
num_block_tokens, # const uintptr_t *num_block_tokens
block_hash, # const uint64_t *block_ids
1, # uintptr_t num_blocks
parent_hash, # const uint64_t *parent_hash
0, # uint64_t lora_id
)
if result == DynamoResult.OK:
logger.debug(f"Store - Published KV Event: {block.content_hash}")
else:
logger.debug(
f"Store - Failed to Publish KV Event: {block.content_hash}")
self.event_id_counter += 1
def enqueue_removed_event(self, block_hash: PrefixHash):
result = self.lib.dynamo_kv_event_publish_removed(
self.event_id_counter,
(ctypes.c_uint64 * 1)(block_hash),
1,
)
if result == DynamoResult.OK:
logger.debug(f"Remove - Published KV Event: {block_hash}")
else:
logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
self.event_id_counter += 1

View File

@ -4,22 +4,22 @@ import enum
import os
import random
import time
import copy
from collections import deque
from dataclasses import dataclass, field
from typing import Callable, Deque, Dict, Iterable, List, Optional
from typing import Sequence as GenericSequence
from typing import Set, Tuple, Union
from typing import Set, Tuple, Union, Any
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
SequenceGroupMetadata, SequenceGroupMetadataDelta,
SequenceStatus)
SequenceStatus, SequenceStage)
from vllm.utils import Device, PyObjectCache
logger = init_logger(__name__)
# Test-only. If configured, decode is preempted with
@ -285,6 +285,7 @@ class SchedulerPrefillOutputs:
# Ignored sequence groups.
ignored_seq_groups: List[SequenceGroup]
num_lookahead_slots: int
num_remote_prefill_groups: int
@classmethod
def create_empty(cls) -> "SchedulerPrefillOutputs":
@ -292,6 +293,7 @@ class SchedulerPrefillOutputs:
seq_groups=[],
ignored_seq_groups=[],
num_lookahead_slots=0,
num_remote_prefill_groups=0,
)
@ -325,12 +327,14 @@ class Scheduler:
def __init__(
self,
model_config: ModelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
lora_config: Optional[LoRAConfig],
pipeline_parallel_size: int = 1,
output_proc_callback: Optional[Callable] = None,
) -> None:
self.model_config = model_config
self.scheduler_config = scheduler_config
self.cache_config = cache_config
# Note for LoRA scheduling: the current policy is extremely
@ -356,6 +360,7 @@ class Scheduler:
# Create the block space manager.
self.block_manager = BlockSpaceManagerImpl(
model_name=self.model_config.served_model_name,
block_size=self.cache_config.block_size,
num_gpu_blocks=num_gpu_blocks,
num_cpu_blocks=num_cpu_blocks,
@ -371,6 +376,16 @@ class Scheduler:
# Sequence groups in the SWAPPED state.
# Contain decode requests that are swapped out.
self.swapped: Deque[SequenceGroup] = deque()
# Sequence groups in the REMOTE_PREFILLING state.
# Contain requests that are being prefilled by a remote worker.
self.remote_prefilling: Deque[SequenceGroup] = deque()
# Contain requests that are being prefilled by a local worker.
self.prefill_sending: Deque[SequenceGroup] = deque()
self._remote_prefill_outputs: Dict[str, int] = {}
# Sequence groups finished requests ids since last step iteration.
# It lets the model know that any state associated with these requests
# can and must be released after the current step.
@ -501,7 +516,7 @@ class Scheduler:
def has_unfinished_seqs(self) -> bool:
return len(self.waiting) != 0 or len(self.running) != 0 or len(
self.swapped) != 0
self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
def get_prefix_cache_hit_rate(self, device: Device) -> float:
return self.block_manager.get_prefix_cache_hit_rate(device)
@ -523,6 +538,8 @@ class Scheduler:
budget: SchedulingBudget,
curr_loras: Optional[Set[int]],
enable_chunking: bool = False,
finished_prefills: Optional[Set[str]] = None,
finished_transfers: Optional[Set[str]] = None
) -> SchedulerRunningOutputs:
"""Schedule sequence groups that are running.
@ -537,6 +554,8 @@ class Scheduler:
chunked number of tokens are scheduled if
`budget.num_batched_tokens` has not enough capacity to schedule
all tokens.
finished_remote_prefill_request_ids: Set of request ids of remote
prefills that have finished.
Returns:
SchedulerRunningOutputs.
@ -566,6 +585,38 @@ class Scheduler:
preempted: List[SequenceGroup] = ret.preempted
swapped_out: List[SequenceGroup] = ret.swapped_out
remote_prefilling_queue = self.remote_prefilling
leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque()
while remote_prefilling_queue:
seq_group = remote_prefilling_queue.popleft()
if seq_group.request_id not in finished_prefills:
leftover_remote_prefilling_sequences.append(seq_group)
continue
else:
finished_prefills.remove(seq_group.request_id)
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
# we computed all but the last token in prefill, we need to decode the first token on decode
seq_group.update_num_computed_tokens(seq.get_len() - 1)
seq.status = SequenceStatus.RUNNING
seq.data._stage = SequenceStage.DECODE
self.running.appendleft(seq_group)
remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
remote_transfers_queue = self.prefill_sending
leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
while remote_transfers_queue:
seq_group = remote_transfers_queue.popleft()
if seq_group.request_id not in finished_transfers:
leftover_remote_transfers_sequences.append(seq_group)
else:
finished_transfers.remove(seq_group.request_id)
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
self.free_seq(seq)
remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
running_queue = self.running
assert len(self._async_stopped) == 0
while running_queue:
@ -925,6 +976,7 @@ class Scheduler:
seq_groups: List[ScheduledSequenceGroup] = []
waiting_queue = self.waiting
num_remote_prefill_groups = 0
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
while self._passed_delay(time.time()) and waiting_queue:
@ -961,8 +1013,10 @@ class Scheduler:
True, enable_chunking)
# If the sequence group cannot be allocated, stop.
is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode
can_allocate = self.block_manager.can_allocate(
seq_group, num_lookahead_slots=num_lookahead_slots)
seq_group, num_lookahead_slots=num_lookahead_slots,
is_remote_decode=is_remote_decode)
if can_allocate == AllocStatus.LATER:
break
elif can_allocate == AllocStatus.NEVER:
@ -1008,7 +1062,18 @@ class Scheduler:
if curr_loras is not None and lora_int_id > 0:
curr_loras.add(lora_int_id)
waiting_queue.popleft()
self._allocate_and_set_running(seq_group)
seq_group_copy = copy.deepcopy(seq_group)
seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group)
num_remote_prefill_groups += is_remote_prefill
if is_remote_decode:
logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
self.prefill_sending.append(seq_group_copy)
if enable_chunking and self.scheduler_config.is_multi_step:
blocks_to_copy: List[Tuple[int, int]] = []
@ -1046,9 +1111,11 @@ class Scheduler:
seq_groups=seq_groups,
ignored_seq_groups=ignored_seq_groups,
num_lookahead_slots=self._get_num_lookahead_slots(
is_prefill=True, enable_chunking=enable_chunking))
is_prefill=True, enable_chunking=enable_chunking),
num_remote_prefill_groups=num_remote_prefill_groups
)
def _schedule_default(self) -> SchedulerOutputs:
def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests.
The current policy is designed to optimize the throughput. First,
@ -1066,9 +1133,13 @@ class Scheduler:
for seq_group in self.running:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = set(
for seq_group in self.remote_prefilling:
budget.add_num_seqs(seq_group.request_id,
seq_group.get_max_num_running_seqs())
curr_loras = (set(
seq_group.lora_int_id for seq_group in self.running
if seq_group.lora_int_id > 0) if self.lora_enabled else None
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
prefills = SchedulerPrefillOutputs.create_empty()
running_scheduled = SchedulerRunningOutputs.create_empty()
@ -1090,7 +1161,9 @@ class Scheduler:
if len(prefills.seq_groups) == 0:
running_scheduled = self._schedule_running(budget,
curr_loras,
enable_chunking=False)
enable_chunking=False,
finished_prefills=finished_prefills,
finished_transfers=finished_transfers)
# If any sequence group is preempted, do not swap in any sequence
# group. because it means there's no slot for new running requests.
@ -1106,7 +1179,12 @@ class Scheduler:
self.waiting.extendleft(running_scheduled.preempted)
# Update new running requests.
if len(prefills.seq_groups) > 0:
self.running.extend([s.seq_group for s in prefills.seq_groups])
for s in prefills.seq_groups:
seq_group = s.seq_group
if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
self.remote_prefilling.append(seq_group)
else:
self.running.append(seq_group)
self.running.extend(running_scheduled.decode_seq_groups_list)
@ -1248,12 +1326,14 @@ class Scheduler:
len(running_scheduled.swapped_out)),
)
def _schedule(self) -> SchedulerOutputs:
def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
"""Schedule queued requests."""
if self.scheduler_config.chunked_prefill_enabled:
if finished_prefills or finished_transfers:
raise ValueError("Chunked prefill does not support remote prefills")
return self._schedule_chunked_prefill()
else:
return self._schedule_default()
return self._schedule_default(finished_prefills, finished_transfers)
def _can_append_slots(self, seq_group: SequenceGroup,
enable_chunking: bool) -> bool:
@ -1287,14 +1367,16 @@ class Scheduler:
return no_single_seq
def schedule(
self
self,
finished_prefills: Optional[Set[str]] = None,
finished_transfers: Optional[Set[str]] = None
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
# Schedule sequence groups.
# This function call changes the internal states of the scheduler
# such as self.running, self.swapped, and self.waiting.
scheduler_start_time = time.perf_counter()
scheduler_outputs: SchedulerOutputs = self._schedule()
scheduler_start_time = time.perf_counter()
scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
now = time.time()
if not self.cache_config.enable_prefix_caching:
@ -1333,7 +1415,8 @@ class Scheduler:
encoder_seq_data = None
cross_block_table = None
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING)
for seq in running_or_remote_prefilling_seqs:
seq_id = seq.seq_id
seq_data[seq_id] = seq.data
block_tables[seq_id] = self.block_manager.get_block_table(seq)
@ -1342,7 +1425,9 @@ class Scheduler:
if self.cache_config.enable_prefix_caching:
common_computed_block_nums = (
self.block_manager.get_common_computed_block_ids(
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
running_or_remote_prefilling_seqs
)
)
do_sample = True
is_prompt = seq_group.is_prefill()
@ -1364,9 +1449,30 @@ class Scheduler:
< seqs[0].data.get_len()):
do_sample = False
is_remote_prefill = False
if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
is_remote_prefill = True
logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums)
if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
# Since we know that prefill is scheduled we can
# assume that the blocks computed on decode
# will be fetched by the time we run prefill
logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids)
if seq_group.remote_prefill_params.decode_computed_block_ids:
computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids)
prefill_block_ids = block_tables[seq_group.seqs[0].seq_id]
prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)]
assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching"
common_computed_block_nums = prefill_fetched_block_ids
# It assumes the scheduled_seq_groups is ordered by
# prefill < decoding.
if is_first_prefill or not self.scheduler_config.send_delta_data:
logger.debug("Assinged blocks: %s", block_tables)
seq_group_metadata = SequenceGroupMetadata(
request_id=seq_group.request_id,
is_prompt=is_prompt,
@ -1392,6 +1498,7 @@ class Scheduler:
if scheduler_outputs.num_prefill_groups > 0 else None,
mm_processor_kwargs=seq_group.mm_processor_kwargs,
prompt_adapter_request=seq_group.prompt_adapter_request,
do_remote_prefill=is_remote_prefill,
)
else:
# When SPMD mode is enabled, we only send delta data except for
@ -1490,11 +1597,17 @@ class Scheduler:
self._async_stopped.clear()
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool:
self.block_manager.allocate(seq_group)
is_remote_prefill = False
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
seq.status = SequenceStatus.RUNNING
if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
seq.status = SequenceStatus.REMOTE_PREFILLING
is_remote_prefill = True
else:
seq.status = SequenceStatus.RUNNING
return is_remote_prefill
def _append_slots(self,
seq_group: SequenceGroup,
blocks_to_copy: List[Tuple[int, int]],

View File

@ -0,0 +1,110 @@
import torch
import triton
import triton.language as tl
@triton.jit
def rearrange_kernel_read(
t1_ptr,
t2_ptr,
N,
B,
H,
C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
curr_n = offsets // block_size
curr_b = offsets // token_size % B
curr_h = offsets // C % H
curr_c = offsets % C
src_pos = offsets
tp_group = curr_h * d // H
dst_h = curr_h % (H // d)
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
dst_pos = tensor_subset_size * tp_group + tp_group_offset
tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos))
@triton.jit
def rearrange_kernel_write(
t1_ptr,
t2_ptr,
N,
B,
H,
C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
curr_n = offsets // block_size
curr_b = offsets // token_size % B
curr_h = offsets // C % H
curr_c = offsets % C
src_pos = offsets
tp_group = curr_h * d // H
dst_h = curr_h % (H // d)
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
dst_pos = tensor_subset_size * tp_group + tp_group_offset
tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str):
N, B, H, C = t1.shape
assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
assert H % d == 0, "H must be divisible by d"
block_size = B * H * C
token_size = H * C
tensor_size = N * block_size
tensor_subset_size = tensor_size // d
BLOCK_SIZE = 1024
grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
if direction == "read":
rearrange_kernel_read[grid](
t1, t2,
N, B, H, C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE=BLOCK_SIZE
)
elif direction == "write":
rearrange_kernel_write[grid](
t1, t2,
N, B, H, C,
d,
tensor_subset_size,
block_size,
token_size,
BLOCK_SIZE=BLOCK_SIZE
)
else:
raise ValueError(f"Invalid direction: {direction}")

View File

@ -0,0 +1,379 @@
import torch
from typing import List, Tuple
from vllm.config import VllmConfig
from vllm.logger import init_logger
import msgspec
import time
import uuid
from collections import defaultdict
from .kv_rearrange import rearrange_tensors
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
class NixlMetadata(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
engine_id: str
agent_metadata: List[bytes]
kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
num_blocks: int
class DynamoNixlConnector:
def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
self.vllm_config = vllm_config
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer
self.num_layers = None
self.num_blocks = None
self.num_heads = None
self.block_len = None
self.kv_caches = None
self.kv_caches_base_addr = {}
self.kv_cache_shape = {}
self._registered_descs = []
self._remote_agents = {}
self.engine_id = engine_id
self.rank = rank
self._tp_size = {}
self.src_xfer_side_handles = {}
self.dst_xfer_side_handles = defaultdict(dict)
self.dst_num_blocks = {}
self._transfers = defaultdict(list)
self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
@property
def agent_name(self):
return self.nixl_wrapper.name
def register_kv_caches(self, kv_caches: List[torch.Tensor]):
_, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
self.num_layers = len(kv_caches)
self.num_blocks = num_blocks
self.num_heads = num_heads
self.kv_caches = kv_caches
kv_caches_base_addr = []
caches_data = []
for key_cache, value_cache in kv_caches:
base_addr = key_cache.data_ptr()
region_len = 2 * num_blocks * self.block_len
caches_data.append((base_addr, region_len, self.rank, ""))
kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
self._registered_descs.append(descs)
def get_agent_metadata(self):
return self.nixl_wrapper.get_agent_metadata()
def shutdown(self):
for descs_list in self._registered_descs:
self.nixl_wrapper.deregister_memory(descs_list)
for agent_names in self._remote_agents.values():
for agent_name in agent_names:
self.nixl_wrapper.remove_remote_agent(agent_name)
for src_xfer_side_handle in self.src_xfer_side_handles.values():
self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
for dst_xfer_side_handle in dst_xfer_side_handles.values():
self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
def _get_ranges(self, block_ids):
# This function should return a list of ranges of block ids that are contiguous
# For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]]
# The ranges are sorted by the starting block id
# The function should also make sure that the block ids are contiguous
# If the block ids are not contiguous, the function should raise an error
ranges = []
for i in range(len(block_ids)):
if i == 0 or block_ids[i] != block_ids[i-1] + 1:
ranges.append([block_ids[i], block_ids[i]])
else:
ranges[-1][1] = block_ids[i]
return ranges
def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None):
if layer_ids == "all":
layer_ids = list(range(self.num_layers))
if block_ids == "all":
block_ids = list(range(self.num_blocks))
descs_ids = []
if i is not None:
num_blocks = self.num_blocks
for layer_id in layer_ids:
for is_value in [0, 1]:
staging_range_idx = 0
for block_id in block_ids:
if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
staging_range_idx += 1
start_offset = staging_ranges[staging_range_idx][0]
i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
else:
num_blocks = self.dst_num_blocks[engine_id]
for layer_id in layer_ids:
for is_value in [0, 1]:
for block_id in block_ids:
descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
return descs_ids
def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
# This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
# For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
# The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]])
src_overlapping_ranges, dst_overlapping_ranges = [], []
original_src_ranges = []
org_src_range = tuple(src_ranges[0])
src_idx, dst_idx = 0, 0
while src_idx < len(src_ranges) and dst_idx < len(dst_ranges):
src_range = src_ranges[src_idx]
dst_range = dst_ranges[dst_idx]
# Calculate the length of each range
src_len = src_range[-1] - src_range[0] + 1
dst_len = dst_range[-1] - dst_range[0] + 1
# If ranges have the same length, add them directly
if src_len == dst_len:
src_overlapping_ranges.append([src_range[0], src_range[-1]])
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
original_src_ranges.append(org_src_range)
src_idx += 1
dst_idx += 1
if src_idx < len(src_ranges):
org_src_range = tuple(src_ranges[src_idx])
# If source range is longer, split it
elif src_len > dst_len:
src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1])
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
original_src_ranges.append(org_src_range)
# Update source range for next iteration
src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]]
dst_idx += 1
# If destination range is longer, split it
else: # src_len < dst_len
src_overlapping_ranges.append([src_range[0], src_range[-1]])
dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1])
original_src_ranges.append(org_src_range)
# Update destination range for next iteration
dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]]
src_idx += 1
if src_idx < len(src_ranges):
org_src_range = tuple(src_ranges[src_idx])
if return_original_src_ranges:
return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges
return src_overlapping_ranges, dst_overlapping_ranges
def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id):
logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id)
assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids)
if len(local_block_ids) == 0:
logger.debug("No blocks to read")
return
start_time = time.perf_counter()
local_ranges = self._get_ranges(local_block_ids)
staging_ranges = self._get_ranges(staging_block_ids)
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
handles = []
logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
create_xfer_start_time = time.perf_counter()
for i in range(tp_multiplier):
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids,
remote_xfer_side_handle, remote_block_descs_ids,
"")
handles.append(handle)
status = self.nixl_wrapper.transfer(handle)
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
transfer_start_time = time.perf_counter()
for handle in handles:
while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE":
if status == "PROC":
time.sleep(0.001)
else:
raise RuntimeError("Read transfer failed with state %s", status)
# self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000)
rearrange_start_time = time.perf_counter()
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
for kv_cache in self.kv_caches:
for cache in kv_cache:
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg):
logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg)
# hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
# isl[-1] token is calculated in the first iteration in decode.
# If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
# one less block due to the missing last token.
remote_block_ids = remote_block_ids[:len(local_block_ids)]
assert len(staging_block_ids) == len(local_block_ids)
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
if len(local_block_ids) == 0:
logger.debug("No blocks to write")
for i in range(tp_multiplier):
self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg)
return
start_time = time.perf_counter()
local_ranges = self._get_ranges(local_block_ids)
staging_ranges = self._get_ranges(staging_block_ids)
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
for kv_cache in self.kv_caches:
for cache in kv_cache:
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
create_xfer_start_time = time.perf_counter()
# getting block descs ids
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
for i in range(tp_multiplier):
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids,
remote_xfer_side_handle, remote_block_descs_ids,
notify_msg)
self._transfers[notify_msg].append(handle)
status = self.nixl_wrapper.transfer(handle)
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
transfer_start_time = time.perf_counter()
logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000)
def get_notifs(self):
return self.nixl_wrapper.update_notifs()
def get_new_notifs(self):
return self.nixl_wrapper.get_new_notifs()
def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
self._tp_size[engine_id] = agent_tp
agent_names = []
for agent_meta in agent_metadata:
agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
agent_names.append(agent_name)
self._remote_agents[engine_id] = agent_names
self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id]
assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
dst_block_len = self.block_len // tp_multiplier
if tp_multiplier not in self.src_xfer_side_handles:
# create descs and xfer side handles
blocks_data = []
for layer_id in range(self.num_layers):
for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
for i in range(tp_multiplier):
tp_multiplier_offset = i * dst_block_len
blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs)
# create dst xfer side handles
self.dst_num_blocks[engine_id] = num_blocks
for i in range(tp_multiplier):
blocks_data = []
for layer_id in range(self.num_layers):
for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]:
for block_id in range(num_blocks):
block_offset = block_id * dst_block_len
blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
return agent_names
def get_done_tranfers(self) -> List[str]:
done_req_ids = []
for req_id, handles in self._transfers.items():
running_reqs = []
for handle in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
# self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors?
continue
if xfer_state == "PROC":
running_reqs.append(handle)
else:
raise RuntimeError("Transfer failed with state %s", xfer_state)
if len(running_reqs) == 0:
done_req_ids.append(req_id)
else:
self._transfers[req_id] = running_reqs
return done_req_ids

View File

@ -0,0 +1,350 @@
# SPDX-License-Identifier: Apache-2.0
"""
Simple KV Cache Connector for Distributed Machine Learning Inference
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
logger = init_logger(__name__)
class DynamoConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
world_group,
):
self.config = config.kv_transfer_config
self.tp_size = config.parallel_config.tensor_parallel_size
self.rank = rank
if self.config.kv_connector != "DynamoNcclConnector":
raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class")
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
PyNcclPipe)
from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import (
DynamoNcclDataPlane)
logger.info(
"Initializing DynamoNcclConnector under kv_transfer_config %s",
self.config)
self.lookup_buffer_size = self.config.kv_buffer_size
self.producer_data_pipe: PyNcclPipe
self.consumer_data_pipe: PyNcclPipe
self.producer_signal_pipe: PyNcclPipe
self.consumer_signal_pipe: PyNcclPipe
self._broadcast_and_enhance_kv_config(rank, config, world_group)
self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
self.tp_size = config.parallel_config.tensor_parallel_size
# 2 pipes for every rank in the world
if self.config.is_kv_producer:
port_offset_base = rank + 1
else:
port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
self.data_pipe = PyNcclPipe(
kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.data_plane = DynamoNcclDataPlane(
data_pipe=self.data_pipe,
port=self._get_data_plane_port(self.global_kv_rank),
)
def send_kv_caches_and_hidden_states(
self,
model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor,
IntermediateTensors],
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
model_config = model_executable.model.config
is_deepseek = "deepseek" in model_config.architectures[0].lower()
if not is_deepseek:
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
else:
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(4.5 * hidden_size / num_attention_heads)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
# FIXME(Kuntai): This assume that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
current_request_id = request_ids[idx]
decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
for target_rank in range(self.config.tensor_parallel_multiplier):
keys, values = [], []
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
head_start = target_rank * num_heads_per_rank
head_end = head_start + num_heads_per_rank
if not is_deepseek:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
else:
key_cache = kv_cache
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(torch.empty(0))
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
decode_global_rank = decode_first_global_rank + target_rank
decode_port = self._get_data_plane_port(decode_global_rank)
partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
self._send(decode_hostname, decode_port, current_request_id, keys, values,
partial_hidden_or_intermediate_states)
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
# When bypass_model_exec is set to False, it means that at least for one
# request its corresponding KV cache or hidden state is missing.
# In this case we need to do prefilling to recompute missing KV cache
# and hidden states.
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
start_pos_list = []
model_config = model_executable.model.config
is_deepseek = "deepseek" in model_config.architectures[0].lower()
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
current_request_id = request_ids[idx]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self._recv(current_request_id)
keys: torch.Tensor = ret[0]
values: torch.Tensor = ret[1]
hidden: torch.Tensor = ret[2]
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
if not is_deepseek:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
else:
key_cache = kv_cache
copy_from =keys[i - model_executable.model.start_layer].to(
key_cache.device)
kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
hidden_or_intermediate_states_for_one_req.append(hidden)
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.debug(
"[rank%d]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.debug(
"[rank%d]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self):
self.data_pipe.close()
# self.data_plane.close()
@staticmethod
def parse_request_id(request_id: str) -> Tuple[str, int]:
# Regular expression to match the string hostname and integer decode_kv_rank
pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
decode_hostname = match.group(1)
decode_rank = int(match.group(2))
return decode_hostname, decode_rank
raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
remote_address = f"{hostname}:{port}"
self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
keys = self.data_plane.recv_tensor(f"{request_id}_keys")
values = self.data_plane.recv_tensor(f"{request_id}_values")
hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
return keys, values, hidden
def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
if kv_rank < config.kv_producers_parallel_size:
return kv_rank
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
if kv_rank <= config.kv_producers_parallel_size:
return kv_rank * config.kv_producers_tensor_parallel_size + rank
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
def _get_data_plane_port(self, global_kv_rank: int) -> int:
return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
if rank == 0:
config_group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port,
rank=self.config.kv_rank,
world_size=self.config.kv_parallel_size,
)
parallel_configs = config_group.all_gather_obj({
"kv_role": self.config.kv_role,
"tensor_parallel_size": config.parallel_config.tensor_parallel_size,
"pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
})
logger.debug("parallel_configs: %s", parallel_configs)
kv_config_enhanced = {
"kv_producers_tensor_parallel_size": None,
"kv_consumers_tensor_parallel_size": None,
"kv_producers_pipeline_parallel_size": None,
"kv_consumers_pipeline_parallel_size": None,
"kv_producers_parallel_size": 0,
}
for parallel_config in parallel_configs:
kv_role = parallel_config["kv_role"]
assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
if kv_role == "kv_producer":
kv_config_enhanced["kv_producers_parallel_size"] += 1
if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
else:
assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
world_group.broadcast_object(kv_config_enhanced)
else:
kv_config_enhanced = world_group.broadcast_object()
logger.info("kv_config_enhanced: %s", kv_config_enhanced)
self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]

View File

@ -27,13 +27,13 @@ class KVConnectorFactory:
@classmethod
def create_connector(cls, rank: int, local_rank: int,
config: "VllmConfig") -> KVConnectorBase:
config: "VllmConfig", world_group) -> KVConnectorBase:
connector_name = config.kv_transfer_config.kv_connector
if connector_name not in cls._registry:
raise ValueError(f"Unsupported connector type: {connector_name}")
connector_cls = cls._registry[connector_name]()
return connector_cls(rank, local_rank, config)
return connector_cls(rank, local_rank, config, world_group)
# Register various connectors here.
@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnector",
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
"SimpleConnector")
KVConnectorFactory.register_connector(
"DynamoNcclConnector",
"vllm.distributed.kv_transfer.kv_connector.dynamo_connector",
"DynamoConnector")

View File

@ -8,13 +8,15 @@ MooncakePipe.
But the logic can be extended to support other pipe and lookup buffer.
"""
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import torch
from vllm import _custom_ops as ops
from vllm.config import VllmConfig
from vllm.config import VllmConfig, KVTransferConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.distributed.utils import StatelessProcessGroup
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
SimpleBuffer)
from vllm.logger import init_logger
@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
rank: int,
local_rank: int,
config: VllmConfig,
world_group,
):
self.config = config.kv_transfer_config
@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase):
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
# 2 pipes for every rank in the world
port_offset_base = 2 * rank
self._broadcast_and_enhance_kv_config(rank, config, world_group)
self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
self.tp_size = config.parallel_config.tensor_parallel_size
# 2 pipes for every rank in the world
if self.config.is_kv_producer:
port_offset_base = 2 * rank + 1
else:
port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
# In disaggregated prefill, the prefill vLLM only uses send pipe
# and the decode vLLM only uses recv pipe
if self.config.is_kv_producer:
if self.config.kv_connector == "PyNcclConnector":
self.producer_data_pipe = PyNcclPipe(
kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.producer_signal_pipe = PyNcclPipe(
kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
# its recv pipe to the send pipe of KV producder
if self.config.kv_connector == "PyNcclConnector":
self.consumer_data_pipe = PyNcclPipe(
kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base,
)
self.consumer_signal_pipe = PyNcclPipe(
kv_group_rank=self.kv_group_rank,
local_rank=local_rank,
config=self.config,
port_offset=port_offset_base + 1,
@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
self.config.kv_buffer_size,
)
def select(self, input_tokens: Optional[torch.Tensor],
def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
assert self.consumer_buffer is not None, "Please initialize the "\
"consumer buffer before calling select."
return self.consumer_buffer.drop_select(input_tokens, roi)
return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
assert self.producer_buffer is not None, "Please initialize the "\
"producer buffer before calling insert."
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
def send_kv_caches_and_hidden_states(
self,
@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase):
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
request_ids = list(model_input.request_ids_to_seq_ids.keys())
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
is_deepseek = "deepseek" in model_config.architectures[0].lower()
if not is_deepseek:
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
else:
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(4.5 * hidden_size / num_attention_heads)
# query_lens contains new KV caches that are added to vLLM.
# so we will send them to decode instance
@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
current_request_id = request_ids[idx]
_, decode_kv_rank = self.parse_request_id(current_request_id)
starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
keys, values = [], []
for target_rank in range(self.config.tensor_parallel_multiplier):
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
keys, values = [], []
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(value_cache[current_slot_mapping].unsqueeze(0))
num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
head_start = target_rank * num_heads_per_rank
head_end = head_start + num_heads_per_rank
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
if not is_deepseek:
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
else:
key_cache = kv_cache
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
values.append(torch.empty(0))
self.insert(current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
keys = torch.cat(keys, dim=0)
values = torch.cat(values, dim=0)
self.insert(starting_kv_group_rank, target_rank, current_tokens,
torch.ones_like(current_tokens,
dtype=bool), keys, values,
hidden_or_intermediate_states[start_pos:end_pos])
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase):
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
request_ids = list(model_input.request_ids_to_seq_ids.keys())
hidden_or_intermediate_states_for_one_req = []
@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase):
num_computed_tokens_list = []
start_pos_list = []
model_config = model_executable.model.config
is_deepseek = "deepseek" in model_config.architectures[0].lower()
# enumerate different requests
# FIXME(Kuntai): This impl assumes that all requests are prefill.
for idx, slen in enumerate(seq_lens):
@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
current_request_id = request_ids[idx]
prefill_rank, _ = self.parse_request_id(current_request_id)
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
ret = self.select(current_tokens,
ret = self.select(prefill_rank, current_tokens,
torch.ones_like(current_tokens, dtype=bool))
if ret[0] is None:
# didn't find any match.
@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase):
kv_cache = kv_caches[i - model_executable.model.start_layer]
layer = model_executable.model.layers[i]
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
if not is_deepseek:
key_cache, value_cache = kv_cache[0], kv_cache[1]
ops.reshape_and_cache_flash(
keys[i - model_executable.model.start_layer].to(
key_cache.device),
values[i - model_executable.model.start_layer].to(
value_cache.device),
key_cache,
value_cache,
slot_mapping[start_pos:end_pos],
layer.self_attn.attn.kv_cache_dtype,
layer.self_attn.attn._k_scale,
layer.self_attn.attn._v_scale,
)
else:
key_cache = kv_cache
copy_from =keys[i - model_executable.model.start_layer].to(
key_cache.device)
kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
hidden_or_intermediate_states_for_one_req.append(hidden)
@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase):
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
# close the data_pipe.
pass
@staticmethod
def parse_request_id(request_id):
# Regular expression to match the ranks
pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
# Use re.search to find the pattern in the request_id
match = re.search(pattern, request_id)
if match:
# Extract the ranks
prefill_rank = int(match.group(1))
decode_rank = int(match.group(2))
return prefill_rank, decode_rank
else:
return None, None
def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
if kv_rank < config.kv_producers_parallel_size:
return kv_rank
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
if rank == 0:
if self.config.kv_connector == "PyNcclConnector":
config_group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port,
rank=self.config.kv_rank,
world_size=self.config.kv_parallel_size,
)
parallel_configs = config_group.all_gather_obj({
"kv_role": self.config.kv_role,
"tensor_parallel_size": config.parallel_config.tensor_parallel_size,
"pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
})
logger.debug("parallel_configs: %s", parallel_configs)
kv_config_enhanced = {
"kv_producers_tensor_parallel_size": None,
"kv_consumers_tensor_parallel_size": None,
"kv_producers_pipeline_parallel_size": None,
"kv_consumers_pipeline_parallel_size": None,
"kv_producers_parallel_size": 0,
}
for parallel_config in parallel_configs:
kv_role = parallel_config["kv_role"]
assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
if kv_role == "kv_producer":
kv_config_enhanced["kv_producers_parallel_size"] += 1
if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
else:
assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
world_group.broadcast_object(kv_config_enhanced)
else:
raise NotImplementedError("MooncakeConnector is not supported in Dynamo patch")
else:
kv_config_enhanced = world_group.broadcast_object()
logger.info("kv_config_enhanced: %s", kv_config_enhanced)
self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]

View File

@ -12,7 +12,8 @@
import threading
import time
from collections import deque
from typing import Deque, List, Optional, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Deque, List, Optional, Union, Dict
import torch
@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase):
self.buffer_lock = threading.Lock()
self.signal_pipe = signal_pipe
self.data_pipe = data_pipe
self.request_handling_thread: Optional[threading.Thread] = None
self.request_handling_thread: Optional[ThreadPoolExecutor] = None
self.normal_signal = torch.tensor([0], device="cpu")
self.end_signal = None
@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase):
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
# tokens_roi_recver: tokens and roi of the consumer (query)
tokens_sender = tokens_roi_sender[0]
tokens_recver = tokens_roi_recver[0]
roi_sender = tokens_roi_sender[1]
roi_recver = tokens_roi_recver[1]
target_rank_sender = tokens_roi_sender[0]
target_rank_recver = tokens_roi_recver[0]
if target_rank_sender.item() != target_rank_recver.item():
return 0
tokens_sender = tokens_roi_sender[1]
tokens_recver = tokens_roi_recver[1]
roi_sender = tokens_roi_sender[2]
roi_recver = tokens_roi_recver[2]
if tokens_recver is None:
# consumer sends an empty request
@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase):
return 0
def _send_tensor_and_dec_size(self,
tensor: Optional[torch.Tensor]) -> None:
def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor],
target_rank: int) -> None:
assert tensor is not None, "Use self.data_pipe.send(None) instead"
self.buffer_size -= tensor.element_size() * tensor.numel()
if tensor.dtype == torch.bool:
tensor = tensor.float()
self.data_pipe.send_tensor(tensor)
self.data_pipe.send_tensor(tensor, target_rank)
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase):
raise AssertionError(f"Unknown data type {type(data)}")
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor):
@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase):
if isinstance(hidden, torch.Tensor):
hidden = hidden.clone()
buffer_item = [input_tokens, roi, key, value, hidden]
buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden]
with self.buffer_lock:
for data in buffer_item:
@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase):
def _is_end_signal(self, signal):
return signal is None
def drop_select_handler(self):
def drop_select_handler(self, rank: int):
try:
while True:
signal = self.signal_pipe.recv_tensor()
if self._is_end_signal(signal):
logger.info("Received end signal!")
break
signal = self.signal_pipe.recv_tensor(rank)
if self._is_end_signal(signal):
logger.info("Received end signal!")
return
target_kv_rank = self.data_pipe.recv_tensor(rank)
# assert target_rank.item() == rank, "Target rank does not match"\
# "the rank of the drop-select handler"
input_tokens = self.data_pipe.recv_tensor(rank)
roi = self.data_pipe.recv_tensor(rank)
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [target_kv_rank, input_tokens, roi]
input_tokens = self.data_pipe.recv_tensor()
matched_length = 0
roi = self.data_pipe.recv_tensor()
assert roi is not None, "Please provide the roi when sending "\
"drop-select request"
roi = (roi > 0.5)
tokens_roi_recver = [input_tokens, roi]
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:
matched_length = 0
for _ in range(len(self.buffer)):
# perform input tokens and roi matching
# FIXME: this matching is O(n), ideally it should be O(1)
# but this buffer size won't (and shouldn't) be too large so
# the fix is not urgent.
with self.buffer_lock:
temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
for _ in range(len(self.buffer)):
if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
target_rank = matched_item[0].item()
for tensor in matched_item[1:]:
self._send_tensor_and_dec_size(tensor, rank)
temp_length = self._matches(self.buffer[0],
tokens_roi_recver)
if temp_length > 0:
matched_length = temp_length
break
# rotate the element we just accessed to the end
self.buffer.rotate(-1)
if matched_length > 0:
# need to clone the tensor
# in case the tensor is freed before sending finishes
matched_item = self.buffer.popleft()
for tensor in matched_item:
self._send_tensor_and_dec_size(tensor)
else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None)
else:
# no match, just send None
for _ in range(5):
self.data_pipe.send_tensor(None, rank)
except RuntimeError as e:
if 'Connection closed by peer' not in str(e):
@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase):
logger.debug("Closing drop_select_handler")
def drop_select(
self, input_tokens: Optional[torch.Tensor],
self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor],
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
assert self.request_handling_thread is None, \
assert not self.request_handling_thread, \
"drop_select should be called by the KV cache consumer "\
"(e.g. the decode vLLM instance)"
@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase):
if isinstance(roi, torch.Tensor):
roi = roi.clone().float()
self.signal_pipe.send_tensor(self.normal_signal)
self.data_pipe.send_tensor(input_tokens)
self.data_pipe.send_tensor(roi)
self.signal_pipe.send_tensor(self.normal_signal, rank)
input_tokens = self.data_pipe.recv_tensor()
roi = self.data_pipe.recv_tensor()
self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
self.data_pipe.send_tensor(input_tokens, rank)
self.data_pipe.send_tensor(roi, rank)
input_tokens = self.data_pipe.recv_tensor(rank)
roi = self.data_pipe.recv_tensor(rank)
if roi is not None:
# convert from float tensor to bool tensor
# as PyNccl does not support sending bool tensor
roi = (roi > 0.5)
key = self.data_pipe.recv_tensor()
value = self.data_pipe.recv_tensor()
hidden = self.data_pipe.recv_tensor()
key = self.data_pipe.recv_tensor(rank)
value = self.data_pipe.recv_tensor(rank)
hidden = self.data_pipe.recv_tensor(rank)
return [input_tokens, roi, key, value, hidden]
def full_handler(self):
time.sleep(0.001)
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
key: torch.Tensor, value: torch.Tensor,
hidden: torch.Tensor) -> None:
@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
while self.buffer_size > self.buffer_size_threshold:
self.full_handler()
self._add_to_buffer(input_tokens, roi, key, value, hidden)
self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
# when calling the insert, the current process is a sender
# need to launch the request handler and start listening to request.
target_rank_global = target_rank + kv_group_rank
if self.request_handling_thread is None:
self.request_handling_thread = threading.Thread(
target=self.drop_select_handler)
self.request_handling_thread.start()
self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
def close(self):
if hasattr(self, "request_handling_thread"
) and self.request_handling_thread is not None:
self.request_handling_thread.join()
if hasattr(self, "request_handling_thread") and self.request_handling_thread:
self.request_handling_thread.shutdown()
else:
# TODO: have a explicit close signal and have a explicit way to

View File

@ -23,7 +23,7 @@ class KVPipeBase(ABC):
"""
@abstractmethod
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
"""Send a tensor, or None, via the pipe.
Need to support sending None -- important for error handling.
@ -41,7 +41,7 @@ class KVPipeBase(ABC):
raise NotImplementedError
@abstractmethod
def recv_tensor(self) -> Optional[torch.Tensor]:
def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""Receive a tensor (can be None) from the pipeline.
Returns:

View File

@ -0,0 +1,124 @@
import logging
import threading
import typing
import zmq
import socket
import time
import torch
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
logger = logging.getLogger(__name__)
class DynamoNcclDataPlane:
def __init__(
self,
data_pipe: PyNcclPipe,
hostname: str = "",
port: int = 0,
) -> None:
self.data_pipe = data_pipe
if not hostname:
hostname = socket.gethostname()
if port == 0:
raise ValueError("Port cannot be 0")
self._hostname = hostname
self._port = port
self.store = {}
self.context = zmq.Context()
self.rep_socket = self.context.socket(zmq.REP)
logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
self._listener_thread.start()
self.req_sockets = {}
logger.info(f"Rank {self.rank} connected to the server")
@property
def rank(self):
return self.data_pipe.kv_group_rank
def send_tensor(
self,
tensor: torch.Tensor,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
return self._send_tensor(tensor, tensor_id, remote_address)
def recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
ret = self._recv_tensor(tensor_id, remote_address)
return ret
def _send_tensor(
self,
tensor: torch.Tensor,
tensor_id: str,
remote_address: typing.Optional[str] = None,
):
logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
if remote_address is None:
self.store[tensor_id] = tensor
else:
# tensor_shape = "_".join(str(dim) for dim in tensor.shape)
# tensor_dtype = str(tensor.dtype)
if remote_address not in self.req_sockets:
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
req_socket = self.req_sockets[remote_address]
# req_socket.connect(f"tcp://{remote_address}")
req_socket.send_string(f"PUT {self.rank} {tensor_id}")
dst_rank = req_socket.recv_string()
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
self.data_pipe.send_tensor(tensor, int(dst_rank))
def _recv_tensor(
self,
tensor_id: str,
remote_address: typing.Optional[str] = None,
) -> torch.Tensor:
logger.debug(f"Rank {self.rank} receiving tensor")
if remote_address is not None:
raise NotImplementedError("Getting tensor from remote rank not implemented")
if tensor_id in self.store:
logger.debug(f"Popping tensor {tensor_id} from store")
future = self.store.pop(tensor_id)
tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
logger.debug(f"Rank {self.rank} received tensor")
return tensor
logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
time.sleep(0.001)
return self._recv_tensor(tensor_id, remote_address)
# raise NotImplementedError("Tensor not found in store")
def _receive_tensor(
self,
tensor_id: str,
rank: int,
):
future = self.data_pipe.recv_tensor(rank)
logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
self.store[tensor_id] = future
def listen_for_requests(self):
while True:
cmd, rank, tensor_id = self.rep_socket.recv_string().split()
logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
self.rep_socket.send_string(f"{self.rank}")
if cmd == "GET":
raise NotImplementedError("Getting tensor from remote rank not implemented")
elif cmd == "PUT":
rank = int(rank)
# shape = [int(dim) for dim in shape.split("_")]
# dtype = getattr(torch, dtype)
self._receive_tensor(tensor_id, rank)

View File

@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase):
METADATA_DTYPE = torch.int64
def __init__(self,
kv_group_rank: int,
local_rank: int,
config: KVTransferConfig,
device: Optional[str] = None,
port_offset: int = 0):
self.config = config
self.local_rank = local_rank
self.kv_rank = self.config.kv_rank
self.kv_group_rank = kv_group_rank
self.kv_parallel_size = self.config.kv_parallel_size
self.kv_world_size = self.config.kv_world_size
if device is None:
self.device = self._select_device(self.config.kv_buffer_device)
else:
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
rank=self.kv_group_rank,
world_size=self.kv_world_size,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
impl = self._get_device_send_recv_impl(self.group)
self.device_send_func, self.device_recv_func = impl
# set target rank
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
# transportation-related variables
self.transport_thread: Optional[ThreadPoolExecutor] = None
@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase):
dtype=metadata["dtype"],
device=self.device)
def _send_metadata(self, metadata: Metadata):
def _send_metadata(self, metadata: Metadata, target_rank: int):
"""
Send the metadata dictionary to the target rank.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape".
"""
self.group.send_obj(metadata, self.target_rank_for_send)
self.group.send_obj(metadata, target_rank)
def _recv_metadata(self) -> Metadata:
def _recv_metadata(self, src_rank: int) -> Metadata:
"""
Receive the metadata dictionary from the target rank.
@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase):
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
return self.group.recv_obj(src_rank)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
target rank.
@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase):
being sent.
"""
metadata = self._make_metadata(tensor)
self._send_metadata(metadata)
self._send_metadata(metadata, target_rank)
if tensor is not None:
self.device_send_func(tensor.to(self.device),
self.target_rank_for_send)
target_rank)
def _recv_impl(self) -> Optional[torch.Tensor]:
def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
the target rank.
@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase):
Returns:
- buffer: The received tensor, or None if no tensor is received.
"""
metadata = self._recv_metadata()
metadata = self._recv_metadata(src_rank)
if metadata["dtype"] is None:
return None
buffer = self._prepare_recv_buffer(metadata)
self.device_recv_func(buffer, self.target_rank_for_recv)
self.device_recv_func(buffer, src_rank)
return buffer
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
tensor_size: int) -> None:
tensor_size: int,
target_rank: int) -> None:
"""
Wrapper for _send_impl to handle exceptions and update buffer size.
"""
try:
self._send_impl(tensor)
self._send_impl(tensor, target_rank)
with self.buffer_size_lock:
self.buffer_size -= tensor_size
@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase):
logger.debug("KV cache transfer pipe is full. Waiting...")
time.sleep(0.05)
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase):
Parameters:
- tensor: The tensor to send, or None if no tensor is being sent.
"""
logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank)
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
@ -241,32 +243,39 @@ class PyNcclPipe(KVPipeBase):
with self.buffer_size_lock:
self.buffer_size += tensor_size
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
tensor_size)
future = self.transport_thread.submit(self.send_tensor_wrapper, tensor,
tensor_size,
target_rank)
return future
def recv_tensor(self) -> Optional[torch.Tensor]:
def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
"""
Receives a tensor and its metadata from the source rank. Blocking call.
Returns:
- tensor: The received tensor, or None if no tensor is received.
"""
logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank)
if self.transport_thread is None:
self.transport_thread = ThreadPoolExecutor(max_workers=1)
future = self.transport_thread.submit(self._recv_impl)
future = self.transport_thread.submit(self._recv_impl, src_rank)
try:
tensor = future.result()
except Exception as e:
logger.error("Encountering exception in KV receiving thread")
logger.error("%s", e)
logger.error("My device: %s", self.device)
import traceback
traceback.print_exc()
raise e
return future
return tensor
# try:
# tensor = future.result()
# except Exception as e:
# logger.error("Encountering exception in KV receiving thread")
# logger.error("%s", e)
# logger.error("My device: %s", self.device)
# import traceback
# traceback.print_exc()
# raise e
# return tensor
def close(self):
"""

View File

@ -35,6 +35,7 @@ class KVTransferAgent:
rank: int,
local_rank: int,
config: "VllmConfig",
world_group,
):
self.config = config
@ -47,7 +48,7 @@ class KVTransferAgent:
"TransferAgent should only be used when kv_connector is set."
self.connector = KVConnectorFactory.create_connector(
rank, local_rank, config)
rank, local_rank, config, world_group)
def send_kv_caches_and_hidden_states(
self,

View File

@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
_KV_TRANSFER = kv_transfer.KVTransferAgent(
rank=get_world_group().rank,
local_rank=get_world_group().local_rank,
config=vllm_config)
config=vllm_config,
world_group=get_world_group())
def ensure_model_parallel_initialized(

View File

@ -2,13 +2,17 @@
import copy
import time
import pickle
import uuid
from collections import Counter as collectionsCounter
from collections import deque
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
List, Mapping, NamedTuple, Optional)
List, Mapping, NamedTuple, Optional, Tuple)
from typing import Sequence as GenericSequence
from typing import Set, Type, Union, cast, overload
@ -60,6 +64,9 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
from vllm.version import __version__ as VLLM_VERSION
from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType
from vllm.distributed.device_communicators.nixl import NixlMetadata
logger = init_logger(__name__)
_LOCAL_LOGGING_INTERVAL_SEC = 5
@ -90,7 +97,7 @@ class OutputData(NamedTuple):
# outputs from multiple steps.
is_first_step_output: Optional[bool]
skip: List[int]
remote_prefill_requests: Optional[List[RemotePrefillRequest]]
class SchedulerContext:
@ -104,11 +111,14 @@ class SchedulerContext:
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
self.remote_prefill_requests: List[RemotePrefillRequest] = []
def append_output(self, outputs: List[SamplerOutput],
seq_group_metadata_list: List[SequenceGroupMetadata],
scheduler_outputs: SchedulerOutputs, is_async: bool,
is_last_step: bool,
is_first_step_output: Optional[bool]):
is_first_step_output: Optional[bool],
remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None):
self.output_queue.append(
OutputData(outputs=outputs,
seq_group_metadata_list=seq_group_metadata_list,
@ -116,7 +126,9 @@ class SchedulerContext:
is_async=is_async,
is_last_step=is_last_step,
is_first_step_output=is_first_step_output,
skip=[]))
skip=[],
remote_prefill_requests=remote_prefill_requests))
class LLMEngine:
@ -348,7 +360,7 @@ class LLMEngine:
# GPU and CPU blocks, which are profiled in the distributed executor.
self.scheduler = [
Scheduler(
self.scheduler_config, self.cache_config, self.lora_config,
self.model_config, self.scheduler_config, self.cache_config, self.lora_config,
self.parallel_config.pipeline_parallel_size,
self.async_callbacks[v_id]
if self.model_config.use_async_output_proc else None)
@ -405,6 +417,40 @@ class LLMEngine:
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
self.engine_id = str(uuid.uuid4())
self._nixl_agents_names: Optional[List[str]] = None
if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector":
self._nixl_agents_names = self._initialize_nixl()
self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
self._finished_prefills = set()
self._finished_transfers = set()
@property
def is_nixl_initialized(self) -> bool:
return getattr(self, "_nixl_agents_names", None) is not None
def get_nixl_metadata(self) -> NixlMetadata:
if not self.is_nixl_initialized:
raise RuntimeError("Nixl is not initialized")
agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata")
kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr")
return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks)
def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]:
if not self.is_nixl_initialized:
raise RuntimeError("Nixl is not initialized")
engine_id = nixl_metadata.engine_id
agents_metadata = nixl_metadata.agent_metadata
kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
num_blocks = nixl_metadata.num_blocks
return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks))
def _initialize_nixl(self) -> List[bytes]:
agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,))
return agents_names
def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
@ -500,6 +546,8 @@ class LLMEngine:
# Shutdown model executor when engine is garbage collected
# Use getattr since __init__ can fail before the field is set
if model_executor := getattr(self, "model_executor", None):
if self.is_nixl_initialized:
model_executor.collective_rpc("shutdown_nixl")
model_executor.shutdown()
def get_tokenizer_group(
@ -552,11 +600,14 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest],
trace_headers: Optional[Mapping[str, str]] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> Optional[SequenceGroup]:
"""Add a processed request to the engine's request pool.
return the created sequence group.
"""
if isinstance(params, SamplingParams) and params.n > 1:
if remote_prefill_params is not None:
raise ValueError("Remote prefill params are not supported for multi-step sampling")
ParallelSampleSequenceGroup.add_request(
request_id,
self,
@ -574,6 +625,8 @@ class LLMEngine:
# Create the sequences.
block_size = self.cache_config.block_size
seq_id = next(self.seq_counter)
if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
next(self.seq_counter) # empty sequence for staging
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
if is_encoder_decoder_inputs(processed_inputs):
@ -584,7 +637,7 @@ class LLMEngine:
encoder_inputs = None
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
lora_request, prompt_adapter_request)
lora_request, prompt_adapter_request, remote_prefill_params)
encoder_seq = (None if encoder_inputs is None else Sequence(
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
@ -601,8 +654,12 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
priority=priority,
remote_prefill_params=remote_prefill_params,
)
elif isinstance(params, PoolingParams):
if remote_prefill_params is not None:
raise ValueError("Remote prefill params are not supported for pooling")
seq_group = self._create_sequence_group_with_pooling(
request_id,
seq,
@ -673,6 +730,7 @@ class LLMEngine:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@ -765,6 +823,7 @@ class LLMEngine:
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
priority=priority,
remote_prefill_params=remote_prefill_params,
)
def _validate_token_prompt(self, prompt: PromptType,
@ -799,6 +858,7 @@ class LLMEngine:
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
encoder_seq: Optional[Sequence] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> SequenceGroup:
"""Creates a SequenceGroup with SamplingParams."""
max_logprobs = self.get_model_config().max_logprobs
@ -829,7 +889,9 @@ class LLMEngine:
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
encoder_seq=encoder_seq,
priority=priority)
priority=priority,
remote_prefill_params=remote_prefill_params
)
return seq_group
@ -995,11 +1057,11 @@ class LLMEngine:
# When we process only one request, no pop is required
# (since later we will process all of the rest)
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0]
else:
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
is_last_step, is_first_step_output,
skip) = ctx.output_queue.popleft()
skip, remote_prefill_requests) = ctx.output_queue.popleft()
# Sanity check
assert len(seq_group_metadata_list) == len(
@ -1325,15 +1387,55 @@ class LLMEngine:
# Clear outputs for each new scheduler iteration
ctx.request_outputs.clear()
ctx.remote_prefill_requests.clear()
# Skip the scheduler if there are any remaining steps in the seq groups.
# This ensures that the scheduler is only called again when the current
# batch has completed.
remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = []
running_seq_group_metadata_list: List[SequenceGroupMetadata] = []
remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = []
running_scheduled_seq_groups: List[ScheduledSequenceGroup] = []
if not self._has_remaining_steps(seq_group_metadata_list):
# Schedule iteration
(seq_group_metadata_list, scheduler_outputs,
allow_async_output_proc
) = self.scheduler[virtual_engine].schedule()
) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
# Separate remote prefill and running seq groups
for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups):
if seq_group_metadata.do_remote_prefill:
remote_prefill_seq_group_metadata_list.append(seq_group_metadata)
remote_prefill_scheduled_seq_groups.append(scheduled_seq_group)
else:
running_seq_group_metadata_list.append(seq_group_metadata)
running_scheduled_seq_groups.append(scheduled_seq_group)
seq_group_metadata_list = running_seq_group_metadata_list
scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups
# Send remote prefill requests before model execution
for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups):
assert len(scheduled_seq_group.seq_group.seqs) == 1
assert self._nixl_agents_names
seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
block_table = seq_group_metadata.block_tables[seq_id]
if len(block_table) == len(seq_group_metadata.computed_block_nums):
logger.debug("No blocks to prefill")
self._finished_prefills.add(seq_group_metadata.request_id)
continue
remote_prefill_request = RemotePrefillRequest(
request_id=seq_group_metadata.request_id,
# prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway
prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks)
sampling_params=scheduled_seq_group.seq_group.sampling_params,
block_ids=block_table,
engine_id=self.engine_id,
computed_block_ids=seq_group_metadata.computed_block_nums,
)
scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request)
ctx.seq_group_metadata_list = seq_group_metadata_list
ctx.scheduler_outputs = scheduler_outputs
@ -1383,9 +1485,46 @@ class LLMEngine:
execute_model_req.async_callback = self.async_callbacks[
virtual_engine]
outputs = self.model_executor.execute_model(
execute_model_req=execute_model_req)
# After model execution, we need to transfer the memory from the prefill to the decode
memory_transfer_reqs = []
for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list):
remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params
if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
assert len(scheduled_seq_group.seq_group.seqs) == 1
req_id = scheduled_seq_group.seq_group.request_id
seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
block_table = seq_group_metadata.block_tables[seq_id]
staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks]
if computed_decode_block_ids:
kv_recv_req = MemoryTransferRequest(
request_id=req_id,
local_block_ids=block_table[:num_computed_blocks],
staging_block_ids=staging_block_ids[:num_computed_blocks],
remote_block_ids=computed_decode_block_ids,
remote_engine_id=remote_prefill_params.decode_engine_id,
notify_msg=req_id,
op_type=MemoryOpType.READ
)
memory_transfer_reqs.append(kv_recv_req)
kv_send_req = MemoryTransferRequest(
request_id=req_id,
local_block_ids=block_table[num_computed_blocks:],
staging_block_ids=staging_block_ids[num_computed_blocks:],
remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:],
remote_engine_id=remote_prefill_params.decode_engine_id,
notify_msg=req_id,
op_type=MemoryOpType.WRITE
)
memory_transfer_reqs.append(kv_send_req)
execute_model_req.memory_transfer_requests = memory_transfer_reqs
outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
# We need to do this here so that last step's sampled_token_ids can
# be passed to the next iteration for PP.
if self.scheduler_config.is_multi_step:
@ -1396,7 +1535,26 @@ class LLMEngine:
if len(ctx.output_queue) > 0:
self._process_model_outputs(ctx=ctx)
# No outputs in this case
outputs = []
execute_model_req = ExecuteModelRequest(
seq_group_metadata_list=[],
blocks_to_swap_in=[],
blocks_to_swap_out=[],
blocks_to_copy=[])
outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
execute_model_req=execute_model_req)
for req_id, notif_count in request_notif_counter.items():
self._request_notif_counter[req_id] += notif_count
if self._request_notif_counter[req_id] > -1:
self._finished_prefills.add(req_id)
del self._request_notif_counter[req_id]
for req_id, done_count in request_done_counter.items():
self._request_done_counter[req_id] += done_count
if self._request_done_counter[req_id] > -1:
self._finished_transfers.add(req_id)
del self._request_done_counter[req_id]
# Finish the current step for all the sequence groups.
if self.scheduler_config.is_multi_step:
@ -1456,7 +1614,7 @@ class LLMEngine:
# queued control plane messages, such as add/remove lora adapters.
logger.debug("Stopping remote worker execution loop.")
self.model_executor.stop_remote_worker_execution_loop()
return ctx.request_outputs
def _has_remaining_steps(

View File

@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.utils import deprecate_kwargs
from vllm.remote_prefill import RemotePrefillParams
from vllm.distributed.device_communicators.nixl import NixlMetadata
VLLM_RPC_SUCCESS_STR = "SUCCESS"
IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
IPC_DATA_EXT = "_data_socket"
IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket"
IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket"
IPC_METRICS_EXT = "_metrics_socket"
class MQEngineDeadError(RuntimeError):
@ -36,6 +40,7 @@ class RPCProcessRequest:
trace_headers: Optional[Mapping[str, str]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
priority: int = 0
remote_prefill_params: Optional[RemotePrefillParams] = None
@overload
def __init__(
@ -78,6 +83,7 @@ class RPCProcessRequest:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
*,
inputs: Optional[PromptType] = None, # DEPRECATED
) -> None:
@ -95,7 +101,7 @@ class RPCProcessRequest:
self.trace_headers = trace_headers
self.prompt_adapter_request = prompt_adapter_request
self.priority = priority
self.remote_prefill_params = remote_prefill_params
@dataclass
class RPCError:
@ -116,7 +122,7 @@ class RPCStartupRequest(Enum):
@dataclass
class RPCStartupResponse:
tracing_enabled: bool
nixl_metadata: Optional[bytes] = None
class RPCUProfileRequest(Enum):
START_PROFILE = 1
@ -157,3 +163,13 @@ def ENGINE_DEAD_ERROR(
return MQEngineDeadError(
"Engine loop is not running. Inspect the stacktrace to "
f"find the original error: {repr(error)}.")
@dataclass
class KvMetrics:
request_active_slots: int
request_total_slots: int
kv_active_blocks: int
kv_total_blocks: int
num_requests_waiting: int
gpu_cache_usage_perc: float
gpu_prefix_cache_hit_rate: float

View File

@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
Optional, Union, cast, overload)
import cloudpickle
import msgspec
import psutil
import zmq
import zmq.asyncio
@ -19,20 +20,23 @@ from vllm import PoolingParams
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.metrics import Stats
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.async_llm_engine import (
build_guided_decoding_logits_processor_async)
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT,
RPC_REQUEST_T,
VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest,
IPC_METRICS_EXT,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
RPCUProfileRequest, KvMetrics)
from vllm.engine.protocol import EngineClient
# yapf: enable
from vllm.envs import VLLM_RPC_TIMEOUT
@ -46,6 +50,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from vllm.utils import deprecate_kwargs
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
from vllm.distributed.device_communicators.nixl import NixlMetadata
logger = init_logger(__name__)
@ -91,6 +97,7 @@ class MQLLMEngineClient(EngineClient):
self._errored_with: Optional[BaseException] = None
# Get the configs.
self.vllm_config = engine_config
self.model_config = engine_config.model_config
self.decoding_config = engine_config.decoding_config
@ -115,6 +122,10 @@ class MQLLMEngineClient(EngineClient):
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
# Metrics.
self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL)
self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
@ -129,8 +140,27 @@ class MQLLMEngineClient(EngineClient):
# Loop to check health of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.health_loop: Optional[asyncio.Task] = None
# Loop to check metrics of the LLMEngine periodically.
# Started after the MQLLMEngine is ready.
self.metrics_loop: Optional[asyncio.Task] = None
self.metrics_publisher = None
self._engine_process = psutil.Process(engine_pid)
self.nixl_metadata: Optional[NixlMetadata] = None
self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL)
self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH)
self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {}
if self.using_nixl_connector:
self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}")
self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}")
@property
def using_nixl_connector(self) -> bool:
return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector"
@staticmethod
def is_unsupported_config(engine_args: AsyncEngineArgs):
# Pipeline parallel not yet supported
@ -180,6 +210,61 @@ class MQLLMEngineClient(EngineClient):
except Exception as e:
self._set_errored(e)
async def run_remote_prefill_request_handler_loop(self):
try:
while True:
if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT):
frames = await self.remote_prefill_request_socket.recv(copy=False)
remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest)
await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request)
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.")
async def run_metrics_loop(self, timeout: int):
"""Background loop that continually checks to ensure the engine process
is still alive.
"""
try:
while True:
# Check if the engine process is running:
if not self._engine_process.is_running() or (
self._engine_process.status() == psutil.STATUS_ZOMBIE):
# NB: is_running() returns True for zombies
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) "
"died."))
break
if await self.metrics_socket.poll(timeout=timeout):
# Metrics received- check the message
message: Frame = await self.metrics_socket.recv(copy=False)
metrics = pickle.loads(message.buffer)
if self.metrics_publisher is not None and isinstance(
metrics, KvMetrics
):
self.metrics_publisher.publish(metrics.request_active_slots,
metrics.request_total_slots,
metrics.kv_active_blocks,
metrics.kv_total_blocks,
metrics.num_requests_waiting,
metrics.gpu_cache_usage_perc,
metrics.gpu_prefix_cache_hit_rate)
logger.debug("Metrics successful.")
# TODO: Investigate sending whole stats object
except asyncio.CancelledError:
logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
except psutil.NoSuchProcess:
self._set_errored(
RuntimeError(
f"Engine process (pid {self._engine_process.pid}) died."))
except Exception as e:
self._set_errored(e)
async def run_output_handler_loop(self):
"""Get RequestOutputs from Engine and stream to Request Queues"""
@ -278,12 +363,26 @@ class MQLLMEngineClient(EngineClient):
# Wait until server is ready.
response = await self._wait_for_server_rpc(socket)
if response.nixl_metadata is not None:
assert self.using_nixl_connector
self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata)
self.tracing_flag = response.tracing_enabled
# Start health_loop.
if self.health_loop is None:
self.health_loop = asyncio.create_task(
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
if self.using_nixl_connector:
self.remote_prefill_loop = asyncio.create_task(
self.run_remote_prefill_request_handler_loop())
# Start metrics_loop.
if self.metrics_loop is None:
self.metrics_loop = asyncio.create_task(
self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT))
def close(self):
"""Destroy the ZeroMQ Context."""
@ -293,6 +392,8 @@ class MQLLMEngineClient(EngineClient):
# Cancel background tasks.
if self.health_loop is not None:
self.health_loop.cancel()
if self.metrics_loop is not None:
self.metrics_loop.cancel()
if self.output_loop is not None:
self.output_loop.cancel()
@ -415,6 +516,9 @@ class MQLLMEngineClient(EngineClient):
"""
if self._errored_with is not None:
raise self._errored_with
async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata):
await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False)
@property
def is_running(self) -> bool:
@ -473,6 +577,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
*,
inputs: Optional[PromptType] = None # DEPRECATED
) -> AsyncGenerator[RequestOutput, None]:
@ -502,7 +607,8 @@ class MQLLMEngineClient(EngineClient):
return self._process_request(prompt, sampling_params, request_id,
lora_request, trace_headers,
prompt_adapter_request, priority)
prompt_adapter_request, priority,
remote_prefill_params)
@overload
def encode(
@ -586,6 +692,7 @@ class MQLLMEngineClient(EngineClient):
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
PoolingRequestOutput, None]]:
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
@ -630,6 +737,12 @@ class MQLLMEngineClient(EngineClient):
else:
lp_bytes = None
if remote_prefill_params is not None:
self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback
remote_prefill_params.remote_prefill_request_callback = None
else:
remote_prefill_request_callback = None
request_bytes = pickle.dumps(
RPCProcessRequest(
prompt=prompt,
@ -639,11 +752,11 @@ class MQLLMEngineClient(EngineClient):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=priority,
remote_prefill_params=remote_prefill_params,
))
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,)
await self.input_socket.send_multipart(parts, copy=False)
# 4) Stream the RequestOutputs from the output queue. Note
@ -705,3 +818,6 @@ class MQLLMEngineClient(EngineClient):
# Raise on error, otherwise happily return None
if isinstance(request_output, BaseException):
raise request_output
def set_metrics_publisher(self, metrics_publisher):
self.metrics_publisher = metrics_publisher

View File

@ -3,35 +3,115 @@
import pickle
import signal
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union
from typing import Iterator, List, Optional, Union, Dict
import cloudpickle
import time
import zmq
import msgspec
from vllm import AsyncEngineArgs, SamplingParams
from vllm.engine.llm_engine import LLMEngine
# yapf conflicts with isort for this block
# yapf: disable
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
IPC_HEALTH_EXT, IPC_INPUT_EXT,
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
REQUEST_OUTPUTS_T,
VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT,
RPCAbortRequest,
IPC_OUTPUT_EXT, IPC_METRICS_EXT,
RPCAdapterLoadedResponse, RPCError,
RPCLoadAdapterRequest,
RPCProcessRequest,
RPCResetPrefixCacheRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
RPCUProfileRequest, IPC_REMOTE_NIXL_METADATA_EXT,
KvMetrics)
# yapf: enable
from vllm.logger import init_logger
from vllm.outputs import RequestOutput
from vllm.usage.usage_lib import UsageContext
from vllm.remote_prefill import RemotePrefillRequest
from vllm.distributed.device_communicators.nixl import NixlMetadata
from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo
from dataclasses import dataclass, field
logger = init_logger(__name__)
POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
class KvStatLogger(StatLoggerBase):
def __init__(
self,
max_num_seqs: int,
num_total_gpu_blocks: int,
metrics_socket
):
# Must query initialized scheduler for max infos
self.request_total_slots = max_num_seqs
self.kv_total_blocks = num_total_gpu_blocks
self.metrics_socket = metrics_socket
# KV metrics
self._send_kv_metrics(0, 0, 0, 0.0, 0.0)
def log(self, stats: Stats) -> None:
self._send_kv_metrics(
stats.num_running_sys,
int(stats.gpu_cache_usage_sys * self.kv_total_blocks),
stats.num_waiting_sys,
stats.gpu_cache_usage_sys,
stats.gpu_prefix_cache_hit_rate
)
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
pass
def _send_kv_metrics(
self,
active_slots,
active_kv_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
):
if not self.metrics_socket.closed:
metrics_bytes = pickle.dumps(
KvMetrics(
active_slots,
self.request_total_slots,
active_kv_blocks,
self.kv_total_blocks,
num_requests_waiting,
gpu_cache_usage_perc,
gpu_prefix_cache_hit_rate,
)
)
self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
# TODO: Send entire stats object to the client
# class StatLogger(StatLoggerBase):
# def __init__(
# self,
# metrics_socket
# ):
# self.metrics_socket = metrics_socket
# def log(self, stats: Stats) -> None:
# self._send_metrics(stats)
# def info(self, type: str, obj: SupportsMetricsInfo) -> None:
# pass
# def _send_metrics(self, stats: Stats):
# if not self.metrics_socket.closed:
# metrics_bytes = pickle.dumps(stats)
# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
class MQLLMEngine:
"""A multiprocessing wrapper for :class:`LLMEngine`.
@ -94,12 +174,37 @@ class MQLLMEngine:
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
# Send metrics back to client.
self.metrics_socket = self.ctx.socket(zmq.constants.PUSH)
self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}")
# IPC path for the data socket.
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
# Error state.
self._errored_with: Optional[BaseException] = None
self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH)
self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL)
if self.engine.is_nixl_initialized:
self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}")
self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}")
# Attach logger for continuous metrics publishing
self.kv_stat_logger = KvStatLogger(
self.engine.scheduler_config.max_num_seqs,
self.engine.cache_config.num_gpu_blocks,
self.metrics_socket
)
self.engine.add_logger("kv_metrics", self.kv_stat_logger)
# TODO investigate sending whole stats object
# self.general_stat_logger = StatLogger(
# self.metrics_socket
# )
# self.engine.add_logger("general_metrics", self.general_stat_logger)
@property
def dead_error(self) -> BaseException:
if self._errored_with is not None:
@ -171,8 +276,17 @@ class MQLLMEngine:
# Handle the query from the Client.
if request == RPCStartupRequest.IS_SERVER_READY:
tracing_enabled = self.engine.is_tracing_enabled()
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
# Send nixl metadata to the client
if self.engine.is_nixl_initialized:
nixl_metadata = self.engine.get_nixl_metadata()
encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata)
response = RPCStartupResponse(
tracing_enabled=tracing_enabled,
nixl_metadata=encoded_nixl_metadata)
else:
response = RPCStartupResponse(
tracing_enabled=tracing_enabled)
except Exception as e:
response = e
@ -185,6 +299,7 @@ class MQLLMEngine:
while True:
if not self.engine.has_unfinished_requests():
logger.debug("No unfinished requests")
# Poll until there is work to do.
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
# When there's no work, check on engine health and send
@ -220,6 +335,13 @@ class MQLLMEngine:
def handle_new_input(self):
"""Handle new input from the socket"""
try:
if self.engine.is_nixl_initialized:
while self.remote_nixl_metadata_socket.poll(timeout=0) != 0:
frames = self.remote_nixl_metadata_socket.recv(copy=False)
nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata)
logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id)
self.engine.add_remote_nixl_metadata(nixl_metadata)
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
@ -262,6 +384,11 @@ class MQLLMEngine:
self._send_outputs(rpc_err)
try:
if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill:
def remote_prefill_request_callback(request: RemotePrefillRequest):
logger.debug("Sending remote prefill request: %s", request.request_id)
self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False)
request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback
self.engine.add_request(
request_id=request_id,
prompt=request.prompt,
@ -269,7 +396,9 @@ class MQLLMEngine:
lora_request=request.lora_request,
trace_headers=request.trace_headers,
prompt_adapter_request=request.prompt_adapter_request,
priority=request.priority)
priority=request.priority,
remote_prefill_params=request.remote_prefill_params,
)
if self.log_requests:
logger.info("Added request %s.", request.request_id)

View File

@ -34,6 +34,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
from vllm.remote_prefill import RemotePrefillParams
logger = init_logger(__name__)
@ -112,6 +113,7 @@ class OpenAIServingChat(OpenAIServing):
self,
request: ChatCompletionRequest,
raw_request: Optional[Request] = None,
remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
ErrorResponse]:
"""
@ -243,6 +245,7 @@ class OpenAIServingChat(OpenAIServing):
trace_headers=trace_headers,
prompt_adapter_request=prompt_adapter_request,
priority=request.priority,
remote_prefill_params=remote_prefill_params,
)
generators.append(generator)

View File

@ -87,6 +87,10 @@ if TYPE_CHECKING:
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
VLLM_RAY_BUNDLE_INDICES: str = ""
VLLM_KV_CAPI_PATH: Optional[str] = None
VLLM_KV_NAMESPACE: Optional[str] = None
VLLM_KV_COMPONENT: Optional[str] = None
VLLM_WORKER_ID: Optional[int] = None
def get_default_cache_root():
@ -572,6 +576,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# models the alignment is already naturally aligned to 256 bytes.
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
# Path to the C API Library
"VLLM_KV_CAPI_PATH":
lambda: os.environ.get("VLLM_KV_CAPI_PATH", None),
# Identifiers to publish KV related information
"VLLM_KV_NAMESPACE":
lambda: os.environ.get("VLLM_KV_NAMESPACE", None),
"VLLM_KV_COMPONENT":
lambda: os.environ.get("VLLM_KV_COMPONENT", None),
# Worker ID used for identifying workers in distributed settings
"VLLM_WORKER_ID":
lambda: int(os.getenv("VLLM_WORKER_ID", "0"))
if "VLLM_WORKER_ID" in os.environ else None,
}
# end-env-vars-definition

View File

@ -585,6 +585,8 @@ class DeepseekV2Model(nn.Module):
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

View File

@ -6,16 +6,16 @@ from typing import Dict, Generic, List, MutableSequence, Optional
from typing import Sequence as GenericSequence
from typing import Union
import msgspec
import torch
from typing_extensions import TypeVar, deprecated
from vllm.lora.request import LoRARequest
from vllm.multimodal.inputs import MultiModalPlaceholderDict
from vllm.sampling_params import RequestOutputKind
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
SequenceGroup, SequenceGroupBase, SequenceStatus)
@dataclass
class CompletionOutput:
"""The output data of one completion output of a request.

67
vllm/remote_prefill.py Normal file
View File

@ -0,0 +1,67 @@
from dataclasses import dataclass
from typing import Callable, Optional, List
from enum import Enum
import msgspec
from vllm.sampling_params import SamplingParams
class RemotePrefillRequest(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
# required for @cached_property.
dict=True):
"""The request data of one remote prefill output of a request.
Args:
engine_id: The unique ID of the engine.
request_id: The unique ID of the request.
prompt_token_ids: The token IDs of the prompt.
sampling_params: The sampling parameters.
block_ids: The block IDs of the request.
computed_block_ids: The computed block IDs of the request.
"""
engine_id: str
request_id: str
prompt_token_ids: List[int]
sampling_params: SamplingParams
block_ids: List[int]
computed_block_ids: List[int]
class MemoryOpType(str, Enum):
WRITE = "WRITE"
READ = "READ"
class MemoryTransferRequest(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True): # type: ignore[call-arg]
"""The request data of one memory transfer output of a request.
Args:
request_id: The unique ID of the request.
"""
request_id: str
local_block_ids: List[int]
staging_block_ids: List[int]
remote_block_ids: List[int]
remote_engine_id: str
notify_msg: str
op_type: MemoryOpType
RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None]
@dataclass
class RemotePrefillParams:
"""Remote prefill parameters for text generation."""
is_remote_prefill: bool = False
is_remote_decode: bool = False
decode_block_ids: Optional[List[int]] = None
decode_computed_block_ids: Optional[List[int]] = None
decode_engine_id: Optional[str] = None
remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None

View File

@ -83,7 +83,7 @@ class RequestOutputKind(Enum):
DELTA = 1
# Do not return intermediate RequestOuputs
FINAL_ONLY = 2
class SamplingParams(
msgspec.Struct,

View File

@ -20,6 +20,7 @@ from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import RequestOutputKind, SamplingParams
from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
@ -59,13 +60,14 @@ class SequenceStatus(enum.IntEnum):
"""Status of a sequence."""
WAITING = 0
RUNNING = 1
SWAPPED = 2
# Note: anything after SWAPPED (2) will be considered
REMOTE_PREFILLING = 2
SWAPPED = 3
# Note: anything after SWAPPED (3) will be considered
# as a finished status.
FINISHED_STOPPED = 3
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
FINISHED_STOPPED = 4
FINISHED_LENGTH_CAPPED = 5
FINISHED_ABORTED = 6
FINISHED_IGNORED = 7
@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
@ -409,6 +411,7 @@ class Sequence:
eos_token_id: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> None:
self.seq_id = seq_id
self.inputs = SingletonInputsAdapter(inputs)
@ -416,7 +419,7 @@ class Sequence:
self.eos_token_id = eos_token_id
self.lora_request = lora_request
self.prompt_adapter_request = prompt_adapter_request
self.remote_prefill_params = remote_prefill_params
self.data = SequenceData.from_seqs(self.prompt_token_ids)
self.output_logprobs: SampleLogprobs = []
self.output_text = ""
@ -639,6 +642,7 @@ class SequenceGroup:
trace_headers: OpenTelemetry trace headers.
prompt_adapter_request: Prompt Adapter request.
priority: User-defined priority of the request.
remote_prefill_params: Remote prefill parameters.
"""
def __init__(
@ -654,6 +658,7 @@ class SequenceGroup:
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
priority: int = 0,
remote_prefill_params: Optional[RemotePrefillParams] = None,
) -> None:
self.request_id = request_id
self.seqs = seqs
@ -678,7 +683,7 @@ class SequenceGroup:
self.encoder_seq = encoder_seq
self.trace_headers = trace_headers
self.priority = priority
self.remote_prefill_params = remote_prefill_params
self.cached_request_output = None
@property
@ -927,6 +932,9 @@ class SequenceGroupMetadata(
query tokens for prefill, we don't need sampling.
token_chunk_size: The number of tokens to be processed (per sequence).
None if chunking is not required.
do_remote_prefill: True if remote prefill is required.
do_remote_decode: True if remote decode is required.
decode_memory_desc: The memory descriptor for the decoder blocks.
lora_request: LoRA request.
computed_block_nums: The block numbers that are already computed,
used in prefix caching.
@ -966,6 +974,9 @@ class SequenceGroupMetadata(
cross_block_table: Optional[List[int]] = None
prompt_adapter_request: Optional[PromptAdapterRequest] = None
token_chunk_size: Optional[int] = None
do_remote_prefill: bool = False
do_remote_decode: bool = False
decode_memory_desc: Optional[bytes] = None
### Stateful fields that are lazily defined. ###
# The number of speculative tokens adopted in this request.
@ -1310,6 +1321,8 @@ class ExecuteModelRequest(
last_sampled_token_ids: Optional[torch.Tensor] = None
# Async callback
async_callback: Optional[Callable] = None
# The memory transfer requests.
memory_transfer_requests: Optional[List[MemoryTransferRequest]] = None
@property
def is_first_multi_step(self) -> bool:

View File

@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if self.vllm_config.kv_transfer_config is None:
return False
if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector":
return False
prefill_meta = model_input.attn_metadata.prefill_metadata
@ -1849,6 +1852,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
if self.vllm_config.kv_transfer_config is None:
return False
if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector":
return False
prefill_meta = model_input.attn_metadata.prefill_metadata

View File

@ -2,7 +2,7 @@
"""A GPU worker class."""
import gc
import os
from typing import Dict, List, Optional, Set, Tuple, Type, Union
from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any
import torch
import torch.distributed
@ -31,6 +31,9 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
from vllm.worker.pooling_model_runner import PoolingModelRunner
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
WorkerInput)
from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
from vllm.remote_prefill import MemoryOpType
logger = init_logger(__name__)
@ -306,6 +309,46 @@ class Worker(LocalOrDistributedWorkerBase):
self._init_cache_engine()
self._warm_up_model()
def initialize_nixl(self, engine_id: str) -> List[bytes]:
# TODO ptarasiewicz nixl can also support DRAM
assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
self.nixl_connector = DynamoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
return self.nixl_connector.agent_name
def get_nixl_agent_metadata(self) -> bytes:
assert self.nixl_connector is not None, "Nixl connector is not initialized"
return self.nixl_connector.get_agent_metadata()
def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str:
assert self.nixl_connector is not None, "Nixl connector is not initialized"
agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank?
return agent_name
def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
assert self.nixl_connector is not None, "Nixl connector is not initialized"
return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id]
def _read_blocks(self, worker_input: WorkerInput) -> None:
for i, op_type in enumerate(worker_input.op_type):
if op_type == MemoryOpType.READ:
self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i])
def _write_blocks(self, worker_input: WorkerInput) -> None:
if not self.is_driver_worker:
torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize
for i, op_type in enumerate(worker_input.op_type):
if op_type == MemoryOpType.WRITE:
self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i])
def shutdown_nixl(self) -> None:
assert self.nixl_connector is not None, "Nixl connector is not initialized"
self.nixl_connector.shutdown()
def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = [
@ -367,6 +410,8 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
device=self.device,
dtype=torch.int64).view(-1, 2)
mem_transfer_reqs = execute_model_req.memory_transfer_requests or []
return WorkerInput(
num_seq_groups=num_seq_groups,
@ -375,6 +420,12 @@ class Worker(LocalOrDistributedWorkerBase):
blocks_to_copy=blocks_to_copy,
virtual_engine=virtual_engine,
num_steps=num_steps,
local_block_ids=[r.local_block_ids for r in mem_transfer_reqs],
staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs],
remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs],
notify_msg=[r.notify_msg for r in mem_transfer_reqs],
op_type=[r.op_type for r in mem_transfer_reqs],
)
@torch.inference_mode()

View File

@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
import cloudpickle
import torch
import torch.nn as nn
from collections import defaultdict
from vllm.config import (ObservabilityConfig, VllmConfig,
set_current_vllm_config)
@ -23,6 +24,9 @@ from vllm.utils import (enable_trace_function_call_for_thread,
from vllm.worker.model_runner_base import (BroadcastableModelInput,
ModelRunnerBase,
ModelRunnerInputBase)
from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
from vllm.remote_prefill import MemoryOpType
logger = init_logger(__name__)
@ -53,6 +57,8 @@ class WorkerBase(ABC):
from vllm.platforms import current_platform
self.current_platform = current_platform
self.nixl_connector: Optional[DynamoNixlConnector] = None
@abstractmethod
def init_device(self) -> None:
"""Initialize device state, such as loading the model or other on-device
@ -216,6 +222,13 @@ class WorkerInput:
virtual_engine: int = 0
num_steps: int = 1
local_block_ids: Optional[List[List[int]]] = None
staging_block_ids: Optional[List[List[int]]] = None
remote_block_ids: Optional[List[List[int]]] = None
remote_engine_id: Optional[List[str]] = None
notify_msg: Optional[List[str]] = None
op_type: Optional[List[MemoryOpType]] = None
@classmethod
def from_broadcasted_tensor_dict(
cls: Type["WorkerInput"],
@ -232,6 +245,12 @@ class WorkerInput:
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
virtual_engine=tensor_dict["virtual_engine"],
num_steps=tensor_dict.pop("num_steps"),
local_block_ids=tensor_dict.pop("local_block_ids"),
staging_block_ids=tensor_dict.pop("staging_block_ids"),
remote_block_ids=tensor_dict.pop("remote_block_ids"),
remote_engine_id=tensor_dict.pop("remote_engine_id"),
notify_msg=tensor_dict.pop("notify_msg"),
op_type=tensor_dict.pop("op_type"),
)
def as_broadcastable_tensor_dict(
@ -246,6 +265,12 @@ class WorkerInput:
"blocks_to_copy": self.blocks_to_copy,
"virtual_engine": self.virtual_engine,
"num_steps": self.num_steps,
"local_block_ids": self.local_block_ids,
"staging_block_ids": self.staging_block_ids,
"remote_block_ids": self.remote_block_ids,
"remote_engine_id": self.remote_engine_id,
"notify_msg": self.notify_msg,
"op_type": self.op_type,
}
return tensor_dict
@ -316,13 +341,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return None
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
if worker_input.num_seq_groups > 0:
model_input = (
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
broadcast_data))
kwargs = extract_previous_hidden_states(broadcast_data)
kwargs = extract_previous_hidden_states(broadcast_data)
return model_input, worker_input, kwargs
return model_input, worker_input, kwargs
else:
return None, worker_input, {}
def _get_driver_input_and_broadcast(
self, execute_model_req: ExecuteModelRequest
@ -396,49 +424,88 @@ class LocalOrDistributedWorkerBase(WorkerBase):
self.execute_worker(worker_input)
# If there is no input, we don't need to execute the model.
if worker_input.num_seq_groups == 0:
return []
if worker_input.num_seq_groups > 0:
intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
self._read_blocks(worker_input)
intermediate_tensors = None
orig_model_execute_time = 0.0
if not get_pp_group().is_first_rank:
intermediate_tensors = IntermediateTensors(
get_pp_group().recv_tensor_dict(
all_gather_group=get_tp_group()))
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
orig_model_execute_time = intermediate_tensors.tensors.get(
"model_execute_time", torch.tensor(0)).item()
and self.observability_config.collect_model_execute_time
and output is not None):
for o in output:
o.model_execute_time = (orig_model_execute_time +
model_execute_time)
output = self.model_runner.execute_model(
model_input=model_input,
kv_caches=self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None,
intermediate_tensors=intermediate_tensors,
num_steps=num_steps,
**kwargs,
)
self._write_blocks(worker_input)
model_execute_time = time.perf_counter() - start_time
if not get_pp_group().is_last_rank:
# output is IntermediateTensors
assert isinstance(output, IntermediateTensors)
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time):
output.tensors["model_execute_time"] = torch.tensor(
model_execute_time + orig_model_execute_time)
get_pp_group().send_tensor_dict(output.tensors,
all_gather_group=get_tp_group())
return [None]
if (self.observability_config is not None
and self.observability_config.collect_model_execute_time
and output is not None):
for o in output:
o.model_execute_time = (orig_model_execute_time +
model_execute_time)
else:
output = []
# collect kv transfer notifications from non driver workers
if self.nixl_connector is not None:
new_notifs = self.nixl_connector.get_new_notifs()
rank = get_tp_group().rank
all_new_notifs = [new_notifs]
if rank > 0:
get_tp_group().send_object(new_notifs, dst=0)
else:
for i in range(1, get_tp_group().world_size):
all_new_notifs.append(get_tp_group().recv_object(src=i))
request_notif_counter = defaultdict(int)
for notifs in all_new_notifs:
for req_ids in notifs.values():
for req_id in req_ids:
request_notif_counter[req_id] += 1
if request_notif_counter:
logger.debug("Request notif counter: %s", request_notif_counter)
request_done_counter = defaultdict(int)
for req_id in self.nixl_connector.get_done_tranfers():
request_done_counter[req_id] += 1
else:
request_notif_counter = {}
request_done_counter = {}
# output is List[SamplerOutput]
return output
return output, request_notif_counter, request_done_counter
def _read_blocks(self, worker_input: WorkerInput) -> None:
pass
def _write_blocks(self, worker_input: WorkerInput) -> None:
pass
def _execute_model_spmd(
self,