[V1][Metrics] add support for kv event publishing (#16750)

Signed-off-by: alec-flowers <aflowers@nvidia.com>
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
Co-authored-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
Alec
2025-04-30 16:44:45 +02:00
committed by GitHub
parent 77073c77bc
commit 0be6d05b5e
15 changed files with 1185 additions and 53 deletions

View File

@ -0,0 +1,86 @@
#!/bin/bash
# This file demonstrates the KV cache event publishing
# We will launch a vllm instances configured to publish KV cache
# events and launch a simple subscriber to log those events.
set -xe
echo "🚧🚧 Warning: The usage of KV cache events is experimental and subject to change 🚧🚧"
sleep 1
MODEL_NAME=${HF_MODEL_NAME:-meta-llama/Meta-Llama-3.1-8B-Instruct}
# Trap the SIGINT signal (triggered by Ctrl+C)
trap 'cleanup' INT
# Cleanup function
cleanup() {
echo "Caught Ctrl+C, cleaning up..."
# Cleanup commands
pgrep python | xargs kill -9
pkill -f python
echo "Cleanup complete. Exiting."
exit 0
}
export VLLM_HOST_IP=$(hostname -I | awk '{print $1}')
# a function that waits vLLM server to start
wait_for_server() {
local port=$1
timeout 1200 bash -c "
until curl -s localhost:${port}/v1/completions > /dev/null; do
sleep 1
done" && return 0 || return 1
}
vllm serve $MODEL_NAME \
--port 8100 \
--max-model-len 100 \
--enforce-eager \
--gpu-memory-utilization 0.8 \
--trust-remote-code \
--kv-events-config \
'{"enable_kv_cache_events": true, "publisher": "zmq", "topic": "kv-events"}' &
wait_for_server 8100
SCRIPT_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
python3 "$SCRIPT_DIR/kv_events_subscriber.py" &
sleep 1
# serve two example requests
output1=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 5-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')
output2=$(curl -X POST -s http://localhost:8100/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "'"$MODEL_NAME"'",
"prompt": "Explain quantum computing in simple terms a 50-year-old could understand.",
"max_tokens": 80,
"temperature": 0
}')
# Cleanup commands
pkill -9 -u "$USER" -f python
pkill -9 -u "$USER" -f vllm
sleep 1
echo "Cleaned up"
# Print the outputs of the curl requests
echo ""
echo "Output of first request: $output1"
echo "Output of second request: $output2"
echo "🎉🎉 Successfully finished 2 test requests! 🎉🎉"
echo ""

View File

@ -0,0 +1,114 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional, Union
import msgspec
import zmq
from msgspec.msgpack import Decoder
#
# Types copied from vllm.distributed.kv_events
#
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True,
gc=False):
ts: float
events: list[Any]
class KVCacheEvent(msgspec.Struct,
array_like=True,
omit_defaults=True,
gc=False,
tag=True):
"""Base class for all KV cache-related events"""
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
class AllBlocksCleared(KVCacheEvent):
pass
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
def process_event(event_batch):
print(f"Received event batch at {event_batch.ts}:")
for event in event_batch.events:
print(f" - {event}")
def main():
decoder = Decoder(type=KVEventBatch)
last_seq = -1
context = zmq.Context()
# Set up the main subscription socket
sub = context.socket(zmq.SUB)
sub.connect("tcp://localhost:5557")
topic = "kv-events"
sub.setsockopt_string(zmq.SUBSCRIBE, topic)
# Initialize replay socket
replay = context.socket(zmq.REQ)
replay.connect("tcp://localhost:5558")
poller = zmq.Poller()
poller.register(replay, zmq.POLLIN)
print("Listening for KV cache events on topic:", topic)
while True:
try:
if sub.poll(50):
_, seq_bytes, payload = sub.recv_multipart()
seq = int.from_bytes(seq_bytes, "big")
if last_seq >= 0 and seq > last_seq + 1:
missed = seq - last_seq - 1
print(f"Missed {missed} messages"
f" (last: {last_seq}, current: {seq})")
replay.send((last_seq + 1).to_bytes(8, "big"))
while poller.poll(timeout=200):
seq_bytes, replay_payload = replay.recv_multipart()
if not replay_payload:
# End of replay marker is sent as an empty frame
# for the payload
break
replay_seq = int.from_bytes(seq_bytes, "big")
if replay_seq > last_seq:
event_batch = decoder.decode(replay_payload)
process_event(event_batch)
last_seq = replay_seq
if replay_seq >= seq - 1:
break
event_batch = decoder.decode(payload)
process_event(event_batch)
# ... do other periodic work or check for shutdown ...
except KeyboardInterrupt:
print("Interrupted")
break
except Exception as e:
print("Error decoding message:", e)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,145 @@
# SPDX-License-Identifier: Apache-2.0
import random
from typing import Optional, Union
import msgspec
import msgspec.msgpack
import pytest
import zmq
from vllm.config import KVEventsConfig
from vllm.distributed.kv_events import EventPublisherFactory
from .test_events import SampleBatch
@pytest.fixture
def random_port():
"""Generate a random port number for testing"""
return random.randint(10000, 60000)
@pytest.fixture
def publisher_config(random_port, request):
"""Create a publisher config with inproc transport"""
how = request.param if hasattr(request, "param") else "inproc"
if how == "inproc":
endpoint = f"inproc://test-{random_port}"
replay_endpoint = endpoint + "-replay"
else:
endpoint = f"tcp://*:{random_port}"
replay_endpoint = f"tcp://*:{random_port + 1}"
return KVEventsConfig(enable_kv_cache_events=True,
publisher="zmq",
endpoint=endpoint,
replay_endpoint=replay_endpoint,
buffer_steps=100,
hwm=1000,
topic="test")
@pytest.fixture
def publisher(publisher_config):
"""Create and return a publisher instance"""
pub = EventPublisherFactory.create(publisher_config)
yield pub
pub.shutdown()
@pytest.fixture
def subscriber(publisher_config):
"""Create and return a subscriber for testing"""
endpoint = publisher_config.endpoint
replay_endpoint = publisher_config.replay_endpoint
if endpoint.startswith("tcp://*"):
endpoint = endpoint.replace("*", "127.0.0.1")
if replay_endpoint and replay_endpoint.startswith("tcp://*"):
replay_endpoint = replay_endpoint.replace("*", "127.0.0.1")
sub = MockSubscriber(endpoint, replay_endpoint, publisher_config.topic)
yield sub
sub.close()
class MockSubscriber:
"""Helper class to receive and verify published events"""
def __init__(self,
pub_endpoint: str,
replay_endpoint: Optional[str] = None,
topic: str = "",
decode_type=SampleBatch):
self.ctx = zmq.Context.instance()
# Set up subscriber socket
self.sub = self.ctx.socket(zmq.SUB)
self.sub.setsockopt(zmq.SUBSCRIBE, topic.encode('utf-8'))
self.sub.connect(pub_endpoint)
# Set up replay socket if provided
self.replay = None
if replay_endpoint:
self.replay = self.ctx.socket(zmq.REQ)
self.replay.connect(replay_endpoint)
self.topic = topic
self.topic_bytes = topic.encode('utf-8')
self.received_msgs: list[tuple[int, SampleBatch]] = []
self.last_seq = -1
self.decoder = msgspec.msgpack.Decoder(type=decode_type)
def receive_one(self,
timeout=1000) -> Union[tuple[int, SampleBatch], None]:
"""Receive a single message with timeout"""
if not self.sub.poll(timeout):
return None
topic_bytes, seq_bytes, payload = self.sub.recv_multipart()
assert topic_bytes == self.topic_bytes
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
self.last_seq = seq
self.received_msgs.append((seq, data))
return seq, data
def request_replay(self, start_seq: int) -> None:
"""Request replay of messages starting from start_seq"""
if not self.replay:
raise ValueError("Replay socket not initialized")
self.replay.send(start_seq.to_bytes(8, "big"))
def receive_replay(self) -> list[tuple[int, SampleBatch]]:
"""Receive replayed messages"""
if not self.replay:
raise ValueError("Replay socket not initialized")
replayed: list[tuple[int, SampleBatch]] = []
while True:
try:
if not self.replay.poll(1000):
break
frames = self.replay.recv_multipart()
if not frames or not frames[-1]:
# End of replay marker
break
seq_bytes, payload = frames
seq = int.from_bytes(seq_bytes, "big")
data = self.decoder.decode(payload)
replayed.append((seq, data))
except zmq.ZMQError as _:
break
return replayed
def close(self):
"""Clean up resources"""
self.sub.close()
if self.replay:
self.replay.close()

