mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
2 Commits
v0.10.0rc1
...
dynamo-pat
Author | SHA1 | Date | |
---|---|---|---|
6de0982dd0 | |||
45fa7f9b8e |
@ -2620,6 +2620,9 @@ class KVTransferConfig(BaseModel):
|
||||
# The KV connector for vLLM to transmit KV caches between vLLM instances.
|
||||
kv_connector: Optional[str] = None
|
||||
|
||||
# Whether to use NIXL prepped xfer for KV cache transfer.
|
||||
use_prepped_xfer: bool = True
|
||||
|
||||
# The device used by kv connector to buffer the KV cache.
|
||||
# Currently only support 'cuda'.
|
||||
kv_buffer_device: Optional[str] = "cuda"
|
||||
@ -2629,7 +2632,7 @@ class KVTransferConfig(BaseModel):
|
||||
kv_buffer_size: float = 1e9
|
||||
|
||||
# Whether this vLLM instance produces, consumes KV cache, or both. Choices
|
||||
# are 'kv_producer', 'kv_consumer', and 'both'.
|
||||
# are 'kv_producer', 'kv_consumer', and 'kv_both'.
|
||||
kv_role: Optional[str] = None
|
||||
|
||||
# The rank of this vLLM instance in the KV cache transfer. Typical value:
|
||||
@ -2647,6 +2650,14 @@ class KVTransferConfig(BaseModel):
|
||||
# The KV connector port, used to build distributed connection
|
||||
kv_port: int = 14579
|
||||
|
||||
|
||||
# This does not need to be set by the user. It is set by the connector.
|
||||
kv_producers_parallel_size: Optional[int] = None
|
||||
kv_producers_tensor_parallel_size: Optional[int] = None
|
||||
kv_producers_pipeline_parallel_size: Optional[int] = None
|
||||
kv_consumers_tensor_parallel_size: Optional[int] = None
|
||||
kv_consumers_pipeline_parallel_size: Optional[int] = None
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -2680,11 +2691,16 @@ class KVTransferConfig(BaseModel):
|
||||
f"Supported roles are `kv_producer`, `kv_consumer`, "
|
||||
f"and `kv_both`")
|
||||
|
||||
if self.kv_connector is not None and self.kv_role is None:
|
||||
if self.kv_connector is not None and self.kv_connector != "DynamoNixlConnector" and self.kv_role is None:
|
||||
raise ValueError("Please specify kv_disagg_role when kv_connector "
|
||||
"is set, supported roles are `kv_producer`, "
|
||||
"`kv_consumer`, and `kv_both`")
|
||||
|
||||
if self.use_prepped_xfer is False:
|
||||
logger.warning("`use_prepped_xfer` parameter is deprecated. All transfers will be done using prepped xfer.")
|
||||
self.use_prepped_xfer = True
|
||||
|
||||
|
||||
@property
|
||||
def is_kv_transfer_instance(self) -> bool:
|
||||
return self.kv_connector is not None and \
|
||||
@ -2694,6 +2710,8 @@ class KVTransferConfig(BaseModel):
|
||||
def need_kv_parallel_group(self) -> bool:
|
||||
# for those database-based connector, vLLM does not need to create
|
||||
# parallel group, and in that case the kv parallel size will be 1.
|
||||
if self.kv_connector == "DynamoNixlConnector":
|
||||
return False
|
||||
return self.kv_connector is not None and self.kv_parallel_size > 1
|
||||
|
||||
@property
|
||||
@ -2706,6 +2724,18 @@ class KVTransferConfig(BaseModel):
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_consumer", "kv_both"]
|
||||
|
||||
@property
|
||||
def tensor_parallel_multiplier(self) -> int:
|
||||
return self.kv_consumers_tensor_parallel_size // self.kv_producers_tensor_parallel_size
|
||||
|
||||
@property
|
||||
def kv_consumers_parallel_size(self) -> int:
|
||||
return self.kv_parallel_size - self.kv_producers_parallel_size
|
||||
|
||||
@property
|
||||
def kv_world_size(self) -> int:
|
||||
return self.kv_producers_parallel_size + self.kv_consumers_parallel_size * self.tensor_parallel_multiplier
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
|
@ -6,6 +6,7 @@ from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId,
|
||||
DeviceAwareBlockAllocator)
|
||||
from vllm.core.block.naive_block import NaiveBlock, NaiveBlockAllocator
|
||||
from vllm.core.block.prefix_caching_block import PrefixCachingBlockAllocator
|
||||
from vllm.core.event_manager import KVCacheEventManager
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -28,6 +29,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
block_size: int,
|
||||
event_manager: Optional[KVCacheEventManager] = None,
|
||||
) -> DeviceAwareBlockAllocator:
|
||||
"""Creates a CpuGpuBlockAllocator instance with the specified
|
||||
configuration.
|
||||
@ -64,6 +66,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
cpu_block_ids = block_ids[num_gpu_blocks:]
|
||||
|
||||
if allocator_type == "naive":
|
||||
assert event_manager is None, "Event API not supported with naive allocator."
|
||||
gpu_allocator: BlockAllocator = NaiveBlockAllocator(
|
||||
create_block=NaiveBlock, # type: ignore
|
||||
num_blocks=num_gpu_blocks,
|
||||
@ -82,12 +85,14 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
num_blocks=num_gpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=gpu_block_ids,
|
||||
event_manager=event_manager,
|
||||
)
|
||||
|
||||
cpu_allocator = PrefixCachingBlockAllocator(
|
||||
num_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
block_ids=cpu_block_ids,
|
||||
event_manager=event_manager,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown allocator type {allocator_type=}")
|
||||
@ -95,10 +100,12 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
return CpuGpuBlockAllocator(
|
||||
cpu_block_allocator=cpu_allocator,
|
||||
gpu_block_allocator=gpu_allocator,
|
||||
event_manager=event_manager,
|
||||
)
|
||||
|
||||
def __init__(self, cpu_block_allocator: BlockAllocator,
|
||||
gpu_block_allocator: BlockAllocator):
|
||||
gpu_block_allocator: BlockAllocator,
|
||||
event_manager: Optional[KVCacheEventManager] = None,):
|
||||
assert not (
|
||||
cpu_block_allocator.all_block_ids
|
||||
& gpu_block_allocator.all_block_ids
|
||||
@ -108,6 +115,7 @@ class CpuGpuBlockAllocator(DeviceAwareBlockAllocator):
|
||||
Device.CPU: cpu_block_allocator,
|
||||
Device.GPU: gpu_block_allocator,
|
||||
}
|
||||
self.event_manager = event_manager
|
||||
|
||||
self._swap_mapping: Dict[int, int] = {}
|
||||
self._null_block: Optional[Block] = None
|
||||
|
@ -2,7 +2,7 @@
|
||||
|
||||
from collections import deque
|
||||
from typing import Deque, FrozenSet, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import heapq
|
||||
from vllm.core.block.common import (BlockPool, CopyOnWriteTracker, RefCounter,
|
||||
get_all_blocks_recursively)
|
||||
from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device
|
||||
@ -38,7 +38,7 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
|
||||
self._free_block_indices: Deque[BlockId] = deque(block_ids)
|
||||
self._free_block_indices: List[BlockId] = list(block_ids)
|
||||
self._all_block_indices = frozenset(block_ids)
|
||||
assert len(self._all_block_indices) == num_blocks
|
||||
|
||||
@ -134,7 +134,8 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
if not self._free_block_indices:
|
||||
raise BlockAllocator.NoFreeBlocksError()
|
||||
|
||||
block_id = self._free_block_indices.popleft()
|
||||
block_id = heapq.heappop(self._free_block_indices)
|
||||
# TODO: figure out why sometime block_id is None
|
||||
self._refcounter.incr(block_id)
|
||||
return block_id
|
||||
|
||||
@ -148,7 +149,7 @@ class NaiveBlockAllocator(BlockAllocator):
|
||||
|
||||
refcount = self._refcounter.decr(block_id)
|
||||
if refcount == 0:
|
||||
self._free_block_indices.appendleft(block_id)
|
||||
heapq.heappush(self._free_block_indices, block_id)
|
||||
|
||||
def free(self, block: Block, keep_block_object: bool = False) -> None:
|
||||
# Release the physical block id
|
||||
|
@ -4,7 +4,7 @@ import sys
|
||||
from bisect import bisect_left
|
||||
from os.path import commonprefix
|
||||
from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set,
|
||||
Tuple)
|
||||
Tuple, TYPE_CHECKING)
|
||||
|
||||
from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker,
|
||||
get_all_blocks_recursively)
|
||||
@ -23,6 +23,9 @@ PrefixHash = int
|
||||
# then we know this block hasn't been accessed yet.
|
||||
_DEFAULT_LAST_ACCESSED_TIME = -1
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.core.event_manager import KVCacheEventManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@ -80,6 +83,7 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
block_size: int,
|
||||
block_ids: Optional[Iterable[int]] = None,
|
||||
eviction_policy: EvictionPolicy = EvictionPolicy.LRU,
|
||||
event_manager: Optional["KVCacheEventManager"] = None,
|
||||
):
|
||||
if block_ids is None:
|
||||
block_ids = range(num_blocks)
|
||||
@ -131,6 +135,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
|
||||
self.metric_data = CacheMetricData()
|
||||
|
||||
self.event_manager = event_manager
|
||||
|
||||
# Implements Block.Factory.
|
||||
def _create_block(
|
||||
self,
|
||||
prev_block: Optional[Block],
|
||||
@ -337,6 +344,9 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
assert self._refcounter.get(_block_id) == 0
|
||||
assert _block_id == block_id
|
||||
|
||||
if self.event_manager:
|
||||
self.event_manager.enqueue_removed_event(content_hash_to_evict)
|
||||
|
||||
self._cached_blocks.pop(content_hash_to_evict)
|
||||
|
||||
self._refcounter.incr(block_id)
|
||||
@ -513,6 +523,10 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
# Mark this block as touched so that it can be marked as
|
||||
# computed after the entire batch of sequences are scheduled.
|
||||
self._touched_blocks.add(block.block_id)
|
||||
|
||||
if self.event_manager:
|
||||
self.event_manager.enqueue_stored_event(block.prev_block, block)
|
||||
|
||||
return block.block_id
|
||||
|
||||
# Reuse the cached content hash
|
||||
@ -579,9 +593,11 @@ class PrefixCachingBlockAllocator(BlockAllocator):
|
||||
|
||||
def mark_blocks_as_computed(self, block_ids: List[int]) -> None:
|
||||
# Mark all touched blocks as computed.
|
||||
for block_id in self._touched_blocks:
|
||||
self._block_tracker[block_id].computed = True
|
||||
self._touched_blocks.clear()
|
||||
for block_id in block_ids:
|
||||
if block_id in self._touched_blocks:
|
||||
logger.debug("Mark block as computed: %s", block_id)
|
||||
self._block_tracker[block_id].computed = True
|
||||
self._touched_blocks.remove(block_id)
|
||||
|
||||
def _track_block_id(self, block_id: Optional[BlockId],
|
||||
computed: bool) -> None:
|
||||
|
@ -10,7 +10,10 @@ from vllm.core.block.interfaces import Block
|
||||
from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker,
|
||||
LastAccessBlocksTracker)
|
||||
from vllm.core.block.utils import check_no_caching_or_swa_for_blockmgr_encdec
|
||||
from vllm.core.event_manager import KVCacheEventManager
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.envs import (VLLM_KV_CAPI_PATH, VLLM_KV_COMPONENT, VLLM_KV_NAMESPACE,
|
||||
VLLM_WORKER_ID)
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
@ -60,6 +63,7 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
block_size: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
@ -91,11 +95,29 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
|
||||
self.watermark_blocks = int(watermark * num_gpu_blocks)
|
||||
|
||||
kv_event_manager_params = [
|
||||
VLLM_WORKER_ID, VLLM_KV_CAPI_PATH, VLLM_KV_NAMESPACE,
|
||||
VLLM_KV_COMPONENT
|
||||
]
|
||||
set_kv_event_manager_params = len(
|
||||
[param for param in kv_event_manager_params if param is not None])
|
||||
|
||||
if set_kv_event_manager_params == len(kv_event_manager_params):
|
||||
self.event_manager = KVCacheEventManager(
|
||||
namespace=VLLM_KV_NAMESPACE,
|
||||
component=VLLM_KV_COMPONENT,
|
||||
worker_id=VLLM_WORKER_ID,
|
||||
lib_path=VLLM_KV_CAPI_PATH,
|
||||
kv_block_size=block_size)
|
||||
else:
|
||||
self.event_manager = None
|
||||
|
||||
self.block_allocator = CpuGpuBlockAllocator.create(
|
||||
allocator_type="prefix_caching" if enable_caching else "naive",
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
block_size=block_size,
|
||||
event_manager=self.event_manager,
|
||||
)
|
||||
|
||||
self.block_tables: Dict[SeqId, BlockTable] = {}
|
||||
@ -108,7 +130,8 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
|
||||
def can_allocate(self,
|
||||
seq_group: SequenceGroup,
|
||||
num_lookahead_slots: int = 0) -> AllocStatus:
|
||||
num_lookahead_slots: int = 0,
|
||||
is_remote_decode: bool = False) -> AllocStatus:
|
||||
# FIXME(woosuk): Here we assume that all sequences in the group share
|
||||
# the same prompt. This may not be true for preempted sequences.
|
||||
|
||||
@ -121,6 +144,10 @@ class SelfAttnBlockSpaceManager(BlockSpaceManager):
|
||||
num_lookahead_slots=num_lookahead_slots,
|
||||
)
|
||||
|
||||
# if remote decode, we need to allocate twice as many blocks for staging
|
||||
if is_remote_decode:
|
||||
num_required_blocks *= 2
|
||||
|
||||
if seq_group.is_encoder_decoder():
|
||||
encoder_seq = seq_group.get_encoder_seq()
|
||||
assert encoder_seq is not None
|
||||
|
108
vllm/core/event_manager.py
Normal file
108
vllm/core/event_manager.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import ctypes
|
||||
import logging
|
||||
import uuid
|
||||
from ctypes import c_char_p, c_size_t, c_uint32, c_void_p, c_int64
|
||||
from typing import Optional
|
||||
|
||||
from vllm.core.block.prefix_caching_block import PrefixCachingBlock, PrefixHash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamoResult:
|
||||
OK = 0
|
||||
ERR = 1
|
||||
|
||||
|
||||
class KVCacheEventManager:
|
||||
|
||||
def __init__(self, namespace: str, component: str, worker_id: int,
|
||||
lib_path: str, kv_block_size: int):
|
||||
self.lib = None
|
||||
|
||||
try:
|
||||
self.lib = ctypes.CDLL(lib_path)
|
||||
self.lib.dynamo_llm_init.argtypes = [
|
||||
c_char_p,
|
||||
c_char_p,
|
||||
c_int64,
|
||||
c_uint32,
|
||||
]
|
||||
self.lib.dynamo_llm_init.restype = c_uint32
|
||||
|
||||
result = self.lib.dynamo_llm_init(
|
||||
namespace.encode(), component.encode(), worker_id, kv_block_size
|
||||
)
|
||||
if result == DynamoResult.OK:
|
||||
logger.info(
|
||||
"KVCacheEventManager initialized successfully. Ready to publish KV Cache Events"
|
||||
)
|
||||
else:
|
||||
logger.info("KVCacheEventManager initialization failed!")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to load {lib_path}")
|
||||
raise e
|
||||
|
||||
self.lib.dynamo_kv_event_publish_stored.argtypes = [
|
||||
ctypes.c_uint64, # event_id
|
||||
ctypes.POINTER(ctypes.c_uint32), # token_ids
|
||||
ctypes.POINTER(ctypes.c_size_t), # num_block_tokens
|
||||
ctypes.POINTER(ctypes.c_uint64), # block_ids
|
||||
ctypes.c_size_t, # num_blocks
|
||||
ctypes.POINTER(ctypes.c_uint64), # parent_hash
|
||||
ctypes.c_uint64, # lora_id
|
||||
]
|
||||
self.lib.dynamo_kv_event_publish_stored.restype = ctypes.c_uint32 # dynamo_llm_result_t
|
||||
|
||||
self.lib.dynamo_kv_event_publish_removed.argtypes = [
|
||||
ctypes.c_uint64, # event_id
|
||||
ctypes.POINTER(ctypes.c_uint64), # block_ids
|
||||
ctypes.c_size_t, # num_blocks
|
||||
]
|
||||
self.lib.dynamo_kv_event_publish_removed.restype = ctypes.c_uint32 # dynamo_llm_result_t
|
||||
|
||||
self.event_id_counter = 0
|
||||
|
||||
def enqueue_stored_event(self, parent: Optional[PrefixCachingBlock],
|
||||
block: PrefixCachingBlock):
|
||||
token_ids_arr = (ctypes.c_uint32 *
|
||||
len(block.token_ids))(*block.token_ids)
|
||||
num_block_tokens = (ctypes.c_size_t * 1)(len(block.token_ids))
|
||||
block_hash = (ctypes.c_uint64 * 1)(block.content_hash)
|
||||
parent_hash = ((ctypes.c_uint64 * 1)(parent.content_hash)
|
||||
if parent is not None else None)
|
||||
|
||||
# Publish the event
|
||||
result = self.lib.dynamo_kv_event_publish_stored(
|
||||
self.event_id_counter, # uint64_t event_id
|
||||
token_ids_arr, # const uint32_t *token_ids
|
||||
num_block_tokens, # const uintptr_t *num_block_tokens
|
||||
block_hash, # const uint64_t *block_ids
|
||||
1, # uintptr_t num_blocks
|
||||
parent_hash, # const uint64_t *parent_hash
|
||||
0, # uint64_t lora_id
|
||||
)
|
||||
|
||||
if result == DynamoResult.OK:
|
||||
logger.debug(f"Store - Published KV Event: {block.content_hash}")
|
||||
else:
|
||||
logger.debug(
|
||||
f"Store - Failed to Publish KV Event: {block.content_hash}")
|
||||
|
||||
self.event_id_counter += 1
|
||||
|
||||
def enqueue_removed_event(self, block_hash: PrefixHash):
|
||||
result = self.lib.dynamo_kv_event_publish_removed(
|
||||
self.event_id_counter,
|
||||
(ctypes.c_uint64 * 1)(block_hash),
|
||||
1,
|
||||
)
|
||||
|
||||
if result == DynamoResult.OK:
|
||||
logger.debug(f"Remove - Published KV Event: {block_hash}")
|
||||
else:
|
||||
logger.debug(f"Remove - Failed to Publish KV Event: {block_hash}")
|
||||
|
||||
self.event_id_counter += 1
|
@ -4,22 +4,22 @@ import enum
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
import copy
|
||||
from collections import deque
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Deque, Dict, Iterable, List, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Tuple, Union
|
||||
from typing import Set, Tuple, Union, Any
|
||||
|
||||
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.config import ModelConfig, CacheConfig, LoRAConfig, SchedulerConfig
|
||||
from vllm.core.interfaces import AllocStatus, BlockSpaceManager
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceGroupMetadataDelta,
|
||||
SequenceStatus)
|
||||
SequenceStatus, SequenceStage)
|
||||
from vllm.utils import Device, PyObjectCache
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Test-only. If configured, decode is preempted with
|
||||
@ -285,6 +285,7 @@ class SchedulerPrefillOutputs:
|
||||
# Ignored sequence groups.
|
||||
ignored_seq_groups: List[SequenceGroup]
|
||||
num_lookahead_slots: int
|
||||
num_remote_prefill_groups: int
|
||||
|
||||
@classmethod
|
||||
def create_empty(cls) -> "SchedulerPrefillOutputs":
|
||||
@ -292,6 +293,7 @@ class SchedulerPrefillOutputs:
|
||||
seq_groups=[],
|
||||
ignored_seq_groups=[],
|
||||
num_lookahead_slots=0,
|
||||
num_remote_prefill_groups=0,
|
||||
)
|
||||
|
||||
|
||||
@ -325,12 +327,14 @@ class Scheduler:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
pipeline_parallel_size: int = 1,
|
||||
output_proc_callback: Optional[Callable] = None,
|
||||
) -> None:
|
||||
self.model_config = model_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
# Note for LoRA scheduling: the current policy is extremely
|
||||
@ -356,6 +360,7 @@ class Scheduler:
|
||||
|
||||
# Create the block space manager.
|
||||
self.block_manager = BlockSpaceManagerImpl(
|
||||
model_name=self.model_config.served_model_name,
|
||||
block_size=self.cache_config.block_size,
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
@ -371,6 +376,16 @@ class Scheduler:
|
||||
# Sequence groups in the SWAPPED state.
|
||||
# Contain decode requests that are swapped out.
|
||||
self.swapped: Deque[SequenceGroup] = deque()
|
||||
|
||||
# Sequence groups in the REMOTE_PREFILLING state.
|
||||
# Contain requests that are being prefilled by a remote worker.
|
||||
self.remote_prefilling: Deque[SequenceGroup] = deque()
|
||||
# Contain requests that are being prefilled by a local worker.
|
||||
self.prefill_sending: Deque[SequenceGroup] = deque()
|
||||
|
||||
self._remote_prefill_outputs: Dict[str, int] = {}
|
||||
|
||||
|
||||
# Sequence groups finished requests ids since last step iteration.
|
||||
# It lets the model know that any state associated with these requests
|
||||
# can and must be released after the current step.
|
||||
@ -501,7 +516,7 @@ class Scheduler:
|
||||
|
||||
def has_unfinished_seqs(self) -> bool:
|
||||
return len(self.waiting) != 0 or len(self.running) != 0 or len(
|
||||
self.swapped) != 0
|
||||
self.swapped) != 0 or len(self.remote_prefilling) != 0 or len(self.prefill_sending) != 0
|
||||
|
||||
def get_prefix_cache_hit_rate(self, device: Device) -> float:
|
||||
return self.block_manager.get_prefix_cache_hit_rate(device)
|
||||
@ -523,6 +538,8 @@ class Scheduler:
|
||||
budget: SchedulingBudget,
|
||||
curr_loras: Optional[Set[int]],
|
||||
enable_chunking: bool = False,
|
||||
finished_prefills: Optional[Set[str]] = None,
|
||||
finished_transfers: Optional[Set[str]] = None
|
||||
) -> SchedulerRunningOutputs:
|
||||
"""Schedule sequence groups that are running.
|
||||
|
||||
@ -537,6 +554,8 @@ class Scheduler:
|
||||
chunked number of tokens are scheduled if
|
||||
`budget.num_batched_tokens` has not enough capacity to schedule
|
||||
all tokens.
|
||||
finished_remote_prefill_request_ids: Set of request ids of remote
|
||||
prefills that have finished.
|
||||
|
||||
Returns:
|
||||
SchedulerRunningOutputs.
|
||||
@ -566,6 +585,38 @@ class Scheduler:
|
||||
preempted: List[SequenceGroup] = ret.preempted
|
||||
swapped_out: List[SequenceGroup] = ret.swapped_out
|
||||
|
||||
remote_prefilling_queue = self.remote_prefilling
|
||||
leftover_remote_prefilling_sequences: Deque[SequenceGroup] = deque()
|
||||
while remote_prefilling_queue:
|
||||
seq_group = remote_prefilling_queue.popleft()
|
||||
if seq_group.request_id not in finished_prefills:
|
||||
leftover_remote_prefilling_sequences.append(seq_group)
|
||||
continue
|
||||
|
||||
else:
|
||||
finished_prefills.remove(seq_group.request_id)
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
# we computed all but the last token in prefill, we need to decode the first token on decode
|
||||
seq_group.update_num_computed_tokens(seq.get_len() - 1)
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
seq.data._stage = SequenceStage.DECODE
|
||||
self.running.appendleft(seq_group)
|
||||
remote_prefilling_queue.extendleft(leftover_remote_prefilling_sequences)
|
||||
|
||||
remote_transfers_queue = self.prefill_sending
|
||||
leftover_remote_transfers_sequences: Deque[SequenceGroup] = deque()
|
||||
while remote_transfers_queue:
|
||||
seq_group = remote_transfers_queue.popleft()
|
||||
if seq_group.request_id not in finished_transfers:
|
||||
leftover_remote_transfers_sequences.append(seq_group)
|
||||
else:
|
||||
finished_transfers.remove(seq_group.request_id)
|
||||
assert len(seq_group.seqs) == 1
|
||||
seq = seq_group.seqs[0]
|
||||
self.free_seq(seq)
|
||||
remote_transfers_queue.extendleft(leftover_remote_transfers_sequences)
|
||||
|
||||
running_queue = self.running
|
||||
assert len(self._async_stopped) == 0
|
||||
while running_queue:
|
||||
@ -925,6 +976,7 @@ class Scheduler:
|
||||
seq_groups: List[ScheduledSequenceGroup] = []
|
||||
|
||||
waiting_queue = self.waiting
|
||||
num_remote_prefill_groups = 0
|
||||
|
||||
leftover_waiting_sequences: Deque[SequenceGroup] = deque()
|
||||
while self._passed_delay(time.time()) and waiting_queue:
|
||||
@ -961,8 +1013,10 @@ class Scheduler:
|
||||
True, enable_chunking)
|
||||
|
||||
# If the sequence group cannot be allocated, stop.
|
||||
is_remote_decode = seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode
|
||||
can_allocate = self.block_manager.can_allocate(
|
||||
seq_group, num_lookahead_slots=num_lookahead_slots)
|
||||
seq_group, num_lookahead_slots=num_lookahead_slots,
|
||||
is_remote_decode=is_remote_decode)
|
||||
if can_allocate == AllocStatus.LATER:
|
||||
break
|
||||
elif can_allocate == AllocStatus.NEVER:
|
||||
@ -1008,7 +1062,18 @@ class Scheduler:
|
||||
if curr_loras is not None and lora_int_id > 0:
|
||||
curr_loras.add(lora_int_id)
|
||||
waiting_queue.popleft()
|
||||
self._allocate_and_set_running(seq_group)
|
||||
|
||||
seq_group_copy = copy.deepcopy(seq_group)
|
||||
seq_group_copy.seqs[0].seq_id = seq_group.seqs[0].seq_id + 1
|
||||
|
||||
logger.debug("Allocating and setting running or remote prefill for seq_group %s", seq_group.request_id)
|
||||
logger.debug("Seq id: %s", seq_group.seqs[0].seq_id)
|
||||
is_remote_prefill = self._allocate_and_set_running_or_remote_prefill(seq_group)
|
||||
num_remote_prefill_groups += is_remote_prefill
|
||||
if is_remote_decode:
|
||||
logger.debug("Seq id: %s", seq_group_copy.seqs[0].seq_id)
|
||||
self._allocate_and_set_running_or_remote_prefill(seq_group_copy)
|
||||
self.prefill_sending.append(seq_group_copy)
|
||||
|
||||
if enable_chunking and self.scheduler_config.is_multi_step:
|
||||
blocks_to_copy: List[Tuple[int, int]] = []
|
||||
@ -1046,9 +1111,11 @@ class Scheduler:
|
||||
seq_groups=seq_groups,
|
||||
ignored_seq_groups=ignored_seq_groups,
|
||||
num_lookahead_slots=self._get_num_lookahead_slots(
|
||||
is_prefill=True, enable_chunking=enable_chunking))
|
||||
is_prefill=True, enable_chunking=enable_chunking),
|
||||
num_remote_prefill_groups=num_remote_prefill_groups
|
||||
)
|
||||
|
||||
def _schedule_default(self) -> SchedulerOutputs:
|
||||
def _schedule_default(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
|
||||
"""Schedule queued requests.
|
||||
|
||||
The current policy is designed to optimize the throughput. First,
|
||||
@ -1066,9 +1133,13 @@ class Scheduler:
|
||||
for seq_group in self.running:
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
curr_loras = set(
|
||||
for seq_group in self.remote_prefilling:
|
||||
budget.add_num_seqs(seq_group.request_id,
|
||||
seq_group.get_max_num_running_seqs())
|
||||
|
||||
curr_loras = (set(
|
||||
seq_group.lora_int_id for seq_group in self.running
|
||||
if seq_group.lora_int_id > 0) if self.lora_enabled else None
|
||||
if seq_group.lora_int_id > 0) if self.lora_enabled else None)
|
||||
|
||||
prefills = SchedulerPrefillOutputs.create_empty()
|
||||
running_scheduled = SchedulerRunningOutputs.create_empty()
|
||||
@ -1090,7 +1161,9 @@ class Scheduler:
|
||||
if len(prefills.seq_groups) == 0:
|
||||
running_scheduled = self._schedule_running(budget,
|
||||
curr_loras,
|
||||
enable_chunking=False)
|
||||
enable_chunking=False,
|
||||
finished_prefills=finished_prefills,
|
||||
finished_transfers=finished_transfers)
|
||||
|
||||
# If any sequence group is preempted, do not swap in any sequence
|
||||
# group. because it means there's no slot for new running requests.
|
||||
@ -1106,7 +1179,12 @@ class Scheduler:
|
||||
self.waiting.extendleft(running_scheduled.preempted)
|
||||
# Update new running requests.
|
||||
if len(prefills.seq_groups) > 0:
|
||||
self.running.extend([s.seq_group for s in prefills.seq_groups])
|
||||
for s in prefills.seq_groups:
|
||||
seq_group = s.seq_group
|
||||
if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
|
||||
self.remote_prefilling.append(seq_group)
|
||||
else:
|
||||
self.running.append(seq_group)
|
||||
|
||||
self.running.extend(running_scheduled.decode_seq_groups_list)
|
||||
|
||||
@ -1248,12 +1326,14 @@ class Scheduler:
|
||||
len(running_scheduled.swapped_out)),
|
||||
)
|
||||
|
||||
def _schedule(self) -> SchedulerOutputs:
|
||||
def _schedule(self, finished_prefills: Optional[Set[str]] = None, finished_transfers: Optional[Set[str]] = None) -> SchedulerOutputs:
|
||||
"""Schedule queued requests."""
|
||||
if self.scheduler_config.chunked_prefill_enabled:
|
||||
if finished_prefills or finished_transfers:
|
||||
raise ValueError("Chunked prefill does not support remote prefills")
|
||||
return self._schedule_chunked_prefill()
|
||||
else:
|
||||
return self._schedule_default()
|
||||
return self._schedule_default(finished_prefills, finished_transfers)
|
||||
|
||||
def _can_append_slots(self, seq_group: SequenceGroup,
|
||||
enable_chunking: bool) -> bool:
|
||||
@ -1287,14 +1367,16 @@ class Scheduler:
|
||||
return no_single_seq
|
||||
|
||||
def schedule(
|
||||
self
|
||||
self,
|
||||
finished_prefills: Optional[Set[str]] = None,
|
||||
finished_transfers: Optional[Set[str]] = None
|
||||
) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]:
|
||||
# Schedule sequence groups.
|
||||
# This function call changes the internal states of the scheduler
|
||||
# such as self.running, self.swapped, and self.waiting.
|
||||
scheduler_start_time = time.perf_counter()
|
||||
|
||||
scheduler_outputs: SchedulerOutputs = self._schedule()
|
||||
scheduler_start_time = time.perf_counter()
|
||||
scheduler_outputs: SchedulerOutputs = self._schedule(finished_prefills, finished_transfers)
|
||||
now = time.time()
|
||||
|
||||
if not self.cache_config.enable_prefix_caching:
|
||||
@ -1333,7 +1415,8 @@ class Scheduler:
|
||||
encoder_seq_data = None
|
||||
cross_block_table = None
|
||||
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING):
|
||||
running_or_remote_prefilling_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) + seq_group.get_seqs(status=SequenceStatus.REMOTE_PREFILLING)
|
||||
for seq in running_or_remote_prefilling_seqs:
|
||||
seq_id = seq.seq_id
|
||||
seq_data[seq_id] = seq.data
|
||||
block_tables[seq_id] = self.block_manager.get_block_table(seq)
|
||||
@ -1342,7 +1425,9 @@ class Scheduler:
|
||||
if self.cache_config.enable_prefix_caching:
|
||||
common_computed_block_nums = (
|
||||
self.block_manager.get_common_computed_block_ids(
|
||||
seq_group.get_seqs(status=SequenceStatus.RUNNING)))
|
||||
running_or_remote_prefilling_seqs
|
||||
)
|
||||
)
|
||||
|
||||
do_sample = True
|
||||
is_prompt = seq_group.is_prefill()
|
||||
@ -1364,9 +1449,30 @@ class Scheduler:
|
||||
< seqs[0].data.get_len()):
|
||||
do_sample = False
|
||||
|
||||
is_remote_prefill = False
|
||||
if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
|
||||
is_remote_prefill = True
|
||||
logger.debug("Remote prefill, computed block nums: %s", common_computed_block_nums)
|
||||
if is_first_prefill and seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_decode:
|
||||
block_tables[seq_group.seqs[0].seq_id + 1] = self.block_manager.block_tables[seq.seq_id + 1].physical_block_ids
|
||||
|
||||
# Since we know that prefill is scheduled we can
|
||||
# assume that the blocks computed on decode
|
||||
# will be fetched by the time we run prefill
|
||||
logger.debug("Computed decode blocks: %s", seq_group.remote_prefill_params.decode_computed_block_ids)
|
||||
if seq_group.remote_prefill_params.decode_computed_block_ids:
|
||||
computed_block_ids = set(seq_group.remote_prefill_params.decode_computed_block_ids)
|
||||
prefill_block_ids = block_tables[seq_group.seqs[0].seq_id]
|
||||
prefill_fetched_block_ids = [prefill_block_ids[i] for i, block_id in enumerate(seq_group.remote_prefill_params.decode_block_ids) if block_id in computed_block_ids and i < len(prefill_block_ids)]
|
||||
|
||||
assert len(common_computed_block_nums) == 0, "common_computed_block_nums should be empty for remote prefill as it doesn't suport prefix caching"
|
||||
common_computed_block_nums = prefill_fetched_block_ids
|
||||
|
||||
|
||||
# It assumes the scheduled_seq_groups is ordered by
|
||||
# prefill < decoding.
|
||||
if is_first_prefill or not self.scheduler_config.send_delta_data:
|
||||
logger.debug("Assinged blocks: %s", block_tables)
|
||||
seq_group_metadata = SequenceGroupMetadata(
|
||||
request_id=seq_group.request_id,
|
||||
is_prompt=is_prompt,
|
||||
@ -1392,6 +1498,7 @@ class Scheduler:
|
||||
if scheduler_outputs.num_prefill_groups > 0 else None,
|
||||
mm_processor_kwargs=seq_group.mm_processor_kwargs,
|
||||
prompt_adapter_request=seq_group.prompt_adapter_request,
|
||||
do_remote_prefill=is_remote_prefill,
|
||||
)
|
||||
else:
|
||||
# When SPMD mode is enabled, we only send delta data except for
|
||||
@ -1490,11 +1597,17 @@ class Scheduler:
|
||||
|
||||
self._async_stopped.clear()
|
||||
|
||||
def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
|
||||
def _allocate_and_set_running_or_remote_prefill(self, seq_group: SequenceGroup) -> bool:
|
||||
self.block_manager.allocate(seq_group)
|
||||
is_remote_prefill = False
|
||||
for seq in seq_group.get_seqs(status=SequenceStatus.WAITING):
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
|
||||
if seq_group.remote_prefill_params is not None and seq_group.remote_prefill_params.is_remote_prefill:
|
||||
seq.status = SequenceStatus.REMOTE_PREFILLING
|
||||
is_remote_prefill = True
|
||||
else:
|
||||
seq.status = SequenceStatus.RUNNING
|
||||
return is_remote_prefill
|
||||
|
||||
def _append_slots(self,
|
||||
seq_group: SequenceGroup,
|
||||
blocks_to_copy: List[Tuple[int, int]],
|
||||
|
110
vllm/distributed/device_communicators/kv_rearrange.py
Normal file
110
vllm/distributed/device_communicators/kv_rearrange.py
Normal file
@ -0,0 +1,110 @@
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@triton.jit
|
||||
def rearrange_kernel_read(
|
||||
t1_ptr,
|
||||
t2_ptr,
|
||||
N,
|
||||
B,
|
||||
H,
|
||||
C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
curr_n = offsets // block_size
|
||||
curr_b = offsets // token_size % B
|
||||
curr_h = offsets // C % H
|
||||
curr_c = offsets % C
|
||||
|
||||
src_pos = offsets
|
||||
|
||||
tp_group = curr_h * d // H
|
||||
dst_h = curr_h % (H // d)
|
||||
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
|
||||
|
||||
dst_pos = tensor_subset_size * tp_group + tp_group_offset
|
||||
|
||||
tl.store(t1_ptr + src_pos, tl.load(t2_ptr + dst_pos))
|
||||
|
||||
@triton.jit
|
||||
def rearrange_kernel_write(
|
||||
t1_ptr,
|
||||
t2_ptr,
|
||||
N,
|
||||
B,
|
||||
H,
|
||||
C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(0)
|
||||
|
||||
block_start = pid * BLOCK_SIZE
|
||||
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
curr_n = offsets // block_size
|
||||
curr_b = offsets // token_size % B
|
||||
curr_h = offsets // C % H
|
||||
curr_c = offsets % C
|
||||
|
||||
src_pos = offsets
|
||||
|
||||
tp_group = curr_h * d // H
|
||||
dst_h = curr_h % (H // d)
|
||||
tp_group_offset = curr_n * (block_size // d) + curr_b * (H // d) * C + dst_h * C + curr_c
|
||||
|
||||
dst_pos = tensor_subset_size * tp_group + tp_group_offset
|
||||
|
||||
tl.store(t2_ptr + dst_pos, tl.load(t1_ptr + src_pos))
|
||||
|
||||
|
||||
|
||||
def rearrange_tensors(t1: torch.Tensor, t2: torch.Tensor, d: int, direction: str):
|
||||
N, B, H, C = t1.shape
|
||||
|
||||
assert t2.shape == (N, B, H, C), "Destination tensor must have same shape as source"
|
||||
assert H % d == 0, "H must be divisible by d"
|
||||
|
||||
block_size = B * H * C
|
||||
token_size = H * C
|
||||
tensor_size = N * block_size
|
||||
tensor_subset_size = tensor_size // d
|
||||
|
||||
BLOCK_SIZE = 1024
|
||||
grid = ((N * B * H * C + BLOCK_SIZE - 1) // BLOCK_SIZE,)
|
||||
|
||||
if direction == "read":
|
||||
rearrange_kernel_read[grid](
|
||||
t1, t2,
|
||||
N, B, H, C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
elif direction == "write":
|
||||
rearrange_kernel_write[grid](
|
||||
t1, t2,
|
||||
N, B, H, C,
|
||||
d,
|
||||
tensor_subset_size,
|
||||
block_size,
|
||||
token_size,
|
||||
BLOCK_SIZE=BLOCK_SIZE
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid direction: {direction}")
|
379
vllm/distributed/device_communicators/nixl.py
Normal file
379
vllm/distributed/device_communicators/nixl.py
Normal file
@ -0,0 +1,379 @@
|
||||
import torch
|
||||
from typing import List, Tuple
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
import msgspec
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from .kv_rearrange import rearrange_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
logger.info("NIXL is available")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
class NixlMetadata(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True):
|
||||
engine_id: str
|
||||
agent_metadata: List[bytes]
|
||||
kv_caches_base_addr: List[List[Tuple[int, int]]] # base address for each rank for each layer for keys and values
|
||||
num_blocks: int
|
||||
|
||||
|
||||
class DynamoNixlConnector:
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str, rank: int):
|
||||
self.vllm_config = vllm_config
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
raise RuntimeError("NIXL is not available")
|
||||
logger.info("Initializing NIXL wrapper")
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
|
||||
self.use_prepped_xfer = vllm_config.kv_transfer_config.use_prepped_xfer
|
||||
|
||||
self.num_layers = None
|
||||
self.num_blocks = None
|
||||
self.num_heads = None
|
||||
self.block_len = None
|
||||
self.kv_caches = None
|
||||
self.kv_caches_base_addr = {}
|
||||
self.kv_cache_shape = {}
|
||||
|
||||
self._registered_descs = []
|
||||
self._remote_agents = {}
|
||||
self.engine_id = engine_id
|
||||
self.rank = rank
|
||||
self._tp_size = {}
|
||||
self.src_xfer_side_handles = {}
|
||||
self.dst_xfer_side_handles = defaultdict(dict)
|
||||
self.dst_num_blocks = {}
|
||||
|
||||
self._transfers = defaultdict(list)
|
||||
|
||||
|
||||
self._tp_size[engine_id] = vllm_config.parallel_config.tensor_parallel_size
|
||||
|
||||
|
||||
@property
|
||||
def agent_name(self):
|
||||
return self.nixl_wrapper.name
|
||||
|
||||
def register_kv_caches(self, kv_caches: List[torch.Tensor]):
|
||||
_, num_blocks, block_size, num_heads, head_dim = kv_caches[0].shape
|
||||
self.block_len = block_size * num_heads * head_dim * kv_caches[0].element_size()
|
||||
logger.debug("Per layer kv cache size: %s", kv_caches[0].shape)
|
||||
self.num_layers = len(kv_caches)
|
||||
self.num_blocks = num_blocks
|
||||
self.num_heads = num_heads
|
||||
self.kv_caches = kv_caches
|
||||
kv_caches_base_addr = []
|
||||
caches_data = []
|
||||
for key_cache, value_cache in kv_caches:
|
||||
base_addr = key_cache.data_ptr()
|
||||
region_len = 2 * num_blocks * self.block_len
|
||||
caches_data.append((base_addr, region_len, self.rank, ""))
|
||||
kv_caches_base_addr.append((key_cache.data_ptr(), value_cache.data_ptr()))
|
||||
|
||||
self.kv_caches_base_addr[self.engine_id] = kv_caches_base_addr
|
||||
|
||||
descs = self.nixl_wrapper.get_reg_descs(caches_data, "VRAM")
|
||||
logger.debug("Registering descs: %s", caches_data)
|
||||
self.nixl_wrapper.register_memory(descs)
|
||||
self._registered_descs.append(descs)
|
||||
|
||||
def get_agent_metadata(self):
|
||||
return self.nixl_wrapper.get_agent_metadata()
|
||||
|
||||
def shutdown(self):
|
||||
for descs_list in self._registered_descs:
|
||||
self.nixl_wrapper.deregister_memory(descs_list)
|
||||
for agent_names in self._remote_agents.values():
|
||||
for agent_name in agent_names:
|
||||
self.nixl_wrapper.remove_remote_agent(agent_name)
|
||||
for src_xfer_side_handle in self.src_xfer_side_handles.values():
|
||||
self.nixl_wrapper.release_dlist_handle(src_xfer_side_handle)
|
||||
for dst_xfer_side_handles in self.dst_xfer_side_handles.values():
|
||||
for dst_xfer_side_handle in dst_xfer_side_handles.values():
|
||||
self.nixl_wrapper.delete_xfer_side(dst_xfer_side_handle)
|
||||
|
||||
def _get_ranges(self, block_ids):
|
||||
# This function should return a list of ranges of block ids that are contiguous
|
||||
# For example, if block_ids is [0, 1, 2, 4, 5, 6], the function should return [[0, 2], [4, 6]]
|
||||
# The ranges are sorted by the starting block id
|
||||
# The function should also make sure that the block ids are contiguous
|
||||
# If the block ids are not contiguous, the function should raise an error
|
||||
ranges = []
|
||||
for i in range(len(block_ids)):
|
||||
if i == 0 or block_ids[i] != block_ids[i-1] + 1:
|
||||
ranges.append([block_ids[i], block_ids[i]])
|
||||
else:
|
||||
ranges[-1][1] = block_ids[i]
|
||||
return ranges
|
||||
|
||||
def _get_block_descs_ids(self, engine_id, layer_ids, block_ids, i=None, tp_multiplier=1, staging_ranges=None):
|
||||
|
||||
if layer_ids == "all":
|
||||
layer_ids = list(range(self.num_layers))
|
||||
if block_ids == "all":
|
||||
block_ids = list(range(self.num_blocks))
|
||||
|
||||
descs_ids = []
|
||||
|
||||
|
||||
if i is not None:
|
||||
num_blocks = self.num_blocks
|
||||
for layer_id in layer_ids:
|
||||
for is_value in [0, 1]:
|
||||
staging_range_idx = 0
|
||||
for block_id in block_ids:
|
||||
if block_id > staging_ranges[staging_range_idx][1] or block_id < staging_ranges[staging_range_idx][0]:
|
||||
staging_range_idx += 1
|
||||
start_offset = staging_ranges[staging_range_idx][0]
|
||||
i_offset = i * (staging_ranges[staging_range_idx][-1] - start_offset + 1)
|
||||
descs_ids.append(layer_id * 2 * num_blocks * tp_multiplier + is_value * num_blocks * tp_multiplier + start_offset * tp_multiplier + i_offset + (block_id - start_offset))
|
||||
else:
|
||||
num_blocks = self.dst_num_blocks[engine_id]
|
||||
for layer_id in layer_ids:
|
||||
for is_value in [0, 1]:
|
||||
for block_id in block_ids:
|
||||
descs_ids.append(layer_id * 2 * num_blocks + is_value * num_blocks + block_id)
|
||||
return descs_ids
|
||||
|
||||
def _get_same_length_ranges(self, src_ranges, dst_ranges, return_original_src_ranges=False):
|
||||
# This function should return a list of ranges for both src and dst so that corresponding ranges are the same length
|
||||
# For example, if src_ranges is [[0, 2] [4, 8]] and dst_ranges is [[1, 3], [5, 7], [9, 10]]
|
||||
# The function should return ([[0, 2], [4, 6], [7, 8]], [[1, 3], [5, 7], [9, 10]])
|
||||
src_overlapping_ranges, dst_overlapping_ranges = [], []
|
||||
|
||||
original_src_ranges = []
|
||||
org_src_range = tuple(src_ranges[0])
|
||||
|
||||
src_idx, dst_idx = 0, 0
|
||||
while src_idx < len(src_ranges) and dst_idx < len(dst_ranges):
|
||||
src_range = src_ranges[src_idx]
|
||||
dst_range = dst_ranges[dst_idx]
|
||||
|
||||
# Calculate the length of each range
|
||||
src_len = src_range[-1] - src_range[0] + 1
|
||||
dst_len = dst_range[-1] - dst_range[0] + 1
|
||||
|
||||
# If ranges have the same length, add them directly
|
||||
if src_len == dst_len:
|
||||
src_overlapping_ranges.append([src_range[0], src_range[-1]])
|
||||
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
|
||||
original_src_ranges.append(org_src_range)
|
||||
src_idx += 1
|
||||
dst_idx += 1
|
||||
if src_idx < len(src_ranges):
|
||||
org_src_range = tuple(src_ranges[src_idx])
|
||||
# If source range is longer, split it
|
||||
elif src_len > dst_len:
|
||||
src_overlapping_ranges.append([src_range[0], src_range[0] + dst_len - 1])
|
||||
dst_overlapping_ranges.append([dst_range[0], dst_range[-1]])
|
||||
original_src_ranges.append(org_src_range)
|
||||
# Update source range for next iteration
|
||||
src_ranges[src_idx] = [src_range[0] + dst_len, src_range[-1]]
|
||||
dst_idx += 1
|
||||
# If destination range is longer, split it
|
||||
else: # src_len < dst_len
|
||||
src_overlapping_ranges.append([src_range[0], src_range[-1]])
|
||||
dst_overlapping_ranges.append([dst_range[0], dst_range[0] + src_len - 1])
|
||||
original_src_ranges.append(org_src_range)
|
||||
# Update destination range for next iteration
|
||||
dst_ranges[dst_idx] = [dst_range[0] + src_len, dst_range[-1]]
|
||||
src_idx += 1
|
||||
if src_idx < len(src_ranges):
|
||||
org_src_range = tuple(src_ranges[src_idx])
|
||||
if return_original_src_ranges:
|
||||
return src_overlapping_ranges, dst_overlapping_ranges, original_src_ranges
|
||||
return src_overlapping_ranges, dst_overlapping_ranges
|
||||
|
||||
def read_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id):
|
||||
logger.debug("Reading %d blocks from %s to %s", len(local_block_ids), self.agent_name, dst_engine_id)
|
||||
|
||||
assert len(local_block_ids) == len(staging_block_ids) == len(remote_block_ids)
|
||||
|
||||
if len(local_block_ids) == 0:
|
||||
logger.debug("No blocks to read")
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
local_ranges = self._get_ranges(local_block_ids)
|
||||
staging_ranges = self._get_ranges(staging_block_ids)
|
||||
|
||||
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
|
||||
|
||||
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
|
||||
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
|
||||
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
|
||||
handles = []
|
||||
|
||||
logger.debug("Time to get block descs ids: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
create_xfer_start_time = time.perf_counter()
|
||||
|
||||
for i in range(tp_multiplier):
|
||||
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
|
||||
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
|
||||
handle = self.nixl_wrapper.make_prepped_xfer("READ", local_xfer_side_handle, staging_block_descs_ids,
|
||||
remote_xfer_side_handle, remote_block_descs_ids,
|
||||
"")
|
||||
handles.append(handle)
|
||||
status = self.nixl_wrapper.transfer(handle)
|
||||
|
||||
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
|
||||
|
||||
transfer_start_time = time.perf_counter()
|
||||
|
||||
for handle in handles:
|
||||
while (status := self.nixl_wrapper.check_xfer_state(handle)) != "DONE":
|
||||
if status == "PROC":
|
||||
time.sleep(0.001)
|
||||
else:
|
||||
raise RuntimeError("Read transfer failed with state %s", status)
|
||||
# self.nixl_wrapper.abort_xfer(handle) # TODO ptarasiewicz: why abort is throwing errors?
|
||||
|
||||
logger.debug("Time to transfer: %s ms", (time.perf_counter() - transfer_start_time) * 1000)
|
||||
|
||||
rearrange_start_time = time.perf_counter()
|
||||
|
||||
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
|
||||
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
|
||||
for kv_cache in self.kv_caches:
|
||||
for cache in kv_cache:
|
||||
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "read")
|
||||
|
||||
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - rearrange_start_time) * 1000)
|
||||
logger.debug("Total time for read: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
|
||||
def write_blocks(self, local_block_ids, staging_block_ids, remote_block_ids, dst_engine_id, notify_msg):
|
||||
logger.debug("Writing %d blocks to %s from %s with notify message %s", len(local_block_ids), dst_engine_id, self.agent_name, notify_msg)
|
||||
|
||||
# hongkuanz: we send isl[:-1] tokens to the prefill where the kv for the last
|
||||
# isl[-1] token is calculated in the first iteration in decode.
|
||||
# If isl equals to a multiple of tokens_per_block + 1, prefill engine will have \
|
||||
# one less block due to the missing last token.
|
||||
remote_block_ids = remote_block_ids[:len(local_block_ids)]
|
||||
|
||||
assert len(staging_block_ids) == len(local_block_ids)
|
||||
tp_multiplier = self._tp_size[dst_engine_id] // self._tp_size[self.engine_id]
|
||||
|
||||
if len(local_block_ids) == 0:
|
||||
logger.debug("No blocks to write")
|
||||
for i in range(tp_multiplier):
|
||||
self.nixl_wrapper.send_notif(self._remote_agents[dst_engine_id][self.rank * tp_multiplier + i], notify_msg)
|
||||
return
|
||||
|
||||
start_time = time.perf_counter()
|
||||
|
||||
local_ranges = self._get_ranges(local_block_ids)
|
||||
staging_ranges = self._get_ranges(staging_block_ids)
|
||||
|
||||
local_rearranging_ranges, staging_rearranging_ranges = self._get_same_length_ranges(local_ranges, staging_ranges)
|
||||
|
||||
for local_range, staging_range in zip(local_rearranging_ranges, staging_rearranging_ranges):
|
||||
logger.debug("Rearranging tensors for cache: %s, local_range: %s, staging_range: %s", self.kv_caches[0].shape, local_range, staging_range)
|
||||
for kv_cache in self.kv_caches:
|
||||
for cache in kv_cache:
|
||||
rearrange_tensors(cache[local_range[0]:local_range[1] + 1], cache[staging_range[0]:staging_range[1] + 1], tp_multiplier, "write")
|
||||
|
||||
logger.debug("Time to rearrange tensors: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
|
||||
create_xfer_start_time = time.perf_counter()
|
||||
|
||||
# getting block descs ids
|
||||
remote_block_descs_ids = self._get_block_descs_ids(dst_engine_id, "all", remote_block_ids)
|
||||
local_xfer_side_handle = self.src_xfer_side_handles[tp_multiplier]
|
||||
|
||||
for i in range(tp_multiplier):
|
||||
staging_block_descs_ids = self._get_block_descs_ids(self.engine_id, "all", staging_block_ids, i=i, tp_multiplier=tp_multiplier, staging_ranges=staging_rearranging_ranges)
|
||||
assert len(staging_block_descs_ids) == len(remote_block_descs_ids)
|
||||
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id][i]
|
||||
handle = self.nixl_wrapper.make_prepped_xfer("WRITE", local_xfer_side_handle, staging_block_descs_ids,
|
||||
remote_xfer_side_handle, remote_block_descs_ids,
|
||||
notify_msg)
|
||||
self._transfers[notify_msg].append(handle)
|
||||
status = self.nixl_wrapper.transfer(handle)
|
||||
|
||||
logger.debug("Time to create xfer: %s ms", (time.perf_counter() - create_xfer_start_time) * 1000)
|
||||
|
||||
transfer_start_time = time.perf_counter()
|
||||
logger.debug("Total time for write: %s ms", (time.perf_counter() - start_time) * 1000)
|
||||
|
||||
def get_notifs(self):
|
||||
return self.nixl_wrapper.update_notifs()
|
||||
|
||||
def get_new_notifs(self):
|
||||
return self.nixl_wrapper.get_new_notifs()
|
||||
|
||||
def add_remote_agent(self, engine_id, agent_metadata, agent_tp, kv_caches_base_addr, num_blocks):
|
||||
self._tp_size[engine_id] = agent_tp
|
||||
agent_names = []
|
||||
for agent_meta in agent_metadata:
|
||||
agent_name = self.nixl_wrapper.add_remote_agent(agent_meta)
|
||||
agent_names.append(agent_name)
|
||||
self._remote_agents[engine_id] = agent_names
|
||||
self.kv_caches_base_addr[engine_id] = kv_caches_base_addr
|
||||
|
||||
tp_multiplier = self._tp_size[engine_id] // self._tp_size[self.engine_id]
|
||||
assert tp_multiplier > 0, f"Decode TP cannot be smaller than prefill TP, got {self._tp_size[engine_id]} and {self._tp_size[self.engine_id]}"
|
||||
|
||||
logger.debug("Creating src xfer side handles for engine %s, tp_multiplier: %s", engine_id, tp_multiplier)
|
||||
dst_block_len = self.block_len // tp_multiplier
|
||||
if tp_multiplier not in self.src_xfer_side_handles:
|
||||
# create descs and xfer side handles
|
||||
blocks_data = []
|
||||
for layer_id in range(self.num_layers):
|
||||
for base_addr in self.kv_caches_base_addr[self.engine_id][layer_id]:
|
||||
for block_id in range(self.num_blocks):
|
||||
block_offset = block_id * self.block_len
|
||||
for i in range(tp_multiplier):
|
||||
tp_multiplier_offset = i * dst_block_len
|
||||
blocks_data.append((base_addr + block_offset + tp_multiplier_offset, dst_block_len, self.rank))
|
||||
logger.debug("Created %s blocks for src engine %s and rank %s", len(blocks_data), self.engine_id, self.rank * tp_multiplier + i)
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.src_xfer_side_handles[tp_multiplier] = self.nixl_wrapper.prep_xfer_dlist("", descs)
|
||||
|
||||
# create dst xfer side handles
|
||||
self.dst_num_blocks[engine_id] = num_blocks
|
||||
for i in range(tp_multiplier):
|
||||
blocks_data = []
|
||||
for layer_id in range(self.num_layers):
|
||||
for base_addr in self.kv_caches_base_addr[engine_id][self.rank * tp_multiplier + i][layer_id]:
|
||||
for block_id in range(num_blocks):
|
||||
block_offset = block_id * dst_block_len
|
||||
blocks_data.append((base_addr + block_offset, dst_block_len, self.rank * tp_multiplier + i))
|
||||
logger.debug("Created %s blocks for dst engine %s and rank %s", len(blocks_data), engine_id, self.rank * tp_multiplier + i)
|
||||
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
|
||||
self.dst_xfer_side_handles[engine_id][i] = self.nixl_wrapper.prep_xfer_dlist(self._remote_agents[engine_id][self.rank * tp_multiplier + i], descs)
|
||||
|
||||
return agent_names
|
||||
|
||||
def get_done_tranfers(self) -> List[str]:
|
||||
done_req_ids = []
|
||||
for req_id, handles in self._transfers.items():
|
||||
running_reqs = []
|
||||
for handle in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
# self.nixl_wrapper.release_xfer_handle(handle) # TODO ptarasiewicz: why abort is throwing errors?
|
||||
continue
|
||||
if xfer_state == "PROC":
|
||||
running_reqs.append(handle)
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s", xfer_state)
|
||||
if len(running_reqs) == 0:
|
||||
done_req_ids.append(req_id)
|
||||
else:
|
||||
self._transfers[req_id] = running_reqs
|
||||
return done_req_ids
|
350
vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py
Normal file
350
vllm/distributed/kv_transfer/kv_connector/dynamo_connector.py
Normal file
@ -0,0 +1,350 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
Simple KV Cache Connector for Distributed Machine Learning Inference
|
||||
|
||||
The SimpleConnector transfers KV caches between prefill vLLM worker (KV cache
|
||||
producer) and decode vLLM worker (KV cache consumer) using PyNcclPipe or
|
||||
MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig, KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class DynamoConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
world_group,
|
||||
):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
|
||||
if self.config.kv_connector != "DynamoNcclConnector":
|
||||
raise NotImplementedError("Only DynamoNcclConnector is supported by the DynamoConnector class")
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import (
|
||||
PyNcclPipe)
|
||||
from vllm.distributed.kv_transfer.kv_pipe.dynamo_nccl_pipe import (
|
||||
DynamoNcclDataPlane)
|
||||
|
||||
logger.info(
|
||||
"Initializing DynamoNcclConnector under kv_transfer_config %s",
|
||||
self.config)
|
||||
|
||||
self.lookup_buffer_size = self.config.kv_buffer_size
|
||||
|
||||
self.producer_data_pipe: PyNcclPipe
|
||||
self.consumer_data_pipe: PyNcclPipe
|
||||
self.producer_signal_pipe: PyNcclPipe
|
||||
self.consumer_signal_pipe: PyNcclPipe
|
||||
|
||||
self._broadcast_and_enhance_kv_config(rank, config, world_group)
|
||||
|
||||
self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
if self.config.is_kv_producer:
|
||||
port_offset_base = rank + 1
|
||||
else:
|
||||
port_offset_base = rank // self.config.tensor_parallel_multiplier + 1
|
||||
|
||||
|
||||
self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
|
||||
self.global_kv_rank = self._get_global_kv_rank(self.config.kv_rank, rank, self.config)
|
||||
|
||||
self.data_pipe = PyNcclPipe(
|
||||
kv_group_rank=self.kv_group_rank,
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
|
||||
self.data_plane = DynamoNcclDataPlane(
|
||||
data_pipe=self.data_pipe,
|
||||
port=self._get_data_plane_port(self.global_kv_rank),
|
||||
)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor,
|
||||
IntermediateTensors],
|
||||
) -> None:
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
model_config = model_executable.model.config
|
||||
is_deepseek = "deepseek" in model_config.architectures[0].lower()
|
||||
if not is_deepseek:
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
else:
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(4.5 * hidden_size / num_attention_heads)
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
# FIXME(Kuntai): This assume that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
current_request_id = request_ids[idx]
|
||||
decode_hostname, decode_kv_rank = self.parse_request_id(current_request_id)
|
||||
decode_first_global_rank = self._get_global_kv_rank(decode_kv_rank, self.rank * self.config.tensor_parallel_multiplier, self.config)
|
||||
|
||||
for target_rank in range(self.config.tensor_parallel_multiplier):
|
||||
|
||||
keys, values = [], []
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
|
||||
head_start = target_rank * num_heads_per_rank
|
||||
head_end = head_start + num_heads_per_rank
|
||||
|
||||
if not is_deepseek:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
|
||||
else:
|
||||
key_cache = kv_cache
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(torch.empty(0))
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
decode_global_rank = decode_first_global_rank + target_rank
|
||||
decode_port = self._get_data_plane_port(decode_global_rank)
|
||||
partial_hidden_or_intermediate_states = hidden_or_intermediate_states[start_pos:end_pos]
|
||||
self._send(decode_hostname, decode_port, current_request_id, keys, values,
|
||||
partial_hidden_or_intermediate_states)
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
|
||||
# When bypass_model_exec is set to False, it means that at least for one
|
||||
# request its corresponding KV cache or hidden state is missing.
|
||||
# In this case we need to do prefilling to recompute missing KV cache
|
||||
# and hidden states.
|
||||
bypass_model_exec = True
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
model_config = model_executable.model.config
|
||||
is_deepseek = "deepseek" in model_config.architectures[0].lower()
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
current_request_id = request_ids[idx]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
ret = self._recv(current_request_id)
|
||||
keys: torch.Tensor = ret[0]
|
||||
values: torch.Tensor = ret[1]
|
||||
hidden: torch.Tensor = ret[2]
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(model_executable.model.start_layer,
|
||||
model_executable.model.end_layer):
|
||||
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
if not is_deepseek:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
else:
|
||||
key_cache = kv_cache
|
||||
copy_from =keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device)
|
||||
kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.debug(
|
||||
"[rank%d]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"[rank%d]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self):
|
||||
self.data_pipe.close()
|
||||
# self.data_plane.close()
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id: str) -> Tuple[str, int]:
|
||||
# Regular expression to match the string hostname and integer decode_kv_rank
|
||||
pattern = r"___decode_hostname_(.*)___decode_kv_rank_(\d+)"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
if match:
|
||||
# Extract the ranks
|
||||
decode_hostname = match.group(1)
|
||||
decode_rank = int(match.group(2))
|
||||
|
||||
return decode_hostname, decode_rank
|
||||
raise ValueError(f"Request id {request_id} does not contain hostname and decode_kv_rank")
|
||||
|
||||
def _send(self, hostname: str, port: int, request_id: str, keys: torch.Tensor, values: torch.Tensor, hidden: torch.Tensor):
|
||||
remote_address = f"{hostname}:{port}"
|
||||
self.data_plane.send_tensor(keys, f"{request_id}_keys", remote_address)
|
||||
self.data_plane.send_tensor(values, f"{request_id}_values", remote_address)
|
||||
self.data_plane.send_tensor(hidden, f"{request_id}_hidden", remote_address)
|
||||
|
||||
def _recv(self, request_id: str) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
keys = self.data_plane.recv_tensor(f"{request_id}_keys")
|
||||
values = self.data_plane.recv_tensor(f"{request_id}_values")
|
||||
hidden = self.data_plane.recv_tensor(f"{request_id}_hidden")
|
||||
return keys, values, hidden
|
||||
|
||||
def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
|
||||
if kv_rank < config.kv_producers_parallel_size:
|
||||
return kv_rank
|
||||
|
||||
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
|
||||
return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
|
||||
|
||||
|
||||
def _get_global_kv_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
|
||||
if kv_rank <= config.kv_producers_parallel_size:
|
||||
return kv_rank * config.kv_producers_tensor_parallel_size + rank
|
||||
|
||||
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
|
||||
return config.kv_producers_parallel_size * config.kv_producers_tensor_parallel_size + kv_consumer_rank * config.kv_consumers_tensor_parallel_size + rank
|
||||
|
||||
|
||||
def _get_data_plane_port(self, global_kv_rank: int) -> int:
|
||||
return self.config.kv_port + self.config.kv_producers_tensor_parallel_size + 1 + global_kv_rank
|
||||
|
||||
def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
|
||||
if rank == 0:
|
||||
config_group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port,
|
||||
rank=self.config.kv_rank,
|
||||
world_size=self.config.kv_parallel_size,
|
||||
)
|
||||
parallel_configs = config_group.all_gather_obj({
|
||||
"kv_role": self.config.kv_role,
|
||||
"tensor_parallel_size": config.parallel_config.tensor_parallel_size,
|
||||
"pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
|
||||
})
|
||||
logger.debug("parallel_configs: %s", parallel_configs)
|
||||
kv_config_enhanced = {
|
||||
"kv_producers_tensor_parallel_size": None,
|
||||
"kv_consumers_tensor_parallel_size": None,
|
||||
"kv_producers_pipeline_parallel_size": None,
|
||||
"kv_consumers_pipeline_parallel_size": None,
|
||||
"kv_producers_parallel_size": 0,
|
||||
}
|
||||
for parallel_config in parallel_configs:
|
||||
kv_role = parallel_config["kv_role"]
|
||||
assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
|
||||
|
||||
if kv_role == "kv_producer":
|
||||
kv_config_enhanced["kv_producers_parallel_size"] += 1
|
||||
if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
|
||||
kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
|
||||
kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
|
||||
else:
|
||||
assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
|
||||
assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
|
||||
world_group.broadcast_object(kv_config_enhanced)
|
||||
else:
|
||||
kv_config_enhanced = world_group.broadcast_object()
|
||||
logger.info("kv_config_enhanced: %s", kv_config_enhanced)
|
||||
|
||||
self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
|
||||
self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
|
||||
self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
|
||||
self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
|
||||
self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
|
@ -27,13 +27,13 @@ class KVConnectorFactory:
|
||||
|
||||
@classmethod
|
||||
def create_connector(cls, rank: int, local_rank: int,
|
||||
config: "VllmConfig") -> KVConnectorBase:
|
||||
config: "VllmConfig", world_group) -> KVConnectorBase:
|
||||
connector_name = config.kv_transfer_config.kv_connector
|
||||
if connector_name not in cls._registry:
|
||||
raise ValueError(f"Unsupported connector type: {connector_name}")
|
||||
|
||||
connector_cls = cls._registry[connector_name]()
|
||||
return connector_cls(rank, local_rank, config)
|
||||
return connector_cls(rank, local_rank, config, world_group)
|
||||
|
||||
|
||||
# Register various connectors here.
|
||||
@ -48,3 +48,8 @@ KVConnectorFactory.register_connector(
|
||||
"MooncakeConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.simple_connector",
|
||||
"SimpleConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"DynamoNcclConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.dynamo_connector",
|
||||
"DynamoConnector")
|
||||
|
@ -8,13 +8,15 @@ MooncakePipe.
|
||||
|
||||
But the logic can be extended to support other pipe and lookup buffer.
|
||||
"""
|
||||
import re
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config import VllmConfig, KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.distributed.kv_transfer.kv_lookup_buffer.simple_buffer import (
|
||||
SimpleBuffer)
|
||||
from vllm.logger import init_logger
|
||||
@ -33,6 +35,7 @@ class SimpleConnector(KVConnectorBase):
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
world_group,
|
||||
):
|
||||
|
||||
self.config = config.kv_transfer_config
|
||||
@ -71,20 +74,31 @@ class SimpleConnector(KVConnectorBase):
|
||||
self.producer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
self.consumer_signal_pipe: Union[PyNcclPipe, MooncakePipe]
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
port_offset_base = 2 * rank
|
||||
self._broadcast_and_enhance_kv_config(rank, config, world_group)
|
||||
|
||||
self.kv_group_rank = self._get_kv_group_rank(self.config.kv_rank, rank, self.config)
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
|
||||
# 2 pipes for every rank in the world
|
||||
if self.config.is_kv_producer:
|
||||
port_offset_base = 2 * rank + 1
|
||||
else:
|
||||
port_offset_base = 2 * (rank // self.config.tensor_parallel_multiplier) + 1
|
||||
|
||||
self.local_kv_rank = rank % self.config.tensor_parallel_multiplier
|
||||
# In disaggregated prefill, the prefill vLLM only uses send pipe
|
||||
# and the decode vLLM only uses recv pipe
|
||||
if self.config.is_kv_producer:
|
||||
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.producer_data_pipe = PyNcclPipe(
|
||||
kv_group_rank=self.kv_group_rank,
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.producer_signal_pipe = PyNcclPipe(
|
||||
kv_group_rank=self.kv_group_rank,
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
@ -108,11 +122,13 @@ class SimpleConnector(KVConnectorBase):
|
||||
# its recv pipe to the send pipe of KV producder
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
self.consumer_data_pipe = PyNcclPipe(
|
||||
kv_group_rank=self.kv_group_rank,
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base,
|
||||
)
|
||||
self.consumer_signal_pipe = PyNcclPipe(
|
||||
kv_group_rank=self.kv_group_rank,
|
||||
local_rank=local_rank,
|
||||
config=self.config,
|
||||
port_offset=port_offset_base + 1,
|
||||
@ -131,21 +147,25 @@ class SimpleConnector(KVConnectorBase):
|
||||
self.config.kv_buffer_size,
|
||||
)
|
||||
|
||||
def select(self, input_tokens: Optional[torch.Tensor],
|
||||
def select(self, source_rank: int, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
||||
|
||||
logger.info("Selecting KV caches and hidden states for source rank %d", source_rank)
|
||||
|
||||
assert self.consumer_buffer is not None, "Please initialize the "\
|
||||
"consumer buffer before calling select."
|
||||
return self.consumer_buffer.drop_select(input_tokens, roi)
|
||||
return self.consumer_buffer.drop_select(source_rank, self.local_kv_rank, input_tokens, roi)
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
logger.info("Inserting KV caches and hidden states for kv_group_rank %d, target rank %d", kv_group_rank, target_rank)
|
||||
|
||||
assert self.producer_buffer is not None, "Please initialize the "\
|
||||
"producer buffer before calling insert."
|
||||
|
||||
self.producer_buffer.insert(input_tokens, roi, key, value, hidden)
|
||||
self.producer_buffer.insert(kv_group_rank, target_rank, input_tokens, roi, key, value, hidden)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
@ -161,12 +181,20 @@ class SimpleConnector(KVConnectorBase):
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
is_deepseek = "deepseek" in model_config.architectures[0].lower()
|
||||
if not is_deepseek:
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
else:
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(4.5 * hidden_size / num_attention_heads)
|
||||
|
||||
# query_lens contains new KV caches that are added to vLLM.
|
||||
# so we will send them to decode instance
|
||||
@ -175,27 +203,40 @@ class SimpleConnector(KVConnectorBase):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
current_request_id = request_ids[idx]
|
||||
_, decode_kv_rank = self.parse_request_id(current_request_id)
|
||||
starting_kv_group_rank = self._get_kv_group_rank(decode_kv_rank, 0, self.config)
|
||||
|
||||
keys, values = [], []
|
||||
for target_rank in range(self.config.tensor_parallel_multiplier):
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
keys, values = [], []
|
||||
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping].unsqueeze(0))
|
||||
num_heads_per_rank = num_heads // self.config.tensor_parallel_multiplier
|
||||
head_start = target_rank * num_heads_per_rank
|
||||
head_end = head_start + num_heads_per_rank
|
||||
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
if not is_deepseek:
|
||||
key_cache = kv_cache[0].reshape(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].reshape(-1, num_heads, head_size)
|
||||
keys.append(key_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
|
||||
values.append(value_cache[current_slot_mapping, head_start:head_end].unsqueeze(0))
|
||||
else:
|
||||
key_cache = kv_cache
|
||||
keys.append(key_cache[current_slot_mapping].unsqueeze(0))
|
||||
values.append(torch.empty(0))
|
||||
|
||||
self.insert(current_tokens,
|
||||
torch.ones_like(current_tokens,
|
||||
dtype=bool), keys, values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
keys = torch.cat(keys, dim=0)
|
||||
values = torch.cat(values, dim=0)
|
||||
|
||||
self.insert(starting_kv_group_rank, target_rank, current_tokens,
|
||||
torch.ones_like(current_tokens,
|
||||
dtype=bool), keys, values,
|
||||
hidden_or_intermediate_states[start_pos:end_pos])
|
||||
|
||||
logger.debug("[rank%d]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
@ -215,6 +256,7 @@ class SimpleConnector(KVConnectorBase):
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
request_ids = list(model_input.request_ids_to_seq_ids.keys())
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
@ -222,6 +264,9 @@ class SimpleConnector(KVConnectorBase):
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
model_config = model_executable.model.config
|
||||
is_deepseek = "deepseek" in model_config.architectures[0].lower()
|
||||
|
||||
# enumerate different requests
|
||||
# FIXME(Kuntai): This impl assumes that all requests are prefill.
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
@ -229,13 +274,15 @@ class SimpleConnector(KVConnectorBase):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
current_request_id = request_ids[idx]
|
||||
prefill_rank, _ = self.parse_request_id(current_request_id)
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
ret = self.select(current_tokens,
|
||||
ret = self.select(prefill_rank, current_tokens,
|
||||
torch.ones_like(current_tokens, dtype=bool))
|
||||
if ret[0] is None:
|
||||
# didn't find any match.
|
||||
@ -267,19 +314,25 @@ class SimpleConnector(KVConnectorBase):
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
layer = model_executable.model.layers[i]
|
||||
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
if not is_deepseek:
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
ops.reshape_and_cache_flash(
|
||||
keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device),
|
||||
values[i - model_executable.model.start_layer].to(
|
||||
value_cache.device),
|
||||
key_cache,
|
||||
value_cache,
|
||||
slot_mapping[start_pos:end_pos],
|
||||
layer.self_attn.attn.kv_cache_dtype,
|
||||
layer.self_attn.attn._k_scale,
|
||||
layer.self_attn.attn._v_scale,
|
||||
)
|
||||
else:
|
||||
key_cache = kv_cache
|
||||
copy_from =keys[i - model_executable.model.start_layer].to(
|
||||
key_cache.device)
|
||||
kv_cache[slot_mapping[start_pos:end_pos]] = copy_from
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(hidden)
|
||||
|
||||
@ -312,3 +365,77 @@ class SimpleConnector(KVConnectorBase):
|
||||
# MooncakePipe reuses data_pipe for signal_pipe, so we only have to
|
||||
# close the data_pipe.
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def parse_request_id(request_id):
|
||||
# Regular expression to match the ranks
|
||||
pattern = r"___prefill_kv_rank_(\d+)___decode_kv_rank_(\d+)"
|
||||
|
||||
# Use re.search to find the pattern in the request_id
|
||||
match = re.search(pattern, request_id)
|
||||
|
||||
if match:
|
||||
# Extract the ranks
|
||||
prefill_rank = int(match.group(1))
|
||||
decode_rank = int(match.group(2))
|
||||
|
||||
return prefill_rank, decode_rank
|
||||
else:
|
||||
return None, None
|
||||
|
||||
|
||||
|
||||
def _get_kv_group_rank(self, kv_rank: int, rank: int, config: KVTransferConfig) -> int:
|
||||
if kv_rank < config.kv_producers_parallel_size:
|
||||
return kv_rank
|
||||
|
||||
kv_consumer_rank = kv_rank - config.kv_producers_parallel_size
|
||||
return config.kv_producers_parallel_size + kv_consumer_rank * config.tensor_parallel_multiplier + rank % config.tensor_parallel_multiplier
|
||||
|
||||
def _broadcast_and_enhance_kv_config(self, rank: int, config: VllmConfig, world_group):
|
||||
if rank == 0:
|
||||
if self.config.kv_connector == "PyNcclConnector":
|
||||
config_group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port,
|
||||
rank=self.config.kv_rank,
|
||||
world_size=self.config.kv_parallel_size,
|
||||
)
|
||||
parallel_configs = config_group.all_gather_obj({
|
||||
"kv_role": self.config.kv_role,
|
||||
"tensor_parallel_size": config.parallel_config.tensor_parallel_size,
|
||||
"pipeline_parallel_size": config.parallel_config.pipeline_parallel_size,
|
||||
})
|
||||
logger.debug("parallel_configs: %s", parallel_configs)
|
||||
kv_config_enhanced = {
|
||||
"kv_producers_tensor_parallel_size": None,
|
||||
"kv_consumers_tensor_parallel_size": None,
|
||||
"kv_producers_pipeline_parallel_size": None,
|
||||
"kv_consumers_pipeline_parallel_size": None,
|
||||
"kv_producers_parallel_size": 0,
|
||||
}
|
||||
for parallel_config in parallel_configs:
|
||||
kv_role = parallel_config["kv_role"]
|
||||
assert parallel_config["pipeline_parallel_size"] == 1, f"Only pipeline parallel size 1 is supported for kv transfer instances"
|
||||
|
||||
if kv_role == "kv_producer":
|
||||
kv_config_enhanced["kv_producers_parallel_size"] += 1
|
||||
if kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] is None:
|
||||
kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] = parallel_config["tensor_parallel_size"]
|
||||
kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] = parallel_config["pipeline_parallel_size"]
|
||||
else:
|
||||
assert kv_config_enhanced[f"{kv_role}s_tensor_parallel_size"] == parallel_config["tensor_parallel_size"], f"All kv {kv_role}s should have the same tensor parallel size"
|
||||
assert kv_config_enhanced[f"{kv_role}s_pipeline_parallel_size"] == parallel_config["pipeline_parallel_size"], f"All kv {kv_role}s should have the same pipeline parallel size"
|
||||
world_group.broadcast_object(kv_config_enhanced)
|
||||
|
||||
else:
|
||||
raise NotImplementedError("MooncakeConnector is not supported in Dynamo patch")
|
||||
else:
|
||||
kv_config_enhanced = world_group.broadcast_object()
|
||||
logger.info("kv_config_enhanced: %s", kv_config_enhanced)
|
||||
|
||||
self.config.kv_producers_tensor_parallel_size = kv_config_enhanced["kv_producers_tensor_parallel_size"]
|
||||
self.config.kv_consumers_tensor_parallel_size = kv_config_enhanced["kv_consumers_tensor_parallel_size"]
|
||||
self.config.kv_producers_pipeline_parallel_size = kv_config_enhanced["kv_producers_pipeline_parallel_size"]
|
||||
self.config.kv_consumers_pipeline_parallel_size = kv_config_enhanced["kv_consumers_pipeline_parallel_size"]
|
||||
self.config.kv_producers_parallel_size = kv_config_enhanced["kv_producers_parallel_size"]
|
||||
|
@ -12,7 +12,8 @@
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Deque, List, Optional, Union
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Deque, List, Optional, Union, Dict
|
||||
|
||||
import torch
|
||||
|
||||
@ -46,7 +47,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
self.buffer_lock = threading.Lock()
|
||||
self.signal_pipe = signal_pipe
|
||||
self.data_pipe = data_pipe
|
||||
self.request_handling_thread: Optional[threading.Thread] = None
|
||||
self.request_handling_thread: Optional[ThreadPoolExecutor] = None
|
||||
|
||||
self.normal_signal = torch.tensor([0], device="cpu")
|
||||
self.end_signal = None
|
||||
@ -57,10 +58,16 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
# tokens_roi_sender: tokens and roi of the producer (in the buffer)
|
||||
# tokens_roi_recver: tokens and roi of the consumer (query)
|
||||
|
||||
tokens_sender = tokens_roi_sender[0]
|
||||
tokens_recver = tokens_roi_recver[0]
|
||||
roi_sender = tokens_roi_sender[1]
|
||||
roi_recver = tokens_roi_recver[1]
|
||||
target_rank_sender = tokens_roi_sender[0]
|
||||
target_rank_recver = tokens_roi_recver[0]
|
||||
|
||||
if target_rank_sender.item() != target_rank_recver.item():
|
||||
return 0
|
||||
|
||||
tokens_sender = tokens_roi_sender[1]
|
||||
tokens_recver = tokens_roi_recver[1]
|
||||
roi_sender = tokens_roi_sender[2]
|
||||
roi_recver = tokens_roi_recver[2]
|
||||
|
||||
if tokens_recver is None:
|
||||
# consumer sends an empty request
|
||||
@ -80,14 +87,14 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
return 0
|
||||
|
||||
def _send_tensor_and_dec_size(self,
|
||||
tensor: Optional[torch.Tensor]) -> None:
|
||||
def _send_tensor_and_dec_size(self, tensor: Optional[torch.Tensor],
|
||||
target_rank: int) -> None:
|
||||
|
||||
assert tensor is not None, "Use self.data_pipe.send(None) instead"
|
||||
self.buffer_size -= tensor.element_size() * tensor.numel()
|
||||
if tensor.dtype == torch.bool:
|
||||
tensor = tensor.float()
|
||||
self.data_pipe.send_tensor(tensor)
|
||||
self.data_pipe.send_tensor(tensor, target_rank)
|
||||
|
||||
def _get_element_size(self, data: Optional[Union[List, torch.Tensor]]):
|
||||
|
||||
@ -100,7 +107,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
|
||||
raise AssertionError(f"Unknown data type {type(data)}")
|
||||
|
||||
def _add_to_buffer(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
def _add_to_buffer(self, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor):
|
||||
|
||||
@ -115,7 +122,7 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
if isinstance(hidden, torch.Tensor):
|
||||
hidden = hidden.clone()
|
||||
|
||||
buffer_item = [input_tokens, roi, key, value, hidden]
|
||||
buffer_item = [torch.tensor(target_rank), input_tokens, roi, key, value, hidden]
|
||||
|
||||
with self.buffer_lock:
|
||||
for data in buffer_item:
|
||||
@ -125,53 +132,54 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
def _is_end_signal(self, signal):
|
||||
return signal is None
|
||||
|
||||
def drop_select_handler(self):
|
||||
def drop_select_handler(self, rank: int):
|
||||
|
||||
try:
|
||||
|
||||
while True:
|
||||
signal = self.signal_pipe.recv_tensor()
|
||||
if self._is_end_signal(signal):
|
||||
logger.info("Received end signal!")
|
||||
break
|
||||
signal = self.signal_pipe.recv_tensor(rank)
|
||||
if self._is_end_signal(signal):
|
||||
logger.info("Received end signal!")
|
||||
return
|
||||
target_kv_rank = self.data_pipe.recv_tensor(rank)
|
||||
# assert target_rank.item() == rank, "Target rank does not match"\
|
||||
# "the rank of the drop-select handler"
|
||||
input_tokens = self.data_pipe.recv_tensor(rank)
|
||||
roi = self.data_pipe.recv_tensor(rank)
|
||||
assert roi is not None, "Please provide the roi when sending "\
|
||||
"drop-select request"
|
||||
roi = (roi > 0.5)
|
||||
tokens_roi_recver = [target_kv_rank, input_tokens, roi]
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
matched_length = 0
|
||||
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
assert roi is not None, "Please provide the roi when sending "\
|
||||
"drop-select request"
|
||||
roi = (roi > 0.5)
|
||||
tokens_roi_recver = [input_tokens, roi]
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
with self.buffer_lock:
|
||||
|
||||
matched_length = 0
|
||||
for _ in range(len(self.buffer)):
|
||||
|
||||
# perform input tokens and roi matching
|
||||
# FIXME: this matching is O(n), ideally it should be O(1)
|
||||
# but this buffer size won't (and shouldn't) be too large so
|
||||
# the fix is not urgent.
|
||||
with self.buffer_lock:
|
||||
temp_length = self._matches(self.buffer[0],
|
||||
tokens_roi_recver)
|
||||
if temp_length > 0:
|
||||
matched_length = temp_length
|
||||
break
|
||||
# rotate the element we just accessed to the end
|
||||
self.buffer.rotate(-1)
|
||||
|
||||
for _ in range(len(self.buffer)):
|
||||
if matched_length > 0:
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
target_rank = matched_item[0].item()
|
||||
for tensor in matched_item[1:]:
|
||||
self._send_tensor_and_dec_size(tensor, rank)
|
||||
|
||||
temp_length = self._matches(self.buffer[0],
|
||||
tokens_roi_recver)
|
||||
if temp_length > 0:
|
||||
matched_length = temp_length
|
||||
break
|
||||
# rotate the element we just accessed to the end
|
||||
self.buffer.rotate(-1)
|
||||
|
||||
if matched_length > 0:
|
||||
# need to clone the tensor
|
||||
# in case the tensor is freed before sending finishes
|
||||
matched_item = self.buffer.popleft()
|
||||
for tensor in matched_item:
|
||||
self._send_tensor_and_dec_size(tensor)
|
||||
|
||||
else:
|
||||
# no match, just send None
|
||||
for _ in range(5):
|
||||
self.data_pipe.send_tensor(None)
|
||||
else:
|
||||
# no match, just send None
|
||||
for _ in range(5):
|
||||
self.data_pipe.send_tensor(None, rank)
|
||||
|
||||
except RuntimeError as e:
|
||||
if 'Connection closed by peer' not in str(e):
|
||||
@ -180,10 +188,10 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
logger.debug("Closing drop_select_handler")
|
||||
|
||||
def drop_select(
|
||||
self, input_tokens: Optional[torch.Tensor],
|
||||
self, rank: int, kv_rank: int, input_tokens: Optional[torch.Tensor],
|
||||
roi: Optional[torch.Tensor]) -> List[Optional[torch.Tensor]]:
|
||||
|
||||
assert self.request_handling_thread is None, \
|
||||
assert not self.request_handling_thread, \
|
||||
"drop_select should be called by the KV cache consumer "\
|
||||
"(e.g. the decode vLLM instance)"
|
||||
|
||||
@ -192,26 +200,28 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
if isinstance(roi, torch.Tensor):
|
||||
roi = roi.clone().float()
|
||||
|
||||
self.signal_pipe.send_tensor(self.normal_signal)
|
||||
self.data_pipe.send_tensor(input_tokens)
|
||||
self.data_pipe.send_tensor(roi)
|
||||
self.signal_pipe.send_tensor(self.normal_signal, rank)
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor()
|
||||
roi = self.data_pipe.recv_tensor()
|
||||
self.data_pipe.send_tensor(torch.tensor(kv_rank), rank)
|
||||
self.data_pipe.send_tensor(input_tokens, rank)
|
||||
self.data_pipe.send_tensor(roi, rank)
|
||||
|
||||
input_tokens = self.data_pipe.recv_tensor(rank)
|
||||
roi = self.data_pipe.recv_tensor(rank)
|
||||
if roi is not None:
|
||||
# convert from float tensor to bool tensor
|
||||
# as PyNccl does not support sending bool tensor
|
||||
roi = (roi > 0.5)
|
||||
key = self.data_pipe.recv_tensor()
|
||||
value = self.data_pipe.recv_tensor()
|
||||
hidden = self.data_pipe.recv_tensor()
|
||||
key = self.data_pipe.recv_tensor(rank)
|
||||
value = self.data_pipe.recv_tensor(rank)
|
||||
hidden = self.data_pipe.recv_tensor(rank)
|
||||
|
||||
return [input_tokens, roi, key, value, hidden]
|
||||
|
||||
def full_handler(self):
|
||||
time.sleep(0.001)
|
||||
|
||||
def insert(self, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
def insert(self, kv_group_rank: int, target_rank: int, input_tokens: torch.Tensor, roi: torch.Tensor,
|
||||
key: torch.Tensor, value: torch.Tensor,
|
||||
hidden: torch.Tensor) -> None:
|
||||
|
||||
@ -222,20 +232,19 @@ class SimpleBuffer(KVLookupBufferBase):
|
||||
while self.buffer_size > self.buffer_size_threshold:
|
||||
self.full_handler()
|
||||
|
||||
self._add_to_buffer(input_tokens, roi, key, value, hidden)
|
||||
self._add_to_buffer(target_rank, input_tokens, roi, key, value, hidden)
|
||||
|
||||
# when calling the insert, the current process is a sender
|
||||
# need to launch the request handler and start listening to request.
|
||||
target_rank_global = target_rank + kv_group_rank
|
||||
if self.request_handling_thread is None:
|
||||
self.request_handling_thread = threading.Thread(
|
||||
target=self.drop_select_handler)
|
||||
self.request_handling_thread.start()
|
||||
self.request_handling_thread = ThreadPoolExecutor(max_workers=1)
|
||||
self.request_handling_thread.submit(self.drop_select_handler, target_rank_global)
|
||||
|
||||
def close(self):
|
||||
|
||||
if hasattr(self, "request_handling_thread"
|
||||
) and self.request_handling_thread is not None:
|
||||
self.request_handling_thread.join()
|
||||
if hasattr(self, "request_handling_thread") and self.request_handling_thread:
|
||||
self.request_handling_thread.shutdown()
|
||||
|
||||
else:
|
||||
# TODO: have a explicit close signal and have a explicit way to
|
||||
|
@ -23,7 +23,7 @@ class KVPipeBase(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int = 0) -> None:
|
||||
"""Send a tensor, or None, via the pipe.
|
||||
|
||||
Need to support sending None -- important for error handling.
|
||||
@ -41,7 +41,7 @@ class KVPipeBase(ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
|
||||
"""Receive a tensor (can be None) from the pipeline.
|
||||
|
||||
Returns:
|
||||
|
124
vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py
Normal file
124
vllm/distributed/kv_transfer/kv_pipe/dynamo_nccl_pipe.py
Normal file
@ -0,0 +1,124 @@
|
||||
import logging
|
||||
import threading
|
||||
import typing
|
||||
import zmq
|
||||
import socket
|
||||
import time
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_transfer.kv_pipe.pynccl_pipe import PyNcclPipe
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DynamoNcclDataPlane:
|
||||
def __init__(
|
||||
self,
|
||||
data_pipe: PyNcclPipe,
|
||||
hostname: str = "",
|
||||
port: int = 0,
|
||||
) -> None:
|
||||
|
||||
self.data_pipe = data_pipe
|
||||
if not hostname:
|
||||
hostname = socket.gethostname()
|
||||
if port == 0:
|
||||
raise ValueError("Port cannot be 0")
|
||||
self._hostname = hostname
|
||||
self._port = port
|
||||
self.store = {}
|
||||
self.context = zmq.Context()
|
||||
self.rep_socket = self.context.socket(zmq.REP)
|
||||
logger.info(f"Rank {self.rank} binding to {self._hostname}:{self._port}")
|
||||
self.rep_socket.bind(f"tcp://{self._hostname}:{self._port}")
|
||||
self._listener_thread = threading.Thread(target=self.listen_for_requests, daemon=True)
|
||||
self._listener_thread.start()
|
||||
self.req_sockets = {}
|
||||
logger.info(f"Rank {self.rank} connected to the server")
|
||||
|
||||
@property
|
||||
def rank(self):
|
||||
return self.data_pipe.kv_group_rank
|
||||
|
||||
def send_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
):
|
||||
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to {remote_address}")
|
||||
return self._send_tensor(tensor, tensor_id, remote_address)
|
||||
|
||||
def recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
ret = self._recv_tensor(tensor_id, remote_address)
|
||||
return ret
|
||||
|
||||
def _send_tensor(
|
||||
self,
|
||||
tensor: torch.Tensor,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
):
|
||||
logger.debug(f"Rank {self.rank} storing tensor with id {tensor_id} of shape {tensor.shape} and dtype {tensor.dtype}")
|
||||
if remote_address is None:
|
||||
self.store[tensor_id] = tensor
|
||||
else:
|
||||
# tensor_shape = "_".join(str(dim) for dim in tensor.shape)
|
||||
# tensor_dtype = str(tensor.dtype)
|
||||
if remote_address not in self.req_sockets:
|
||||
self.req_sockets[remote_address] = self.context.socket(zmq.REQ)
|
||||
self.req_sockets[remote_address].connect(f"tcp://{remote_address}")
|
||||
|
||||
req_socket = self.req_sockets[remote_address]
|
||||
# req_socket.connect(f"tcp://{remote_address}")
|
||||
req_socket.send_string(f"PUT {self.rank} {tensor_id}")
|
||||
dst_rank = req_socket.recv_string()
|
||||
logger.debug(f"Rank {self.rank} sending tensor {tensor_id} to rank {dst_rank}")
|
||||
self.data_pipe.send_tensor(tensor, int(dst_rank))
|
||||
|
||||
def _recv_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
remote_address: typing.Optional[str] = None,
|
||||
) -> torch.Tensor:
|
||||
logger.debug(f"Rank {self.rank} receiving tensor")
|
||||
if remote_address is not None:
|
||||
raise NotImplementedError("Getting tensor from remote rank not implemented")
|
||||
if tensor_id in self.store:
|
||||
logger.debug(f"Popping tensor {tensor_id} from store")
|
||||
future = self.store.pop(tensor_id)
|
||||
tensor = future.result() # TODO ptarasiewicz we should run other request instead of wait
|
||||
logger.debug(f"Rank {self.rank} received tensor")
|
||||
return tensor
|
||||
|
||||
logger.debug(f"Rank {self.rank} waiting for tensor {tensor_id}")
|
||||
time.sleep(0.001)
|
||||
return self._recv_tensor(tensor_id, remote_address)
|
||||
# raise NotImplementedError("Tensor not found in store")
|
||||
|
||||
def _receive_tensor(
|
||||
self,
|
||||
tensor_id: str,
|
||||
rank: int,
|
||||
):
|
||||
future = self.data_pipe.recv_tensor(rank)
|
||||
logger.debug(f"Rank {self.rank} storing tensor {tensor_id} in store")
|
||||
self.store[tensor_id] = future
|
||||
|
||||
def listen_for_requests(self):
|
||||
while True:
|
||||
cmd, rank, tensor_id = self.rep_socket.recv_string().split()
|
||||
logger.debug(f"Rank {self.rank} received request for tensor {tensor_id}")
|
||||
self.rep_socket.send_string(f"{self.rank}")
|
||||
if cmd == "GET":
|
||||
raise NotImplementedError("Getting tensor from remote rank not implemented")
|
||||
elif cmd == "PUT":
|
||||
rank = int(rank)
|
||||
# shape = [int(dim) for dim in shape.split("_")]
|
||||
# dtype = getattr(torch, dtype)
|
||||
self._receive_tensor(tensor_id, rank)
|
@ -45,33 +45,33 @@ class PyNcclPipe(KVPipeBase):
|
||||
METADATA_DTYPE = torch.int64
|
||||
|
||||
def __init__(self,
|
||||
kv_group_rank: int,
|
||||
local_rank: int,
|
||||
config: KVTransferConfig,
|
||||
device: Optional[str] = None,
|
||||
port_offset: int = 0):
|
||||
self.config = config
|
||||
self.local_rank = local_rank
|
||||
self.kv_rank = self.config.kv_rank
|
||||
self.kv_group_rank = kv_group_rank
|
||||
self.kv_parallel_size = self.config.kv_parallel_size
|
||||
self.kv_world_size = self.config.kv_world_size
|
||||
if device is None:
|
||||
self.device = self._select_device(self.config.kv_buffer_device)
|
||||
else:
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# build distributed connection and send/recv implementation
|
||||
logger.info("Creating process group for kv transfer with rank %d and world size %d, ip: %s, port: %d", self.kv_group_rank, self.kv_world_size, self.config.kv_ip, self.config.kv_port + port_offset)
|
||||
self.group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port + port_offset,
|
||||
rank=self.kv_rank,
|
||||
world_size=self.kv_parallel_size,
|
||||
rank=self.kv_group_rank,
|
||||
world_size=self.kv_world_size,
|
||||
)
|
||||
# add a barrier to make sure the connection is initiated properly
|
||||
self.group.barrier()
|
||||
impl = self._get_device_send_recv_impl(self.group)
|
||||
self.device_send_func, self.device_recv_func = impl
|
||||
# set target rank
|
||||
self.target_rank_for_send = (self.kv_rank + 1) % self.kv_parallel_size
|
||||
self.target_rank_for_recv = (self.kv_rank - 1) % self.kv_parallel_size
|
||||
|
||||
# transportation-related variables
|
||||
self.transport_thread: Optional[ThreadPoolExecutor] = None
|
||||
@ -145,16 +145,16 @@ class PyNcclPipe(KVPipeBase):
|
||||
dtype=metadata["dtype"],
|
||||
device=self.device)
|
||||
|
||||
def _send_metadata(self, metadata: Metadata):
|
||||
def _send_metadata(self, metadata: Metadata, target_rank: int):
|
||||
"""
|
||||
Send the metadata dictionary to the target rank.
|
||||
|
||||
Parameters:
|
||||
- metadata: A dictionary with keys "dtype" and "shape".
|
||||
"""
|
||||
self.group.send_obj(metadata, self.target_rank_for_send)
|
||||
self.group.send_obj(metadata, target_rank)
|
||||
|
||||
def _recv_metadata(self) -> Metadata:
|
||||
def _recv_metadata(self, src_rank: int) -> Metadata:
|
||||
"""
|
||||
Receive the metadata dictionary from the target rank.
|
||||
|
||||
@ -162,9 +162,9 @@ class PyNcclPipe(KVPipeBase):
|
||||
- metadata: A dictionary with keys "dtype" and "shape" describing
|
||||
the tensor.
|
||||
"""
|
||||
return self.group.recv_obj(self.target_rank_for_recv)
|
||||
return self.group.recv_obj(src_rank)
|
||||
|
||||
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
def _send_impl(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
|
||||
"""
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
target rank.
|
||||
@ -174,12 +174,12 @@ class PyNcclPipe(KVPipeBase):
|
||||
being sent.
|
||||
"""
|
||||
metadata = self._make_metadata(tensor)
|
||||
self._send_metadata(metadata)
|
||||
self._send_metadata(metadata, target_rank)
|
||||
if tensor is not None:
|
||||
self.device_send_func(tensor.to(self.device),
|
||||
self.target_rank_for_send)
|
||||
target_rank)
|
||||
|
||||
def _recv_impl(self) -> Optional[torch.Tensor]:
|
||||
def _recv_impl(self, src_rank: int) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
the target rank.
|
||||
@ -187,21 +187,22 @@ class PyNcclPipe(KVPipeBase):
|
||||
Returns:
|
||||
- buffer: The received tensor, or None if no tensor is received.
|
||||
"""
|
||||
metadata = self._recv_metadata()
|
||||
metadata = self._recv_metadata(src_rank)
|
||||
if metadata["dtype"] is None:
|
||||
return None
|
||||
buffer = self._prepare_recv_buffer(metadata)
|
||||
self.device_recv_func(buffer, self.target_rank_for_recv)
|
||||
self.device_recv_func(buffer, src_rank)
|
||||
|
||||
return buffer
|
||||
|
||||
def send_tensor_wrapper(self, tensor: Optional[torch.Tensor],
|
||||
tensor_size: int) -> None:
|
||||
tensor_size: int,
|
||||
target_rank: int) -> None:
|
||||
"""
|
||||
Wrapper for _send_impl to handle exceptions and update buffer size.
|
||||
"""
|
||||
try:
|
||||
self._send_impl(tensor)
|
||||
self._send_impl(tensor, target_rank)
|
||||
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size -= tensor_size
|
||||
@ -220,7 +221,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
logger.debug("KV cache transfer pipe is full. Waiting...")
|
||||
time.sleep(0.05)
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor], target_rank: int) -> None:
|
||||
"""
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
non-blocking way.
|
||||
@ -228,6 +229,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
Parameters:
|
||||
- tensor: The tensor to send, or None if no tensor is being sent.
|
||||
"""
|
||||
logger.debug("Rank %d sending tensor of shape %s dtype %s to rank %d", self.kv_group_rank, tensor.shape if tensor is not None else "None", tensor.dtype if tensor is not None else "None", target_rank)
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
@ -241,32 +243,39 @@ class PyNcclPipe(KVPipeBase):
|
||||
with self.buffer_size_lock:
|
||||
self.buffer_size += tensor_size
|
||||
|
||||
self.transport_thread.submit(self.send_tensor_wrapper, tensor,
|
||||
tensor_size)
|
||||
future = self.transport_thread.submit(self.send_tensor_wrapper, tensor,
|
||||
tensor_size,
|
||||
target_rank)
|
||||
return future
|
||||
|
||||
def recv_tensor(self) -> Optional[torch.Tensor]:
|
||||
def recv_tensor(self, src_rank: int) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
Receives a tensor and its metadata from the source rank. Blocking call.
|
||||
|
||||
Returns:
|
||||
- tensor: The received tensor, or None if no tensor is received.
|
||||
"""
|
||||
|
||||
logger.debug("Rank %d receiving tensor from rank %d", self.kv_group_rank, src_rank)
|
||||
|
||||
if self.transport_thread is None:
|
||||
self.transport_thread = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
future = self.transport_thread.submit(self._recv_impl)
|
||||
future = self.transport_thread.submit(self._recv_impl, src_rank)
|
||||
|
||||
try:
|
||||
tensor = future.result()
|
||||
except Exception as e:
|
||||
logger.error("Encountering exception in KV receiving thread")
|
||||
logger.error("%s", e)
|
||||
logger.error("My device: %s", self.device)
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
return future
|
||||
|
||||
return tensor
|
||||
# try:
|
||||
# tensor = future.result()
|
||||
# except Exception as e:
|
||||
# logger.error("Encountering exception in KV receiving thread")
|
||||
# logger.error("%s", e)
|
||||
# logger.error("My device: %s", self.device)
|
||||
# import traceback
|
||||
# traceback.print_exc()
|
||||
# raise e
|
||||
|
||||
# return tensor
|
||||
|
||||
def close(self):
|
||||
"""
|
||||
|
@ -35,6 +35,7 @@ class KVTransferAgent:
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: "VllmConfig",
|
||||
world_group,
|
||||
):
|
||||
|
||||
self.config = config
|
||||
@ -47,7 +48,7 @@ class KVTransferAgent:
|
||||
"TransferAgent should only be used when kv_connector is set."
|
||||
|
||||
self.connector = KVConnectorFactory.create_connector(
|
||||
rank, local_rank, config)
|
||||
rank, local_rank, config, world_group)
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self,
|
||||
|
@ -1085,7 +1085,8 @@ def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
|
||||
_KV_TRANSFER = kv_transfer.KVTransferAgent(
|
||||
rank=get_world_group().rank,
|
||||
local_rank=get_world_group().local_rank,
|
||||
config=vllm_config)
|
||||
config=vllm_config,
|
||||
world_group=get_world_group())
|
||||
|
||||
|
||||
def ensure_model_parallel_initialized(
|
||||
|
@ -2,13 +2,17 @@
|
||||
|
||||
import copy
|
||||
import time
|
||||
import pickle
|
||||
import uuid
|
||||
from collections import Counter as collectionsCounter
|
||||
from collections import deque
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import (TYPE_CHECKING, Callable, ClassVar, Deque, Dict, Iterable,
|
||||
List, Mapping, NamedTuple, Optional)
|
||||
List, Mapping, NamedTuple, Optional, Tuple)
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Set, Type, Union, cast, overload
|
||||
|
||||
@ -60,6 +64,9 @@ from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
|
||||
usage_message)
|
||||
from vllm.utils import Counter, Device, deprecate_kwargs, weak_bind
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
from vllm.remote_prefill import RemotePrefillRequest, RemotePrefillParams, MemoryTransferRequest, MemoryOpType
|
||||
from vllm.distributed.device_communicators.nixl import NixlMetadata
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
_LOCAL_LOGGING_INTERVAL_SEC = 5
|
||||
@ -90,7 +97,7 @@ class OutputData(NamedTuple):
|
||||
# outputs from multiple steps.
|
||||
is_first_step_output: Optional[bool]
|
||||
skip: List[int]
|
||||
|
||||
remote_prefill_requests: Optional[List[RemotePrefillRequest]]
|
||||
|
||||
class SchedulerContext:
|
||||
|
||||
@ -104,11 +111,14 @@ class SchedulerContext:
|
||||
|
||||
self.multi_step_stream_outputs: bool = multi_step_stream_outputs
|
||||
|
||||
self.remote_prefill_requests: List[RemotePrefillRequest] = []
|
||||
|
||||
def append_output(self, outputs: List[SamplerOutput],
|
||||
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||
scheduler_outputs: SchedulerOutputs, is_async: bool,
|
||||
is_last_step: bool,
|
||||
is_first_step_output: Optional[bool]):
|
||||
is_first_step_output: Optional[bool],
|
||||
remote_prefill_requests: Optional[List[RemotePrefillRequest]] = None):
|
||||
self.output_queue.append(
|
||||
OutputData(outputs=outputs,
|
||||
seq_group_metadata_list=seq_group_metadata_list,
|
||||
@ -116,7 +126,9 @@ class SchedulerContext:
|
||||
is_async=is_async,
|
||||
is_last_step=is_last_step,
|
||||
is_first_step_output=is_first_step_output,
|
||||
skip=[]))
|
||||
skip=[],
|
||||
remote_prefill_requests=remote_prefill_requests))
|
||||
|
||||
|
||||
|
||||
class LLMEngine:
|
||||
@ -348,7 +360,7 @@ class LLMEngine:
|
||||
# GPU and CPU blocks, which are profiled in the distributed executor.
|
||||
self.scheduler = [
|
||||
Scheduler(
|
||||
self.scheduler_config, self.cache_config, self.lora_config,
|
||||
self.model_config, self.scheduler_config, self.cache_config, self.lora_config,
|
||||
self.parallel_config.pipeline_parallel_size,
|
||||
self.async_callbacks[v_id]
|
||||
if self.model_config.use_async_output_proc else None)
|
||||
@ -405,6 +417,40 @@ class LLMEngine:
|
||||
|
||||
self.seq_id_to_seq_group: Dict[str, SequenceGroupBase] = {}
|
||||
|
||||
self.engine_id = str(uuid.uuid4())
|
||||
self._nixl_agents_names: Optional[List[str]] = None
|
||||
if self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector":
|
||||
self._nixl_agents_names = self._initialize_nixl()
|
||||
|
||||
self._request_notif_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
|
||||
self._request_done_counter = defaultdict(lambda: -self.parallel_config.tensor_parallel_size)
|
||||
self._finished_prefills = set()
|
||||
self._finished_transfers = set()
|
||||
|
||||
@property
|
||||
def is_nixl_initialized(self) -> bool:
|
||||
return getattr(self, "_nixl_agents_names", None) is not None
|
||||
|
||||
def get_nixl_metadata(self) -> NixlMetadata:
|
||||
if not self.is_nixl_initialized:
|
||||
raise RuntimeError("Nixl is not initialized")
|
||||
agent_metadata = self.model_executor.collective_rpc("get_nixl_agent_metadata")
|
||||
kv_caches_base_addr = self.model_executor.collective_rpc("get_nixl_kv_caches_base_addr")
|
||||
return NixlMetadata(engine_id=self.engine_id, agent_metadata=agent_metadata, kv_caches_base_addr=kv_caches_base_addr, num_blocks=self.cache_config.num_gpu_blocks)
|
||||
|
||||
def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata) -> List[str]:
|
||||
if not self.is_nixl_initialized:
|
||||
raise RuntimeError("Nixl is not initialized")
|
||||
engine_id = nixl_metadata.engine_id
|
||||
agents_metadata = nixl_metadata.agent_metadata
|
||||
kv_caches_base_addr = nixl_metadata.kv_caches_base_addr
|
||||
num_blocks = nixl_metadata.num_blocks
|
||||
return self.model_executor.collective_rpc("add_remote_nixl_metadata", args=(engine_id, agents_metadata, kv_caches_base_addr, num_blocks))
|
||||
|
||||
def _initialize_nixl(self) -> List[bytes]:
|
||||
agents_names = self.model_executor.collective_rpc("initialize_nixl", args=(self.engine_id,))
|
||||
return agents_names
|
||||
|
||||
def _initialize_kv_caches(self) -> None:
|
||||
"""Initialize the KV cache in the worker(s).
|
||||
|
||||
@ -500,6 +546,8 @@ class LLMEngine:
|
||||
# Shutdown model executor when engine is garbage collected
|
||||
# Use getattr since __init__ can fail before the field is set
|
||||
if model_executor := getattr(self, "model_executor", None):
|
||||
if self.is_nixl_initialized:
|
||||
model_executor.collective_rpc("shutdown_nixl")
|
||||
model_executor.shutdown()
|
||||
|
||||
def get_tokenizer_group(
|
||||
@ -552,11 +600,14 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest],
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
) -> Optional[SequenceGroup]:
|
||||
"""Add a processed request to the engine's request pool.
|
||||
return the created sequence group.
|
||||
"""
|
||||
if isinstance(params, SamplingParams) and params.n > 1:
|
||||
if remote_prefill_params is not None:
|
||||
raise ValueError("Remote prefill params are not supported for multi-step sampling")
|
||||
ParallelSampleSequenceGroup.add_request(
|
||||
request_id,
|
||||
self,
|
||||
@ -574,6 +625,8 @@ class LLMEngine:
|
||||
# Create the sequences.
|
||||
block_size = self.cache_config.block_size
|
||||
seq_id = next(self.seq_counter)
|
||||
if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
|
||||
next(self.seq_counter) # empty sequence for staging
|
||||
eos_token_id = self.input_preprocessor.get_eos_token_id(lora_request)
|
||||
|
||||
if is_encoder_decoder_inputs(processed_inputs):
|
||||
@ -584,7 +637,7 @@ class LLMEngine:
|
||||
encoder_inputs = None
|
||||
|
||||
seq = Sequence(seq_id, decoder_inputs, block_size, eos_token_id,
|
||||
lora_request, prompt_adapter_request)
|
||||
lora_request, prompt_adapter_request, remote_prefill_params)
|
||||
|
||||
encoder_seq = (None if encoder_inputs is None else Sequence(
|
||||
seq_id, encoder_inputs, block_size, eos_token_id, lora_request,
|
||||
@ -601,8 +654,12 @@ class LLMEngine:
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
priority=priority,
|
||||
remote_prefill_params=remote_prefill_params,
|
||||
)
|
||||
elif isinstance(params, PoolingParams):
|
||||
if remote_prefill_params is not None:
|
||||
raise ValueError("Remote prefill params are not supported for pooling")
|
||||
seq_group = self._create_sequence_group_with_pooling(
|
||||
request_id,
|
||||
seq,
|
||||
@ -673,6 +730,7 @@ class LLMEngine:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
@ -765,6 +823,7 @@ class LLMEngine:
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
trace_headers=trace_headers,
|
||||
priority=priority,
|
||||
remote_prefill_params=remote_prefill_params,
|
||||
)
|
||||
|
||||
def _validate_token_prompt(self, prompt: PromptType,
|
||||
@ -799,6 +858,7 @@ class LLMEngine:
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
encoder_seq: Optional[Sequence] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
) -> SequenceGroup:
|
||||
"""Creates a SequenceGroup with SamplingParams."""
|
||||
max_logprobs = self.get_model_config().max_logprobs
|
||||
@ -829,7 +889,9 @@ class LLMEngine:
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
encoder_seq=encoder_seq,
|
||||
priority=priority)
|
||||
priority=priority,
|
||||
remote_prefill_params=remote_prefill_params
|
||||
)
|
||||
|
||||
return seq_group
|
||||
|
||||
@ -995,11 +1057,11 @@ class LLMEngine:
|
||||
# When we process only one request, no pop is required
|
||||
# (since later we will process all of the rest)
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step, is_first_step_output, skip) = ctx.output_queue[0]
|
||||
is_last_step, is_first_step_output, skip, remote_prefill_requests) = ctx.output_queue[0]
|
||||
else:
|
||||
(outputs, seq_group_metadata_list, scheduler_outputs, is_async,
|
||||
is_last_step, is_first_step_output,
|
||||
skip) = ctx.output_queue.popleft()
|
||||
skip, remote_prefill_requests) = ctx.output_queue.popleft()
|
||||
|
||||
# Sanity check
|
||||
assert len(seq_group_metadata_list) == len(
|
||||
@ -1325,15 +1387,55 @@ class LLMEngine:
|
||||
|
||||
# Clear outputs for each new scheduler iteration
|
||||
ctx.request_outputs.clear()
|
||||
ctx.remote_prefill_requests.clear()
|
||||
|
||||
# Skip the scheduler if there are any remaining steps in the seq groups.
|
||||
# This ensures that the scheduler is only called again when the current
|
||||
# batch has completed.
|
||||
remote_prefill_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
running_seq_group_metadata_list: List[SequenceGroupMetadata] = []
|
||||
remote_prefill_scheduled_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
running_scheduled_seq_groups: List[ScheduledSequenceGroup] = []
|
||||
|
||||
if not self._has_remaining_steps(seq_group_metadata_list):
|
||||
# Schedule iteration
|
||||
|
||||
(seq_group_metadata_list, scheduler_outputs,
|
||||
allow_async_output_proc
|
||||
) = self.scheduler[virtual_engine].schedule()
|
||||
) = self.scheduler[virtual_engine].schedule(self._finished_prefills, self._finished_transfers)
|
||||
|
||||
|
||||
# Separate remote prefill and running seq groups
|
||||
for seq_group_metadata, scheduled_seq_group in zip(seq_group_metadata_list, scheduler_outputs.scheduled_seq_groups):
|
||||
if seq_group_metadata.do_remote_prefill:
|
||||
remote_prefill_seq_group_metadata_list.append(seq_group_metadata)
|
||||
remote_prefill_scheduled_seq_groups.append(scheduled_seq_group)
|
||||
else:
|
||||
running_seq_group_metadata_list.append(seq_group_metadata)
|
||||
running_scheduled_seq_groups.append(scheduled_seq_group)
|
||||
|
||||
seq_group_metadata_list = running_seq_group_metadata_list
|
||||
scheduler_outputs.scheduled_seq_groups = running_scheduled_seq_groups
|
||||
|
||||
# Send remote prefill requests before model execution
|
||||
for seq_group_metadata, scheduled_seq_group in zip(remote_prefill_seq_group_metadata_list, remote_prefill_scheduled_seq_groups):
|
||||
assert len(scheduled_seq_group.seq_group.seqs) == 1
|
||||
assert self._nixl_agents_names
|
||||
seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
if len(block_table) == len(seq_group_metadata.computed_block_nums):
|
||||
logger.debug("No blocks to prefill")
|
||||
self._finished_prefills.add(seq_group_metadata.request_id)
|
||||
continue
|
||||
remote_prefill_request = RemotePrefillRequest(
|
||||
request_id=seq_group_metadata.request_id,
|
||||
# prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids[:-1], # last one will be decoded on decode for sampling anyway
|
||||
prompt_token_ids=scheduled_seq_group.seq_group.seqs[0].inputs.prompt_token_ids, # TODO ptarasiewicz do not send the last token when NIXL fixes send notif (needed for writing 0 blocks)
|
||||
sampling_params=scheduled_seq_group.seq_group.sampling_params,
|
||||
block_ids=block_table,
|
||||
engine_id=self.engine_id,
|
||||
computed_block_ids=seq_group_metadata.computed_block_nums,
|
||||
)
|
||||
scheduled_seq_group.seq_group.remote_prefill_params.remote_prefill_request_callback(remote_prefill_request)
|
||||
|
||||
ctx.seq_group_metadata_list = seq_group_metadata_list
|
||||
ctx.scheduler_outputs = scheduler_outputs
|
||||
@ -1383,9 +1485,46 @@ class LLMEngine:
|
||||
execute_model_req.async_callback = self.async_callbacks[
|
||||
virtual_engine]
|
||||
|
||||
outputs = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
# After model execution, we need to transfer the memory from the prefill to the decode
|
||||
memory_transfer_reqs = []
|
||||
for scheduled_seq_group, seq_group_metadata in zip(scheduler_outputs.scheduled_seq_groups, seq_group_metadata_list):
|
||||
remote_prefill_params = scheduled_seq_group.seq_group.remote_prefill_params
|
||||
if remote_prefill_params is not None and remote_prefill_params.is_remote_decode:
|
||||
assert len(scheduled_seq_group.seq_group.seqs) == 1
|
||||
req_id = scheduled_seq_group.seq_group.request_id
|
||||
seq_id = scheduled_seq_group.seq_group.seqs[0].seq_id
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
staging_block_ids = seq_group_metadata.block_tables[seq_id + 1]
|
||||
|
||||
num_computed_blocks = len(seq_group_metadata.computed_block_nums)
|
||||
computed_decode_block_ids = remote_prefill_params.decode_block_ids[:num_computed_blocks]
|
||||
|
||||
if computed_decode_block_ids:
|
||||
kv_recv_req = MemoryTransferRequest(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_table[:num_computed_blocks],
|
||||
staging_block_ids=staging_block_ids[:num_computed_blocks],
|
||||
remote_block_ids=computed_decode_block_ids,
|
||||
remote_engine_id=remote_prefill_params.decode_engine_id,
|
||||
notify_msg=req_id,
|
||||
op_type=MemoryOpType.READ
|
||||
)
|
||||
memory_transfer_reqs.append(kv_recv_req)
|
||||
|
||||
kv_send_req = MemoryTransferRequest(
|
||||
request_id=req_id,
|
||||
local_block_ids=block_table[num_computed_blocks:],
|
||||
staging_block_ids=staging_block_ids[num_computed_blocks:],
|
||||
remote_block_ids=remote_prefill_params.decode_block_ids[num_computed_blocks:],
|
||||
remote_engine_id=remote_prefill_params.decode_engine_id,
|
||||
notify_msg=req_id,
|
||||
op_type=MemoryOpType.WRITE
|
||||
)
|
||||
memory_transfer_reqs.append(kv_send_req)
|
||||
execute_model_req.memory_transfer_requests = memory_transfer_reqs
|
||||
|
||||
outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
# We need to do this here so that last step's sampled_token_ids can
|
||||
# be passed to the next iteration for PP.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
@ -1396,7 +1535,26 @@ class LLMEngine:
|
||||
if len(ctx.output_queue) > 0:
|
||||
self._process_model_outputs(ctx=ctx)
|
||||
# No outputs in this case
|
||||
outputs = []
|
||||
execute_model_req = ExecuteModelRequest(
|
||||
seq_group_metadata_list=[],
|
||||
blocks_to_swap_in=[],
|
||||
blocks_to_swap_out=[],
|
||||
blocks_to_copy=[])
|
||||
|
||||
outputs, request_notif_counter, request_done_counter = self.model_executor.execute_model(
|
||||
execute_model_req=execute_model_req)
|
||||
|
||||
for req_id, notif_count in request_notif_counter.items():
|
||||
self._request_notif_counter[req_id] += notif_count
|
||||
if self._request_notif_counter[req_id] > -1:
|
||||
self._finished_prefills.add(req_id)
|
||||
del self._request_notif_counter[req_id]
|
||||
|
||||
for req_id, done_count in request_done_counter.items():
|
||||
self._request_done_counter[req_id] += done_count
|
||||
if self._request_done_counter[req_id] > -1:
|
||||
self._finished_transfers.add(req_id)
|
||||
del self._request_done_counter[req_id]
|
||||
|
||||
# Finish the current step for all the sequence groups.
|
||||
if self.scheduler_config.is_multi_step:
|
||||
@ -1456,7 +1614,7 @@ class LLMEngine:
|
||||
# queued control plane messages, such as add/remove lora adapters.
|
||||
logger.debug("Stopping remote worker execution loop.")
|
||||
self.model_executor.stop_remote_worker_execution_loop()
|
||||
|
||||
|
||||
return ctx.request_outputs
|
||||
|
||||
def _has_remaining_steps(
|
||||
|
@ -14,13 +14,17 @@ from vllm.outputs import RequestOutput
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import deprecate_kwargs
|
||||
|
||||
from vllm.remote_prefill import RemotePrefillParams
|
||||
from vllm.distributed.device_communicators.nixl import NixlMetadata
|
||||
VLLM_RPC_SUCCESS_STR = "SUCCESS"
|
||||
|
||||
IPC_INPUT_EXT = "_input_socket"
|
||||
IPC_OUTPUT_EXT = "_output_socket"
|
||||
IPC_HEALTH_EXT = "_health_socket"
|
||||
IPC_DATA_EXT = "_data_socket"
|
||||
IPC_REMOTE_PREFILL_REQUEST_EXT = "_remote_prefill_request_socket"
|
||||
IPC_REMOTE_NIXL_METADATA_EXT = "_remote_nixl_metadata_socket"
|
||||
IPC_METRICS_EXT = "_metrics_socket"
|
||||
|
||||
|
||||
class MQEngineDeadError(RuntimeError):
|
||||
@ -36,6 +40,7 @@ class RPCProcessRequest:
|
||||
trace_headers: Optional[Mapping[str, str]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
priority: int = 0
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
@ -78,6 +83,7 @@ class RPCProcessRequest:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None, # DEPRECATED
|
||||
) -> None:
|
||||
@ -95,7 +101,7 @@ class RPCProcessRequest:
|
||||
self.trace_headers = trace_headers
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
self.priority = priority
|
||||
|
||||
self.remote_prefill_params = remote_prefill_params
|
||||
|
||||
@dataclass
|
||||
class RPCError:
|
||||
@ -116,7 +122,7 @@ class RPCStartupRequest(Enum):
|
||||
@dataclass
|
||||
class RPCStartupResponse:
|
||||
tracing_enabled: bool
|
||||
|
||||
nixl_metadata: Optional[bytes] = None
|
||||
|
||||
class RPCUProfileRequest(Enum):
|
||||
START_PROFILE = 1
|
||||
@ -157,3 +163,13 @@ def ENGINE_DEAD_ERROR(
|
||||
return MQEngineDeadError(
|
||||
"Engine loop is not running. Inspect the stacktrace to "
|
||||
f"find the original error: {repr(error)}.")
|
||||
|
||||
@dataclass
|
||||
class KvMetrics:
|
||||
request_active_slots: int
|
||||
request_total_slots: int
|
||||
kv_active_blocks: int
|
||||
kv_total_blocks: int
|
||||
num_requests_waiting: int
|
||||
gpu_cache_usage_perc: float
|
||||
gpu_prefix_cache_hit_rate: float
|
||||
|
@ -8,6 +8,7 @@ from typing import (Any, AsyncGenerator, Dict, Iterator, List, Mapping,
|
||||
Optional, Union, cast, overload)
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
@ -19,20 +20,23 @@ from vllm import PoolingParams
|
||||
from vllm.config import DecodingConfig, ModelConfig, VllmConfig
|
||||
from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.metrics import Stats
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.async_llm_engine import (
|
||||
build_guided_decoding_logits_processor_async)
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
IPC_OUTPUT_EXT, IPC_REMOTE_PREFILL_REQUEST_EXT,
|
||||
RPC_REQUEST_T,
|
||||
VLLM_RPC_SUCCESS_STR, IPC_REMOTE_NIXL_METADATA_EXT, RPCAbortRequest,
|
||||
IPC_METRICS_EXT,
|
||||
RPCAdapterLoadedResponse, RPCError,
|
||||
RPCLoadAdapterRequest,
|
||||
RPCProcessRequest,
|
||||
RPCResetPrefixCacheRequest,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
RPCUProfileRequest, KvMetrics)
|
||||
from vllm.engine.protocol import EngineClient
|
||||
# yapf: enable
|
||||
from vllm.envs import VLLM_RPC_TIMEOUT
|
||||
@ -46,6 +50,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
from vllm.utils import deprecate_kwargs
|
||||
from vllm.remote_prefill import RemotePrefillParams, RemotePrefillRequest, RemotePrefillRequestCallback
|
||||
from vllm.distributed.device_communicators.nixl import NixlMetadata
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -91,6 +97,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
# Get the configs.
|
||||
self.vllm_config = engine_config
|
||||
self.model_config = engine_config.model_config
|
||||
self.decoding_config = engine_config.decoding_config
|
||||
|
||||
@ -115,6 +122,10 @@ class MQLLMEngineClient(EngineClient):
|
||||
self.heartbeat_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.heartbeat_socket.connect(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# Metrics.
|
||||
self.metrics_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.metrics_socket.connect(f"{ipc_path}{IPC_METRICS_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
@ -129,8 +140,27 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Loop to check health of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready.
|
||||
self.health_loop: Optional[asyncio.Task] = None
|
||||
|
||||
# Loop to check metrics of the LLMEngine periodically.
|
||||
# Started after the MQLLMEngine is ready.
|
||||
self.metrics_loop: Optional[asyncio.Task] = None
|
||||
self.metrics_publisher = None
|
||||
|
||||
self._engine_process = psutil.Process(engine_pid)
|
||||
|
||||
self.nixl_metadata: Optional[NixlMetadata] = None
|
||||
self.remote_prefill_request_socket: Socket = self.context.socket(zmq.constants.PULL)
|
||||
self.remote_nixl_metadata_socket: Socket = self.context.socket(zmq.constants.PUSH)
|
||||
self.remote_prefill_requests_callback: Dict[str, RemotePrefillRequestCallback] = {}
|
||||
if self.using_nixl_connector:
|
||||
self.remote_prefill_request_socket.connect(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}")
|
||||
self.remote_nixl_metadata_socket.connect(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}")
|
||||
|
||||
|
||||
@property
|
||||
def using_nixl_connector(self) -> bool:
|
||||
return self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector"
|
||||
|
||||
@staticmethod
|
||||
def is_unsupported_config(engine_args: AsyncEngineArgs):
|
||||
# Pipeline parallel not yet supported
|
||||
@ -180,6 +210,61 @@ class MQLLMEngineClient(EngineClient):
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
|
||||
async def run_remote_prefill_request_handler_loop(self):
|
||||
try:
|
||||
while True:
|
||||
if await self.remote_prefill_request_socket.poll(timeout=VLLM_RPC_TIMEOUT):
|
||||
frames = await self.remote_prefill_request_socket.recv(copy=False)
|
||||
remote_prefill_request = msgspec.msgpack.decode(frames.buffer, type=RemotePrefillRequest)
|
||||
await self.remote_prefill_requests_callback[remote_prefill_request.request_id](remote_prefill_request)
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient remote prefill request handler loop.")
|
||||
|
||||
async def run_metrics_loop(self, timeout: int):
|
||||
"""Background loop that continually checks to ensure the engine process
|
||||
is still alive.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
# Check if the engine process is running:
|
||||
if not self._engine_process.is_running() or (
|
||||
self._engine_process.status() == psutil.STATUS_ZOMBIE):
|
||||
# NB: is_running() returns True for zombies
|
||||
self._set_errored(
|
||||
RuntimeError(
|
||||
f"Engine process (pid {self._engine_process.pid}) "
|
||||
"died."))
|
||||
break
|
||||
|
||||
if await self.metrics_socket.poll(timeout=timeout):
|
||||
# Metrics received- check the message
|
||||
message: Frame = await self.metrics_socket.recv(copy=False)
|
||||
metrics = pickle.loads(message.buffer)
|
||||
if self.metrics_publisher is not None and isinstance(
|
||||
metrics, KvMetrics
|
||||
):
|
||||
self.metrics_publisher.publish(metrics.request_active_slots,
|
||||
metrics.request_total_slots,
|
||||
metrics.kv_active_blocks,
|
||||
metrics.kv_total_blocks,
|
||||
metrics.num_requests_waiting,
|
||||
metrics.gpu_cache_usage_perc,
|
||||
metrics.gpu_prefix_cache_hit_rate)
|
||||
logger.debug("Metrics successful.")
|
||||
|
||||
# TODO: Investigate sending whole stats object
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug("Shutting down MQLLMEngineClient check metrics loop.")
|
||||
|
||||
except psutil.NoSuchProcess:
|
||||
self._set_errored(
|
||||
RuntimeError(
|
||||
f"Engine process (pid {self._engine_process.pid}) died."))
|
||||
|
||||
except Exception as e:
|
||||
self._set_errored(e)
|
||||
|
||||
async def run_output_handler_loop(self):
|
||||
"""Get RequestOutputs from Engine and stream to Request Queues"""
|
||||
|
||||
@ -278,12 +363,26 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Wait until server is ready.
|
||||
response = await self._wait_for_server_rpc(socket)
|
||||
|
||||
if response.nixl_metadata is not None:
|
||||
assert self.using_nixl_connector
|
||||
self.nixl_metadata = msgspec.msgpack.decode(response.nixl_metadata, type=NixlMetadata)
|
||||
|
||||
self.tracing_flag = response.tracing_enabled
|
||||
|
||||
# Start health_loop.
|
||||
if self.health_loop is None:
|
||||
self.health_loop = asyncio.create_task(
|
||||
self.run_heartbeat_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
|
||||
if self.using_nixl_connector:
|
||||
self.remote_prefill_loop = asyncio.create_task(
|
||||
self.run_remote_prefill_request_handler_loop())
|
||||
|
||||
# Start metrics_loop.
|
||||
if self.metrics_loop is None:
|
||||
self.metrics_loop = asyncio.create_task(
|
||||
self.run_metrics_loop(timeout=VLLM_RPC_TIMEOUT))
|
||||
|
||||
|
||||
def close(self):
|
||||
"""Destroy the ZeroMQ Context."""
|
||||
@ -293,6 +392,8 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Cancel background tasks.
|
||||
if self.health_loop is not None:
|
||||
self.health_loop.cancel()
|
||||
if self.metrics_loop is not None:
|
||||
self.metrics_loop.cancel()
|
||||
if self.output_loop is not None:
|
||||
self.output_loop.cancel()
|
||||
|
||||
@ -415,6 +516,9 @@ class MQLLMEngineClient(EngineClient):
|
||||
"""
|
||||
if self._errored_with is not None:
|
||||
raise self._errored_with
|
||||
|
||||
async def add_remote_nixl_metadata(self, nixl_metadata: NixlMetadata):
|
||||
await self.remote_nixl_metadata_socket.send(msgspec.msgpack.encode(nixl_metadata), copy=False)
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
@ -473,6 +577,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
*,
|
||||
inputs: Optional[PromptType] = None # DEPRECATED
|
||||
) -> AsyncGenerator[RequestOutput, None]:
|
||||
@ -502,7 +607,8 @@ class MQLLMEngineClient(EngineClient):
|
||||
|
||||
return self._process_request(prompt, sampling_params, request_id,
|
||||
lora_request, trace_headers,
|
||||
prompt_adapter_request, priority)
|
||||
prompt_adapter_request, priority,
|
||||
remote_prefill_params)
|
||||
|
||||
@overload
|
||||
def encode(
|
||||
@ -586,6 +692,7 @@ class MQLLMEngineClient(EngineClient):
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
) -> Union[AsyncGenerator[RequestOutput, None], AsyncGenerator[
|
||||
PoolingRequestOutput, None]]:
|
||||
"""Send an RPCGenerateRequest to the RPCServer and stream responses."""
|
||||
@ -630,6 +737,12 @@ class MQLLMEngineClient(EngineClient):
|
||||
else:
|
||||
lp_bytes = None
|
||||
|
||||
if remote_prefill_params is not None:
|
||||
self.remote_prefill_requests_callback[request_id] = remote_prefill_params.remote_prefill_request_callback
|
||||
remote_prefill_params.remote_prefill_request_callback = None
|
||||
else:
|
||||
remote_prefill_request_callback = None
|
||||
|
||||
request_bytes = pickle.dumps(
|
||||
RPCProcessRequest(
|
||||
prompt=prompt,
|
||||
@ -639,11 +752,11 @@ class MQLLMEngineClient(EngineClient):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=priority,
|
||||
remote_prefill_params=remote_prefill_params,
|
||||
))
|
||||
|
||||
# 3) Send the RPCGenerateRequest to the MQLLMEngine.
|
||||
parts = (request_bytes,
|
||||
lp_bytes) if lp_bytes else (request_bytes, )
|
||||
parts = (request_bytes, lp_bytes) if lp_bytes else (request_bytes,)
|
||||
await self.input_socket.send_multipart(parts, copy=False)
|
||||
|
||||
# 4) Stream the RequestOutputs from the output queue. Note
|
||||
@ -705,3 +818,6 @@ class MQLLMEngineClient(EngineClient):
|
||||
# Raise on error, otherwise happily return None
|
||||
if isinstance(request_output, BaseException):
|
||||
raise request_output
|
||||
|
||||
def set_metrics_publisher(self, metrics_publisher):
|
||||
self.metrics_publisher = metrics_publisher
|
||||
|
@ -3,35 +3,115 @@
|
||||
import pickle
|
||||
import signal
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator, List, Optional, Union
|
||||
from typing import Iterator, List, Optional, Union, Dict
|
||||
|
||||
import cloudpickle
|
||||
import time
|
||||
import zmq
|
||||
|
||||
import msgspec
|
||||
from vllm import AsyncEngineArgs, SamplingParams
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
# yapf conflicts with isort for this block
|
||||
# yapf: disable
|
||||
from vllm.engine.multiprocessing import (ENGINE_DEAD_ERROR, IPC_DATA_EXT,
|
||||
IPC_HEALTH_EXT, IPC_INPUT_EXT,
|
||||
IPC_OUTPUT_EXT, REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, RPCAbortRequest,
|
||||
REQUEST_OUTPUTS_T,
|
||||
VLLM_RPC_SUCCESS_STR, IPC_REMOTE_PREFILL_REQUEST_EXT,
|
||||
RPCAbortRequest,
|
||||
IPC_OUTPUT_EXT, IPC_METRICS_EXT,
|
||||
RPCAdapterLoadedResponse, RPCError,
|
||||
RPCLoadAdapterRequest,
|
||||
RPCProcessRequest,
|
||||
RPCResetPrefixCacheRequest,
|
||||
RPCStartupRequest, RPCStartupResponse,
|
||||
RPCUProfileRequest)
|
||||
RPCUProfileRequest, IPC_REMOTE_NIXL_METADATA_EXT,
|
||||
KvMetrics)
|
||||
# yapf: enable
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.remote_prefill import RemotePrefillRequest
|
||||
from vllm.distributed.device_communicators.nixl import NixlMetadata
|
||||
|
||||
from vllm.engine.metrics_types import StatLoggerBase, Stats, SupportsMetricsInfo
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
POLLING_TIMEOUT_MS = 10000
|
||||
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
|
||||
|
||||
class KvStatLogger(StatLoggerBase):
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
num_total_gpu_blocks: int,
|
||||
metrics_socket
|
||||
):
|
||||
# Must query initialized scheduler for max infos
|
||||
self.request_total_slots = max_num_seqs
|
||||
self.kv_total_blocks = num_total_gpu_blocks
|
||||
self.metrics_socket = metrics_socket
|
||||
|
||||
# KV metrics
|
||||
self._send_kv_metrics(0, 0, 0, 0.0, 0.0)
|
||||
|
||||
def log(self, stats: Stats) -> None:
|
||||
self._send_kv_metrics(
|
||||
stats.num_running_sys,
|
||||
int(stats.gpu_cache_usage_sys * self.kv_total_blocks),
|
||||
stats.num_waiting_sys,
|
||||
stats.gpu_cache_usage_sys,
|
||||
stats.gpu_prefix_cache_hit_rate
|
||||
)
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
pass
|
||||
|
||||
def _send_kv_metrics(
|
||||
self,
|
||||
active_slots,
|
||||
active_kv_blocks,
|
||||
num_requests_waiting,
|
||||
gpu_cache_usage_perc,
|
||||
gpu_prefix_cache_hit_rate,
|
||||
):
|
||||
if not self.metrics_socket.closed:
|
||||
metrics_bytes = pickle.dumps(
|
||||
KvMetrics(
|
||||
active_slots,
|
||||
self.request_total_slots,
|
||||
active_kv_blocks,
|
||||
self.kv_total_blocks,
|
||||
num_requests_waiting,
|
||||
gpu_cache_usage_perc,
|
||||
gpu_prefix_cache_hit_rate,
|
||||
)
|
||||
)
|
||||
self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
|
||||
|
||||
# TODO: Send entire stats object to the client
|
||||
# class StatLogger(StatLoggerBase):
|
||||
# def __init__(
|
||||
# self,
|
||||
# metrics_socket
|
||||
# ):
|
||||
# self.metrics_socket = metrics_socket
|
||||
|
||||
# def log(self, stats: Stats) -> None:
|
||||
# self._send_metrics(stats)
|
||||
|
||||
# def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
# pass
|
||||
|
||||
# def _send_metrics(self, stats: Stats):
|
||||
# if not self.metrics_socket.closed:
|
||||
# metrics_bytes = pickle.dumps(stats)
|
||||
# self.metrics_socket.send_multipart((metrics_bytes, ), copy=False)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class MQLLMEngine:
|
||||
"""A multiprocessing wrapper for :class:`LLMEngine`.
|
||||
@ -94,12 +174,37 @@ class MQLLMEngine:
|
||||
self.heartbeat_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.heartbeat_socket.bind(f"{ipc_path}{IPC_HEALTH_EXT}")
|
||||
|
||||
# Send metrics back to client.
|
||||
self.metrics_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.metrics_socket.bind(f"{ipc_path}{IPC_METRICS_EXT}")
|
||||
|
||||
# IPC path for the data socket.
|
||||
self.data_ipc_path = f"{ipc_path}{IPC_DATA_EXT}"
|
||||
|
||||
# Error state.
|
||||
self._errored_with: Optional[BaseException] = None
|
||||
|
||||
self.remote_prefill_request_socket = self.ctx.socket(zmq.constants.PUSH)
|
||||
self.remote_nixl_metadata_socket = self.ctx.socket(zmq.constants.PULL)
|
||||
if self.engine.is_nixl_initialized:
|
||||
self.remote_prefill_request_socket.bind(f"{ipc_path}{IPC_REMOTE_PREFILL_REQUEST_EXT}")
|
||||
self.remote_nixl_metadata_socket.bind(f"{ipc_path}{IPC_REMOTE_NIXL_METADATA_EXT}")
|
||||
|
||||
|
||||
# Attach logger for continuous metrics publishing
|
||||
self.kv_stat_logger = KvStatLogger(
|
||||
self.engine.scheduler_config.max_num_seqs,
|
||||
self.engine.cache_config.num_gpu_blocks,
|
||||
self.metrics_socket
|
||||
)
|
||||
self.engine.add_logger("kv_metrics", self.kv_stat_logger)
|
||||
|
||||
# TODO investigate sending whole stats object
|
||||
# self.general_stat_logger = StatLogger(
|
||||
# self.metrics_socket
|
||||
# )
|
||||
# self.engine.add_logger("general_metrics", self.general_stat_logger)
|
||||
|
||||
@property
|
||||
def dead_error(self) -> BaseException:
|
||||
if self._errored_with is not None:
|
||||
@ -171,8 +276,17 @@ class MQLLMEngine:
|
||||
# Handle the query from the Client.
|
||||
if request == RPCStartupRequest.IS_SERVER_READY:
|
||||
tracing_enabled = self.engine.is_tracing_enabled()
|
||||
response = RPCStartupResponse(
|
||||
tracing_enabled=tracing_enabled)
|
||||
|
||||
# Send nixl metadata to the client
|
||||
if self.engine.is_nixl_initialized:
|
||||
nixl_metadata = self.engine.get_nixl_metadata()
|
||||
encoded_nixl_metadata = msgspec.msgpack.encode(nixl_metadata)
|
||||
response = RPCStartupResponse(
|
||||
tracing_enabled=tracing_enabled,
|
||||
nixl_metadata=encoded_nixl_metadata)
|
||||
else:
|
||||
response = RPCStartupResponse(
|
||||
tracing_enabled=tracing_enabled)
|
||||
|
||||
except Exception as e:
|
||||
response = e
|
||||
@ -185,6 +299,7 @@ class MQLLMEngine:
|
||||
|
||||
while True:
|
||||
if not self.engine.has_unfinished_requests():
|
||||
logger.debug("No unfinished requests")
|
||||
# Poll until there is work to do.
|
||||
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
|
||||
# When there's no work, check on engine health and send
|
||||
@ -220,6 +335,13 @@ class MQLLMEngine:
|
||||
def handle_new_input(self):
|
||||
"""Handle new input from the socket"""
|
||||
try:
|
||||
if self.engine.is_nixl_initialized:
|
||||
while self.remote_nixl_metadata_socket.poll(timeout=0) != 0:
|
||||
frames = self.remote_nixl_metadata_socket.recv(copy=False)
|
||||
nixl_metadata = msgspec.msgpack.decode(frames.buffer, type=NixlMetadata)
|
||||
logger.debug("Adding remote nixl metadata for engine: %s", nixl_metadata.engine_id)
|
||||
self.engine.add_remote_nixl_metadata(nixl_metadata)
|
||||
|
||||
while self.input_socket.poll(timeout=0) != 0:
|
||||
frames = self.input_socket.recv_multipart(copy=False)
|
||||
request = pickle.loads(frames[0].buffer)
|
||||
@ -262,6 +384,11 @@ class MQLLMEngine:
|
||||
self._send_outputs(rpc_err)
|
||||
|
||||
try:
|
||||
if request.remote_prefill_params is not None and request.remote_prefill_params.is_remote_prefill:
|
||||
def remote_prefill_request_callback(request: RemotePrefillRequest):
|
||||
logger.debug("Sending remote prefill request: %s", request.request_id)
|
||||
self.remote_prefill_request_socket.send(msgspec.msgpack.encode(request), copy=False)
|
||||
request.remote_prefill_params.remote_prefill_request_callback = remote_prefill_request_callback
|
||||
self.engine.add_request(
|
||||
request_id=request_id,
|
||||
prompt=request.prompt,
|
||||
@ -269,7 +396,9 @@ class MQLLMEngine:
|
||||
lora_request=request.lora_request,
|
||||
trace_headers=request.trace_headers,
|
||||
prompt_adapter_request=request.prompt_adapter_request,
|
||||
priority=request.priority)
|
||||
priority=request.priority,
|
||||
remote_prefill_params=request.remote_prefill_params,
|
||||
)
|
||||
|
||||
if self.log_requests:
|
||||
logger.info("Added request %s.", request.request_id)
|
||||
|
@ -34,6 +34,7 @@ from vllm.sampling_params import BeamSearchParams, SamplingParams
|
||||
from vllm.sequence import Logprob
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer
|
||||
from vllm.transformers_utils.tokenizers import maybe_serialize_tool_calls
|
||||
from vllm.remote_prefill import RemotePrefillParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -112,6 +113,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
raw_request: Optional[Request] = None,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
) -> Union[AsyncGenerator[str, None], ChatCompletionResponse,
|
||||
ErrorResponse]:
|
||||
"""
|
||||
@ -243,6 +245,7 @@ class OpenAIServingChat(OpenAIServing):
|
||||
trace_headers=trace_headers,
|
||||
prompt_adapter_request=prompt_adapter_request,
|
||||
priority=request.priority,
|
||||
remote_prefill_params=remote_prefill_params,
|
||||
)
|
||||
|
||||
generators.append(generator)
|
||||
|
19
vllm/envs.py
19
vllm/envs.py
@ -87,6 +87,10 @@ if TYPE_CHECKING:
|
||||
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
|
||||
VLLM_RAY_PER_WORKER_GPUS: float = 1.0
|
||||
VLLM_RAY_BUNDLE_INDICES: str = ""
|
||||
VLLM_KV_CAPI_PATH: Optional[str] = None
|
||||
VLLM_KV_NAMESPACE: Optional[str] = None
|
||||
VLLM_KV_COMPONENT: Optional[str] = None
|
||||
VLLM_WORKER_ID: Optional[int] = None
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -572,6 +576,21 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
# models the alignment is already naturally aligned to 256 bytes.
|
||||
"VLLM_CUDA_MEM_ALIGN_KV_CACHE":
|
||||
lambda: bool(int(os.getenv("VLLM_CUDA_MEM_ALIGN_KV_CACHE", "1"))),
|
||||
|
||||
# Path to the C API Library
|
||||
"VLLM_KV_CAPI_PATH":
|
||||
lambda: os.environ.get("VLLM_KV_CAPI_PATH", None),
|
||||
|
||||
# Identifiers to publish KV related information
|
||||
"VLLM_KV_NAMESPACE":
|
||||
lambda: os.environ.get("VLLM_KV_NAMESPACE", None),
|
||||
"VLLM_KV_COMPONENT":
|
||||
lambda: os.environ.get("VLLM_KV_COMPONENT", None),
|
||||
|
||||
# Worker ID used for identifying workers in distributed settings
|
||||
"VLLM_WORKER_ID":
|
||||
lambda: int(os.getenv("VLLM_WORKER_ID", "0"))
|
||||
if "VLLM_WORKER_ID" in os.environ else None,
|
||||
}
|
||||
|
||||
# end-env-vars-definition
|
||||
|
@ -585,6 +585,8 @@ class DeepseekV2Model(nn.Module):
|
||||
cache_config = vllm_config.cache_config
|
||||
quant_config = vllm_config.quant_config
|
||||
|
||||
self.config = config
|
||||
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
|
@ -6,16 +6,16 @@ from typing import Dict, Generic, List, MutableSequence, Optional
|
||||
from typing import Sequence as GenericSequence
|
||||
from typing import Union
|
||||
|
||||
import msgspec
|
||||
import torch
|
||||
from typing_extensions import TypeVar, deprecated
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal.inputs import MultiModalPlaceholderDict
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs,
|
||||
SequenceGroup, SequenceGroupBase, SequenceStatus)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompletionOutput:
|
||||
"""The output data of one completion output of a request.
|
||||
|
67
vllm/remote_prefill.py
Normal file
67
vllm/remote_prefill.py
Normal file
@ -0,0 +1,67 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Optional, List
|
||||
from enum import Enum
|
||||
|
||||
import msgspec
|
||||
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class RemotePrefillRequest(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
# required for @cached_property.
|
||||
dict=True):
|
||||
"""The request data of one remote prefill output of a request.
|
||||
|
||||
Args:
|
||||
engine_id: The unique ID of the engine.
|
||||
request_id: The unique ID of the request.
|
||||
prompt_token_ids: The token IDs of the prompt.
|
||||
sampling_params: The sampling parameters.
|
||||
block_ids: The block IDs of the request.
|
||||
computed_block_ids: The computed block IDs of the request.
|
||||
"""
|
||||
engine_id: str
|
||||
request_id: str
|
||||
prompt_token_ids: List[int]
|
||||
sampling_params: SamplingParams
|
||||
block_ids: List[int]
|
||||
computed_block_ids: List[int]
|
||||
|
||||
|
||||
class MemoryOpType(str, Enum):
|
||||
WRITE = "WRITE"
|
||||
READ = "READ"
|
||||
|
||||
|
||||
class MemoryTransferRequest(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True): # type: ignore[call-arg]
|
||||
"""The request data of one memory transfer output of a request.
|
||||
|
||||
Args:
|
||||
request_id: The unique ID of the request.
|
||||
"""
|
||||
request_id: str
|
||||
local_block_ids: List[int]
|
||||
staging_block_ids: List[int]
|
||||
remote_block_ids: List[int]
|
||||
remote_engine_id: str
|
||||
notify_msg: str
|
||||
op_type: MemoryOpType
|
||||
|
||||
|
||||
RemotePrefillRequestCallback = Callable[[RemotePrefillRequest], None]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RemotePrefillParams:
|
||||
"""Remote prefill parameters for text generation."""
|
||||
is_remote_prefill: bool = False
|
||||
is_remote_decode: bool = False
|
||||
decode_block_ids: Optional[List[int]] = None
|
||||
decode_computed_block_ids: Optional[List[int]] = None
|
||||
decode_engine_id: Optional[str] = None
|
||||
remote_prefill_request_callback: Optional[RemotePrefillRequestCallback] = None
|
@ -83,7 +83,7 @@ class RequestOutputKind(Enum):
|
||||
DELTA = 1
|
||||
# Do not return intermediate RequestOuputs
|
||||
FINAL_ONLY = 2
|
||||
|
||||
|
||||
|
||||
class SamplingParams(
|
||||
msgspec.Struct,
|
||||
|
@ -20,6 +20,7 @@ from vllm.multimodal import MultiModalDataDict, MultiModalPlaceholderDict
|
||||
from vllm.pooling_params import PoolingParams
|
||||
from vllm.prompt_adapter.request import PromptAdapterRequest
|
||||
from vllm.sampling_params import RequestOutputKind, SamplingParams
|
||||
from vllm.remote_prefill import RemotePrefillParams, MemoryTransferRequest
|
||||
|
||||
VLLM_TOKEN_ID_ARRAY_TYPE = "l"
|
||||
|
||||
@ -59,13 +60,14 @@ class SequenceStatus(enum.IntEnum):
|
||||
"""Status of a sequence."""
|
||||
WAITING = 0
|
||||
RUNNING = 1
|
||||
SWAPPED = 2
|
||||
# Note: anything after SWAPPED (2) will be considered
|
||||
REMOTE_PREFILLING = 2
|
||||
SWAPPED = 3
|
||||
# Note: anything after SWAPPED (3) will be considered
|
||||
# as a finished status.
|
||||
FINISHED_STOPPED = 3
|
||||
FINISHED_LENGTH_CAPPED = 4
|
||||
FINISHED_ABORTED = 5
|
||||
FINISHED_IGNORED = 6
|
||||
FINISHED_STOPPED = 4
|
||||
FINISHED_LENGTH_CAPPED = 5
|
||||
FINISHED_ABORTED = 6
|
||||
FINISHED_IGNORED = 7
|
||||
|
||||
@staticmethod
|
||||
def is_finished(status: "SequenceStatus") -> bool:
|
||||
@ -409,6 +411,7 @@ class Sequence:
|
||||
eos_token_id: Optional[int] = None,
|
||||
lora_request: Optional[LoRARequest] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
) -> None:
|
||||
self.seq_id = seq_id
|
||||
self.inputs = SingletonInputsAdapter(inputs)
|
||||
@ -416,7 +419,7 @@ class Sequence:
|
||||
self.eos_token_id = eos_token_id
|
||||
self.lora_request = lora_request
|
||||
self.prompt_adapter_request = prompt_adapter_request
|
||||
|
||||
self.remote_prefill_params = remote_prefill_params
|
||||
self.data = SequenceData.from_seqs(self.prompt_token_ids)
|
||||
self.output_logprobs: SampleLogprobs = []
|
||||
self.output_text = ""
|
||||
@ -639,6 +642,7 @@ class SequenceGroup:
|
||||
trace_headers: OpenTelemetry trace headers.
|
||||
prompt_adapter_request: Prompt Adapter request.
|
||||
priority: User-defined priority of the request.
|
||||
remote_prefill_params: Remote prefill parameters.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -654,6 +658,7 @@ class SequenceGroup:
|
||||
trace_headers: Optional[Mapping[str, str]] = None,
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None,
|
||||
priority: int = 0,
|
||||
remote_prefill_params: Optional[RemotePrefillParams] = None,
|
||||
) -> None:
|
||||
self.request_id = request_id
|
||||
self.seqs = seqs
|
||||
@ -678,7 +683,7 @@ class SequenceGroup:
|
||||
self.encoder_seq = encoder_seq
|
||||
self.trace_headers = trace_headers
|
||||
self.priority = priority
|
||||
|
||||
self.remote_prefill_params = remote_prefill_params
|
||||
self.cached_request_output = None
|
||||
|
||||
@property
|
||||
@ -927,6 +932,9 @@ class SequenceGroupMetadata(
|
||||
query tokens for prefill, we don't need sampling.
|
||||
token_chunk_size: The number of tokens to be processed (per sequence).
|
||||
None if chunking is not required.
|
||||
do_remote_prefill: True if remote prefill is required.
|
||||
do_remote_decode: True if remote decode is required.
|
||||
decode_memory_desc: The memory descriptor for the decoder blocks.
|
||||
lora_request: LoRA request.
|
||||
computed_block_nums: The block numbers that are already computed,
|
||||
used in prefix caching.
|
||||
@ -966,6 +974,9 @@ class SequenceGroupMetadata(
|
||||
cross_block_table: Optional[List[int]] = None
|
||||
prompt_adapter_request: Optional[PromptAdapterRequest] = None
|
||||
token_chunk_size: Optional[int] = None
|
||||
do_remote_prefill: bool = False
|
||||
do_remote_decode: bool = False
|
||||
decode_memory_desc: Optional[bytes] = None
|
||||
|
||||
### Stateful fields that are lazily defined. ###
|
||||
# The number of speculative tokens adopted in this request.
|
||||
@ -1310,6 +1321,8 @@ class ExecuteModelRequest(
|
||||
last_sampled_token_ids: Optional[torch.Tensor] = None
|
||||
# Async callback
|
||||
async_callback: Optional[Callable] = None
|
||||
# The memory transfer requests.
|
||||
memory_transfer_requests: Optional[List[MemoryTransferRequest]] = None
|
||||
|
||||
@property
|
||||
def is_first_multi_step(self) -> bool:
|
||||
|
@ -1824,6 +1824,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
return False
|
||||
|
||||
if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector":
|
||||
return False
|
||||
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
|
||||
@ -1849,6 +1852,9 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||
|
||||
if self.vllm_config.kv_transfer_config is None:
|
||||
return False
|
||||
|
||||
if self.vllm_config.kv_transfer_config.kv_connector == "DynamoNixlConnector":
|
||||
return False
|
||||
|
||||
prefill_meta = model_input.attn_metadata.prefill_metadata
|
||||
|
||||
|
@ -2,7 +2,7 @@
|
||||
"""A GPU worker class."""
|
||||
import gc
|
||||
import os
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union
|
||||
from typing import Dict, List, Optional, Set, Tuple, Type, Union, TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -31,6 +31,9 @@ from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
|
||||
from vllm.worker.pooling_model_runner import PoolingModelRunner
|
||||
from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, WorkerBase,
|
||||
WorkerInput)
|
||||
from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
|
||||
from vllm.remote_prefill import MemoryOpType
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -306,6 +309,46 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
self._init_cache_engine()
|
||||
self._warm_up_model()
|
||||
|
||||
def initialize_nixl(self, engine_id: str) -> List[bytes]:
|
||||
|
||||
# TODO ptarasiewicz nixl can also support DRAM
|
||||
assert self.device_config.device_type == "cuda", "Currently only CUDA is supported for Nixl connector"
|
||||
|
||||
self.nixl_connector = DynamoNixlConnector(self.vllm_config, engine_id, self.local_rank) # TODO ptarasiewicz: rank or local_rank?
|
||||
assert len(self.cache_engine) == 1, "Only one cache engine is supported for now"
|
||||
self.nixl_connector.register_kv_caches(self.cache_engine[0].gpu_cache)
|
||||
return self.nixl_connector.agent_name
|
||||
|
||||
def get_nixl_agent_metadata(self) -> bytes:
|
||||
assert self.nixl_connector is not None, "Nixl connector is not initialized"
|
||||
return self.nixl_connector.get_agent_metadata()
|
||||
|
||||
def add_remote_nixl_metadata(self, engine_id: str, agents_metadata: List[bytes], kv_caches_base_addr: List[List[Tuple[int, int]]], num_blocks: int) -> str:
|
||||
assert self.nixl_connector is not None, "Nixl connector is not initialized"
|
||||
agent_name = self.nixl_connector.add_remote_agent(engine_id, agents_metadata, len(agents_metadata), kv_caches_base_addr, num_blocks) # TODO ptarasiewicz: rank or local_rank?
|
||||
return agent_name
|
||||
|
||||
def get_nixl_kv_caches_base_addr(self) -> List[bytes]:
|
||||
assert self.nixl_connector is not None, "Nixl connector is not initialized"
|
||||
return self.nixl_connector.kv_caches_base_addr[self.nixl_connector.engine_id]
|
||||
|
||||
def _read_blocks(self, worker_input: WorkerInput) -> None:
|
||||
for i, op_type in enumerate(worker_input.op_type):
|
||||
if op_type == MemoryOpType.READ:
|
||||
self.nixl_connector.read_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i])
|
||||
|
||||
def _write_blocks(self, worker_input: WorkerInput) -> None:
|
||||
if not self.is_driver_worker:
|
||||
torch.cuda.synchronize() # to make sure that the blocks are ready, on driver worker we transfer after sampling, so there's no need to synchronize
|
||||
|
||||
for i, op_type in enumerate(worker_input.op_type):
|
||||
if op_type == MemoryOpType.WRITE:
|
||||
self.nixl_connector.write_blocks(worker_input.local_block_ids[i], worker_input.staging_block_ids[i], worker_input.remote_block_ids[i], worker_input.remote_engine_id[i], worker_input.notify_msg[i])
|
||||
|
||||
def shutdown_nixl(self) -> None:
|
||||
assert self.nixl_connector is not None, "Nixl connector is not initialized"
|
||||
self.nixl_connector.shutdown()
|
||||
|
||||
def _init_cache_engine(self):
|
||||
assert self.cache_config.num_gpu_blocks is not None
|
||||
self.cache_engine = [
|
||||
@ -367,6 +410,8 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
blocks_to_copy = torch.tensor(execute_model_req.blocks_to_copy,
|
||||
device=self.device,
|
||||
dtype=torch.int64).view(-1, 2)
|
||||
|
||||
mem_transfer_reqs = execute_model_req.memory_transfer_requests or []
|
||||
|
||||
return WorkerInput(
|
||||
num_seq_groups=num_seq_groups,
|
||||
@ -375,6 +420,12 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||
blocks_to_copy=blocks_to_copy,
|
||||
virtual_engine=virtual_engine,
|
||||
num_steps=num_steps,
|
||||
local_block_ids=[r.local_block_ids for r in mem_transfer_reqs],
|
||||
staging_block_ids=[r.staging_block_ids for r in mem_transfer_reqs],
|
||||
remote_block_ids=[r.remote_block_ids for r in mem_transfer_reqs],
|
||||
remote_engine_id=[r.remote_engine_id for r in mem_transfer_reqs],
|
||||
notify_msg=[r.notify_msg for r in mem_transfer_reqs],
|
||||
op_type=[r.op_type for r in mem_transfer_reqs],
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
|
||||
import cloudpickle
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from collections import defaultdict
|
||||
|
||||
from vllm.config import (ObservabilityConfig, VllmConfig,
|
||||
set_current_vllm_config)
|
||||
@ -23,6 +24,9 @@ from vllm.utils import (enable_trace_function_call_for_thread,
|
||||
from vllm.worker.model_runner_base import (BroadcastableModelInput,
|
||||
ModelRunnerBase,
|
||||
ModelRunnerInputBase)
|
||||
from vllm.distributed.device_communicators.nixl import DynamoNixlConnector
|
||||
from vllm.remote_prefill import MemoryOpType
|
||||
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -53,6 +57,8 @@ class WorkerBase(ABC):
|
||||
from vllm.platforms import current_platform
|
||||
self.current_platform = current_platform
|
||||
|
||||
self.nixl_connector: Optional[DynamoNixlConnector] = None
|
||||
|
||||
@abstractmethod
|
||||
def init_device(self) -> None:
|
||||
"""Initialize device state, such as loading the model or other on-device
|
||||
@ -216,6 +222,13 @@ class WorkerInput:
|
||||
virtual_engine: int = 0
|
||||
num_steps: int = 1
|
||||
|
||||
local_block_ids: Optional[List[List[int]]] = None
|
||||
staging_block_ids: Optional[List[List[int]]] = None
|
||||
remote_block_ids: Optional[List[List[int]]] = None
|
||||
remote_engine_id: Optional[List[str]] = None
|
||||
notify_msg: Optional[List[str]] = None
|
||||
op_type: Optional[List[MemoryOpType]] = None
|
||||
|
||||
@classmethod
|
||||
def from_broadcasted_tensor_dict(
|
||||
cls: Type["WorkerInput"],
|
||||
@ -232,6 +245,12 @@ class WorkerInput:
|
||||
blocks_to_copy=tensor_dict.pop("blocks_to_copy"),
|
||||
virtual_engine=tensor_dict["virtual_engine"],
|
||||
num_steps=tensor_dict.pop("num_steps"),
|
||||
local_block_ids=tensor_dict.pop("local_block_ids"),
|
||||
staging_block_ids=tensor_dict.pop("staging_block_ids"),
|
||||
remote_block_ids=tensor_dict.pop("remote_block_ids"),
|
||||
remote_engine_id=tensor_dict.pop("remote_engine_id"),
|
||||
notify_msg=tensor_dict.pop("notify_msg"),
|
||||
op_type=tensor_dict.pop("op_type"),
|
||||
)
|
||||
|
||||
def as_broadcastable_tensor_dict(
|
||||
@ -246,6 +265,12 @@ class WorkerInput:
|
||||
"blocks_to_copy": self.blocks_to_copy,
|
||||
"virtual_engine": self.virtual_engine,
|
||||
"num_steps": self.num_steps,
|
||||
"local_block_ids": self.local_block_ids,
|
||||
"staging_block_ids": self.staging_block_ids,
|
||||
"remote_block_ids": self.remote_block_ids,
|
||||
"remote_engine_id": self.remote_engine_id,
|
||||
"notify_msg": self.notify_msg,
|
||||
"op_type": self.op_type,
|
||||
}
|
||||
|
||||
return tensor_dict
|
||||
@ -316,13 +341,16 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
return None
|
||||
|
||||
worker_input = WorkerInput.from_broadcasted_tensor_dict(broadcast_data)
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
if worker_input.num_seq_groups > 0:
|
||||
model_input = (
|
||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||
broadcast_data))
|
||||
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||
|
||||
return model_input, worker_input, kwargs
|
||||
return model_input, worker_input, kwargs
|
||||
else:
|
||||
return None, worker_input, {}
|
||||
|
||||
def _get_driver_input_and_broadcast(
|
||||
self, execute_model_req: ExecuteModelRequest
|
||||
@ -396,49 +424,88 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||
self.execute_worker(worker_input)
|
||||
|
||||
# If there is no input, we don't need to execute the model.
|
||||
if worker_input.num_seq_groups == 0:
|
||||
return []
|
||||
if worker_input.num_seq_groups > 0:
|
||||
|
||||
intermediate_tensors = None
|
||||
orig_model_execute_time = 0.0
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
self._read_blocks(worker_input)
|
||||
|
||||
intermediate_tensors = None
|
||||
orig_model_execute_time = 0.0
|
||||
if not get_pp_group().is_first_rank:
|
||||
intermediate_tensors = IntermediateTensors(
|
||||
get_pp_group().recv_tensor_dict(
|
||||
all_gather_group=get_tp_group()))
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
orig_model_execute_time = intermediate_tensors.tensors.get(
|
||||
"model_execute_time", torch.tensor(0)).item()
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
num_steps=num_steps,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
model_execute_time + orig_model_execute_time)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return [None]
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
orig_model_execute_time = intermediate_tensors.tensors.get(
|
||||
"model_execute_time", torch.tensor(0)).item()
|
||||
and self.observability_config.collect_model_execute_time
|
||||
and output is not None):
|
||||
for o in output:
|
||||
o.model_execute_time = (orig_model_execute_time +
|
||||
model_execute_time)
|
||||
|
||||
output = self.model_runner.execute_model(
|
||||
model_input=model_input,
|
||||
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||
if self.kv_cache is not None else None,
|
||||
intermediate_tensors=intermediate_tensors,
|
||||
num_steps=num_steps,
|
||||
**kwargs,
|
||||
)
|
||||
self._write_blocks(worker_input)
|
||||
|
||||
model_execute_time = time.perf_counter() - start_time
|
||||
if not get_pp_group().is_last_rank:
|
||||
# output is IntermediateTensors
|
||||
assert isinstance(output, IntermediateTensors)
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time):
|
||||
output.tensors["model_execute_time"] = torch.tensor(
|
||||
model_execute_time + orig_model_execute_time)
|
||||
get_pp_group().send_tensor_dict(output.tensors,
|
||||
all_gather_group=get_tp_group())
|
||||
return [None]
|
||||
if (self.observability_config is not None
|
||||
and self.observability_config.collect_model_execute_time
|
||||
and output is not None):
|
||||
for o in output:
|
||||
o.model_execute_time = (orig_model_execute_time +
|
||||
model_execute_time)
|
||||
else:
|
||||
output = []
|
||||
|
||||
# collect kv transfer notifications from non driver workers
|
||||
|
||||
if self.nixl_connector is not None:
|
||||
new_notifs = self.nixl_connector.get_new_notifs()
|
||||
rank = get_tp_group().rank
|
||||
all_new_notifs = [new_notifs]
|
||||
if rank > 0:
|
||||
get_tp_group().send_object(new_notifs, dst=0)
|
||||
else:
|
||||
for i in range(1, get_tp_group().world_size):
|
||||
all_new_notifs.append(get_tp_group().recv_object(src=i))
|
||||
|
||||
request_notif_counter = defaultdict(int)
|
||||
for notifs in all_new_notifs:
|
||||
for req_ids in notifs.values():
|
||||
for req_id in req_ids:
|
||||
request_notif_counter[req_id] += 1
|
||||
|
||||
if request_notif_counter:
|
||||
logger.debug("Request notif counter: %s", request_notif_counter)
|
||||
|
||||
request_done_counter = defaultdict(int)
|
||||
for req_id in self.nixl_connector.get_done_tranfers():
|
||||
request_done_counter[req_id] += 1
|
||||
else:
|
||||
request_notif_counter = {}
|
||||
request_done_counter = {}
|
||||
# output is List[SamplerOutput]
|
||||
return output
|
||||
return output, request_notif_counter, request_done_counter
|
||||
|
||||
def _read_blocks(self, worker_input: WorkerInput) -> None:
|
||||
pass
|
||||
|
||||
def _write_blocks(self, worker_input: WorkerInput) -> None:
|
||||
pass
|
||||
|
||||
def _execute_model_spmd(
|
||||
self,
|
||||
|
Reference in New Issue
Block a user