Files
vllm-ascend/vllm_ascend/distributed/mooncake_layerwise_connector.py
zxr2333 c2c1db78a7 [Bugfix] fix ZeroDivisionError when prefill_tp_size > num_kv_head and fix tp_resharding README (#3437)
### What this PR does / why we need it?
Fix ZeroDivisionError when prefill_tp_size > num_kv_head, in this
situation, num_head_replica can be 0 and used to divide another value,
this PR restricts the minimum value of a to be 1. And this PR fix
tp_resharding README.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
2025-10-15 08:45:44 +08:00

1341 lines
58 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# SPDX-License-Identifier: Apache-2.0
import contextlib
import hashlib
import math
import queue
import random
import struct
import threading
import time
from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
import httpx
import msgspec
import numpy as np
import numpy.typing as npt
import torch
import zmq
from mooncake.engine import TransferEngine # type: ignore
from vllm import envs
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
get_tp_group, get_world_group)
from vllm.utils import get_ip, logger, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import (align_memory,
kv_alltoall_and_rearrange)
if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.request import Request
GET_META_MSG = b"get_meta_msg"
DONE_RECVING_MSG = b"done_recving_msg"
class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
engine_id: str
te_rpc_port: int
kv_caches_base_addr: list[int]
num_blocks: int
@dataclass
class ReqMeta:
local_block_ids: list[int]
# Not None if layer-wise is disabled
remote_block_ids: Optional[list[int]]
remote_host: Optional[str]
remote_port: Optional[int]
remote_engine_id: Optional[str]
# Not None if layer-wise is enabled
metaserver: Optional[str]
remote_tp_size: Optional[int]
class DecodeMooncakeAgentMetadata(msgspec.Struct,
omit_defaults=True,
dict=True):
req_id: str
block_ids: list[int]
host: str
port: int
engine_id: str
te_rpc_port: int
kv_caches_base_addr: list[int]
num_blocks: int
class KVCacheTaskTracker:
def __init__(self,
target_count: int = 1,
on_done: Callable[[str], None] = lambda x: None,
on_timeout: Callable[[set[str]], Any] = lambda x: None):
super().__init__()
self.target_count = target_count
self.done_task_lock = threading.Lock()
self.done_task_counts: defaultdict[str, int] = defaultdict(int)
self.finished_requests: set[str] = set()
# Only used in prefill node. Tracks requests whose kv blocks freeing is
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
# Notice: In layer-wise mode, the transfer may complete before it is
# added to delayed_free_requests when prefill node finishes forwarding.
# Therefore we need to track requests that are removed before being added.
self.delayed_free_requests: dict[str, float] = {}
self.removed_delayed_free_requests: set[str] = set()
self.on_done = on_done
self.on_timeout = on_timeout
def update_done_task_count(self, request_id: str):
self.done_task_counts[request_id] += 1
if self.done_task_counts[request_id] == self.target_count:
with self.done_task_lock:
self.finished_requests.add(request_id)
self.done_task_counts.pop(request_id)
self.on_done(request_id)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
expired_requests = self._retrieve_expired_requests()
finished_requests.update(expired_requests)
self.finished_requests.clear()
self.on_timeout(expired_requests)
return finished_requests
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request, where delay_start_time is monotonic increasing."""
with self.done_task_lock:
if request_id in self.removed_delayed_free_requests:
self.removed_delayed_free_requests.remove(request_id)
else:
self.delayed_free_requests[request_id] = delay_start_time
def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
expired_requests: set[str] = set()
# Free delayed requests if they exceed the timeout
current_time = time.time()
while self.delayed_free_requests:
request_id, delay_start_time = next(
iter(self.delayed_free_requests.items()))
if (current_time - delay_start_time
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.pop(request_id)
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
break
return expired_requests
def remove_delayed_request(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
with self.done_task_lock:
if self.delayed_free_requests.pop(request_id, None) is None:
self.removed_delayed_free_requests.add(request_id)
class KVCacheSendingLayerThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, decode_tp_size: int,
local_engine_id: str, side_channel_host: str,
side_channel_port: int, metadata: MooncakeAgentMetadata,
ready_event: threading.Event, total_layers: int,
engine: TransferEngine, local_kv_base_addr: list[int],
block_len: list[int], use_mla: bool,
first_kv_cache: torch.Tensor):
super().__init__(daemon=True, name="KVCacheSendingLayerThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
self.decode_tp_size = decode_tp_size
self.local_engine_id = local_engine_id
self.side_channel_host = side_channel_host
self.side_channel_port = side_channel_port
self.task_tracker = KVCacheTaskTracker(total_layers,
on_done=self._post_transfer,
on_timeout=self._abort_requests)
self.send_layer_thread = SendingLayerThread(
self.task_tracker, total_layers, engine, local_kv_base_addr,
block_len, use_mla, self.tp_rank, first_kv_cache)
self.ready_decode = dict[str, DecodeMooncakeAgentMetadata]()
self.pending_decode = dict[str,
list[tuple[list[int], int, torch.Tensor,
torch.Tensor]]]()
self.total_layers = total_layers
self.lock = threading.Lock()
self.ready_event = ready_event
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
# vllm won't call us if all inference is done, so we can't do step 9 here
return self.task_tracker.get_and_clear_finished_requests()
def add_delayed_request(self, request_id: str, delay_start_time: float):
return self.task_tracker.add_delayed_request(request_id,
delay_start_time)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.send_layer_thread.start()
handshake_port = self.side_channel_port + self.tp_rank
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
self.ready_event.set()
decoder = msgspec.msgpack.Decoder(type=DecodeMooncakeAgentMetadata)
while True:
try:
frames = sock.recv_multipart()
if len(frames) < 2:
logger.error("Invalid message format: %s", frames)
continue
identity = frames[0]
payload = [f for f in frames[1:] if f != b""]
if len(payload) != 1:
logger.error("Invalid message format: %s", frames)
continue
metadata = decoder.decode(payload[0])
request_id = metadata.req_id
logger.debug(
f"Prefiller has received that request {request_id} from the decoder."
)
sock.send_multipart((identity, b"", b"ACK"))
self.task_tracker.remove_delayed_request(request_id)
with self.lock:
self.ready_decode[request_id] = metadata
pending = self.pending_decode.pop(request_id, [])
for local_block_ids, layer_index, key, value in pending:
self.send_layer_thread.send_queue.put(
(metadata, request_id, local_block_ids,
layer_index, key, value))
except Exception as e:
logger.error("Failed to decode message: %s", e)
def _post_transfer(self, request_id: str):
with self.lock:
decoder_meta = self.ready_decode.pop(request_id)
path = make_zmq_path("tcp", decoder_meta.host, decoder_meta.port)
msg_encoder = msgspec.msgpack.Encoder()
encoded_data = msg_encoder.encode(request_id)
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data)
ack = sock.recv()
if ack != b"ACK":
raise ValueError(f"Unexpected ACK response: {ack}")
def add_request(self, request_id: str, local_block_ids: list[int],
layer_index: int, key: torch.Tensor, value: torch.Tensor):
# add request to send layer thread
with self.lock:
if request_id in self.ready_decode:
self.send_layer_thread.send_queue.put(
(self.ready_decode[request_id], request_id,
local_block_ids, layer_index, key, value))
else:
self.pending_decode.setdefault(request_id, []).append(
(local_block_ids, layer_index, key, value))
def _abort_requests(self, request_ids: set[str]):
with self.lock:
for request_id in request_ids:
self.pending_decode.pop(request_id, None)
class SendingLayerThread(threading.Thread):
def __init__(self, task_tracker: KVCacheTaskTracker, total_layers: int,
engine: TransferEngine, local_kv_base_addr: list[int],
block_len: list[int], use_mla: bool, tp_rank: int,
first_kv_cache: torch.Tensor):
super().__init__(daemon=True, name="KVCacheRecvingPrefillerByeThread")
self.send_queue = queue.Queue[tuple[DecodeMooncakeAgentMetadata, str,
list[int], int, torch.Tensor,
torch.Tensor]]()
self.completion_event: Optional[threading.Event] = None
self.completion_event_count: int
self.task_tracker = task_tracker
self.total_layers = total_layers
self.local_kv_base_addr = local_kv_base_addr
self.block_len = block_len
self.use_mla = use_mla
self.engine = engine
self.tp_rank = tp_rank
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
self.num_head_replica = get_ascend_config().num_head_replica
self.pd_head_ratio = get_ascend_config().pd_head_ratio
vllm_config = get_current_vllm_config()
max_model_len = vllm_config.scheduler_config.max_model_len
first_kv_cache = first_kv_cache[:max_model_len]
alignment = 2 * 1024 * 1024
self.k_buffer = torch.zeros(
first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device) # 【4,1,128】-》【1000 128】
self.k_buffer = align_memory(self.k_buffer,
alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
dtype=first_kv_cache.dtype,
device=first_kv_cache.device)
self.v_buffer = align_memory(self.v_buffer,
alignment)[:first_kv_cache.numel()].view(
-1, first_kv_cache.shape[-1])
for tensor in (self.k_buffer, self.v_buffer):
assert tensor.data_ptr(
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(tensor.data_ptr(),
tensor.numel())
logger.info(
f"Sendinglayerthread register_memory {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
def run(self):
"""Run the thread to handle KV cache receiving for prefiller bye messages."""
# send kv cache for request in send_queue
local_rank = get_world_group().local_rank
device = torch.device(f"npu:{local_rank}")
torch.npu.set_device(device)
while True:
request = self.send_queue.get()
self._handle_request(request)
def _handle_request(self, request: tuple[DecodeMooncakeAgentMetadata, str,
list[int], int, torch.Tensor,
torch.Tensor]):
# send kv layer to remote
req_meta, request_id, local_block_ids, layer_index, key, value = request
try:
logger.debug(
f"Starting to transfer KV cache for request {request_id}.")
self._transfer_kv_cache(req_meta, local_block_ids, layer_index,
key, value)
logger.debug(
f"Finished transferring KV cache for request {request_id}.")
except Exception as e:
logger.error("Failed to transfer KV cache for request "
f"{request_id}: {e}")
finally:
self.task_tracker.update_done_task_count(request_id)
self.send_queue.task_done()
def _transfer_kv_cache(self, req_meta: DecodeMooncakeAgentMetadata,
local_block_ids: list[int], layer_index: int, key,
value):
# send kv layer to remote
if len(local_block_ids) == 0:
return
remote_host = req_meta.host
remote_te_port = req_meta.te_rpc_port
remote_kv_base_addrs = req_meta.kv_caches_base_addr
remote_block_ids = req_meta.block_ids
if self.tp_rank % self.num_head_replica != 0:
pass
elif self.pd_head_ratio == 1:
layer_local_kv_base_addr = [
self.local_kv_base_addr[i]
for i in [2 * layer_index, 2 * layer_index + 1]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i]
for i in [2 * layer_index, 2 * layer_index + 1]
]
grouped_remote_block_ids, grouped_local_block_ids = \
group_concurrent_contiguous(remote_block_ids, local_block_ids)
session_id = f"{remote_host}:{remote_te_port}"
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
block_len = self.block_len[
k % 2] if self.use_mla else self.block_len[0]
for group_remote_block_id, group_local_block_id in zip(
grouped_remote_block_ids, grouped_local_block_ids):
src = src_layer_base_addr + group_local_block_id[
0] * block_len
dst = dst_layer_base_addr + group_remote_block_id[
0] * block_len
length = len(group_local_block_id) * block_len
src_list.append(src)
dst_list.append(dst)
length_list.append(length)
torch.npu.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s",
req_meta.req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
else:
key = key.view(-1, key.shape[-1])
value = value.view(-1, key.shape[-1])
self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] ->
self.v_buffer[:value.shape[0]].copy_(value)
layer_local_kv_base_addr = [
self.k_buffer.data_ptr(),
self.v_buffer.data_ptr()
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs[i]
for i in [2 * layer_index, 2 * layer_index + 1]
]
grouped_remote_block_ids, _ = group_concurrent_contiguous(
remote_block_ids)
session_id = f"{remote_host}:{remote_te_port}"
src_list, dst_list, length_list = [], [], []
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
src_layer_addr = src_layer_base_addr
for group_remote_block_id in grouped_remote_block_ids:
block_len = self.block_len[0]
remote_block_len = self.block_len[0] * self.pd_head_ratio
src_list.append(src_layer_addr)
if src_layer_addr + len(
group_remote_block_id
) * block_len > src_layer_base_addr + key.numel(
) * key.element_size():
length = src_layer_base_addr + key.numel(
) * key.element_size() - src_layer_addr
else:
length = len(group_remote_block_id) * block_len
length_list.append(length)
dst_list.append(dst_layer_base_addr +
group_remote_block_id[0] *
remote_block_len + length *
((self.tp_rank // self.num_head_replica) %
self.pd_head_ratio))
src_layer_addr += length
torch.npu.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
logger.error("Mooncake transfer failed for request %s",
req_meta.req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
if self.completion_event is not None:
self.completion_event_count -= 1
if self.completion_event_count == 0:
self.completion_event.set()
self.completion_event = None
def add_event(self, event: threading.Event, count: int) -> None:
self.completion_event = event
self.completion_event_count = count
class KVCacheRecvingLayerThread(threading.Thread):
def __init__(self, tp_rank: int, side_channel_port: int, tp_size: int,
local_engine_id: str, ready_event: threading.Event):
super().__init__(daemon=True, name="KVCacheRecvingLayerThread")
self.tp_rank = tp_rank
self.tp_size = tp_size
self.local_engine_id = local_engine_id
self.side_channel_host = get_ip()
self.side_channel_port = side_channel_port
self.lock = threading.Lock()
self.done_requests = set[str]()
self.ready_event = ready_event
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.lock:
finished_requests = self.done_requests
self.done_requests = set()
return finished_requests
def run(self):
"""Run the thread to handle KV cache transfer requests."""
#TODO layerwise step9
# with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
# while True:
# recv_msg from prefill request send finish=
# Listen for new requests for metadata.
# NOTE(rob): we need each rank to have a unique port. This hack to keeps
# us moving. We will switch when moving to etcd or where we have a
# single ZMQ socket in the scheduler.
handshake_port = self.side_channel_port + self.tp_rank
path = make_zmq_path("tcp", self.side_channel_host, handshake_port)
logger.info("Starting listening on path: %s", path)
with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore
self.ready_event.set()
decoder = msgspec.msgpack.Decoder(type=str)
while True:
try:
frames = sock.recv_multipart()
if len(frames) < 2:
logger.error("Invalid message format: %s", frames)
continue
identity = frames[0]
payload = [f for f in frames[1:] if f != b""]
if len(payload) != 1:
logger.error("Invalid message format: %s", frames)
continue
request_id = decoder.decode(payload[0])
with self.lock:
self.done_requests.add(request_id)
sock.send_multipart((identity, b"", b"ACK"))
except Exception as e:
logger.error("Failed to decode message: %s", e)
class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
def __init__(self):
self.requests: dict[str, ReqMeta] = {}
self.requests_to_send: dict[str, float] = {}
def add_new_req(self,
request_id: str,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
metaserver=None):
self.requests[request_id] = ReqMeta(
local_block_ids=local_block_ids,
remote_block_ids=kv_transfer_params.get("remote_block_ids", None),
remote_engine_id=kv_transfer_params["remote_engine_id"],
remote_host=kv_transfer_params["remote_host"],
remote_port=kv_transfer_params["remote_port"],
metaserver=metaserver,
remote_tp_size=kv_transfer_params.get("remote_tp_size", None),
)
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler: Optional[MooncakeLayerwiseConnectorScheduler] = \
MooncakeLayerwiseConnectorScheduler(vllm_config, str(self.engine_id))
self.connector_worker: Optional[
MooncakeLayerwiseConnectorWorker] = None
elif role == KVConnectorRole.WORKER:
self.connector_scheduler = None
self.connector_worker = MooncakeLayerwiseConnectorWorker(
vllm_config, str(self.engine_id))
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
def get_finished_count(self) -> Optional[int]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_finished_count()
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
return self.connector_worker.get_finished()
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
MooncakeLayerwiseConnectorMetadata)
self.connector_worker.start_load_kv(self._connector_metadata)
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeLayerwiseConnector does not do layerwise saving."""
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
MooncakeLayerwiseConnectorMetadata)
self.connector_worker.wait_for_layer_load(layer_name)
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""MooncakeLayerwiseConnector does not save explicitly."""
assert self.connector_worker is not None
assert isinstance(self._connector_metadata,
MooncakeLayerwiseConnectorMetadata)
self.connector_worker.save_kv_layer(layer_name, kv_layer,
attn_metadata,
self._connector_metadata)
def wait_for_save(self):
"""MooncakeLayerwiseConnector does not save explicitly."""
pass
class MooncakeLayerwiseConnectorScheduler:
"""Implementation of Scheduler side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.engine_id = engine_id
logger.info("Initializing Mooncake Scheduler %s", engine_id)
self.side_channel_host = get_ip()
self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \
vllm_config.parallel_config.data_parallel_size
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank_local *
vllm_config.parallel_config.tensor_parallel_size)
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int]]] = {}
self._reqs_need_send: dict[str, float] = {}
self._reqs_need_send_layerwise: dict[str, tuple[str, int,
list[int]]] = {}
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
"""
For remote prefill, pull all prompt blocks from remote
asynchronously relative to engine execution.
Args:
request (Request): the request object.
num_computed_tokens (int): the number of locally
computed tokens for this request
Returns:
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed.
* true if the external KV cache tokens will be loaded
asynchronously (between scheduler steps).
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeLayerwiseConnector get_num_new_matched_tokens: "
"num_computed_tokens=%s, kv_transfer_params=%s",
num_computed_tokens, params)
if params is not None and params.get("do_remote_prefill"):
assert num_computed_tokens == 0, "Currently only support " \
"prefill with num_computed_tokens == 0."
# Assume that the request's KV cache is already fully prefilled and
# can be fetched entirely from the prefill node.
count = len(request.prompt_token_ids)
if count > 0:
return count, True
# No remote prefill for this request.
return 0, False
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
params = request.kv_transfer_params
logger.debug(
"MooncakeLayerwiseConnector update_state_after_alloc: "
"num_external_tokens=%s, kv_transfer_params=%s",
num_external_tokens, params)
if params is not None and params.get("do_remote_prefill"):
if all(p in params for p in ("remote_engine_id", "remote_host",
"remote_port")):
local_block_ids = (blocks.get_unhashed_block_ids()
if num_external_tokens > 0 else [])
# Get unhashed blocks to pull from remote.
self._reqs_need_recv[request.request_id] = (request,
local_block_ids)
else:
logger.warning(
"Got invalid KVTransferParams: %s. This "
"request will not utilize KVTransfer", params)
params["do_remote_prefill"] = False
# Layerwise prefiller add request need send
if params is not None and params.get("do_remote_decode"):
local_block_ids = (blocks.get_block_ids()[0])
self._reqs_need_send_layerwise[request.request_id] = (
params["metaserver"], len(request.all_token_ids),
local_block_ids)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
meta = MooncakeLayerwiseConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
for req_id, new_blocks in zip(cached_reqs.req_ids,
cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
metaserver, total_tokens, block_ids = self._reqs_need_send_layerwise[
req_id]
block_ids.extend(new_blocks[0])
computed_tokens = dict(
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) +
[(x.req_id, x.num_computed_tokens) for x in new_reqs])
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
):
if req_id in self._reqs_need_send_layerwise:
metaserver, total_tokens, block_ids = self._reqs_need_send_layerwise[
req_id]
current_tokens = computed_tokens.get(req_id,
0) + scheduled_tokens
if current_tokens == total_tokens:
meta.add_new_req(
request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=defaultdict(lambda: None),
metaserver=metaserver)
self._reqs_need_send_layerwise.pop(req_id)
meta.requests_to_send = self._reqs_need_send
self._reqs_need_send = {}
return meta
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
"""
Once a request is finished, determine whether request blocks
should be freed now or will be sent asynchronously and freed later.
"""
params = request.kv_transfer_params
logger.debug(
"MooncakeLayerwiseConnector request_finished, request_status=%s, "
"kv_transfer_params=%s", request.status, params)
if (params is None or not params.get("do_remote_decode")
or request.status != RequestStatus.FINISHED_LENGTH_CAPPED):
return False, None
computed_block_ids = block_ids
delay_free_blocks = len(computed_block_ids) > 0
if delay_free_blocks:
logger.info("Delaying free of %d blocks for request %s",
len(computed_block_ids), request.request_id)
self._reqs_need_send[request.request_id] = time.time()
return delay_free_blocks, dict(
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
remote_block_ids=computed_block_ids,
)
def get_finished_count(self) -> Optional[int]:
prefill_parallel_config: dict[
str,
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {})
assert "tp_size" in prefill_parallel_config.keys()
self._prefill_tp_size = prefill_parallel_config["tp_size"]
decode_parallel_config: dict[
str,
Any] = self.vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {})
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
if self.vllm_config.model_config.use_mla:
return self._decode_tp_size
else:
# TODO support mha and gqa
return None
class MooncakeLayerwiseConnectorWorker:
"""Implementation of Worker side methods"""
def __init__(self, vllm_config: VllmConfig, engine_id: str):
self._get_prefill_decode_size(vllm_config)
if self._prefill_tp_size < self._decode_tp_size:
raise ValueError(
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
f" or equal to the decode_tp_size: {self._decode_tp_size}")
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = TransferEngine()
# Metadata.
self.completion_event: threading.Event
self.vllm_config = vllm_config
self.engine_id = engine_id
self.tp_rank = get_tensor_model_parallel_rank()
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
self.tp_group = get_tp_group()
self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local
self.dp_size = vllm_config.parallel_config.data_parallel_size_local
self.kv_caches: dict[str, torch.Tensor] = {}
self.side_channel_host = get_ip()
self.max_device_id = self.tp_size * self.dp_size
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.total_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.executor = ThreadPoolExecutor(1)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=None) if self.tp_rank == 0 else None
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
vllm_config.parallel_config.data_parallel_rank_local *
vllm_config.parallel_config.tensor_parallel_size)
self.handshake_port = self.side_channel_port + self.tp_rank
self.sockets: dict = {}
# get tp device id
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
# introducing some changes
device_ids_str = envs_ascend.PHYSICAL_DEVICES
if device_ids_str is None:
device_ids = list(
range(self.dp_rank * self.tp_size,
(self.dp_rank + 1) * self.tp_size))
else:
device_ids = list(map(int, device_ids_str.split(',')))
start_index = self.dp_rank * self.tp_size
end_index = start_index + self.tp_size
if len(device_ids) < end_index:
raise ValueError(
f"Not enough physical devices available for DP rank {self.dp_rank}. "
f"Expected at least {end_index} devices, but found {len(device_ids)} "
"in PHYSICAL_DEVICES.")
device_ids = device_ids[start_index:end_index]
assert len(device_ids) > self.tp_rank # type: ignore
self.device_id = device_ids[self.tp_rank] # type: ignore
if vllm_config.kv_transfer_config.get_from_extra_config(
'use_ascend_direct', False):
hostname = self.side_channel_host
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
self._initialize(hostname=hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches.
self.kv_send_layer_thread: Optional[KVCacheSendingLayerThread] = None
self.kv_recv_layer_thread: Optional[KVCacheRecvingLayerThread] = None
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
self.kv_caches_base_addr: list[int] = []
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
self.pd_head_ratio = get_ascend_config().pd_head_ratio
self.first_kv_cache = None
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
# get prefill tp and dp size from extra config
prefill_parallel_config: dict[
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {})
assert "tp_size" in prefill_parallel_config.keys()
self._prefill_tp_size = prefill_parallel_config["tp_size"]
assert "dp_size" in prefill_parallel_config.keys()
self._prefill_dp_size = prefill_parallel_config["dp_size"]
# get decode tp and dp size from extra config
decode_parallel_config: dict[
str, Any] = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {})
assert "tp_size" in decode_parallel_config.keys()
self._decode_tp_size = decode_parallel_config["tp_size"]
assert "dp_size" in decode_parallel_config.keys()
self._decode_dp_size = decode_parallel_config["dp_size"]
def _initialize(
self,
hostname: str,
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
device_name = device_name if device_name is not None else ""
ret_value = self.engine.initialize(hostname, "P2PHANDSHAKE", "ascend",
device_name)
if ret_value != 0:
raise RuntimeError(
f"Mooncake initialization failed with ret_value: {ret_value}")
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data."""
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
first_kv_cache = first_kv_cache_tuple[0]
self.first_kv_cache = first_kv_cache
# TODO(tms): Find a more robust way to detect and handle MLA
self.use_mla = first_kv_cache_tuple[0].size(
-1) != first_kv_cache_tuple[1].size(-1)
if self.use_mla:
# MLA case.[num_block, block_size, 1, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
self.block_len = [
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
first_kv_cache[1].element_size() * math.prod(block_shape_pe)
]
logger.info(
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
self.num_blocks, block_shape_norm, block_shape_pe)
else:
# [num_block, block_size, num_head, hidden_dim]
self.num_blocks = first_kv_cache.shape[0]
kv_elem_size = first_kv_cache.element_size()
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_len = [kv_elem_size * math.prod(block_shape)]
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
block_shape)
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
self.use_mla, first_kv_cache.shape)
self.kv_caches = kv_caches
kv_caches_base_addr = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
self.kv_caches_base_addr = kv_caches_base_addr
# After KV Caches registered, start the sending or receiving thread.
metadata = MooncakeAgentMetadata(
engine_id=self.engine_id,
te_rpc_port=self.te_rpc_port,
kv_caches_base_addr=kv_caches_base_addr,
num_blocks=self.num_blocks,
)
ready_event = threading.Event()
if self.kv_role == 'kv_producer':
self.kv_send_layer_thread = KVCacheSendingLayerThread(
self.tp_rank, self.tp_size, self._decode_tp_size,
self.engine_id, self.side_channel_host, self.side_channel_port,
metadata, ready_event, self.total_layers, self.engine,
kv_caches_base_addr, self.block_len, self.use_mla,
self.first_kv_cache)
self.kv_send_layer_thread.start()
else:
self.kv_recv_layer_thread = KVCacheRecvingLayerThread(
self.tp_rank, self.side_channel_port, self.tp_size,
self.engine_id, ready_event)
self.kv_recv_layer_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.info(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
ret_value = self.engine.register_memory(ptr, length)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed.")
def _access_metaserver(self, url, message):
self.metaserver_client.post(url, json=message)
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_layer_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role == 'kv_producer' else set())
done_recving = (
self.kv_recv_layer_thread.
get_and_clear_finished_requests( # type: ignore[union-attr]
) if self.kv_role == 'kv_consumer' else set())
if self.tp_rank == 0:
logger.debug(
"Number of completed KV cache send requests: %d, receive "
"requests: %d", len(done_sending), len(done_recving))
return done_sending, done_recving
def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
"""Start loading KV blocks from remote engine."""
self.current_layer = 0
if self.vllm_config.kv_transfer_config.is_kv_producer:
for req_id, meta in metadata.requests.items():
logger.debug(
f"Send request: {req_id} to proxy metaserver: {meta.metaserver}"
)
if self.tp_rank == 0:
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
request_id=req_id,
do_remote_prefill=True,
do_remote_decode=False,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port)
future = self.executor.submit(
self._access_metaserver,
url=meta.metaserver,
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}"
)
future.add_done_callback(handle_exception)
else:
for req_id, meta in metadata.requests.items():
for offset in range(self.pd_tp_ratio):
path = make_zmq_path(
"tcp", meta.remote_host, meta.remote_port +
self.tp_rank * self.pd_tp_ratio + offset)
logger.info(
f"Notify the prefiller: {path} that request: {req_id} from decoder is ready."
)
msg_encoder = msgspec.msgpack.Encoder()
docode_metadata = DecodeMooncakeAgentMetadata(
req_id=req_id,
block_ids=meta.local_block_ids,
port=self.handshake_port,
host=self.side_channel_host,
engine_id=self.engine_id,
te_rpc_port=self.te_rpc_port,
kv_caches_base_addr=self.kv_caches_base_addr,
num_blocks=self.num_blocks)
encoded_data = msg_encoder.encode(docode_metadata)
size_in_bytes = len(encoded_data)
logger.debug(
"Size of encoded Mooncake agent metadata: %d bytes",
size_in_bytes)
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
ensure_zmq_send(sock, encoded_data)
ack = sock.recv()
if ack != b"ACK":
raise ValueError(
f"Unexpected ACK from prefill node: {ack}")
if self.kv_send_layer_thread is not None:
for req_id, delay_start_time in metadata.requests_to_send.items():
if self.tp_rank in self._get_remote_tp_ranks_for_req(req_id):
self.kv_send_layer_thread.add_delayed_request(
req_id, delay_start_time)
def save_kv_layer(self, layer_name: str, kv_layer: Tuple[torch.Tensor,
torch.Tensor],
attn_metadata: "AttentionMetadata",
connector_metadata: MooncakeLayerwiseConnectorMetadata,
**kwargs) -> None:
"""MooncakeLayerwiseConnector does not save explicitly."""
if self.kv_role == 'kv_producer':
if self.pd_head_ratio != 1:
if self.current_layer != 0:
self.completion_event.wait()
self.completion_event = threading.Event()
if self.kv_send_layer_thread is not None:
self.kv_send_layer_thread.send_layer_thread.add_event(
self.completion_event,
len(connector_metadata.requests.keys()))
def sort_kv_cache(input_kv: list[list[int]]):
return torch.cat([
torch.chunk(tensor, self.pd_head_ratio, dim=0)[x]
for x in range(self.pd_head_ratio)
for tensor in input_kv
])
total_block_ids = [
request.local_block_ids
for request in connector_metadata.requests.values()
]
keys = [
kv_layer[0][block_ids].reshape(
-1, *kv_layer[0].shape[2:]).clone()
for block_ids in total_block_ids
]
values = [
kv_layer[1][block_ids].reshape(
-1, *kv_layer[1].shape[2:]).clone()
for block_ids in total_block_ids
]
key_block_size = keys[0].size(0) // len(total_block_ids[0])
value_block_size = values[0].size(0) // len(total_block_ids[0])
keys = sort_kv_cache(keys) # [req1_key, req2_key]
values = sort_kv_cache(values)
(keys,
values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys,
values)
key_start_id = 0
value_start_id = 0
else:
key = None
value = None
for req_id, request in connector_metadata.requests.items():
logger.info(f"Add request {req_id} to kv send layer thread. ")
if self.pd_head_ratio != 1:
key_block_num = len(
request.local_block_ids) * key_block_size
value_block_num = len(
request.local_block_ids) * value_block_size
key = keys[key_start_id:key_start_id +
key_block_num] #.clone().contiguous()
value = values[value_start_id:value_start_id +
value_block_num] #.clone().contiguous()
key_start_id += key_block_num
value_start_id += value_block_num
if self.kv_send_layer_thread is not None:
self.kv_send_layer_thread.add_request(
request_id=req_id,
local_block_ids=request.local_block_ids,
layer_index=self.current_layer,
key=key,
value=value)
self.current_layer += 1
def wait_for_layer_load(self, layer_name: str) -> None:
pass
def _get_remote_tp_rank(self, req_id: str) -> int:
return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank]
def _get_remote_tp_ranks_for_req(self, req_id: str) -> list[int]:
if self._prefill_tp_size == self._decode_tp_size:
return list(range(self._prefill_tp_size))
seed = string_to_int64_hash(req_id)
rand = random.Random(seed)
sampled_nums = rand.sample(range(self._prefill_tp_size),
self._decode_tp_size)
return sampled_nums
@contextlib.contextmanager
def zmq_ctx(socket_type: Any,
addr: str) -> Iterator[zmq.Socket]: # type: ignore
"""Context manager for a ZMQ socket"""
if socket_type not in (zmq.ROUTER, zmq.REQ, zmq.DEALER): # type: ignore
raise ValueError(f"Unexpected socket type: {socket_type}")
ctx: Optional[zmq.Context] = None # type: ignore
try:
ctx = zmq.Context() # type: ignore
yield make_zmq_socket(ctx=ctx,
path=addr,
socket_type=socket_type,
bind=socket_type == zmq.ROUTER) # type: ignore
finally:
if ctx is not None:
ctx.destroy(linger=0)
def group_concurrent_contiguous(
src: List[int],
dst: List[int] = []
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
"""Vectorised NumPy implementation."""
if not dst:
src_only_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64)
if src_only_indices.size == 0:
return [], []
brk = np.where((np.diff(src_only_indices) != 1))[0] + 1
src_groups = np.split(src_only_indices, brk)
src_groups = [g.tolist() for g in src_groups]
return src_groups, []
else:
src_indices: npt.NDArray[np.int64] = np.array(src, dtype=np.int64)
dst_indices: npt.NDArray[np.int64] = np.array(dst, dtype=np.int64)
if src_indices.size == 0:
return [], []
brk = np.where((np.diff(src_indices) != 1)
| (np.diff(dst_indices) != 1))[0] + 1
src_groups = np.split(src_indices, brk)
dst_groups = np.split(dst_indices, brk)
src_groups = [g.tolist() for g in src_groups]
dst_groups = [g.tolist() for g in dst_groups]
return src_groups, dst_groups
def string_to_int64_hash(input_str):
"""
Hash the string using SHA-256 and convert it into an int64 integer.
"""
hashed_bytes = hashlib.sha256(input_str.encode("utf-8")).digest()
trunked_bytes = hashed_bytes[:8]
uint64_value = struct.unpack("<Q", trunked_bytes)[0]
return uint64_value
def ensure_zmq_send(
socket: zmq.Socket, # type: ignore
data: bytes,
max_retries: int = 3):
retries_left = max_retries
while True:
try:
socket.send(data)
return
except zmq.ZMQError as e: # type: ignore
retries_left -= 1
if retries_left > 0:
logger.warning(
f"Send failed: {e}, retrying... ({retries_left} "
"attempts left)")
time.sleep(0.1)
else:
logger.error(f"Send failed after all retries: {e}")
raise RuntimeError(f"Failed to send data after {max_retries} "
f"retries: {e}")
def ensure_zmq_recv(
socket: zmq.Socket, # type: ignore
poller: zmq.Poller, # type: ignore
timeout: float = 1.0,
max_retries: int = 3) -> bytes:
retries_left = max_retries
while True:
try:
if dict(poller.poll(int(timeout * 1000))): # milliseconds
data = socket.recv()
return data
else:
raise zmq.ZMQError("Receive timeout") # type: ignore
except zmq.ZMQError as e: # type: ignore
retries_left -= 1
if retries_left > 0:
logger.warning(f"Receive failed: {e}, retrying... "
f"({retries_left} attempts left)")
time.sleep(0.1)
else:
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")