View File

@ -0,0 +1,193 @@
# SPDX-License-Identifier: Apache-2.0
import threading
import time
import msgspec
import pytest
from vllm.distributed.kv_events import (EventBatch, EventPublisherFactory,
NullEventPublisher)
class EventSample(
msgspec.Struct,
tag=True, # type: ignore
array_like=True # type: ignore
):
"""Test event for publisher testing"""
id: int
value: str
class SampleBatch(EventBatch):
"""Test event batch for publisher testing"""
events: list[EventSample]
def create_test_events(count: int) -> SampleBatch:
"""Create a batch of test events"""
events = [EventSample(id=i, value=f"test-{i}") for i in range(count)]
return SampleBatch(ts=time.time(), events=events)
def test_basic_publishing(publisher, subscriber):
"""Test basic event publishing works"""
test_batch = create_test_events(5)
publisher.publish(test_batch)
result = subscriber.receive_one(timeout=1000)
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert received.ts == pytest.approx(test_batch.ts,
abs=0.1), ("Timestamp mismatch")
assert len(received.events) == len(
test_batch.events), ("Number of events mismatch")
for i, event in enumerate(received.events):
assert event.id == i, "Event id mismatch"
assert event.value == f"test-{i}", "Event value mismatch"
def test_multiple_events(publisher, subscriber):
"""Test publishing and receiving multiple event batches"""
for _ in range(10):
batch = create_test_events(2)
publisher.publish(batch)
received = []
for _ in range(10):
data = subscriber.receive_one(timeout=100)
if data:
received.append(data)
assert len(received) == 10, "Number of messages mismatch"
seqs = [seq for seq, _ in received]
assert seqs == list(range(10)), "Sequence numbers mismatch"
def test_replay_mechanism(publisher, subscriber):
"""Test the replay mechanism works correctly"""
for _ in range(19):
batch = create_test_events(1)
publisher.publish(batch)
time.sleep(0.5) # Need publisher to process above requests
subscriber.request_replay(10)
batch = create_test_events(1)
publisher.publish(batch) # 20th message
replayed = subscriber.receive_replay()
assert len(replayed) > 0, "No replayed messages received"
seqs = [seq for seq, _ in replayed]
assert all(seq >= 10 for seq in seqs), "Replayed messages not in order"
assert seqs == list(range(min(seqs),
max(seqs) +
1)), ("Replayed messages not consecutive")
def test_buffer_limit(publisher, subscriber, publisher_config):
"""Test buffer limit behavior"""
buffer_size = publisher_config.buffer_steps
# Publish more events than the buffer can hold
for i in range(buffer_size + 10):
batch = create_test_events(1)
publisher.publish(batch)
time.sleep(0.5) # Need publisher to process above requests
subscriber.request_replay(0)
batch = create_test_events(1)
publisher.publish(batch)
replayed = subscriber.receive_replay()
assert len(replayed) <= buffer_size, "Can't replay more than buffer size"
oldest_seq = min(seq for seq, _ in replayed)
assert oldest_seq >= 10, "The oldest sequence should be at least 10"
def test_topic_filtering(publisher_config):
"""
Test that a subscriber only receives messages matching its topic filter
"""
publisher_config.replay_endpoint = None
cfg = publisher_config.model_copy()
cfg.topic = "foo"
pub = EventPublisherFactory.create(cfg)
from .conftest import MockSubscriber
sub_foo = MockSubscriber(cfg.endpoint, None, "foo")
sub_bar = MockSubscriber(cfg.endpoint, None, "bar")
try:
time.sleep(0.1)
for _ in range(3):
pub.publish(create_test_events(1))
foo_received = [sub_foo.receive_one(timeout=200) for _ in range(3)]
assert all(msg is not None for msg in foo_received), (
"Subscriber with matching topic should receive messages")
bar_received = [sub_bar.receive_one(timeout=200) for _ in range(3)]
assert all(msg is None for msg in bar_received), (
"Subscriber with non-matching topic should receive no messages")
finally:
pub.shutdown()
sub_foo.close()
sub_bar.close()
def test_high_volume(publisher, subscriber):
"""Test publishing and receiving a high volume of events"""
num_batches = 10_000
events_per_batch = 100
# Publish events in a separate thread to not block
def publish_events():
for i in range(num_batches):
batch = create_test_events(events_per_batch)
publisher.publish(batch)
# Small delay to avoid overwhelming
if i % 100 == 0:
time.sleep(0.01)
received: list[tuple[int, SampleBatch]] = []
publisher_thread = threading.Thread(target=publish_events)
publisher_thread.start()
start_time = time.time()
while len(received) < num_batches:
if time.time() - start_time > 10: # Timeout after 10 seconds
break
result = subscriber.receive_one(timeout=100)
if result:
received.append(result)
publisher_thread.join()
assert len(received) >= num_batches * 0.9, (
"We should have received most messages")
seqs = [seq for seq, _ in received]
assert sorted(seqs) == seqs, "Sequence numbers should be in order"
def test_null_publisher():
"""Test that NullEventPublisher can be used without errors"""
publisher = NullEventPublisher()
# This should not raise any errors
batch = create_test_events(5)
publisher.publish(batch)
publisher.shutdown()

