mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Metrics] add support for kv event publishing (#16750)
Signed-off-by: alec-flowers <aflowers@nvidia.com> Signed-off-by: Mark McLoughlin <markmc@redhat.com> Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
86
examples/online_serving/kv_events.sh
Normal file
86
examples/online_serving/kv_events.sh
Normal file
@ -0,0 +1,86 @@
|
||||
#!/bin/bash
|
||||
# This file demonstrates the KV cache event publishing
|
||||
# We will launch a vllm instances configured to publish KV cache
|
||||
# events and launch a simple subscriber to log those events.
|
||||
|
||||
set -xe
|
||||
|
||||
echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
|
||||
sleep 1
|
||||
|
||||
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
|
||||
|
||||
# Trap the SIGINT signal (triggered by Ctrl+C)
|
||||
trap 'cleanup' INT
|
||||
|
||||
# Cleanup function
|
||||
cleanup() {
|
||||
echo "Caught Ctrl+C, cleaning up..."
|
||||
# Cleanup commands
|
||||
pgrep python | xargs kill -9
|
||||
pkill -f python
|
||||
echo "Cleanup complete. Exiting."
|
||||
exit 0
|
||||
}
|
||||
|
||||
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
|
||||
|
||||
# a function that waits vLLM server to start
|
||||
wait_for_server() {
|
||||
local port=$1
|
||||
timeout 1200 bash -c "
|
||||
until curl -s localhost:${port}/v1/completions > /dev/null; do
|
||||
sleep 1
|
||||
done" && return 0 || return 1
|
||||
}
|
||||
|
||||
vllm serve $MODEL_NAME \
|
||||
--port 8100 \
|
||||
--max-model-len 100 \
|
||||
--enforce-eager \
|
||||
--gpu-memory-utilization 0.8 \
|
||||
--trust-remote-code \
|
||||
--kv-events-config \
|
||||
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &
|
||||
|
||||
wait_for_server 8100
|
||||
|
||||
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
|
||||
|
||||
python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
|
||||
sleep 1
|
||||
|
||||
# serve two example requests
|
||||
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'"$MODEL_NAME"'",
|
||||
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
|
||||
"max_tokens": 80,
|
||||
"temperature": 0
|
||||
}')
|
||||
|
||||
output2=$(curl -X POST -s http://localhost:8100/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "'"$MODEL_NAME"'",
|
||||
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
|
||||
"max_tokens": 80,
|
||||
"temperature": 0
|
||||
}')
|
||||
|
||||
# Cleanup commands
|
||||
pkill -9 -u "$USER" -f python
|
||||
pkill -9 -u "$USER" -f vllm
|
||||
|
||||
sleep 1
|
||||
|
||||
echo "Cleaned up"
|
||||
|
||||
# Print the outputs of the curl requests
|
||||
echo ""
|
||||
echo "Output of first request: $output1"
|
||||
echo "Output of second request: $output2"
|
||||
|
||||
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
|
||||
echo ""
|
114
examples/online_serving/kv_events_subscriber.py
Normal file
114
examples/online_serving/kv_events_subscriber.py
Normal file
@ -0,0 +1,114 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
from msgspec.msgpack import Decoder
|
||||
|
||||
|
||||
#
|
||||
# Types copied from vllm.distributed.kv_events
|
||||
#
|
||||
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
|
||||
gc=False):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
|
||||
|
||||
class KVCacheEvent(msgspec.Struct,
|
||||
array_like=True,
|
||||
omit_defaults=True,
|
||||
gc=False,
|
||||
tag=True):
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
|
||||
|
||||
class KVEventBatch(EventBatch):
|
||||
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||
|
||||
|
||||
def process_event(event_batch):
|
||||
print(f"Received event batch at {event_batch.ts}:")
|
||||
for event in event_batch.events:
|
||||
print(f" - {event}")
|
||||
|
||||
|
||||
def main():
|
||||
decoder = Decoder(type=KVEventBatch)
|
||||
last_seq = -1
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
# Set up the main subscription socket
|
||||
sub = context.socket(zmq.SUB)
|
||||
sub.connect("tcp://localhost:5557")
|
||||
topic = "kv-events"
|
||||
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
|
||||
|
||||
# Initialize replay socket
|
||||
replay = context.socket(zmq.REQ)
|
||||
replay.connect("tcp://localhost:5558")
|
||||
poller = zmq.Poller()
|
||||
poller.register(replay, zmq.POLLIN)
|
||||
|
||||
print("Listening for KV cache events on topic:", topic)
|
||||
|
||||
while True:
|
||||
try:
|
||||
if sub.poll(50):
|
||||
_, seq_bytes, payload = sub.recv_multipart()
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
|
||||
if last_seq >= 0 and seq > last_seq + 1:
|
||||
missed = seq - last_seq - 1
|
||||
print(f"Missed {missed} messages"
|
||||
f" (last: {last_seq}, current: {seq})")
|
||||
|
||||
replay.send((last_seq + 1).to_bytes(8, "big"))
|
||||
|
||||
while poller.poll(timeout=200):
|
||||
seq_bytes, replay_payload = replay.recv_multipart()
|
||||
if not replay_payload:
|
||||
# End of replay marker is sent as an empty frame
|
||||
# for the payload
|
||||
break
|
||||
|
||||
replay_seq = int.from_bytes(seq_bytes, "big")
|
||||
|
||||
if replay_seq > last_seq:
|
||||
event_batch = decoder.decode(replay_payload)
|
||||
process_event(event_batch)
|
||||
last_seq = replay_seq
|
||||
if replay_seq >= seq - 1:
|
||||
break
|
||||
|
||||
event_batch = decoder.decode(payload)
|
||||
process_event(event_batch)
|
||||
|
||||
# ... do other periodic work or check for shutdown ...
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Interrupted")
|
||||
break
|
||||
except Exception as e:
|
||||
print("Error decoding message:", e)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
145
tests/distributed/conftest.py
Normal file
145
tests/distributed/conftest.py
Normal file
@ -0,0 +1,145 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import random
|
||||
from typing import Optional, Union
|
||||
|
||||
import msgspec
|
||||
import msgspec.msgpack
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVEventsConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory
|
||||
|
||||
from .test_events import SampleBatch
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_port():
|
||||
"""Generate a random port number for testing"""
|
||||
return random.randint(10000, 60000)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher_config(random_port, request):
|
||||
"""Create a publisher config with inproc transport"""
|
||||
how = request.param if hasattr(request, "param") else "inproc"
|
||||
|
||||
if how == "inproc":
|
||||
endpoint = f"inproc://test-{random_port}"
|
||||
replay_endpoint = endpoint + "-replay"
|
||||
else:
|
||||
endpoint = f"tcp://*:{random_port}"
|
||||
replay_endpoint = f"tcp://*:{random_port + 1}"
|
||||
|
||||
return KVEventsConfig(enable_kv_cache_events=True,
|
||||
publisher="zmq",
|
||||
endpoint=endpoint,
|
||||
replay_endpoint=replay_endpoint,
|
||||
buffer_steps=100,
|
||||
hwm=1000,
|
||||
topic="test")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def publisher(publisher_config):
|
||||
"""Create and return a publisher instance"""
|
||||
pub = EventPublisherFactory.create(publisher_config)
|
||||
yield pub
|
||||
pub.shutdown()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def subscriber(publisher_config):
|
||||
"""Create and return a subscriber for testing"""
|
||||
endpoint = publisher_config.endpoint
|
||||
replay_endpoint = publisher_config.replay_endpoint
|
||||
|
||||
if endpoint.startswith("tcp://*"):
|
||||
endpoint = endpoint.replace("*", "127.0.0.1")
|
||||
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
|
||||
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
|
||||
|
||||
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
|
||||
yield sub
|
||||
sub.close()
|
||||
|
||||
|
||||
class MockSubscriber:
|
||||
"""Helper class to receive and verify published events"""
|
||||
|
||||
def __init__(self,
|
||||
pub_endpoint: str,
|
||||
replay_endpoint: Optional[str] = None,
|
||||
topic: str = "",
|
||||
decode_type=SampleBatch):
|
||||
self.ctx = zmq.Context.instance()
|
||||
|
||||
# Set up subscriber socket
|
||||
self.sub = self.ctx.socket(zmq.SUB)
|
||||
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
|
||||
self.sub.connect(pub_endpoint)
|
||||
|
||||
# Set up replay socket if provided
|
||||
self.replay = None
|
||||
if replay_endpoint:
|
||||
self.replay = self.ctx.socket(zmq.REQ)
|
||||
self.replay.connect(replay_endpoint)
|
||||
|
||||
self.topic = topic
|
||||
self.topic_bytes = topic.encode('utf-8')
|
||||
self.received_msgs: list[tuple[int, SampleBatch]] = []
|
||||
self.last_seq = -1
|
||||
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
|
||||
|
||||
def receive_one(self,
|
||||
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
|
||||
"""Receive a single message with timeout"""
|
||||
if not self.sub.poll(timeout):
|
||||
return None
|
||||
|
||||
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
|
||||
assert topic_bytes == self.topic_bytes
|
||||
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
data = self.decoder.decode(payload)
|
||||
self.last_seq = seq
|
||||
self.received_msgs.append((seq, data))
|
||||
return seq, data
|
||||
|
||||
def request_replay(self, start_seq: int) -> None:
|
||||
"""Request replay of messages starting from start_seq"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
|
||||
self.replay.send(start_seq.to_bytes(8, "big"))
|
||||
|
||||
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
|
||||
"""Receive replayed messages"""
|
||||
if not self.replay:
|
||||
raise ValueError("Replay socket not initialized")
|
||||
|
||||
replayed: list[tuple[int, SampleBatch]] = []
|
||||
while True:
|
||||
try:
|
||||
if not self.replay.poll(1000):
|
||||
break
|
||||
|
||||
frames = self.replay.recv_multipart()
|
||||
if not frames or not frames[-1]:
|
||||
# End of replay marker
|
||||
break
|
||||
|
||||
seq_bytes, payload = frames
|
||||
seq = int.from_bytes(seq_bytes, "big")
|
||||
data = self.decoder.decode(payload)
|
||||
replayed.append((seq, data))
|
||||
except zmq.ZMQError as _:
|
||||
break
|
||||
|
||||
return replayed
|
||||
|
||||
def close(self):
|
||||
"""Clean up resources"""
|
||||
self.sub.close()
|
||||
if self.replay:
|
||||
self.replay.close()
|
193
tests/distributed/test_events.py
Normal file
193
tests/distributed/test_events.py
Normal file
@ -0,0 +1,193 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import threading
|
||||
import time
|
||||
|
||||
import msgspec
|
||||
import pytest
|
||||
|
||||
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
|
||||
NullEventPublisher)
|
||||
|
||||
|
||||
class EventSample(
|
||||
msgspec.Struct,
|
||||
tag=True, # type: ignore
|
||||
array_like=True # type: ignore
|
||||
):
|
||||
"""Test event for publisher testing"""
|
||||
id: int
|
||||
value: str
|
||||
|
||||
|
||||
class SampleBatch(EventBatch):
|
||||
"""Test event batch for publisher testing"""
|
||||
events: list[EventSample]
|
||||
|
||||
|
||||
def create_test_events(count: int) -> SampleBatch:
|
||||
"""Create a batch of test events"""
|
||||
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
|
||||
return SampleBatch(ts=time.time(), events=events)
|
||||
|
||||
|
||||
def test_basic_publishing(publisher, subscriber):
|
||||
"""Test basic event publishing works"""
|
||||
|
||||
test_batch = create_test_events(5)
|
||||
publisher.publish(test_batch)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert received.ts == pytest.approx(test_batch.ts,
|
||||
abs=0.1), ("Timestamp mismatch")
|
||||
assert len(received.events) == len(
|
||||
test_batch.events), ("Number of events mismatch")
|
||||
|
||||
for i, event in enumerate(received.events):
|
||||
assert event.id == i, "Event id mismatch"
|
||||
assert event.value == f"test-{i}", "Event value mismatch"
|
||||
|
||||
|
||||
def test_multiple_events(publisher, subscriber):
|
||||
"""Test publishing and receiving multiple event batches"""
|
||||
for _ in range(10):
|
||||
batch = create_test_events(2)
|
||||
publisher.publish(batch)
|
||||
|
||||
received = []
|
||||
for _ in range(10):
|
||||
data = subscriber.receive_one(timeout=100)
|
||||
if data:
|
||||
received.append(data)
|
||||
|
||||
assert len(received) == 10, "Number of messages mismatch"
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert seqs == list(range(10)), "Sequence numbers mismatch"
|
||||
|
||||
|
||||
def test_replay_mechanism(publisher, subscriber):
|
||||
"""Test the replay mechanism works correctly"""
|
||||
for _ in range(19):
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
time.sleep(0.5) # Need publisher to process above requests
|
||||
subscriber.request_replay(10)
|
||||
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch) # 20th message
|
||||
|
||||
replayed = subscriber.receive_replay()
|
||||
|
||||
assert len(replayed) > 0, "No replayed messages received"
|
||||
seqs = [seq for seq, _ in replayed]
|
||||
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
|
||||
assert seqs == list(range(min(seqs),
|
||||
max(seqs) +
|
||||
1)), ("Replayed messages not consecutive")
|
||||
|
||||
|
||||
def test_buffer_limit(publisher, subscriber, publisher_config):
|
||||
"""Test buffer limit behavior"""
|
||||
buffer_size = publisher_config.buffer_steps
|
||||
|
||||
# Publish more events than the buffer can hold
|
||||
for i in range(buffer_size + 10):
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
time.sleep(0.5) # Need publisher to process above requests
|
||||
subscriber.request_replay(0)
|
||||
|
||||
batch = create_test_events(1)
|
||||
publisher.publish(batch)
|
||||
|
||||
replayed = subscriber.receive_replay()
|
||||
|
||||
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
|
||||
|
||||
oldest_seq = min(seq for seq, _ in replayed)
|
||||
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
|
||||
|
||||
|
||||
def test_topic_filtering(publisher_config):
|
||||
"""
|
||||
Test that a subscriber only receives messages matching its topic filter
|
||||
"""
|
||||
publisher_config.replay_endpoint = None
|
||||
|
||||
cfg = publisher_config.model_copy()
|
||||
cfg.topic = "foo"
|
||||
pub = EventPublisherFactory.create(cfg)
|
||||
|
||||
from .conftest import MockSubscriber
|
||||
sub_foo = MockSubscriber(cfg.endpoint, None, "foo")
|
||||
sub_bar = MockSubscriber(cfg.endpoint, None, "bar")
|
||||
|
||||
try:
|
||||
time.sleep(0.1)
|
||||
|
||||
for _ in range(3):
|
||||
pub.publish(create_test_events(1))
|
||||
|
||||
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is not None for msg in foo_received), (
|
||||
"Subscriber with matching topic should receive messages")
|
||||
|
||||
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
|
||||
assert all(msg is None for msg in bar_received), (
|
||||
"Subscriber with non-matching topic should receive no messages")
|
||||
finally:
|
||||
pub.shutdown()
|
||||
sub_foo.close()
|
||||
sub_bar.close()
|
||||
|
||||
|
||||
def test_high_volume(publisher, subscriber):
|
||||
"""Test publishing and receiving a high volume of events"""
|
||||
num_batches = 10_000
|
||||
events_per_batch = 100
|
||||
|
||||
# Publish events in a separate thread to not block
|
||||
def publish_events():
|
||||
for i in range(num_batches):
|
||||
batch = create_test_events(events_per_batch)
|
||||
publisher.publish(batch)
|
||||
# Small delay to avoid overwhelming
|
||||
if i % 100 == 0:
|
||||
time.sleep(0.01)
|
||||
|
||||
received: list[tuple[int, SampleBatch]] = []
|
||||
|
||||
publisher_thread = threading.Thread(target=publish_events)
|
||||
publisher_thread.start()
|
||||
|
||||
start_time = time.time()
|
||||
while len(received) < num_batches:
|
||||
if time.time() - start_time > 10: # Timeout after 10 seconds
|
||||
break
|
||||
|
||||
result = subscriber.receive_one(timeout=100)
|
||||
if result:
|
||||
received.append(result)
|
||||
|
||||
publisher_thread.join()
|
||||
|
||||
assert len(received) >= num_batches * 0.9, (
|
||||
"We should have received most messages")
|
||||
|
||||
seqs = [seq for seq, _ in received]
|
||||
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
|
||||
|
||||
|
||||
def test_null_publisher():
|
||||
"""Test that NullEventPublisher can be used without errors"""
|
||||
publisher = NullEventPublisher()
|
||||
|
||||
# This should not raise any errors
|
||||
batch = create_test_events(5)
|
||||
publisher.publish(batch)
|
||||
publisher.shutdown()
|
@ -6,6 +6,7 @@ from typing import Optional
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
|
||||
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import sha256
|
||||
@ -48,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
|
||||
num_blocks=num_blocks,
|
||||
tensors={},
|
||||
kv_cache_groups=[
|
||||
KVCacheGroupSpec(['layer'],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32,
|
||||
False))
|
||||
KVCacheGroupSpec(
|
||||
["layer"],
|
||||
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@ -783,6 +785,60 @@ def test_prefix_cache_stats_disabled():
|
||||
assert manager.prefix_cache_stats is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
|
||||
def test_kv_cache_events(blocks_to_cache: int):
|
||||
block_size = 16
|
||||
num_blocks = blocks_to_cache + 1
|
||||
|
||||
# Allocate Blocks
|
||||
# Should see a single block stored event with a blocks_to_cache number of
|
||||
# block hashes
|
||||
# take_events should reset the kv_event_queue
|
||||
manager = KVCacheManager(
|
||||
make_kv_cache_config(block_size, num_blocks),
|
||||
max_model_len=8192,
|
||||
enable_caching=True,
|
||||
enable_kv_cache_events=True,
|
||||
)
|
||||
|
||||
num_tokens = block_size * blocks_to_cache
|
||||
req0 = make_request("0", list(range(num_tokens)))
|
||||
_ = manager.allocate_slots(req0, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
block = events[-1]
|
||||
assert (len(block.block_hashes) == blocks_to_cache == len(
|
||||
manager.block_pool.cached_block_hash_to_block))
|
||||
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
|
||||
assert len(manager.block_pool.kv_event_queue) == 0
|
||||
|
||||
stored_block_hash = block.block_hashes
|
||||
|
||||
# Remove blocks and send another request
|
||||
# Should see block_to_cache number of removed block events and a new block
|
||||
# stored event
|
||||
manager.free(req0)
|
||||
req1 = make_request("1", list(range(num_tokens)))
|
||||
_ = manager.allocate_slots(req1, num_tokens)
|
||||
events = manager.take_events()
|
||||
|
||||
for blocks in events[:-1]:
|
||||
assert blocks.block_hashes[0] in stored_block_hash
|
||||
assert len(events) == blocks_to_cache + 1
|
||||
assert (isinstance(events[-2], BlockRemoved))
|
||||
assert (len(events[-1].block_hashes) == blocks_to_cache == len(
|
||||
manager.block_pool.cached_block_hash_to_block))
|
||||
|
||||
# All Blocks Cleared
|
||||
# Should see a single all blocks cleared event
|
||||
manager.free(req1)
|
||||
manager.reset_prefix_cache()
|
||||
events = manager.take_events()
|
||||
|
||||
assert isinstance(events[-1], AllBlocksCleared)
|
||||
assert len(manager.block_pool.cached_block_hash_to_block) == 0
|
||||
|
||||
|
||||
def test_eagle_enabled_removes_last_block():
|
||||
"""Verify Eagle does NOT remove blocks when request
|
||||
length is divisible by block size."""
|
||||
|
@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
|
||||
|
||||
from ...distributed.conftest import publisher_config, random_port # noqa: F401
|
||||
|
||||
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
|
||||
|
||||
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]
|
||||
|
@ -11,6 +11,7 @@ import pytest
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.distributed.kv_events import BlockStored, KVEventBatch
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.usage.usage_lib import UsageContext
|
||||
@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
|
||||
SyncMPClient)
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
||||
from ...distributed.conftest import MockSubscriber
|
||||
from ...utils import create_new_process_for_each_test
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
try:
|
||||
MAX_TOKENS = 20
|
||||
params = SamplingParams(max_tokens=MAX_TOKENS)
|
||||
"""Normal Request Cycle."""
|
||||
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
requests = [make_request(params) for _ in range(10)]
|
||||
request_ids = [req.request_id for req in requests]
|
||||
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
# Add requests to the engine.
|
||||
for request in requests:
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
for req_id in request_ids:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
f"{outputs[req_id]=}, {MAX_TOKENS=}")
|
||||
"""Abort Request Cycle."""
|
||||
|
||||
core_client: AsyncMPClient = client
|
||||
# Add requests to the engine.
|
||||
for idx, request in enumerate(requests):
|
||||
await client.add_request_async(request)
|
||||
await asyncio.sleep(0.01)
|
||||
if idx % 2 == 0:
|
||||
await client.abort_requests_async([request.request_id])
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
outputs = {req_id: [] for req_id in request_ids}
|
||||
await loop_until_done_async(client, outputs)
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
for idx, req_id in enumerate(request_ids):
|
||||
if idx % 2 == 0:
|
||||
assert len(outputs[req_id]) < MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
else:
|
||||
assert len(outputs[req_id]) == MAX_TOKENS, (
|
||||
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
|
||||
"""Utility method invocation"""
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
core_client: AsyncMPClient = client
|
||||
|
||||
result = await core_client.call_utility_async("echo", "testarg")
|
||||
assert result == "testarg"
|
||||
|
||||
with pytest.raises(Exception) as e_info:
|
||||
await core_client.call_utility_async("echo", None, "help!")
|
||||
|
||||
assert str(e_info.value) == "Call to echo method failed: help!"
|
||||
finally:
|
||||
client.shutdown()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"multiprocessing_mode,publisher_config",
|
||||
[(True, "tcp"), (False, "inproc")],
|
||||
indirect=["publisher_config"],
|
||||
)
|
||||
def test_kv_cache_events(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
multiprocessing_mode: bool,
|
||||
publisher_config,
|
||||
):
|
||||
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
block_size = 16
|
||||
num_blocks = 2
|
||||
|
||||
engine_args = EngineArgs(model=MODEL_NAME,
|
||||
enforce_eager=True,
|
||||
enable_prefix_caching=True,
|
||||
block_size=block_size)
|
||||
engine_args.kv_events_config = publisher_config
|
||||
|
||||
vllm_config = engine_args.create_engine_config(
|
||||
UsageContext.UNKNOWN_CONTEXT)
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
client = EngineCoreClient.make_client(
|
||||
multiprocess_mode=multiprocessing_mode,
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
|
||||
time.sleep(0.1)
|
||||
subscriber = MockSubscriber(endpoint,
|
||||
topic=publisher_config.topic,
|
||||
decode_type=KVEventBatch)
|
||||
|
||||
try:
|
||||
custom_tokens = list(range(num_blocks * block_size))
|
||||
request = EngineCoreRequest(
|
||||
request_id=str(uuid.uuid4()),
|
||||
prompt_token_ids=custom_tokens,
|
||||
mm_inputs=None,
|
||||
mm_hashes=None,
|
||||
mm_placeholders=None,
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=1), # Short completion for speed
|
||||
eos_token_id=None,
|
||||
arrival_time=time.time(),
|
||||
lora_request=None,
|
||||
)
|
||||
client.add_request(request)
|
||||
|
||||
outputs: dict[str, list] = {request.request_id: []}
|
||||
loop_until_done(client, outputs)
|
||||
|
||||
result = subscriber.receive_one(timeout=1000)
|
||||
assert result is not None, "No message received"
|
||||
|
||||
seq, received = result
|
||||
|
||||
assert seq == 0, "Sequence number mismatch"
|
||||
assert len(received.events) == 1, (
|
||||
"We should have exactly one BlockStored event")
|
||||
event = received.events[0]
|
||||
assert isinstance(
|
||||
event, BlockStored), ("We should have a BlockStored event")
|
||||
assert len(event.block_hashes) == num_blocks, (
|
||||
"We should have a BlockStored event with 2 block_hashes")
|
||||
assert event.block_size == block_size, (
|
||||
"Block size should be the same as the block size")
|
||||
assert event.parent_block_hash is None, (
|
||||
"Parent block hash should be None")
|
||||
assert event.lora_id is None, "Lora id should be None"
|
||||
assert len(event.token_ids) == num_blocks * block_size, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
assert event.token_ids == custom_tokens, (
|
||||
"Token ids should be the same as the custom tokens")
|
||||
finally:
|
||||
client.shutdown()
|
||||
return
|
||||
|
||||
|
||||
@pytest.mark.timeout(10)
|
||||
|
@ -1958,6 +1958,8 @@ class SchedulerConfig:
|
||||
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
|
||||
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
|
||||
|
||||
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
|
||||
# or "mod.custom_class".
|
||||
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
|
||||
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
|
||||
default scheduler. Can be a class directly or the path to a class of form
|
||||
@ -3417,6 +3419,51 @@ class KVTransferConfig(BaseModel):
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
class KVEventsConfig(BaseModel):
|
||||
"""Configuration for KV event publishing."""
|
||||
|
||||
enable_kv_cache_events: bool = False
|
||||
"""If True, enable KV cache events for tracking block storage and removal.
|
||||
Events can be published externally by zmq using the event publisher config.
|
||||
"""
|
||||
|
||||
publisher: str = "null"
|
||||
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
||||
"""
|
||||
|
||||
endpoint: str = "tcp://*:5557"
|
||||
"""The zmq endpoint to use for publishing kv events.
|
||||
"""
|
||||
|
||||
replay_endpoint: Optional[str] = None
|
||||
"""The zmq endpoint to use for replaying kv events.
|
||||
"""
|
||||
|
||||
buffer_steps: int = 10_000
|
||||
"""The number of steps to cache for replay endpoint. Will only save
|
||||
events from the last N steps for the replay endpoint.
|
||||
"""
|
||||
|
||||
hwm: int = 100_000
|
||||
"""The zmq high water mark for the event publisher. After queueing N events,
|
||||
events will start dropping if the consumer is not keeping up.
|
||||
"""
|
||||
|
||||
max_queue_size: int = 100_000
|
||||
"""The maximum number of events to queue while waiting for publishing.
|
||||
"""
|
||||
|
||||
topic: str = ""
|
||||
"""The topic to use for the event publisher. Consumers can subscribe to
|
||||
this topic to receive events.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
||||
"""Parse the CLI value for the event publisher config."""
|
||||
return KVEventsConfig.model_validate_json(cli_value)
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
NO_COMPILATION = 0
|
||||
@ -3779,6 +3826,7 @@ class VllmConfig:
|
||||
init=True) # type: ignore
|
||||
kv_transfer_config: KVTransferConfig = field(default=None,
|
||||
init=True) # type: ignore
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
# some opaque config, only used to provide additional information
|
||||
# for the hash computation, mainly used for testing, debugging or out of
|
||||
# tree config registration.
|
||||
@ -4038,6 +4086,18 @@ class VllmConfig:
|
||||
if self.cache_config is not None:
|
||||
self.cache_config.enable_prefix_caching = False
|
||||
|
||||
if (self.kv_events_config
|
||||
and self.kv_events_config.enable_kv_cache_events
|
||||
and not self.cache_config.enable_prefix_caching):
|
||||
logger.warning(
|
||||
"KV cache events are on, but prefix caching is not enabled."
|
||||
"Use --enable-prefix-caching to enable.")
|
||||
if (self.kv_events_config and self.kv_events_config.publisher != "null"
|
||||
and not self.kv_events_config.enable_kv_cache_events):
|
||||
logger.warning("KV cache events are disabled,"
|
||||
"but the scheduler is configured to publish them."
|
||||
"Modify KVEventsConfig.enable_kv_cache_events"
|
||||
"to True to enable.")
|
||||
current_platform.check_and_update_config(self)
|
||||
|
||||
if not self.instance_id:
|
||||
|
295
vllm/distributed/kv_events.py
Normal file
295
vllm/distributed/kv_events.py
Normal file
@ -0,0 +1,295 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from itertools import count
|
||||
from queue import Queue
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import zmq
|
||||
|
||||
from vllm.config import KVEventsConfig
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class EventBatch(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
):
|
||||
ts: float
|
||||
events: list[Any]
|
||||
|
||||
|
||||
class KVCacheEvent(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
gc=False, # type: ignore[call-arg]
|
||||
tag=True):
|
||||
"""Base class for all KV cache-related events"""
|
||||
|
||||
|
||||
class BlockStored(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
parent_block_hash: Optional[int]
|
||||
token_ids: list[int]
|
||||
block_size: int
|
||||
lora_id: Optional[int]
|
||||
|
||||
|
||||
class BlockRemoved(KVCacheEvent):
|
||||
block_hashes: list[int]
|
||||
|
||||
|
||||
class AllBlocksCleared(KVCacheEvent):
|
||||
pass
|
||||
|
||||
|
||||
class KVEventBatch(EventBatch):
|
||||
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
||||
|
||||
|
||||
class EventPublisher(ABC):
|
||||
"""Lightweight publisher for EventBatch batches."""
|
||||
|
||||
@abstractmethod
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
"""Emit events in order.
|
||||
|
||||
Implementations should guarantee at-least-once delivery and
|
||||
monotonic ordering (e.g., via sequence numbers).
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the publisher."""
|
||||
|
||||
|
||||
class NullEventPublisher(EventPublisher):
|
||||
"""No-op implementation (default when disabled)."""
|
||||
|
||||
def publish(self, events) -> None:
|
||||
return
|
||||
|
||||
def shutdown(self) -> None:
|
||||
return
|
||||
|
||||
|
||||
class ZmqEventPublisher(EventPublisher):
|
||||
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
||||
|
||||
Spawns a separate thread to handle publishing from a queue.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
endpoint:
|
||||
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
||||
connect.
|
||||
replay_endpoint:
|
||||
Optional ROUTER address for replay requests. When given, subscribers can
|
||||
request missed batches by sending the starting sequence number as an
|
||||
8-byte big-endian integer.
|
||||
buffer_steps:
|
||||
Number of past batches to keep for replay.
|
||||
hwm:
|
||||
ZeroMQ high-water-mark for PUB socket.
|
||||
max_queue_size:
|
||||
Maximum number of events to buffer in memory.
|
||||
topic:
|
||||
Topic to publish events to.
|
||||
"""
|
||||
SHUTDOWN_TIMEOUT: float = 1.0
|
||||
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
endpoint: str = "tcp://*:5557",
|
||||
replay_endpoint: Optional[str] = None,
|
||||
buffer_steps: int = 10_000,
|
||||
hwm: int = 100_000,
|
||||
max_queue_size: int = 100_000,
|
||||
topic: str = "",
|
||||
) -> None:
|
||||
# Storage
|
||||
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
||||
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
||||
|
||||
# ZMQ sockets
|
||||
self._ctx = zmq.Context.instance()
|
||||
self._pub: Optional[zmq.Socket] = None
|
||||
self._replay: Optional[zmq.Socket] = None
|
||||
self._endpoint = endpoint
|
||||
self._replay_endpoint = replay_endpoint
|
||||
self._hwm = hwm
|
||||
|
||||
# Payload
|
||||
self._seq_gen = count()
|
||||
self._topic_bytes = topic.encode('utf-8')
|
||||
|
||||
# Thread
|
||||
self._running = True
|
||||
logger.info("Starting ZMQ publisher thread")
|
||||
|
||||
self._thread = threading.Thread(target=self._publisher_thread,
|
||||
daemon=True,
|
||||
name="zmq-publisher")
|
||||
self._thread.start()
|
||||
|
||||
def publish(self, events: EventBatch) -> None:
|
||||
if not self._running:
|
||||
raise RuntimeError("Publisher is closed")
|
||||
self._event_queue.put(events)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Stop the publisher thread and clean up resources."""
|
||||
self._running = False
|
||||
self._event_queue.put_nowait(None)
|
||||
|
||||
start = time.time()
|
||||
pending_items = True
|
||||
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
||||
pending_items = not self._event_queue.empty()
|
||||
if pending_items:
|
||||
time.sleep(0.1)
|
||||
|
||||
if pending_items:
|
||||
logger.warning(
|
||||
"Warning: Queue still has %s items after %s seconds timeout",
|
||||
self._event_queue.qsize(),
|
||||
self.SHUTDOWN_TIMEOUT,
|
||||
)
|
||||
|
||||
if self._thread.is_alive():
|
||||
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
||||
|
||||
# Clean up ZMQ resources
|
||||
try:
|
||||
if self._pub is not None:
|
||||
self._pub.close(linger=0)
|
||||
if self._replay is not None:
|
||||
self._replay.close(linger=0)
|
||||
finally:
|
||||
pass # Do not terminate context; other sockets may use it
|
||||
|
||||
def _socket_setup(self) -> None:
|
||||
"""Initialize sockets
|
||||
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
||||
"""
|
||||
if self._pub is None:
|
||||
self._pub = self._ctx.socket(zmq.PUB)
|
||||
self._pub.set_hwm(self._hwm)
|
||||
# Heuristic: bind if wildcard / * present, else connect.
|
||||
# bind stable, connect volatile convention
|
||||
if ("*" in self._endpoint or "::" in self._endpoint
|
||||
or self._endpoint.startswith("ipc://")
|
||||
or self._endpoint.startswith("inproc://")):
|
||||
self._pub.bind(self._endpoint)
|
||||
else:
|
||||
self._pub.connect(self._endpoint)
|
||||
|
||||
# Set up replay socket: use ROUTER
|
||||
# 1) handles multiple REQ clients (identities)
|
||||
# 2) lets us send back one request → many replies (streamed events)
|
||||
# 3) works in our non‑blocking poll loop alongside PUB
|
||||
if self._replay_endpoint is not None:
|
||||
self._replay = self._ctx.socket(zmq.ROUTER)
|
||||
self._replay.bind(self._replay_endpoint)
|
||||
|
||||
def _publisher_thread(self) -> None:
|
||||
"""Background thread that processes the event queue."""
|
||||
self._pack = msgspec.msgpack.Encoder()
|
||||
self._socket_setup()
|
||||
|
||||
assert self._pub is not None # narrows type for mypy
|
||||
|
||||
while self._running or self._event_queue.qsize() > 0:
|
||||
# --- replay (non-critical) ---------------------------------
|
||||
if self._replay is not None and self._replay.poll(0):
|
||||
try:
|
||||
self._service_replay()
|
||||
except Exception as e:
|
||||
logger.exception("Error in replay: %s", e)
|
||||
|
||||
# --- main queue (critical) ---------------------------------
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
if event is None:
|
||||
break # Sentinel received, exit thread
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
try:
|
||||
seq = next(self._seq_gen)
|
||||
|
||||
payload = self._pack.encode(event)
|
||||
seq_bytes = seq.to_bytes(8, "big")
|
||||
self._pub.send_multipart(
|
||||
(self._topic_bytes, seq_bytes, payload))
|
||||
|
||||
self._buffer.append((seq, payload))
|
||||
self._event_queue.task_done()
|
||||
|
||||
except Exception as e:
|
||||
# Publishing failed; back-off a bit to avoid a tight error loop
|
||||
logger.exception("Error in publisher thread: %s", e)
|
||||
time.sleep(0.1)
|
||||
|
||||
def _service_replay(self) -> None:
|
||||
"""If a replay request is waiting, send buffered batches."""
|
||||
assert self._replay is not None # narrows type for mypy
|
||||
|
||||
frame = self._replay.recv_multipart()
|
||||
if len(frame) != 3:
|
||||
logger.warning("Invalid replay request: %s", frame)
|
||||
return
|
||||
client_id, _, start_seq_bytes = frame
|
||||
start_seq = int.from_bytes(start_seq_bytes, "big")
|
||||
|
||||
for seq, buf in self._buffer:
|
||||
if seq >= start_seq:
|
||||
# [identity, empty_delim, seq_bytes, payload]
|
||||
# (identity, empty_delim) are stripped off by the router
|
||||
# receiving payload is (seq_bytes, payload)
|
||||
self._replay.send_multipart(
|
||||
(client_id, b"", seq.to_bytes(8, "big"), buf))
|
||||
# Send end of sequence marker
|
||||
# receiving payload is (-1, b""")
|
||||
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
||||
|
||||
|
||||
class EventPublisherFactory:
|
||||
_registry: dict[str, Callable[..., EventPublisher]] = {
|
||||
"null": NullEventPublisher,
|
||||
"zmq": ZmqEventPublisher,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def register_publisher(cls, name: str,
|
||||
ctor: Callable[..., EventPublisher]) -> None:
|
||||
if name in cls._registry:
|
||||
raise KeyError(f"publisher '{name}' already registered")
|
||||
cls._registry[name] = ctor
|
||||
|
||||
@classmethod
|
||||
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
|
||||
"""Create publisher from a config mapping."""
|
||||
if not config:
|
||||
return NullEventPublisher()
|
||||
|
||||
config_dict = config.model_dump()
|
||||
|
||||
kind = config_dict.pop("publisher", "null")
|
||||
config_dict.pop("enable_kv_cache_events")
|
||||
try:
|
||||
constructor = cls._registry[kind]
|
||||
except KeyError as exc:
|
||||
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
||||
return constructor(**config_dict)
|
@ -19,14 +19,14 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
ConfigFormat, ConfigType, DecodingConfig, Device,
|
||||
DeviceConfig, DistributedExecutorBackend,
|
||||
GuidedDecodingBackend, GuidedDecodingBackendV1,
|
||||
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
|
||||
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
|
||||
MultiModalConfig, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo,
|
||||
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, TaskOption, TokenizerMode,
|
||||
TokenizerPoolConfig, VllmConfig, get_attr_docs,
|
||||
get_field)
|
||||
HfOverrides, KVEventsConfig, KVTransferConfig,
|
||||
LoadConfig, LoadFormat, LoRAConfig, ModelConfig,
|
||||
ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, PromptAdapterConfig,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
TaskOption, TokenizerMode, TokenizerPoolConfig,
|
||||
VllmConfig, get_attr_docs, get_field)
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.quantization import QuantizationMethods
|
||||
@ -353,6 +353,7 @@ class EngineArgs:
|
||||
worker_extension_cls: str = ParallelConfig.worker_extension_cls
|
||||
|
||||
kv_transfer_config: Optional[KVTransferConfig] = None
|
||||
kv_events_config: Optional[KVEventsConfig] = None
|
||||
|
||||
generation_config: str = ModelConfig.generation_config
|
||||
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
|
||||
@ -769,6 +770,10 @@ class EngineArgs:
|
||||
default=None,
|
||||
help='The configurations for distributed KV cache '
|
||||
'transfer. Should be a JSON string.')
|
||||
parser.add_argument('--kv-events-config',
|
||||
type=KVEventsConfig.from_cli,
|
||||
default=None,
|
||||
help='The configurations for event publishing.')
|
||||
|
||||
parser.add_argument(
|
||||
'--worker-cls',
|
||||
@ -1125,6 +1130,7 @@ class EngineArgs:
|
||||
prompt_adapter_config=prompt_adapter_config,
|
||||
compilation_config=self.compilation_config,
|
||||
kv_transfer_config=self.kv_transfer_config,
|
||||
kv_events_config=self.kv_events_config,
|
||||
additional_config=self.additional_config,
|
||||
)
|
||||
|
||||
|
@ -3,6 +3,8 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Callable, Optional
|
||||
|
||||
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
|
||||
BlockStored, KVCacheEvent)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
@ -26,7 +28,12 @@ class BlockPool:
|
||||
enable_caching: Whether to enable prefix caching.
|
||||
"""
|
||||
|
||||
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
|
||||
def __init__(
|
||||
self,
|
||||
num_gpu_blocks: int,
|
||||
enable_caching: bool,
|
||||
enable_kv_cache_events: bool = False,
|
||||
):
|
||||
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
self.enable_caching = enable_caching
|
||||
@ -56,6 +63,9 @@ class BlockPool:
|
||||
# avoid freeing it.
|
||||
self.null_block = self.free_block_queue.popleft()
|
||||
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue: list[KVCacheEvent] = []
|
||||
|
||||
def get_cached_block(self,
|
||||
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
|
||||
"""Get a cached block by the block hash, or None if cache miss.
|
||||
@ -116,6 +126,9 @@ class BlockPool:
|
||||
assert prev_block.block_hash is not None
|
||||
prev_block_hash_value = prev_block.block_hash.hash_value
|
||||
|
||||
parent_block_hash = prev_block_hash_value
|
||||
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
|
||||
else None)
|
||||
for i, blk in enumerate(new_full_blocks):
|
||||
assert blk.block_hash is None
|
||||
|
||||
@ -153,8 +166,23 @@ class BlockPool:
|
||||
# Update and added the full block to the cache.
|
||||
blk.block_hash = block_hash
|
||||
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
|
||||
if new_hashes is not None:
|
||||
new_hashes.append(block_hash.hash_value)
|
||||
prev_block_hash_value = block_hash.hash_value
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(
|
||||
BlockStored(
|
||||
block_hashes=new_hashes,
|
||||
parent_block_hash=parent_block_hash,
|
||||
token_ids=request.
|
||||
all_token_ids[num_cached_blocks *
|
||||
block_size:num_full_blocks * block_size],
|
||||
block_size=block_size,
|
||||
lora_id=request.lora_request.id
|
||||
if request.lora_request else None,
|
||||
))
|
||||
|
||||
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
|
||||
"""Get new blocks from the free block pool.
|
||||
|
||||
@ -206,6 +234,9 @@ class BlockPool:
|
||||
if len(self.cached_block_hash_to_block[block_hash]) == 0:
|
||||
del self.cached_block_hash_to_block[block_hash]
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(
|
||||
BlockRemoved(block_hashes=[block_hash.hash_value]))
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -262,6 +293,10 @@ class BlockPool:
|
||||
block.reset_hash()
|
||||
|
||||
logger.info("Successfully reset prefix cache")
|
||||
|
||||
if self.enable_kv_cache_events:
|
||||
self.kv_event_queue.append(AllBlocksCleared())
|
||||
|
||||
return True
|
||||
|
||||
def get_num_free_blocks(self) -> int:
|
||||
@ -279,3 +314,15 @@ class BlockPool:
|
||||
The KV cache usage (between 0.0 and 1.0).
|
||||
"""
|
||||
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Atomically takes all events and clears the queue.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
if not self.enable_kv_cache_events:
|
||||
return []
|
||||
events = self.kv_event_queue
|
||||
self.kv_event_queue = []
|
||||
return events
|
||||
|
@ -4,6 +4,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Optional
|
||||
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import cdiv, sha256
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
@ -27,6 +28,7 @@ class KVCacheManager:
|
||||
caching_hash_algo: str = "builtin",
|
||||
use_eagle: bool = False,
|
||||
log_stats: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
) -> None:
|
||||
assert len(kv_cache_config.kv_cache_groups) == 1, (
|
||||
"KVCacheManager does not support hybrid models with more than 1 "
|
||||
@ -44,7 +46,9 @@ class KVCacheManager:
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||
|
||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
|
||||
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
|
||||
enable_kv_cache_events)
|
||||
|
||||
self.specialized_manager = get_specialized_manager(
|
||||
kv_cache_spec=kv_cache_spec,
|
||||
block_pool=self.block_pool,
|
||||
@ -383,3 +387,11 @@ class KVCacheManager:
|
||||
is finished, not when it is preempted.
|
||||
"""
|
||||
self.req_to_block_hashes.pop(request.request_id, None)
|
||||
|
||||
def take_events(self) -> list[KVCacheEvent]:
|
||||
"""Take the KV cache events from the block pool.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
return self.block_pool.take_events()
|
||||
|
@ -132,3 +132,8 @@ class SchedulerInterface(ABC):
|
||||
The SchedulerStats object is created for every scheduling step.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the scheduler."""
|
||||
raise NotImplementedError
|
||||
|
@ -8,6 +8,7 @@ from collections.abc import Iterable
|
||||
from typing import Optional, Union
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
@ -48,6 +49,7 @@ class Scheduler(SchedulerInterface):
|
||||
self.cache_config = vllm_config.cache_config
|
||||
self.lora_config = vllm_config.lora_config
|
||||
self.kv_cache_config = kv_cache_config
|
||||
self.kv_events_config = vllm_config.kv_events_config
|
||||
self.log_stats = log_stats
|
||||
self.structured_output_manager = structured_output_manager
|
||||
|
||||
@ -62,6 +64,9 @@ class Scheduler(SchedulerInterface):
|
||||
self.max_num_scheduled_tokens = \
|
||||
self.scheduler_config.max_num_batched_tokens
|
||||
self.max_model_len = self.scheduler_config.max_model_len
|
||||
self.enable_kv_cache_events = (
|
||||
self.kv_events_config is not None
|
||||
and self.kv_events_config.enable_kv_cache_events)
|
||||
|
||||
# Create KVConnector for the Scheduler. Note that each Worker
|
||||
# will have a corresponding KVConnector with Role=WORKER.
|
||||
@ -71,6 +76,9 @@ class Scheduler(SchedulerInterface):
|
||||
self.connector = KVConnectorFactory.create_connector_v1(
|
||||
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
|
||||
|
||||
self.kv_event_publisher = EventPublisherFactory.create(
|
||||
self.kv_events_config)
|
||||
|
||||
num_gpu_blocks = self.cache_config.num_gpu_blocks
|
||||
assert num_gpu_blocks is not None and num_gpu_blocks > 0
|
||||
|
||||
@ -132,7 +140,9 @@ class Scheduler(SchedulerInterface):
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
|
||||
use_eagle=self.use_eagle,
|
||||
log_stats=self.log_stats)
|
||||
log_stats=self.log_stats,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
)
|
||||
|
||||
def schedule(self) -> SchedulerOutput:
|
||||
# NOTE(woosuk) on the scheduling algorithm:
|
||||
@ -493,6 +503,11 @@ class Scheduler(SchedulerInterface):
|
||||
meta = self.connector.build_connector_meta(scheduler_output)
|
||||
scheduler_output.kv_connector_metadata = meta
|
||||
|
||||
events = self.kv_cache_manager.take_events()
|
||||
if events:
|
||||
batch = KVEventBatch(ts=time.time(), events=events)
|
||||
self.kv_event_publisher.publish(batch)
|
||||
|
||||
# Advance the number of computed tokens for the request AFTER
|
||||
# the request is scheduled.
|
||||
# 1. The scheduler_output of the current step has to include the
|
||||
@ -843,3 +858,7 @@ class Scheduler(SchedulerInterface):
|
||||
num_draft_tokens=num_draft_tokens,
|
||||
num_accepted_tokens=num_accepted_tokens)
|
||||
return spec_decoding_stats
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if self.kv_event_publisher:
|
||||
self.kv_event_publisher.shutdown()
|
||||
|
@ -259,6 +259,8 @@ class EngineCore:
|
||||
self.structured_output_manager.clear_backend()
|
||||
if self.model_executor:
|
||||
self.model_executor.shutdown()
|
||||
if self.scheduler:
|
||||
self.scheduler.shutdown()
|
||||
|
||||
def profile(self, is_start: bool = True):
|
||||
self.model_executor.profile(is_start)
|
||||
|
Reference in New Issue
Block a user