View File

@ -6,6 +6,7 @@ from typing import Optional
import pytest
import torch
from vllm.distributed.kv_events import AllBlocksCleared, BlockRemoved
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.utils import sha256
@ -48,9 +49,10 @@ def make_kv_cache_config(block_size: int, num_blocks: int) -> KVCacheConfig:
num_blocks=num_blocks,
tensors={},
kv_cache_groups=[
KVCacheGroupSpec(['layer'],
FullAttentionSpec(block_size, 1, 1, torch.float32,
False))
KVCacheGroupSpec(
["layer"],
FullAttentionSpec(block_size, 1, 1, torch.float32, False),
)
],
)
@ -783,6 +785,60 @@ def test_prefix_cache_stats_disabled():
assert manager.prefix_cache_stats is None
@pytest.mark.parametrize("blocks_to_cache", [2, 3, 10])
def test_kv_cache_events(blocks_to_cache: int):
block_size = 16
num_blocks = blocks_to_cache + 1
# Allocate Blocks
# Should see a single block stored event with a blocks_to_cache number of
# block hashes
# take_events should reset the kv_event_queue
manager = KVCacheManager(
make_kv_cache_config(block_size, num_blocks),
max_model_len=8192,
enable_caching=True,
enable_kv_cache_events=True,
)
num_tokens = block_size * blocks_to_cache
req0 = make_request("0", list(range(num_tokens)))
_ = manager.allocate_slots(req0, num_tokens)
events = manager.take_events()
block = events[-1]
assert (len(block.block_hashes) == blocks_to_cache == len(
manager.block_pool.cached_block_hash_to_block))
assert len(block.token_ids) == block.block_size * len(block.block_hashes)
assert len(manager.block_pool.kv_event_queue) == 0
stored_block_hash = block.block_hashes
# Remove blocks and send another request
# Should see block_to_cache number of removed block events and a new block
# stored event
manager.free(req0)
req1 = make_request("1", list(range(num_tokens)))
_ = manager.allocate_slots(req1, num_tokens)
events = manager.take_events()
for blocks in events[:-1]:
assert blocks.block_hashes[0] in stored_block_hash
assert len(events) == blocks_to_cache + 1
assert (isinstance(events[-2], BlockRemoved))
assert (len(events[-1].block_hashes) == blocks_to_cache == len(
manager.block_pool.cached_block_hash_to_block))
# All Blocks Cleared
# Should see a single all blocks cleared event
manager.free(req1)
manager.reset_prefix_cache()
events = manager.take_events()
assert isinstance(events[-1], AllBlocksCleared)
assert len(manager.block_pool.cached_block_hash_to_block) == 0
def test_eagle_enabled_removes_last_block():
"""Verify Eagle does NOT remove blocks when request
length is divisible by block size."""

View File

@ -13,6 +13,8 @@ from tests.v1.engine.utils import (NUM_PROMPT_LOGPROBS_UNDER_TEST,
from vllm.engine.arg_utils import EngineArgs
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
from ...distributed.conftest import publisher_config, random_port # noqa: F401
from tests.v1.engine.utils import FULL_STRINGS # isort: skip
EngineCoreSampleLogprobsType = list[tuple[torch.Tensor, torch.Tensor]]

View File

@ -11,6 +11,7 @@ import pytest
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.distributed.kv_events import BlockStored, KVEventBatch
from vllm.engine.arg_utils import EngineArgs
from vllm.platforms import current_platform
from vllm.usage.usage_lib import UsageContext
@ -20,6 +21,7 @@ from vllm.v1.engine.core_client import (AsyncMPClient, EngineCoreClient,
SyncMPClient)
from vllm.v1.executor.abstract import Executor
from ...distributed.conftest import MockSubscriber
from ...utils import create_new_process_for_each_test
if not current_platform.is_cuda():
@ -199,54 +201,142 @@ async def test_engine_core_client_asyncio(monkeypatch: pytest.MonkeyPatch):
log_stats=True,
)
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
try:
MAX_TOKENS = 20
params = SamplingParams(max_tokens=MAX_TOKENS)
"""Normal Request Cycle."""
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
requests = [make_request(params) for _ in range(10)]
request_ids = [req.request_id for req in requests]
# Add requests to the engine.
for request in requests:
await client.add_request_async(request)
await asyncio.sleep(0.01)
# Add requests to the engine.
for request in requests:
await client.add_request_async(request)
await asyncio.sleep(0.01)
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
outputs: dict[str, list] = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
for req_id in request_ids:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
f"{outputs[req_id]=}, {MAX_TOKENS=}")
"""Abort Request Cycle."""
core_client: AsyncMPClient = client
# Add requests to the engine.
for idx, request in enumerate(requests):
await client.add_request_async(request)
await asyncio.sleep(0.01)
if idx % 2 == 0:
await client.abort_requests_async([request.request_id])
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"
outputs = {req_id: [] for req_id in request_ids}
await loop_until_done_async(client, outputs)
with pytest.raises(Exception) as e_info:
await core_client.call_utility_async("echo", None, "help!")
for idx, req_id in enumerate(request_ids):
if idx % 2 == 0:
assert len(outputs[req_id]) < MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
else:
assert len(outputs[req_id]) == MAX_TOKENS, (
f"{len(outputs[req_id])=}, {MAX_TOKENS=}")
"""Utility method invocation"""
assert str(e_info.value) == "Call to echo method failed: help!"
core_client: AsyncMPClient = client
result = await core_client.call_utility_async("echo", "testarg")
assert result == "testarg"
with pytest.raises(Exception) as e_info:
await core_client.call_utility_async("echo", None, "help!")
assert str(e_info.value) == "Call to echo method failed: help!"
finally:
client.shutdown()
@pytest.mark.parametrize(
"multiprocessing_mode,publisher_config",
[(True, "tcp"), (False, "inproc")],
indirect=["publisher_config"],
)
def test_kv_cache_events(
monkeypatch: pytest.MonkeyPatch,
multiprocessing_mode: bool,
publisher_config,
):
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
block_size = 16
num_blocks = 2
engine_args = EngineArgs(model=MODEL_NAME,
enforce_eager=True,
enable_prefix_caching=True,
block_size=block_size)
engine_args.kv_events_config = publisher_config
vllm_config = engine_args.create_engine_config(
UsageContext.UNKNOWN_CONTEXT)
executor_class = Executor.get_class(vllm_config)
client = EngineCoreClient.make_client(
multiprocess_mode=multiprocessing_mode,
asyncio_mode=False,
vllm_config=vllm_config,
executor_class=executor_class,
log_stats=False,
)
endpoint = publisher_config.endpoint.replace("*", "127.0.0.1")
time.sleep(0.1)
subscriber = MockSubscriber(endpoint,
topic=publisher_config.topic,
decode_type=KVEventBatch)
try:
custom_tokens = list(range(num_blocks * block_size))
request = EngineCoreRequest(
request_id=str(uuid.uuid4()),
prompt_token_ids=custom_tokens,
mm_inputs=None,
mm_hashes=None,
mm_placeholders=None,
sampling_params=SamplingParams(
max_tokens=1), # Short completion for speed
eos_token_id=None,
arrival_time=time.time(),
lora_request=None,
)
client.add_request(request)
outputs: dict[str, list] = {request.request_id: []}
loop_until_done(client, outputs)
result = subscriber.receive_one(timeout=1000)
assert result is not None, "No message received"
seq, received = result
assert seq == 0, "Sequence number mismatch"
assert len(received.events) == 1, (
"We should have exactly one BlockStored event")
event = received.events[0]
assert isinstance(
event, BlockStored), ("We should have a BlockStored event")
assert len(event.block_hashes) == num_blocks, (
"We should have a BlockStored event with 2 block_hashes")
assert event.block_size == block_size, (
"Block size should be the same as the block size")
assert event.parent_block_hash is None, (
"Parent block hash should be None")
assert event.lora_id is None, "Lora id should be None"
assert len(event.token_ids) == num_blocks * block_size, (
"Token ids should be the same as the custom tokens")
assert event.token_ids == custom_tokens, (
"Token ids should be the same as the custom tokens")
finally:
client.shutdown()
return
@pytest.mark.timeout(10)

View File

@ -1958,6 +1958,8 @@ class SchedulerConfig:
some image tokens can be scheduled (like TTTTIIIII, leaving IIIII),
it will be scheduled as TTTT in one step and IIIIIIIIII in the next."""
# scheduler class or path. "vllm.core.scheduler.Scheduler" (default)
# or "mod.custom_class".
scheduler_cls: Union[str, type[object]] = "vllm.core.scheduler.Scheduler"
"""The scheduler class to use. "vllm.core.scheduler.Scheduler" is the
default scheduler. Can be a class directly or the path to a class of form
@ -3417,6 +3419,51 @@ class KVTransferConfig(BaseModel):
return self.kv_connector_extra_config.get(key, default)
class KVEventsConfig(BaseModel):
"""Configuration for KV event publishing."""
enable_kv_cache_events: bool = False
"""If True, enable KV cache events for tracking block storage and removal.
Events can be published externally by zmq using the event publisher config.
"""
publisher: str = "null"
"""The publisher to use for publishing kv events. Can be "null", "zmq".
"""
endpoint: str = "tcp://*:5557"
"""The zmq endpoint to use for publishing kv events.
"""
replay_endpoint: Optional[str] = None
"""The zmq endpoint to use for replaying kv events.
"""
buffer_steps: int = 10_000
"""The number of steps to cache for replay endpoint. Will only save
events from the last N steps for the replay endpoint.
"""
hwm: int = 100_000
"""The zmq high water mark for the event publisher. After queueing N events,
events will start dropping if the consumer is not keeping up.
"""
max_queue_size: int = 100_000
"""The maximum number of events to queue while waiting for publishing.
"""
topic: str = ""
"""The topic to use for the event publisher. Consumers can subscribe to
this topic to receive events.
"""
@classmethod
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
"""Parse the CLI value for the event publisher config."""
return KVEventsConfig.model_validate_json(cli_value)
class CompilationLevel:
# constants for the levels of the compilation process
NO_COMPILATION = 0
@ -3779,6 +3826,7 @@ class VllmConfig:
init=True) # type: ignore
kv_transfer_config: KVTransferConfig = field(default=None,
init=True) # type: ignore
kv_events_config: Optional[KVEventsConfig] = None
# some opaque config, only used to provide additional information
# for the hash computation, mainly used for testing, debugging or out of
# tree config registration.
@ -4038,6 +4086,18 @@ class VllmConfig:
if self.cache_config is not None:
self.cache_config.enable_prefix_caching = False
if (self.kv_events_config
and self.kv_events_config.enable_kv_cache_events
and not self.cache_config.enable_prefix_caching):
logger.warning(
"KV cache events are on, but prefix caching is not enabled."
"Use --enable-prefix-caching to enable.")
if (self.kv_events_config and self.kv_events_config.publisher != "null"
and not self.kv_events_config.enable_kv_cache_events):
logger.warning("KV cache events are disabled,"
"but the scheduler is configured to publish them."
"Modify KVEventsConfig.enable_kv_cache_events"
"to True to enable.")
current_platform.check_and_update_config(self)
if not self.instance_id:

View File

@ -0,0 +1,295 @@
# SPDX-License-Identifier: Apache-2.0
import queue
import threading
import time
from abc import ABC, abstractmethod
from collections import deque
from itertools import count
from queue import Queue
from typing import Any, Callable, Optional, Union
import msgspec
import zmq
from vllm.config import KVEventsConfig
from vllm.logger import init_logger
logger = init_logger(__name__)
class EventBatch(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
):
ts: float
events: list[Any]
class KVCacheEvent(
msgspec.Struct,
array_like=True, # type: ignore[call-arg]
omit_defaults=True, # type: ignore[call-arg]
gc=False, # type: ignore[call-arg]
tag=True):
"""Base class for all KV cache-related events"""
class BlockStored(KVCacheEvent):
block_hashes: list[int]
parent_block_hash: Optional[int]
token_ids: list[int]
block_size: int
lora_id: Optional[int]
class BlockRemoved(KVCacheEvent):
block_hashes: list[int]
class AllBlocksCleared(KVCacheEvent):
pass
class KVEventBatch(EventBatch):
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
class EventPublisher(ABC):
"""Lightweight publisher for EventBatch batches."""
@abstractmethod
def publish(self, events: EventBatch) -> None:
"""Emit events in order.
Implementations should guarantee at-least-once delivery and
monotonic ordering (e.g., via sequence numbers).
"""
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the publisher."""
class NullEventPublisher(EventPublisher):
"""No-op implementation (default when disabled)."""
def publish(self, events) -> None:
return
def shutdown(self) -> None:
return
class ZmqEventPublisher(EventPublisher):
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
Spawns a separate thread to handle publishing from a queue.
Parameters
----------
endpoint:
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
connect.
replay_endpoint:
Optional ROUTER address for replay requests. When given, subscribers can
request missed batches by sending the starting sequence number as an
8-byte big-endian integer.
buffer_steps:
Number of past batches to keep for replay.
hwm:
ZeroMQ high-water-mark for PUB socket.
max_queue_size:
Maximum number of events to buffer in memory.
topic:
Topic to publish events to.
"""
SHUTDOWN_TIMEOUT: float = 1.0
END_SEQ = (-1).to_bytes(8, "big", signed=True)
def __init__(
self,
endpoint: str = "tcp://*:5557",
replay_endpoint: Optional[str] = None,
buffer_steps: int = 10_000,
hwm: int = 100_000,
max_queue_size: int = 100_000,
topic: str = "",
) -> None:
# Storage
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
# ZMQ sockets
self._ctx = zmq.Context.instance()
self._pub: Optional[zmq.Socket] = None
self._replay: Optional[zmq.Socket] = None
self._endpoint = endpoint
self._replay_endpoint = replay_endpoint
self._hwm = hwm
# Payload
self._seq_gen = count()
self._topic_bytes = topic.encode('utf-8')
# Thread
self._running = True
logger.info("Starting ZMQ publisher thread")
self._thread = threading.Thread(target=self._publisher_thread,
daemon=True,
name="zmq-publisher")
self._thread.start()
def publish(self, events: EventBatch) -> None:
if not self._running:
raise RuntimeError("Publisher is closed")
self._event_queue.put(events)
def shutdown(self) -> None:
"""Stop the publisher thread and clean up resources."""
self._running = False
self._event_queue.put_nowait(None)
start = time.time()
pending_items = True
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
pending_items = not self._event_queue.empty()
if pending_items:
time.sleep(0.1)
if pending_items:
logger.warning(
"Warning: Queue still has %s items after %s seconds timeout",
self._event_queue.qsize(),
self.SHUTDOWN_TIMEOUT,
)
if self._thread.is_alive():
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
# Clean up ZMQ resources
try:
if self._pub is not None:
self._pub.close(linger=0)
if self._replay is not None:
self._replay.close(linger=0)
finally:
pass # Do not terminate context; other sockets may use it
def _socket_setup(self) -> None:
"""Initialize sockets
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
"""
if self._pub is None:
self._pub = self._ctx.socket(zmq.PUB)
self._pub.set_hwm(self._hwm)
# Heuristic: bind if wildcard / * present, else connect.
# bind stable, connect volatile convention
if ("*" in self._endpoint or "::" in self._endpoint
or self._endpoint.startswith("ipc://")
or self._endpoint.startswith("inproc://")):
self._pub.bind(self._endpoint)
else:
self._pub.connect(self._endpoint)
# Set up replay socket: use ROUTER
# 1) handles multiple REQ clients (identities)
# 2) lets us send back one request → many replies (streamed events)
# 3) works in our nonblocking poll loop alongside PUB
if self._replay_endpoint is not None:
self._replay = self._ctx.socket(zmq.ROUTER)
self._replay.bind(self._replay_endpoint)
def _publisher_thread(self) -> None:
"""Background thread that processes the event queue."""
self._pack = msgspec.msgpack.Encoder()
self._socket_setup()
assert self._pub is not None # narrows type for mypy
while self._running or self._event_queue.qsize() > 0:
# --- replay (non-critical) ---------------------------------
if self._replay is not None and self._replay.poll(0):
try:
self._service_replay()
except Exception as e:
logger.exception("Error in replay: %s", e)
# --- main queue (critical) ---------------------------------
try:
event = self._event_queue.get(timeout=0.1)
if event is None:
break # Sentinel received, exit thread
except queue.Empty:
continue
try:
seq = next(self._seq_gen)
payload = self._pack.encode(event)
seq_bytes = seq.to_bytes(8, "big")
self._pub.send_multipart(
(self._topic_bytes, seq_bytes, payload))
self._buffer.append((seq, payload))
self._event_queue.task_done()
except Exception as e:
# Publishing failed; back-off a bit to avoid a tight error loop
logger.exception("Error in publisher thread: %s", e)
time.sleep(0.1)
def _service_replay(self) -> None:
"""If a replay request is waiting, send buffered batches."""
assert self._replay is not None # narrows type for mypy
frame = self._replay.recv_multipart()
if len(frame) != 3:
logger.warning("Invalid replay request: %s", frame)
return
client_id, _, start_seq_bytes = frame
start_seq = int.from_bytes(start_seq_bytes, "big")
for seq, buf in self._buffer:
if seq >= start_seq:
# [identity, empty_delim, seq_bytes, payload]
# (identity, empty_delim) are stripped off by the router
# receiving payload is (seq_bytes, payload)
self._replay.send_multipart(
(client_id, b"", seq.to_bytes(8, "big"), buf))
# Send end of sequence marker
# receiving payload is (-1, b""")
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
class EventPublisherFactory:
_registry: dict[str, Callable[..., EventPublisher]] = {
"null": NullEventPublisher,
"zmq": ZmqEventPublisher,
}
@classmethod
def register_publisher(cls, name: str,
ctor: Callable[..., EventPublisher]) -> None:
if name in cls._registry:
raise KeyError(f"publisher '{name}' already registered")
cls._registry[name] = ctor
@classmethod
def create(cls, config: Optional[KVEventsConfig]) -> EventPublisher:
"""Create publisher from a config mapping."""
if not config:
return NullEventPublisher()
config_dict = config.model_dump()
kind = config_dict.pop("publisher", "null")
config_dict.pop("enable_kv_cache_events")
try:
constructor = cls._registry[kind]
except KeyError as exc:
raise ValueError(f"Unknown event publisher '{kind}'") from exc
return constructor(**config_dict)

View File

@ -19,14 +19,14 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
ConfigFormat, ConfigType, DecodingConfig, Device,
DeviceConfig, DistributedExecutorBackend,
GuidedDecodingBackend, GuidedDecodingBackendV1,
HfOverrides, KVTransferConfig, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ModelDType, ModelImpl,
MultiModalConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PrefixCachingHashAlgo,
PromptAdapterConfig, SchedulerConfig, SchedulerPolicy,
SpeculativeConfig, TaskOption, TokenizerMode,
TokenizerPoolConfig, VllmConfig, get_attr_docs,
get_field)
HfOverrides, KVEventsConfig, KVTransferConfig,
LoadConfig, LoadFormat, LoRAConfig, ModelConfig,
ModelDType, ModelImpl, MultiModalConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PrefixCachingHashAlgo, PromptAdapterConfig,
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
TaskOption, TokenizerMode, TokenizerPoolConfig,
VllmConfig, get_attr_docs, get_field)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QuantizationMethods
@ -353,6 +353,7 @@ class EngineArgs:
worker_extension_cls: str = ParallelConfig.worker_extension_cls
kv_transfer_config: Optional[KVTransferConfig] = None
kv_events_config: Optional[KVEventsConfig] = None
generation_config: str = ModelConfig.generation_config
enable_sleep_mode: bool = ModelConfig.enable_sleep_mode
@ -769,6 +770,10 @@ class EngineArgs:
default=None,
help='The configurations for distributed KV cache '
'transfer. Should be a JSON string.')
parser.add_argument('--kv-events-config',
type=KVEventsConfig.from_cli,
default=None,
help='The configurations for event publishing.')
parser.add_argument(
'--worker-cls',
@ -1125,6 +1130,7 @@ class EngineArgs:
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
kv_transfer_config=self.kv_transfer_config,
kv_events_config=self.kv_events_config,
additional_config=self.additional_config,
)

View File

@ -3,6 +3,8 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Callable, Optional
from vllm.distributed.kv_events import (AllBlocksCleared, BlockRemoved,
BlockStored, KVCacheEvent)
from vllm.logger import init_logger
from vllm.v1.core.kv_cache_utils import (BlockHashType, FreeKVCacheBlockQueue,
KVCacheBlock,
@ -26,7 +28,12 @@ class BlockPool:
enable_caching: Whether to enable prefix caching.
"""
def __init__(self, num_gpu_blocks: int, enable_caching: bool):
def __init__(
self,
num_gpu_blocks: int,
enable_caching: bool,
enable_kv_cache_events: bool = False,
):
assert isinstance(num_gpu_blocks, int) and num_gpu_blocks > 0
self.num_gpu_blocks = num_gpu_blocks
self.enable_caching = enable_caching
@ -56,6 +63,9 @@ class BlockPool:
# avoid freeing it.
self.null_block = self.free_block_queue.popleft()
self.enable_kv_cache_events = enable_kv_cache_events
self.kv_event_queue: list[KVCacheEvent] = []
def get_cached_block(self,
block_hash: BlockHashType) -> Optional[KVCacheBlock]:
"""Get a cached block by the block hash, or None if cache miss.
@ -116,6 +126,9 @@ class BlockPool:
assert prev_block.block_hash is not None
prev_block_hash_value = prev_block.block_hash.hash_value
parent_block_hash = prev_block_hash_value
new_hashes: Optional[list[int]] = ([] if self.enable_kv_cache_events
else None)
for i, blk in enumerate(new_full_blocks):
assert blk.block_hash is None
@ -153,8 +166,23 @@ class BlockPool:
# Update and added the full block to the cache.
blk.block_hash = block_hash
self.cached_block_hash_to_block[block_hash][blk.block_id] = blk
if new_hashes is not None:
new_hashes.append(block_hash.hash_value)
prev_block_hash_value = block_hash.hash_value
if self.enable_kv_cache_events:
self.kv_event_queue.append(
BlockStored(
block_hashes=new_hashes,
parent_block_hash=parent_block_hash,
token_ids=request.
all_token_ids[num_cached_blocks *
block_size:num_full_blocks * block_size],
block_size=block_size,
lora_id=request.lora_request.id
if request.lora_request else None,
))
def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.
@ -206,6 +234,9 @@ class BlockPool:
if len(self.cached_block_hash_to_block[block_hash]) == 0:
del self.cached_block_hash_to_block[block_hash]
if self.enable_kv_cache_events:
self.kv_event_queue.append(
BlockRemoved(block_hashes=[block_hash.hash_value]))
return True
return False
@ -262,6 +293,10 @@ class BlockPool:
block.reset_hash()
logger.info("Successfully reset prefix cache")
if self.enable_kv_cache_events:
self.kv_event_queue.append(AllBlocksCleared())
return True
def get_num_free_blocks(self) -> int:
@ -279,3 +314,15 @@ class BlockPool:
The KV cache usage (between 0.0 and 1.0).
"""
return 1.0 - (self.get_num_free_blocks() / self.num_gpu_blocks)
def take_events(self) -> list[KVCacheEvent]:
"""Atomically takes all events and clears the queue.
Returns:
A list of KV cache events.
"""
if not self.enable_kv_cache_events:
return []
events = self.kv_event_queue
self.kv_event_queue = []
return events

View File

@ -4,6 +4,7 @@ from collections import defaultdict
from collections.abc import Iterable
from typing import Optional
from vllm.distributed.kv_events import KVCacheEvent
from vllm.logger import init_logger
from vllm.utils import cdiv, sha256
from vllm.v1.core.block_pool import BlockPool
@ -27,6 +28,7 @@ class KVCacheManager:
caching_hash_algo: str = "builtin",
use_eagle: bool = False,
log_stats: bool = False,
enable_kv_cache_events: bool = False,
) -> None:
assert len(kv_cache_config.kv_cache_groups) == 1, (
"KVCacheManager does not support hybrid models with more than 1 "
@ -44,7 +46,9 @@ class KVCacheManager:
# FIXME: make prefix cache stats conditional on log_stats
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching)
self.block_pool = BlockPool(self.num_gpu_blocks, enable_caching,
enable_kv_cache_events)
self.specialized_manager = get_specialized_manager(
kv_cache_spec=kv_cache_spec,
block_pool=self.block_pool,
@ -383,3 +387,11 @@ class KVCacheManager:
is finished, not when it is preempted.
"""
self.req_to_block_hashes.pop(request.request_id, None)
def take_events(self) -> list[KVCacheEvent]:
"""Take the KV cache events from the block pool.
Returns:
A list of KV cache events.
"""
return self.block_pool.take_events()

View File

@ -132,3 +132,8 @@ class SchedulerInterface(ABC):
The SchedulerStats object is created for every scheduling step.
"""
raise NotImplementedError
@abstractmethod
def shutdown(self) -> None:
"""Shutdown the scheduler."""
raise NotImplementedError

View File

@ -8,6 +8,7 @@ from collections.abc import Iterable
from typing import Optional, Union
from vllm.config import VllmConfig
from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch
from vllm.distributed.kv_transfer.kv_connector.factory import (
KVConnectorFactory)
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
@ -48,6 +49,7 @@ class Scheduler(SchedulerInterface):
self.cache_config = vllm_config.cache_config
self.lora_config = vllm_config.lora_config
self.kv_cache_config = kv_cache_config
self.kv_events_config = vllm_config.kv_events_config
self.log_stats = log_stats
self.structured_output_manager = structured_output_manager
@ -62,6 +64,9 @@ class Scheduler(SchedulerInterface):
self.max_num_scheduled_tokens = \
self.scheduler_config.max_num_batched_tokens
self.max_model_len = self.scheduler_config.max_model_len
self.enable_kv_cache_events = (
self.kv_events_config is not None
and self.kv_events_config.enable_kv_cache_events)
# Create KVConnector for the Scheduler. Note that each Worker
# will have a corresponding KVConnector with Role=WORKER.
@ -71,6 +76,9 @@ class Scheduler(SchedulerInterface):
self.connector = KVConnectorFactory.create_connector_v1(
config=self.vllm_config, role=KVConnectorRole.SCHEDULER)
self.kv_event_publisher = EventPublisherFactory.create(
self.kv_events_config)
num_gpu_blocks = self.cache_config.num_gpu_blocks
assert num_gpu_blocks is not None and num_gpu_blocks > 0
@ -132,7 +140,9 @@ class Scheduler(SchedulerInterface):
enable_caching=self.cache_config.enable_prefix_caching,
caching_hash_algo=self.cache_config.prefix_caching_hash_algo,
use_eagle=self.use_eagle,
log_stats=self.log_stats)
log_stats=self.log_stats,
enable_kv_cache_events=self.enable_kv_cache_events,
)
def schedule(self) -> SchedulerOutput:
# NOTE(woosuk) on the scheduling algorithm:
@ -493,6 +503,11 @@ class Scheduler(SchedulerInterface):
meta = self.connector.build_connector_meta(scheduler_output)
scheduler_output.kv_connector_metadata = meta
events = self.kv_cache_manager.take_events()
if events:
batch = KVEventBatch(ts=time.time(), events=events)
self.kv_event_publisher.publish(batch)
# Advance the number of computed tokens for the request AFTER
# the request is scheduled.
# 1. The scheduler_output of the current step has to include the
@ -843,3 +858,7 @@ class Scheduler(SchedulerInterface):
num_draft_tokens=num_draft_tokens,
num_accepted_tokens=num_accepted_tokens)
return spec_decoding_stats
def shutdown(self) -> None:
if self.kv_event_publisher:
self.kv_event_publisher.shutdown()

View File

@ -259,6 +259,8 @@ class EngineCore:
self.structured_output_manager.clear_backend()
if self.model_executor:
self.model_executor.shutdown()
if self.scheduler:
self.scheduler.shutdown()
def profile(self, is_start: bool = True):
self.model_executor.profile(is_start)