mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Core] Shared memory based object store for Multimodal data caching and IPC (#20452)
Signed-off-by: donglu <donglu@cohere.com>
This commit is contained in:
@ -789,6 +789,8 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s distributed/test_comm_ops.py
|
||||
- pytest -v -s distributed/test_shm_broadcast.py
|
||||
- pytest -v -s distributed/test_shm_buffer.py
|
||||
- pytest -v -s distributed/test_shm_storage.py
|
||||
|
||||
- label: 2 Node Tests (4 GPUs in total) # 16min
|
||||
timeout_in_minutes: 30
|
||||
|
@ -230,6 +230,20 @@ Multi-modal IPC caching is automatically enabled when
|
||||
there is a one-to-one correspondence between API (`P0`) and engine core (`P1`) processes,
|
||||
to avoid repeatedly transferring the same multi-modal inputs between them.
|
||||
|
||||
#### Key-Replicated Cache
|
||||
|
||||
By default, IPC caching uses a **key-replicated cache**, where cache keys exist
|
||||
in both the API (`P0`) and engine core (`P1`) processes, but the actual cache
|
||||
data resides only in `P1`.
|
||||
|
||||
#### Shared Memory Cache
|
||||
|
||||
When multiple worker processes are involved (e.g., when TP > 1), a
|
||||
**shared-memory cache** is more efficient. This can be enabled by setting
|
||||
`mm_processor_cache_type="shm"`. In this mode, cache keys are stored
|
||||
on `P0`, while the cache data itself lives in shared memory accessible by all
|
||||
processes.
|
||||
|
||||
### Configuration
|
||||
|
||||
You can adjust the size of the cache by setting the value of `mm_processor_cache_gb` (default 4 GiB).
|
||||
@ -244,6 +258,12 @@ Examples:
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=8)
|
||||
|
||||
# Use a shared-memory based IPC cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
tensor_parallel_size=2,
|
||||
mm_processor_cache_type="shm",
|
||||
mm_processor_cache_gb=8)
|
||||
|
||||
# Disable the cache
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_cache_gb=0)
|
||||
@ -253,11 +273,12 @@ llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
|
||||
Based on the configuration, the content of the multi-modal caches on `P0` and `P1` are as follows:
|
||||
|
||||
| Processor Caching | IPC Caching | `P0` Cache | `P1` Cache | Max. Memory |
|
||||
|-------------------|-------------|------------|------------|-------------|
|
||||
| ✅ | ✅ | K | K + V | `mm_processor_cache_gb * data_parallel_size` |
|
||||
| ✅ | ❌ | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
|
||||
| ❌ | ❌ | N/A | N/A | `0` |
|
||||
| mm_processor_cache_type | Cache Type | `P0` Cache | `P1` Engine Cache | `P1` Worker Cache | Max. Memory |
|
||||
|-------------------|-------------|------------|------------|-------------|-------------|
|
||||
| lru | Processor Caching | K + V | N/A | N/A | `mm_processor_cache_gb * data_parallel_size` |
|
||||
| lru | Key-Replicated Caching | K | K + V | N/A | `mm_processor_cache_gb * api_server_count` |
|
||||
| shm | Shared Memory Caching | K | N/A | V | `mm_processor_cache_gb * api_server_count` |
|
||||
| N/A | Disabled | N/A | N/A | N/A | `0` |
|
||||
|
||||
K: Stores the hashes of multi-modal items
|
||||
V: Stores the processed tensor data of multi-modal items
|
||||
|
172
tests/distributed/test_shm_buffer.py
Normal file
172
tests/distributed/test_shm_buffer.py
Normal file
@ -0,0 +1,172 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import traceback
|
||||
import unittest
|
||||
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
SingleWriterShmRingBuffer)
|
||||
|
||||
|
||||
class TestSingleWriterShmRingBuffer(unittest.TestCase):
|
||||
"""Test suite for the ring buffer implementation"""
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures"""
|
||||
self.buffer_size = 4096
|
||||
self.ring_buffer = None
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after tests"""
|
||||
if self.ring_buffer:
|
||||
del self.ring_buffer
|
||||
|
||||
def test_buffer_opening(self):
|
||||
"""Test opening an existing buffer"""
|
||||
# First create a buffer
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=self.buffer_size, create=True)
|
||||
|
||||
# Then open it with another instance
|
||||
reader_buffer = SingleWriterShmRingBuffer(*self.ring_buffer.handle())
|
||||
self.assertFalse(reader_buffer.is_writer)
|
||||
self.assertEqual(reader_buffer.shared_memory.name,
|
||||
self.ring_buffer.shared_memory.name)
|
||||
|
||||
def test_buffer_access(self):
|
||||
"""Test accessing allocated buffers"""
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=self.buffer_size, create=True)
|
||||
|
||||
size = 100
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||
|
||||
# Write some test data
|
||||
test_data = b"Hello, World!" * 7 # 91 bytes
|
||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||
data_buf[0:len(test_data)] = test_data
|
||||
|
||||
# Read it back
|
||||
with self.ring_buffer.access_buf(address) as (data_buf2, metadata2):
|
||||
read_data = bytes(data_buf2[0:len(test_data)])
|
||||
read_id = metadata2[0]
|
||||
|
||||
self.assertEqual(read_data, test_data)
|
||||
self.assertEqual(read_id, monotonic_id)
|
||||
|
||||
def test_memory_error_on_full_buffer(self):
|
||||
"""Test that MemoryError is raised when buffer is full"""
|
||||
small_buffer_size = 200
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=small_buffer_size, create=True)
|
||||
|
||||
# Fill up the buffer
|
||||
self.ring_buffer.allocate_buf(100)
|
||||
self.ring_buffer.allocate_buf(80) # Total: 196 bytes used
|
||||
|
||||
# This should fail
|
||||
with self.assertRaises(MemoryError):
|
||||
self.ring_buffer.allocate_buf(1) # Would exceed buffer capacity
|
||||
|
||||
def test_allocation_and_free(self):
|
||||
"""Test allocation and freeing of buffers"""
|
||||
small_buffer_size = 200
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=small_buffer_size, create=True)
|
||||
|
||||
size = 80
|
||||
# Write some data
|
||||
test_data = b"Repeated test data"
|
||||
for i in range(5):
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(size)
|
||||
with self.ring_buffer.access_buf(address) as (data_buf, metadata):
|
||||
data_buf[0:4] = (0).to_bytes(4, "little") # 0 for not in-use
|
||||
data_buf[4:len(test_data) + 4] = test_data
|
||||
print(self.ring_buffer.metadata)
|
||||
freed_ids = self.ring_buffer.free_buf(lambda *args: True)
|
||||
print(f" Freed IDs: {freed_ids}")
|
||||
self.assertEqual(freed_ids[0], i)
|
||||
|
||||
def test_clear_buffer(self):
|
||||
"""Test clearing the buffer"""
|
||||
self.ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=self.buffer_size, create=True)
|
||||
|
||||
# Allocate some buffers
|
||||
for _ in range(3):
|
||||
self.ring_buffer.allocate_buf(100)
|
||||
|
||||
# Clear the buffer
|
||||
self.ring_buffer.clear()
|
||||
|
||||
# Check that metadata is empty and IDs reset
|
||||
self.assertEqual(len(self.ring_buffer.metadata), 0)
|
||||
self.assertEqual(self.ring_buffer.monotonic_id_start, 0)
|
||||
self.assertEqual(self.ring_buffer.monotonic_id_end, 0)
|
||||
self.assertEqual(self.ring_buffer.data_buffer_start, 0)
|
||||
self.assertEqual(self.ring_buffer.data_buffer_end, 0)
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function demonstrating usage and running tests"""
|
||||
print("=== SingleWriterShmRingBuffer Test Suite ===\n")
|
||||
|
||||
# Run unit tests
|
||||
print("Running unit tests...")
|
||||
unittest.main(argv=[""], exit=False, verbosity=2)
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("=== Manual Demo ===\n")
|
||||
|
||||
# Manual demonstration
|
||||
try:
|
||||
print("Creating ring buffer...")
|
||||
writer_buffer = SingleWriterShmRingBuffer(data_buffer_size=2048,
|
||||
create=True)
|
||||
reader_buffer = SingleWriterShmRingBuffer(*writer_buffer.handle())
|
||||
|
||||
print(f"Buffer created with name: {writer_buffer.shared_memory.name}")
|
||||
|
||||
# Allocate some buffers
|
||||
print("\nAllocating buffers...")
|
||||
address_array = []
|
||||
for i in range(3):
|
||||
size = 100 + i * 50
|
||||
try:
|
||||
writer_buffer.free_buf(lambda *args: True)
|
||||
address, monotonic_id = writer_buffer.allocate_buf(size)
|
||||
address_array.append((address, size, monotonic_id))
|
||||
|
||||
# Write some test data
|
||||
with writer_buffer.access_buf(address) as (data_buf, metadata):
|
||||
test_message = f"Test message {i}".encode()
|
||||
data_buf[0:len(test_message)] = test_message
|
||||
|
||||
except MemoryError as e:
|
||||
print(f" Failed to allocate {size} bytes: {e}")
|
||||
|
||||
print("\nBuffer state:")
|
||||
print(f" Data buffer start: {writer_buffer.data_buffer_start}")
|
||||
print(f" Data buffer end: {writer_buffer.data_buffer_end}")
|
||||
print(f" Monotonic ID start: {writer_buffer.monotonic_id_start}")
|
||||
print(f" Monotonic ID end: {writer_buffer.monotonic_id_end}")
|
||||
print(f" Metadata entries: {len(writer_buffer.metadata)}")
|
||||
|
||||
# Try to read back the data
|
||||
print("\nReading back data...")
|
||||
for address, size, monotonic_id in address_array:
|
||||
with reader_buffer.access_buf(address) as (data_buf, metadata):
|
||||
# Find null terminator or read first 50 chars
|
||||
data_bytes = bytes(data_buf[0:size])
|
||||
message = data_bytes.decode()
|
||||
print(f" ID {monotonic_id}: '{message}'")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Demo error: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
print("\n=== Demo Complete ===")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
327
tests/distributed/test_shm_storage.py
Normal file
327
tests/distributed/test_shm_storage.py
Normal file
@ -0,0 +1,327 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import multiprocessing
|
||||
import random
|
||||
import time
|
||||
import traceback
|
||||
import unittest
|
||||
from multiprocessing import Lock
|
||||
|
||||
import torch
|
||||
|
||||
# Assuming these are imported from your module
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
|
||||
MultiModalSharedField)
|
||||
|
||||
|
||||
def _dummy_elem(modality: str, key: str, size: int):
|
||||
return MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key=key,
|
||||
data=torch.empty((size, ), dtype=torch.int8),
|
||||
field=MultiModalSharedField(1),
|
||||
)
|
||||
|
||||
|
||||
def _dummy_item(modality: str, size_by_key: dict[str, int]):
|
||||
return MultiModalKwargsItem.from_elems([
|
||||
_dummy_elem(modality, key, size) for key, size in size_by_key.items()
|
||||
])
|
||||
|
||||
|
||||
class TestSingleWriterShmObjectStorage(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
"""Set up test fixtures before each test method."""
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=1024 * 100,
|
||||
create=True, # 10 MB buffer
|
||||
)
|
||||
self.storage = SingleWriterShmObjectStorage(
|
||||
max_object_size=1024 * 10, # 10KB max object
|
||||
n_readers=2,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=Lock(),
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
"""Clean up after each test."""
|
||||
if self.storage:
|
||||
del self.storage
|
||||
|
||||
def test_minimal_put_get_cycle(self):
|
||||
"""Test basic put and get operations."""
|
||||
key = "test_key"
|
||||
value = _dummy_item("text", {"field1": 10, "field2": 20})
|
||||
|
||||
# Put operation
|
||||
address, monotonic_id = self.storage.put(key, value)
|
||||
|
||||
# Verify key is in index
|
||||
self.assertIn(key, self.storage.key_index)
|
||||
self.assertEqual(self.storage.key_index[key], (address, monotonic_id))
|
||||
self.assertEqual(self.storage.id_index[monotonic_id], key)
|
||||
|
||||
# Get operation
|
||||
result = self.storage.get(address, monotonic_id)
|
||||
|
||||
# Verify result
|
||||
self.assertEqual(result, value)
|
||||
|
||||
def test_put_same_key_twice(self):
|
||||
"""Test behavior when putting the same key multiple times."""
|
||||
key = "duplicate_key"
|
||||
value1 = "first value"
|
||||
value2 = "second value"
|
||||
|
||||
# First put
|
||||
address1, id1 = self.storage.put(key, value1)
|
||||
retrieved1 = self.storage.get(address1, id1)
|
||||
self.assertEqual(retrieved1, value1)
|
||||
|
||||
# should raise an error on second put
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.storage.put(key, value2)
|
||||
|
||||
self.assertIn("already exists in the storage", str(context.exception))
|
||||
|
||||
def test_large_object_rejection(self):
|
||||
"""Test that objects exceeding max_object_size are rejected."""
|
||||
# Create an object larger than max_object_size
|
||||
large_data = "x" * (self.storage.max_object_size + 100)
|
||||
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.storage.put("large_key", large_data)
|
||||
|
||||
self.assertIn("exceeds max object size", str(context.exception))
|
||||
|
||||
def test_buffer_overflow_and_cleanup(self):
|
||||
"""Test behavior when buffer fills up and needs cleanup."""
|
||||
# Fill up the buffer with many small objects
|
||||
stored_items = []
|
||||
|
||||
try:
|
||||
for i in range(1000): # Try to store many items
|
||||
key = f"item_{i}"
|
||||
value = f"data_{i}" * 100 # Make it reasonably sized
|
||||
address, monotonic_id = self.storage.put(key, value)
|
||||
stored_items.append((key, value, address, monotonic_id))
|
||||
except MemoryError:
|
||||
print(f"Buffer filled after {len(stored_items)} items")
|
||||
|
||||
# Verify that some items are still accessible
|
||||
accessible_count = 0
|
||||
for key, original_value, address, monotonic_id in stored_items:
|
||||
for i in range(self.storage.n_readers):
|
||||
retrieved = self.storage.get(address, monotonic_id)
|
||||
if retrieved == original_value:
|
||||
accessible_count += 1
|
||||
|
||||
self.assertEqual(accessible_count, len(stored_items))
|
||||
|
||||
try:
|
||||
for i in range(len(stored_items), 1000): # Try to store many items
|
||||
key = f"item_{i}"
|
||||
value = f"data_{i}" * 100 # Make it reasonably sized
|
||||
address, monotonic_id = self.storage.put(key, value)
|
||||
stored_items.append((key, value, address, monotonic_id))
|
||||
except MemoryError:
|
||||
print(f"Buffer filled after {len(stored_items)} items")
|
||||
|
||||
# Verify that some items are still accessibles
|
||||
for key, original_value, address, monotonic_id in stored_items:
|
||||
try:
|
||||
for i in range(self.storage.n_readers):
|
||||
retrieved = self.storage.get(address, monotonic_id)
|
||||
if retrieved == original_value:
|
||||
accessible_count += 1
|
||||
except ValueError as e:
|
||||
print(f"Error retrieving {key}: {e}")
|
||||
|
||||
# some items from the first batch may still be accessible
|
||||
self.assertGreaterEqual(accessible_count, len(stored_items))
|
||||
|
||||
def test_blocking_unread_object(self):
|
||||
"""Test behavior when buffer fills up and needs cleanup."""
|
||||
# Fill up the buffer with many small objects
|
||||
stored_items = []
|
||||
|
||||
try:
|
||||
for i in range(1000): # Try to store many items
|
||||
key = f"item_{i}"
|
||||
value = f"data_{i}" * 100 # Make it reasonably sized
|
||||
address, monotonic_id = self.storage.put(key, value)
|
||||
stored_items.append((key, value, address, monotonic_id))
|
||||
except MemoryError:
|
||||
print(f"Buffer filled after {len(stored_items)} items")
|
||||
|
||||
# read all items except the first one
|
||||
# to simulate a blocking situation
|
||||
accessible_count = 0
|
||||
for key, original_value, address, monotonic_id in stored_items[1:]:
|
||||
for i in range(self.storage.n_readers):
|
||||
retrieved = self.storage.get(address, monotonic_id)
|
||||
if retrieved == original_value:
|
||||
accessible_count += 1
|
||||
|
||||
self.assertEqual(accessible_count, len(stored_items) - 1)
|
||||
|
||||
try:
|
||||
key = f"item_{len(stored_items)}"
|
||||
value = f"data_{len(stored_items)}" * 100
|
||||
address, monotonic_id = self.storage.put(key, value)
|
||||
except MemoryError:
|
||||
print(f"Buffer filled after {len(stored_items)} items")
|
||||
|
||||
# read the first item
|
||||
for i in range(self.storage.n_readers):
|
||||
key, original_value, address, monotonic_id = stored_items[0]
|
||||
retrieved = self.storage.get(address, monotonic_id)
|
||||
self.assertEqual(retrieved, original_value)
|
||||
|
||||
try:
|
||||
for i in range(len(stored_items), 1000): # Try to store many items
|
||||
key = f"item_{i}"
|
||||
value = f"data_{i}" * 100 # Make it reasonably sized
|
||||
address, monotonic_id = self.storage.put(key, value)
|
||||
stored_items.append((key, value, address, monotonic_id))
|
||||
except MemoryError:
|
||||
print(f"Buffer filled after {len(stored_items)} items")
|
||||
|
||||
# some items from the first batch may still be accessible
|
||||
self.assertGreaterEqual(len(stored_items), accessible_count + 10)
|
||||
|
||||
def test_invalid_get_operations(self):
|
||||
"""Test various invalid get operations."""
|
||||
# Test with non-existent address
|
||||
with self.assertRaises(ValueError): # Could be various exceptions
|
||||
self.storage.get(99999, 1)
|
||||
|
||||
# Store something first
|
||||
address, monotonic_id = self.storage.put("test", "value")
|
||||
|
||||
# Test with wrong monotonic_id
|
||||
with self.assertRaises(ValueError) as context:
|
||||
self.storage.get(address, monotonic_id + 100)
|
||||
|
||||
self.assertIn("has been modified or is invalid", \
|
||||
str(context.exception))
|
||||
|
||||
def test_clear_storage(self):
|
||||
"""Test clearing the storage."""
|
||||
# Store some items
|
||||
for i in range(5):
|
||||
self.storage.put(f"item_{i}", f"value_{i}")
|
||||
|
||||
# Clear the storage
|
||||
self.storage.clear()
|
||||
|
||||
# Verify that all indices are empty
|
||||
self.assertEqual(len(self.storage.key_index), 0)
|
||||
self.assertEqual(len(self.storage.id_index), 0)
|
||||
self.assertEqual(len(self.storage.ring_buffer.metadata), 0)
|
||||
|
||||
# Verify that new items can be added after clearing
|
||||
address, monotonic_id = self.storage.put("new_item", "new_value")
|
||||
self.assertIn("new_item", self.storage.key_index)
|
||||
self.assertEqual((address, monotonic_id), (0, 0))
|
||||
|
||||
|
||||
# Reader process function
|
||||
def reader_process(process_id, storage_handle, items_to_read):
|
||||
"""Reader process that connects to existing shared memory and reads data."""
|
||||
reader_storage = SingleWriterShmObjectStorage.create_from_handle(
|
||||
storage_handle)
|
||||
|
||||
print(f"Reader {process_id} started")
|
||||
|
||||
errors = []
|
||||
|
||||
for key, original_value, address, monotonic_id in items_to_read:
|
||||
time.sleep(random.random() / 100)
|
||||
try:
|
||||
# Read data from shared memory
|
||||
retrieved_value = reader_storage.get(address, monotonic_id)
|
||||
|
||||
# Verify data integrity
|
||||
assert retrieved_value == original_value
|
||||
print(f"Reader {process_id} retrieved {key}: {retrieved_value}")
|
||||
except Exception as e:
|
||||
errors.append((key, str(e), type(e).__name__))
|
||||
|
||||
|
||||
def run_multiprocess_example():
|
||||
"""Run a minimal working example with real shared memory."""
|
||||
print("=== Minimal Object Storage Example ===")
|
||||
|
||||
try:
|
||||
# Create storage instance
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=1024 * 100,
|
||||
create=True, # 10 MB buffer
|
||||
)
|
||||
storage = SingleWriterShmObjectStorage(
|
||||
max_object_size=1024,
|
||||
n_readers=3,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=Lock(),
|
||||
)
|
||||
|
||||
print(f"Created storage (writer: {storage.is_writer})")
|
||||
|
||||
# Test basic data types
|
||||
test_data = [
|
||||
("user_data", {
|
||||
"name": "Alice",
|
||||
"age": 30,
|
||||
"scores": [95, 87, 92]
|
||||
}),
|
||||
("simple_string", "Hello, World!"),
|
||||
("number", 42),
|
||||
("list_data", [1, 2, 3, "four", 5.0]),
|
||||
]
|
||||
|
||||
stored_items = []
|
||||
|
||||
# Store all data
|
||||
for key, value in test_data:
|
||||
print(f"Storing {key}: {value}")
|
||||
address, monotonic_id = storage.put(key, value)
|
||||
stored_items.append((key, value, address, monotonic_id))
|
||||
print(f" -> Stored at address {address}, ID {monotonic_id}")
|
||||
|
||||
print("\n--- Retrieving Data ---")
|
||||
processes = []
|
||||
handle = storage.handle()
|
||||
# initialize lock for reader processes
|
||||
handle.reader_lock = Lock()
|
||||
for i in range(storage.n_readers):
|
||||
p = multiprocessing.Process(target=reader_process,
|
||||
args=(i, handle, stored_items))
|
||||
processes.append(p)
|
||||
p.start()
|
||||
|
||||
for p in processes:
|
||||
p.join(timeout=10)
|
||||
if p.is_alive():
|
||||
p.terminate()
|
||||
p.join()
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in minimal example: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the minimal example first
|
||||
run_multiprocess_example()
|
||||
print("\n" + "=" * 50 + "\n")
|
||||
|
||||
# Run the test suite
|
||||
print("Running comprehensive test suite...")
|
||||
unittest.main(verbosity=2, exit=False)
|
@ -10,8 +10,8 @@ from vllm.config import ModelConfig, ParallelConfig, VllmConfig
|
||||
from vllm.multimodal.cache import (MultiModalCache,
|
||||
MultiModalProcessorCacheItem,
|
||||
MultiModalProcessorCacheItemMetadata,
|
||||
processor_cache_from_config,
|
||||
receiver_cache_from_config)
|
||||
engine_receiver_cache_from_config,
|
||||
processor_cache_from_config)
|
||||
from vllm.multimodal.hasher import MultiModalHasher
|
||||
from vllm.multimodal.inputs import (MultiModalFieldElem, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems,
|
||||
@ -115,9 +115,9 @@ def _compare_caches(
|
||||
):
|
||||
mm_registry = MultiModalRegistry()
|
||||
cache_0_p0 = processor_cache_from_config(config_0, mm_registry)
|
||||
cache_0_p1 = receiver_cache_from_config(config_0, mm_registry)
|
||||
cache_0_p1 = engine_receiver_cache_from_config(config_0, mm_registry)
|
||||
cache_1_p0 = processor_cache_from_config(config_1, mm_registry)
|
||||
cache_1_p1 = receiver_cache_from_config(config_1, mm_registry)
|
||||
cache_1_p1 = engine_receiver_cache_from_config(config_1, mm_registry)
|
||||
|
||||
cache_size_gb = max(
|
||||
config_0.model_config.mm_processor_cache_gb,
|
||||
|
@ -90,6 +90,7 @@ class DummyExecutor(UniProcExecutor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.mm_receiver_cache = None
|
||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||
self.collective_rpc("init_device")
|
||||
|
||||
|
@ -39,6 +39,7 @@ ALLOWED_FILES = set([
|
||||
'vllm/engine/multiprocessing/client.py',
|
||||
'vllm/distributed/device_communicators/all_reduce_utils.py',
|
||||
'vllm/distributed/device_communicators/shm_broadcast.py',
|
||||
'vllm/distributed/device_communicators/shm_object_storage.py',
|
||||
'vllm/engine/multiprocessing/engine.py',
|
||||
'benchmarks/kernels/graph_machete_bench.py',
|
||||
'benchmarks/kernels/benchmark_lora.py',
|
||||
|
@ -262,6 +262,7 @@ def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
TokenizerMode = Literal["auto", "slow", "mistral", "custom"]
|
||||
ModelDType = Literal["auto", "half", "float16", "bfloat16", "float", "float32"]
|
||||
MMEncoderTPMode = Literal["weights", "data"]
|
||||
MMCacheType = Literal["shm", "lru"]
|
||||
|
||||
|
||||
class LogprobsMode(enum.Enum):
|
||||
@ -450,6 +451,13 @@ class ModelConfig:
|
||||
`mm_processor_cache_gb * (api_server_count + data_parallel_size)`.
|
||||
|
||||
Set to `0` to disable this cache completely (not recommended)."""
|
||||
mm_processor_cache_type: MMCacheType = "lru"
|
||||
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
|
||||
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
|
||||
mm_shm_cache_max_object_size_mb: int = 128
|
||||
"""Size limit (in MiB) for each object stored in the multi-modal processor
|
||||
shared memory cache. Only effective when `mm_processor_cache_type` is
|
||||
`"shm"`."""
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
|
||||
"""Indicates how to optimize multi-modal encoder inference using
|
||||
tensor parallelism (TP).
|
||||
@ -881,6 +889,9 @@ class ModelConfig:
|
||||
media_io_kwargs=self.media_io_kwargs,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||
mm_processor_cache_type=self.mm_processor_cache_type,
|
||||
mm_shm_cache_max_object_size_mb=self.
|
||||
mm_shm_cache_max_object_size_mb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
interleave_mm_strings=self.interleave_mm_strings,
|
||||
skip_mm_profiling=self.skip_mm_profiling,
|
||||
@ -2448,6 +2459,15 @@ class MultiModalConfig:
|
||||
Set to `0` to disable this cache completely (not recommended).
|
||||
"""
|
||||
|
||||
mm_processor_cache_type: MMCacheType = "lru"
|
||||
"""Type of cache to use for the multi-modal preprocessor/mapper. If `shm`,
|
||||
use shared memory FIFO cache. If `lru`, use mirrored LRU cache."""
|
||||
|
||||
mm_shm_cache_max_object_size_mb: int = 128
|
||||
"""Size limit (in MiB) for each object stored in the multi-modal processor
|
||||
shared memory cache. Only effective when `mm_processor_cache_type` is
|
||||
`"shm"`."""
|
||||
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = "weights"
|
||||
"""
|
||||
Indicates how to optimize multi-modal encoder inference using
|
||||
|
635
vllm/distributed/device_communicators/shm_object_storage.py
Normal file
635
vllm/distributed/device_communicators/shm_object_storage.py
Normal file
@ -0,0 +1,635 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import pickle
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain
|
||||
from multiprocessing import shared_memory
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class SingleWriterShmRingBuffer:
|
||||
"""
|
||||
A single-writer, multiple-reader ring buffer implementation using shared
|
||||
memory. This class provides a thread-safe ring buffer where one process
|
||||
can write data while multiple processes/threads can read from it.
|
||||
|
||||
Architecture:
|
||||
- Uses shared memory for cross-process communication
|
||||
- Maintains metadata for each allocated buffer chunk in the writer process
|
||||
- Supports custom "is_free_fn" functions to determine when buffers can be
|
||||
reused
|
||||
- Each buffer chunk contains: [4-byte id][4-byte size][actual_data]
|
||||
|
||||
Key Concepts:
|
||||
- monotonic_id_start/end: Track the range of active buffer IDs
|
||||
- data_buffer_start/end: Track the physical memory range in use
|
||||
- Automatic wraparound when reaching buffer end
|
||||
- Lazy garbage collection based on is_free_fn checks
|
||||
|
||||
Example Usage Scenarios:
|
||||
|
||||
Scenario 1: Simple Linear Allocation
|
||||
```
|
||||
Buffer size: 100 bytes
|
||||
Initial state: [................................................. ]
|
||||
^start=end(0)
|
||||
|
||||
After allocating 20 bytes (id=0):
|
||||
[id:0|size:20|data........][...................................]
|
||||
^start(0) ^end(28)
|
||||
|
||||
After allocating 30 bytes (id=1):
|
||||
[id:0|size:20|data........][id:1|size:30|data..............][..]
|
||||
^start(0) ^end(66)
|
||||
```
|
||||
|
||||
Scenario 2: Memory Reclamation
|
||||
```
|
||||
Before freeing (both buffers still in use):
|
||||
[id:0|size:20|data........][id:1|size:30|data..............][..]
|
||||
^start(0) ^end(66)
|
||||
|
||||
After id:0 is marked free by readers:
|
||||
[FREED.................... ][id:1|size:30|data..............][..]
|
||||
^start(28) ^end(66)
|
||||
|
||||
After both are freed:
|
||||
[FREED..............................................][..]
|
||||
^start=end(66)
|
||||
```
|
||||
|
||||
Scenario 3: Wraparound Allocation (continuing from Scenario 2)
|
||||
```
|
||||
Starting from after memory reclamation in Scenario 2:
|
||||
[FREED..............................................][..]
|
||||
^start=end(66)
|
||||
|
||||
Allocate 40 bytes (id=2) - only 34 bytes available at end, so wraparound:
|
||||
[id:2|size:40|data........................][FREED.............][..]
|
||||
^end(148) ^start(66)
|
||||
```
|
||||
|
||||
Scenario 4: Error Handling - Out of Space
|
||||
```
|
||||
Starting from after wraparound allocation in Scenario 3:
|
||||
[id:2|size:40|data........................][FREED.............][..]
|
||||
^end(148) ^start(66)
|
||||
|
||||
Trying to allocate 20 more bytes:
|
||||
occupied_size_new = end + size - start = 148 + 28 - 66 > buffer_size(100)
|
||||
-> Raises MemoryError: "Not enough space in the data buffer"
|
||||
```
|
||||
|
||||
Thread Safety:
|
||||
- Single writer: Only one process/thread should write (allocate_buf)
|
||||
- Multiple readers: Multiple processes/threads can read (access_buf)
|
||||
- Reader synchronization handled by is_free_fn callback
|
||||
- Writer handles garbage collection (free_buf) based on reader feedback
|
||||
|
||||
Memory Layout per Buffer Chunk:
|
||||
[4-byte monotonic_id][4-byte chunk_size][actual_data...]
|
||||
^metadata_start ^data_start
|
||||
|
||||
The monotonic_id ensures data integrity - readers can verify they're
|
||||
accessing the correct data even after buffer wraparound or reuse.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data_buffer_size: int,
|
||||
name: Optional[str] = None,
|
||||
create: bool = False,
|
||||
):
|
||||
self.data_buffer_size = data_buffer_size
|
||||
self.is_writer = create
|
||||
|
||||
self.ID_NBYTES = 4
|
||||
self.ID_MAX = 2**31 # exclusive, so 2**31 - 1 is the max value
|
||||
self.SIZE_NBYTES = 4
|
||||
# 4 bytes for id, 4 bytes for buffer size
|
||||
self.MD_SIZE = self.ID_NBYTES + self.SIZE_NBYTES
|
||||
self.monotonic_id_end = 0
|
||||
self.monotonic_id_start = 0
|
||||
self.data_buffer_start = 0
|
||||
self.data_buffer_end = 0
|
||||
|
||||
if create:
|
||||
# we are creating a buffer
|
||||
self.metadata = {
|
||||
self.monotonic_id_end: self.data_buffer_end
|
||||
} # monotonic_id -> start address
|
||||
self.shared_memory = shared_memory.SharedMemory(
|
||||
create=True, size=self.data_buffer_size, name=name)
|
||||
else:
|
||||
# we are opening an existing buffer
|
||||
# fix to https://stackoverflow.com/q/62748654/9191338
|
||||
# Python incorrectly tracks shared memory even if it is not
|
||||
# created by the process. The following patch is a workaround.
|
||||
with patch(
|
||||
"multiprocessing.resource_tracker.register",
|
||||
lambda *args, **kwargs: None,
|
||||
):
|
||||
self.shared_memory = shared_memory.SharedMemory(name=name)
|
||||
# See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
|
||||
# Some platforms allocate memory based on page size,
|
||||
# so the shared memory block size may be larger or equal
|
||||
# to the requested size. The size parameter is ignored
|
||||
# when attaching to an existing block.
|
||||
assert self.shared_memory.size >= self.data_buffer_size
|
||||
|
||||
logger.debug("Shared memory created/opened with name: %s, size: %d",
|
||||
self.shared_memory.name, self.data_buffer_size)
|
||||
|
||||
def handle(self):
|
||||
return (
|
||||
self.data_buffer_size,
|
||||
self.shared_memory.name,
|
||||
)
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the ring buffer."""
|
||||
assert self.is_writer, "Only the writer can clear the buffer."
|
||||
self.metadata.clear()
|
||||
self.monotonic_id_end = 0
|
||||
self.monotonic_id_start = 0
|
||||
self.data_buffer_start = 0
|
||||
self.data_buffer_end = 0
|
||||
|
||||
def __del__(self):
|
||||
if hasattr(self, "shared_memory"):
|
||||
self.shared_memory.close()
|
||||
if self.is_writer:
|
||||
self.shared_memory.unlink()
|
||||
|
||||
def int2byte(self, integer: int) -> bytes:
|
||||
"""Convert an integer to bytes."""
|
||||
return integer.to_bytes(self.ID_NBYTES, "little", signed=True)
|
||||
|
||||
def byte2int(self, byte_data: bytes) -> int:
|
||||
"""Convert bytes back to an integer."""
|
||||
return int.from_bytes(byte_data, "little", signed=True)
|
||||
|
||||
def allocate_buf(self, size: int) -> tuple[int, int]:
|
||||
'''
|
||||
Allocate a buffer `MD_SIZE` + `size` bytes in the shared memory.
|
||||
Memory layout:
|
||||
[4-byte monotonic_id][4-byte size][buffer data...]
|
||||
'''
|
||||
assert self.is_writer, "Only the writer can allocate buffers."
|
||||
assert size > 0, "Size must be greater than 0"
|
||||
size += self.MD_SIZE # add metadata size to the buffer size
|
||||
# reset to beginning if the buffer does have enough contiguous space
|
||||
buffer_end_reset = self.data_buffer_end % self.data_buffer_size
|
||||
if buffer_end_reset + size > self.data_buffer_size:
|
||||
buffer_end_reset = (self.data_buffer_end // self.data_buffer_size +
|
||||
1) * self.data_buffer_size
|
||||
else: # no reset needed
|
||||
buffer_end_reset = self.data_buffer_end
|
||||
|
||||
# check if we have enough space in the data buffer
|
||||
# i.e. if the new end (self.data_buffer_end + size)
|
||||
# exceeds the start of the data buffer
|
||||
occupied_size_new = buffer_end_reset + size - self.data_buffer_start
|
||||
if occupied_size_new > self.data_buffer_size:
|
||||
raise MemoryError("Not enough space in the data buffer, "
|
||||
"try calling free_buf() to free up space")
|
||||
self.data_buffer_end = buffer_end_reset
|
||||
|
||||
# first 4 bytes as the monotonic id
|
||||
buf_idx = self.data_buffer_end % self.data_buffer_size
|
||||
self.shared_memory.buf[buf_idx:buf_idx + self.ID_NBYTES] = \
|
||||
self.int2byte(self.monotonic_id_end)
|
||||
# next 4 bytes as the size of the data buffer
|
||||
self.shared_memory.buf[buf_idx + self.ID_NBYTES: \
|
||||
buf_idx + self.MD_SIZE] = self.int2byte(size)
|
||||
|
||||
# record metadata
|
||||
self.metadata[self.monotonic_id_end %
|
||||
self.ID_MAX] = self.data_buffer_end
|
||||
# update buffer and monotonic id indices
|
||||
current_buffer_end = self.data_buffer_end
|
||||
current_id_end = self.monotonic_id_end
|
||||
self.data_buffer_end += size
|
||||
self.monotonic_id_end = (self.monotonic_id_end + 1) % self.ID_MAX
|
||||
return current_buffer_end, current_id_end
|
||||
|
||||
@contextmanager
|
||||
def access_buf(self, address: int):
|
||||
buf_idx = address % self.data_buffer_size
|
||||
|
||||
# read metadata
|
||||
metadata_buff = self.shared_memory.buf[buf_idx:buf_idx + self.MD_SIZE]
|
||||
id = self.byte2int(metadata_buff[:self.ID_NBYTES])
|
||||
size = self.byte2int(metadata_buff[self.ID_NBYTES:self.MD_SIZE])
|
||||
|
||||
# yield the data buffer and metadata
|
||||
data_buff = self.shared_memory.buf[buf_idx + self.MD_SIZE:buf_idx +
|
||||
size]
|
||||
with (memoryview(data_buff) as data_view, ):
|
||||
yield data_view, (id, size)
|
||||
|
||||
def free_buf(self,
|
||||
is_free_fn: Callable[[int, memoryview], bool],
|
||||
nbytes: Optional[int] = None) -> Iterable[int]:
|
||||
'''
|
||||
Free a buffer of the given size. This is a no-op in shared memory,
|
||||
but we need to keep track of the metadata.
|
||||
|
||||
If freed memory spreads across the end and start of the ring buffer,
|
||||
the actual freed memory will be in two segments. In this case there
|
||||
still might not be a contiguous space of `nbytes` available.
|
||||
|
||||
Args:
|
||||
nbytes (int, optional): The size of the buffer to free. If None,
|
||||
frees the maximum size of the ring buffer.
|
||||
'''
|
||||
|
||||
assert self.is_writer, "Only the writer can free buffers."
|
||||
logger.debug(
|
||||
"Freeing up space in the ring buffer, "
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d",
|
||||
self.monotonic_id_start, self.monotonic_id_end)
|
||||
monotonic_id_before = self.monotonic_id_start
|
||||
# if nbytes is None, free up the maximum size of the ring buffer
|
||||
if nbytes is None:
|
||||
nbytes = self.data_buffer_size
|
||||
freed_bytes = 0
|
||||
while self.monotonic_id_start in self.metadata and freed_bytes < nbytes:
|
||||
address = self.metadata[self.monotonic_id_start]
|
||||
with self.access_buf(address) as (data_buff, metadata):
|
||||
if is_free_fn(self.monotonic_id_start, data_buff):
|
||||
# check passed, we can free the buffer
|
||||
del self.metadata[self.monotonic_id_start]
|
||||
self.monotonic_id_start = ((self.monotonic_id_start + 1) %
|
||||
self.ID_MAX)
|
||||
self.data_buffer_start = address
|
||||
freed_bytes += metadata[1]
|
||||
else:
|
||||
# there are still readers, we cannot free the buffer
|
||||
break
|
||||
|
||||
logger.debug(
|
||||
"Freed %d bytes from the ring buffer, "
|
||||
"monotonic_id_start: %d, monotonic_id_end: %d", freed_bytes,
|
||||
self.monotonic_id_start, self.monotonic_id_end)
|
||||
|
||||
# buffer wrap around
|
||||
if self.data_buffer_start >= self.data_buffer_size:
|
||||
self.data_buffer_start -= self.data_buffer_size
|
||||
self.data_buffer_end -= self.data_buffer_size
|
||||
|
||||
monotonic_id_after = self.monotonic_id_start
|
||||
# id wrap around
|
||||
if monotonic_id_after >= monotonic_id_before:
|
||||
return range(monotonic_id_before, monotonic_id_after)
|
||||
else:
|
||||
return chain(range(monotonic_id_before, self.ID_MAX),
|
||||
range(0, monotonic_id_after))
|
||||
|
||||
|
||||
class ObjectSerde(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def serialize(self, value: Any) -> tuple[Any, int, bytes, int]:
|
||||
"""Serialize an object to bytes."""
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def deserialize(self, data: memoryview) -> Any:
|
||||
"""Deserialize bytes back to an object."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MsgpackSerde(ObjectSerde):
|
||||
|
||||
def __init__(self):
|
||||
# Delayed import to avoid circular dependency
|
||||
from vllm.multimodal.inputs import MultiModalKwargsItem
|
||||
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
|
||||
|
||||
self.encoder = MsgpackEncoder()
|
||||
self.tensor_decoder = MsgpackDecoder(torch.Tensor)
|
||||
self.mm_decoder = MsgpackDecoder(MultiModalKwargsItem)
|
||||
self._mm_kwargs_item_cls = MultiModalKwargsItem
|
||||
|
||||
def serialize(
|
||||
self,
|
||||
value: Any) -> tuple[Union[bytes, list[bytes]], int, bytes, int]:
|
||||
len_arr = None
|
||||
if isinstance(value, (torch.Tensor, self._mm_kwargs_item_cls)):
|
||||
type_name = type(value).__name__
|
||||
value = self.encoder.encode(value)
|
||||
len_arr = [len(s) for s in value]
|
||||
nbytes = sum(len_arr)
|
||||
else:
|
||||
value = pickle.dumps(value, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
type_name = type(value).__name__
|
||||
nbytes = len(value)
|
||||
|
||||
object_metadata = (type_name, nbytes, len_arr)
|
||||
serialized_metadata = pickle.dumps(object_metadata,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
return value, nbytes, serialized_metadata, len(serialized_metadata)
|
||||
|
||||
def deserialize(self, data_view: memoryview) -> Any:
|
||||
# pickle.loads do not read past the end of a pickled object
|
||||
# within a large buffer, so we can skip storing the metadata size
|
||||
type_name, nbytes, len_arr = pickle.loads(data_view)
|
||||
serialized_data = bytearray(data_view[-nbytes:])
|
||||
|
||||
if type_name == torch.Tensor.__name__:
|
||||
obj = []
|
||||
start_idx = 0
|
||||
for length in len_arr:
|
||||
item_bytes = serialized_data[start_idx:start_idx + length]
|
||||
obj.append(item_bytes)
|
||||
start_idx += length
|
||||
obj = self.tensor_decoder.decode(obj)
|
||||
elif type_name == self._mm_kwargs_item_cls.__name__:
|
||||
obj = []
|
||||
start_idx = 0
|
||||
for length in len_arr:
|
||||
item_bytes = serialized_data[start_idx:start_idx + length]
|
||||
obj.append(item_bytes)
|
||||
start_idx += length
|
||||
obj = self.mm_decoder.decode(obj)
|
||||
elif type_name == bytes.__name__:
|
||||
obj = pickle.loads(serialized_data)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported object type '{type_name}' in metadata")
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShmObjectStorageHandle:
|
||||
max_object_size: int
|
||||
n_readers: int
|
||||
ring_buffer_handle: tuple[int, str]
|
||||
serde_class: type[ObjectSerde]
|
||||
reader_lock: Optional[LockType]
|
||||
|
||||
|
||||
class SingleWriterShmObjectStorage:
|
||||
"""
|
||||
A single-writer, multiple-reader object storage system built on top of a
|
||||
shared memory ring buffer. Provides key-value storage with automatic memory
|
||||
management and cross-process serialization support.
|
||||
|
||||
This storage system follows a FIFO (First-In-First-Out) eviction policy
|
||||
where the oldest objects are automatically freed when memory runs low.
|
||||
Memory is reclaimed based on reader reference counting - objects are only
|
||||
freed when all readers have finished accessing them.
|
||||
|
||||
Architecture:
|
||||
- Single writer process can put(key, value) objects
|
||||
- Multiple reader processes can get(address, monotonic_id) objects
|
||||
- Built on SingleWriterShmRingBuffer for efficient shared memory management
|
||||
- Thread-safe operations with reader synchronization via locks
|
||||
|
||||
Key Features:
|
||||
- FIFO Eviction: Oldest objects are evicted first when memory is full
|
||||
- Reference Counting: Objects are only freed when no readers are
|
||||
accessing them
|
||||
- Duplicate Key Handling: Existing keys are not overwritten, just
|
||||
re-referenced
|
||||
- Customized Serialization: By default uses Msgpack for efficient
|
||||
serialization of Python objects, but can be extended for custom types
|
||||
- Cross-Process Safety: Uses shared memory with proper synchronization
|
||||
- Automatic Cleanup: Garbage collection happens transparently during
|
||||
allocation
|
||||
|
||||
Memory Layout per Object:
|
||||
[4-byte reference_count][metadata_size][serialized_object_data]
|
||||
|
||||
Thread Safety:
|
||||
- Writer operations (put, clear) are single-threaded by design
|
||||
- Reader operations (get) are thread-safe with lock-based reference
|
||||
counting
|
||||
- Memory reclamation is handled exclusively by the writer process
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_object_size: int,
|
||||
n_readers: int,
|
||||
ring_buffer: SingleWriterShmRingBuffer,
|
||||
serde_class: type[ObjectSerde] = MsgpackSerde,
|
||||
reader_lock: Optional[LockType] = None,
|
||||
):
|
||||
"""
|
||||
Initialize the object storage.
|
||||
|
||||
Args:
|
||||
max_object_size: Maximum size for a single object in bytes.
|
||||
n_readers: Number of reader processes that can access the storage.
|
||||
ring_buffer: The shared memory ring buffer for storing objects.
|
||||
serde_class: Serializer/deserializer for objects.
|
||||
reader_lock: Optional lock for synchronizing reader access.
|
||||
Raises:
|
||||
ValueError: If reader_lock is None for readers.
|
||||
"""
|
||||
|
||||
self.max_object_size = max_object_size
|
||||
self.n_readers = n_readers
|
||||
self.serde_class = serde_class
|
||||
self.ser_de = serde_class()
|
||||
self.ring_buffer = ring_buffer
|
||||
self.is_writer = self.ring_buffer.is_writer
|
||||
|
||||
self.flag_bytes = 4 # for in-use flag
|
||||
|
||||
if self.is_writer:
|
||||
# Key-value mapping: key -> (address, monotonic_id)
|
||||
self.key_index: dict[str, tuple[int, int]] = {}
|
||||
# Reverse mapping: monotonic_id -> key
|
||||
self.id_index: dict[int, str] = {}
|
||||
# Writer flag to track in-use status: monotonic_id -> count
|
||||
self.writer_flag: dict[int, int] = {}
|
||||
else:
|
||||
if reader_lock is None:
|
||||
raise ValueError("Lock must be provided for readers.")
|
||||
|
||||
self._reader_lock = reader_lock
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear the object storage."""
|
||||
if self.is_writer:
|
||||
self.ring_buffer.clear()
|
||||
self.key_index.clear()
|
||||
self.id_index.clear()
|
||||
self.writer_flag.clear()
|
||||
logger.debug("Object storage cleared and reinitialized.")
|
||||
|
||||
def copy_to_buffer(
|
||||
self,
|
||||
data: Union[bytes, list[bytes]],
|
||||
data_bytes: int,
|
||||
metadata: bytes,
|
||||
md_bytes: int,
|
||||
data_view: memoryview,
|
||||
) -> None:
|
||||
data_view[self.flag_bytes:self.flag_bytes + md_bytes] = metadata
|
||||
if isinstance(data, bytes):
|
||||
data_view[-data_bytes:] = data
|
||||
elif isinstance(data, list):
|
||||
start_idx = self.flag_bytes + md_bytes
|
||||
for item_bytes in data:
|
||||
item_size = len(item_bytes)
|
||||
data_view[start_idx:start_idx + item_size] = item_bytes
|
||||
start_idx += item_size
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported data type for serialization: {type(data)}")
|
||||
|
||||
def increment_writer_flag(self, id: int) -> None:
|
||||
"""Set the in-use flag for the writer."""
|
||||
self.writer_flag[id] = self.writer_flag.get(id, 0) + 1
|
||||
|
||||
def increment_reader_flag(self, data_view: memoryview) -> None:
|
||||
"""Set the in-use flag for the reader."""
|
||||
# >0 for in-use flag
|
||||
reader_count = self.ring_buffer.byte2int(data_view)
|
||||
data_view[:] = self.ring_buffer.int2byte(reader_count + 1)
|
||||
|
||||
def free_unused(self) -> None:
|
||||
"""Free unused buffers in the ring buffer."""
|
||||
# try to free up 2*max_object_size bytes of space in the ring buffer,
|
||||
# since the buffer might be fragmented
|
||||
freed_ids = self.ring_buffer.free_buf(self.default_is_free_check,
|
||||
2 * self.max_object_size)
|
||||
# update the metadata after freeing up space
|
||||
for freed_id in freed_ids:
|
||||
key_to_free = self.id_index[freed_id]
|
||||
del self.key_index[key_to_free]
|
||||
del self.id_index[freed_id]
|
||||
del self.writer_flag[freed_id]
|
||||
|
||||
def is_cached(self, key: str) -> bool:
|
||||
"""
|
||||
Check if the object with the given key is cached.
|
||||
"""
|
||||
return key in self.key_index
|
||||
|
||||
def get_cached(self, key: str) -> tuple[int, int]:
|
||||
"""
|
||||
Get the cached object by key if it exists.
|
||||
"""
|
||||
address, monotonic_id = self.key_index[key]
|
||||
self.increment_writer_flag(monotonic_id)
|
||||
return address, monotonic_id
|
||||
|
||||
def put(self, key: str, value: Any) -> tuple[int, int]:
|
||||
"""
|
||||
Store a key-value pair in the object storage.
|
||||
Attempts to free max_object_size bytes using FIFO order
|
||||
when the ring buffer runs out of space during a put() operation.
|
||||
|
||||
Args:
|
||||
key: String key to identify the object
|
||||
value: Any serializable Python object
|
||||
|
||||
Raises:
|
||||
MemoryError: If there's not enough space in the buffer
|
||||
ValueError: If the serialized object is too large
|
||||
ValueError: If the key already exists in the storage
|
||||
"""
|
||||
if key in self.key_index:
|
||||
raise ValueError(f"Key '{key}' already exists in the storage.")
|
||||
|
||||
object_data, data_bytes, object_metadata, md_bytes = \
|
||||
self.ser_de.serialize(value)
|
||||
buffer_size = self.flag_bytes + data_bytes + md_bytes
|
||||
|
||||
# Sanity checks
|
||||
if buffer_size > self.max_object_size:
|
||||
raise ValueError(
|
||||
f"Serialized object size ({buffer_size} bytes) exceeds "
|
||||
f"max object size ({self.max_object_size} bytes)")
|
||||
|
||||
# Allocate new buffer
|
||||
try:
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size)
|
||||
except MemoryError:
|
||||
self.free_unused()
|
||||
# try again after freeing up space
|
||||
address, monotonic_id = self.ring_buffer.allocate_buf(buffer_size)
|
||||
|
||||
# Write data to buffer
|
||||
with self.ring_buffer.access_buf(address) as (data_view, metadata):
|
||||
data_view[:self.flag_bytes] = self.ring_buffer.int2byte(0)
|
||||
self.copy_to_buffer(object_data, data_bytes, object_metadata,
|
||||
md_bytes, data_view)
|
||||
self.increment_writer_flag(monotonic_id)
|
||||
|
||||
# Update key index
|
||||
self.key_index[key] = (address, monotonic_id)
|
||||
self.id_index[monotonic_id] = key
|
||||
return address, monotonic_id
|
||||
|
||||
def get(self, address: int, monotonic_id: int) -> Any:
|
||||
# Read data from buffer
|
||||
with self.ring_buffer.access_buf(address) as (data_view, buf_metadata):
|
||||
# check id from metadata
|
||||
if buf_metadata[0] != monotonic_id:
|
||||
raise ValueError(
|
||||
f"Data for address:id '{address}:{monotonic_id}'"
|
||||
" has been modified or is invalid.")
|
||||
|
||||
obj = self.ser_de.deserialize(data_view[self.flag_bytes:])
|
||||
|
||||
# decrease the in-use flag for reader reads
|
||||
if self._reader_lock is not None:
|
||||
with self._reader_lock:
|
||||
self.increment_reader_flag(data_view[:self.flag_bytes])
|
||||
else:
|
||||
# if self._reader_lock is None, it means we are the writer
|
||||
# in this case, we do not need to decrease the reader count
|
||||
assert self.is_writer
|
||||
|
||||
return obj
|
||||
|
||||
def handle(self):
|
||||
"""Get handle for sharing across processes."""
|
||||
return ShmObjectStorageHandle(
|
||||
max_object_size=self.max_object_size,
|
||||
n_readers=self.n_readers,
|
||||
ring_buffer_handle=self.ring_buffer.handle(),
|
||||
serde_class=self.serde_class,
|
||||
reader_lock=self._reader_lock,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_from_handle(
|
||||
handle: ShmObjectStorageHandle) -> "SingleWriterShmObjectStorage":
|
||||
logger.debug("Creating storage from handle: %s", handle)
|
||||
ring_buffer = SingleWriterShmRingBuffer(*handle.ring_buffer_handle)
|
||||
return SingleWriterShmObjectStorage(
|
||||
max_object_size=handle.max_object_size,
|
||||
n_readers=handle.n_readers,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=handle.serde_class,
|
||||
reader_lock=handle.reader_lock,
|
||||
)
|
||||
|
||||
def default_is_free_check(self, id: int, buf: memoryview) -> bool:
|
||||
"""
|
||||
Default is_free function that checks if the first 4 bytes are zero.
|
||||
This indicates that the buffer is free.
|
||||
"""
|
||||
reader_count = int.from_bytes(buf[0:4], "little", signed=True)
|
||||
writer_count = self.writer_flag[id]
|
||||
return reader_count >= writer_count * self.n_readers
|
@ -27,8 +27,8 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
DistributedExecutorBackend, EPLBConfig,
|
||||
GuidedDecodingBackend, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||
ModelDType, ModelImpl, MultiModalConfig,
|
||||
LoRAConfig, MambaDType, MMCacheType, MMEncoderTPMode,
|
||||
ModelConfig, ModelDType, ModelImpl, MultiModalConfig,
|
||||
ObservabilityConfig, ParallelConfig, PoolerConfig,
|
||||
PrefixCachingHashAlgo, RunnerOption, SchedulerConfig,
|
||||
SchedulerPolicy, SpeculativeConfig, TaskOption,
|
||||
@ -373,6 +373,10 @@ class EngineArgs:
|
||||
MultiModalConfig.mm_processor_kwargs
|
||||
disable_mm_preprocessor_cache: bool = False # DEPRECATED
|
||||
mm_processor_cache_gb: float = MultiModalConfig.mm_processor_cache_gb
|
||||
mm_processor_cache_type: Optional[MMCacheType] = \
|
||||
MultiModalConfig.mm_processor_cache_type
|
||||
mm_shm_cache_max_object_size_mb: int = \
|
||||
MultiModalConfig.mm_shm_cache_max_object_size_mb
|
||||
mm_encoder_tp_mode: MMEncoderTPMode = MultiModalConfig.mm_encoder_tp_mode
|
||||
io_processor_plugin: Optional[str] = None
|
||||
skip_mm_profiling: bool = MultiModalConfig.skip_mm_profiling
|
||||
@ -782,6 +786,12 @@ class EngineArgs:
|
||||
multimodal_group.add_argument("--disable-mm-preprocessor-cache",
|
||||
action="store_true",
|
||||
deprecated=True)
|
||||
multimodal_group.add_argument(
|
||||
"--mm-processor-cache-type",
|
||||
**multimodal_kwargs["mm_processor_cache_type"])
|
||||
multimodal_group.add_argument(
|
||||
"--mm-shm-cache-max-object-size-mb",
|
||||
**multimodal_kwargs["mm_shm_cache_max_object_size_mb"])
|
||||
multimodal_group.add_argument(
|
||||
"--mm-encoder-tp-mode", **multimodal_kwargs["mm_encoder_tp_mode"])
|
||||
multimodal_group.add_argument(
|
||||
@ -998,6 +1008,9 @@ class EngineArgs:
|
||||
config_format=self.config_format,
|
||||
mm_processor_kwargs=self.mm_processor_kwargs,
|
||||
mm_processor_cache_gb=self.mm_processor_cache_gb,
|
||||
mm_processor_cache_type=self.mm_processor_cache_type,
|
||||
mm_shm_cache_max_object_size_mb=self.
|
||||
mm_shm_cache_max_object_size_mb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
|
@ -175,6 +175,7 @@ if TYPE_CHECKING:
|
||||
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
|
||||
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
|
||||
VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True
|
||||
VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME: str = "VLLM_OBJECT_STORAGE_SHM_BUFFER"
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -1241,6 +1242,12 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
# raw bytes. Defaults to True for backward compatibility.
|
||||
"VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES":
|
||||
lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))),
|
||||
|
||||
# Name of the shared memory buffer used for object storage.
|
||||
# Only effective when mm_config.mm_processor_cache_type == "shm".
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME":
|
||||
lambda: os.getenv("VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME",
|
||||
"VLLM_OBJECT_STORAGE_SHM_BUFFER"),
|
||||
}
|
||||
|
||||
# --8<-- [end:env-vars-definition]
|
||||
|
@ -2,6 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import os
|
||||
from multiprocessing import Lock
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@ -10,9 +11,12 @@ import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (get_distributed_init_method, get_ip, get_open_port,
|
||||
run_method)
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -44,6 +48,8 @@ class UniProcExecutor(ExecutorBase):
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
@ -55,6 +61,8 @@ class UniProcExecutor(ExecutorBase):
|
||||
kwargs: Optional[Dict] = None) -> List[Any]:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if self.mm_receiver_cache is not None and method == "execute_model":
|
||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||
answer = run_method(self.driver_worker, method, args, kwargs)
|
||||
return [answer]
|
||||
|
||||
@ -128,6 +136,8 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=is_driver_worker,
|
||||
)
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
self.vllm_config, MULTIMODAL_REGISTRY, Lock())
|
||||
self.collective_rpc("init_worker", args=([kwargs], ))
|
||||
self.collective_rpc("init_device")
|
||||
self.collective_rpc("load_model")
|
||||
|
@ -3,19 +3,24 @@
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from typing import TYPE_CHECKING, Generic, Optional, TypeVar, Union, cast
|
||||
|
||||
import torch
|
||||
from typing_extensions import TypeAlias, override
|
||||
|
||||
from vllm.distributed.device_communicators.shm_object_storage import (
|
||||
MsgpackSerde, SingleWriterShmObjectStorage, SingleWriterShmRingBuffer)
|
||||
from vllm.envs import VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import GiB_bytes, LRUCache
|
||||
from vllm.utils import GiB_bytes, LRUCache, MiB_bytes
|
||||
from vllm.utils.jsontree import (json_count_leaves, json_map_leaves,
|
||||
json_reduce_leaves)
|
||||
|
||||
from .inputs import (MultiModalFeatureSpec, MultiModalFieldElem,
|
||||
MultiModalKwargs, MultiModalKwargsItem,
|
||||
MultiModalKwargsItems, NestedTensors)
|
||||
from .inputs import (MultiModalBatchedField, MultiModalFeatureSpec,
|
||||
MultiModalFieldElem, MultiModalKwargs,
|
||||
MultiModalKwargsItem, MultiModalKwargsItems,
|
||||
NestedTensors)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
@ -389,6 +394,106 @@ class MultiModalProcessorSenderCache(BaseMultiModalProcessorCache):
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
class ShmObjectStoreSenderCache(BaseMultiModalProcessorCache):
|
||||
"""
|
||||
The cache which is used on P0 when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item is already in the cache, clear the input to avoid
|
||||
unnecessary IPC.
|
||||
|
||||
- If the item is not in the cache, store the data in shared memory.
|
||||
"""
|
||||
|
||||
def __init__(self, vllm_config: "VllmConfig") -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=True, # sender is the writer
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb *
|
||||
MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
)
|
||||
# cache (prompt_updates, modality) for P0 only
|
||||
self._p0_cache: dict[str, tuple[Sequence[ResolvedPromptUpdate],
|
||||
str]] = {}
|
||||
|
||||
@override
|
||||
def is_cached_item(self, mm_hash: str) -> bool:
|
||||
return self._shm_cache.is_cached(mm_hash)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: MultiModalProcessorCacheInItem,
|
||||
mm_hash: str,
|
||||
) -> MultiModalProcessorCacheOutItem:
|
||||
|
||||
if self._shm_cache.is_cached(mm_hash):
|
||||
address, monotonic_id = self._shm_cache.get_cached(mm_hash)
|
||||
prompt_updates, modality = self._p0_cache[mm_hash]
|
||||
return self.address_as_item(address, monotonic_id,
|
||||
modality), prompt_updates
|
||||
|
||||
assert mm_item is not None, f"Expected a cached item for {mm_hash=}"
|
||||
|
||||
try:
|
||||
address, monotonic_id = self._shm_cache.put(mm_hash, mm_item[0])
|
||||
# Try to remove dangling items if p0 cache is too large.
|
||||
if len(self._p0_cache) >= 2 * len(self._shm_cache.key_index):
|
||||
self.remove_dangling_items()
|
||||
self._p0_cache[mm_hash] = mm_item[1], mm_item[0].modality
|
||||
address_item = self.address_as_item(address, monotonic_id,
|
||||
mm_item[0].modality)
|
||||
return address_item, mm_item[1]
|
||||
except (ValueError, MemoryError) as e:
|
||||
# put may fail if the object is too large or
|
||||
# the cache is full.
|
||||
# In this case we log the error and keep the original mm_input.
|
||||
logger.debug("Failed to cache mm_input with hash %s: %s", mm_hash,
|
||||
e)
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
self._p0_cache.clear()
|
||||
|
||||
def remove_dangling_items(self) -> None:
|
||||
"""Remove items that are no longer in the shared memory cache."""
|
||||
cached_hashes = self._shm_cache.key_index.keys()
|
||||
dangling_hashes = set(self._p0_cache.keys()) - cached_hashes
|
||||
for mm_hash in dangling_hashes:
|
||||
del self._p0_cache[mm_hash]
|
||||
|
||||
def address_as_item(self, address: int, monotonic_id: int,
|
||||
modality: str) -> MultiModalKwargsItem:
|
||||
addr_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="address",
|
||||
data=address,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
id_elem = MultiModalFieldElem(
|
||||
modality=modality,
|
||||
key="monotonic_id",
|
||||
data=monotonic_id,
|
||||
field=MultiModalBatchedField(),
|
||||
)
|
||||
mm_item = MultiModalKwargsItem.from_elems([addr_elem, id_elem])
|
||||
return mm_item
|
||||
|
||||
|
||||
def _enable_processor_cache(
|
||||
model_config: "ModelConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
@ -408,6 +513,17 @@ def _enable_ipc_cache(vllm_config: "VllmConfig") -> bool:
|
||||
return supports_ipc_cache
|
||||
|
||||
|
||||
def _enable_mm_input_shm_cache(vllm_config: "VllmConfig") -> bool:
|
||||
"""Whether the shared memory based cache should be enabled."""
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return False
|
||||
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
return mm_config.mm_processor_cache_type == "shm"
|
||||
|
||||
|
||||
def processor_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
@ -421,7 +537,9 @@ def processor_cache_from_config(
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return MultiModalProcessorOnlyCache(model_config)
|
||||
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalProcessorSenderCache(model_config)
|
||||
return ShmObjectStoreSenderCache(vllm_config)
|
||||
|
||||
|
||||
def processor_only_cache_from_config(
|
||||
@ -491,11 +609,68 @@ class MultiModalReceiverCache(BaseMultiModalReceiverCache):
|
||||
self._cache.clear()
|
||||
|
||||
|
||||
def receiver_cache_from_config(
|
||||
class ShmObjectStoreReceiverCache(BaseMultiModalReceiverCache):
|
||||
"""
|
||||
The cache which is used on P1 Worker Process when IPC caching is enabled.
|
||||
|
||||
How to update each item:
|
||||
|
||||
- If the item has an address, replace the input with the cached item.
|
||||
- If not, return the input.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: "VllmConfig",
|
||||
shared_worker_lock: LockType,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.world_size = vllm_config.parallel_config.world_size
|
||||
mm_config = vllm_config.model_config.get_multimodal_config()
|
||||
|
||||
ring_buffer = SingleWriterShmRingBuffer(
|
||||
data_buffer_size=int(mm_config.mm_processor_cache_gb * GiB_bytes),
|
||||
name=VLLM_OBJECT_STORAGE_SHM_BUFFER_NAME,
|
||||
create=False, # Server is a reader
|
||||
)
|
||||
self._shm_cache = SingleWriterShmObjectStorage(
|
||||
max_object_size=mm_config.mm_shm_cache_max_object_size_mb *
|
||||
MiB_bytes,
|
||||
n_readers=self.world_size,
|
||||
ring_buffer=ring_buffer,
|
||||
serde_class=MsgpackSerde,
|
||||
reader_lock=shared_worker_lock,
|
||||
)
|
||||
|
||||
@override
|
||||
def get_and_update_item(
|
||||
self,
|
||||
mm_item: Optional[MultiModalKwargsItem],
|
||||
mm_hash: str,
|
||||
) -> MultiModalKwargsItem:
|
||||
assert mm_item is not None, f"Expected an address item for {mm_hash=}"
|
||||
if "address" in mm_item:
|
||||
address = cast(int, mm_item["address"].data)
|
||||
monotonic_id = cast(int, mm_item["monotonic_id"].data)
|
||||
return self._shm_cache.get(address, monotonic_id)
|
||||
|
||||
return mm_item
|
||||
|
||||
@override
|
||||
def clear_cache(self) -> None:
|
||||
self._shm_cache.clear()
|
||||
|
||||
|
||||
def engine_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
) -> Optional[BaseMultiModalReceiverCache]:
|
||||
"""Return a `BaseMultiModalReceiverCache`, if enabled."""
|
||||
"""
|
||||
This is used in the engine process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="lru".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
@ -504,4 +679,31 @@ def receiver_cache_from_config(
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return MultiModalReceiverCache(model_config)
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return MultiModalReceiverCache(model_config)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def worker_receiver_cache_from_config(
|
||||
vllm_config: "VllmConfig",
|
||||
mm_registry: "MultiModalRegistry",
|
||||
shared_worker_lock: LockType,
|
||||
) -> Optional[BaseMultiModalReceiverCache]:
|
||||
"""
|
||||
This is used in the worker process.
|
||||
Return a `BaseMultiModalReceiverCache` only when IPC caching is enabled and
|
||||
mm_processor_cache_type=="shm".
|
||||
"""
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if not _enable_processor_cache(model_config, mm_registry):
|
||||
return None
|
||||
|
||||
if not _enable_ipc_cache(vllm_config):
|
||||
return None
|
||||
|
||||
if not _enable_mm_input_shm_cache(vllm_config):
|
||||
return None
|
||||
|
||||
return ShmObjectStoreReceiverCache(vllm_config, shared_worker_lock)
|
||||
|
@ -163,6 +163,12 @@ STR_FLASH_ATTN_VAL: str = "FLASH_ATTN"
|
||||
STR_DUAL_CHUNK_FLASH_ATTN_VAL: str = "DUAL_CHUNK_FLASH_ATTN"
|
||||
STR_INVALID_VAL: str = "INVALID"
|
||||
|
||||
MB_bytes = 1_000_000
|
||||
"""The number of bytes in one megabyte (MB)."""
|
||||
|
||||
MiB_bytes = 1 << 20
|
||||
"""The number of bytes in one mebibyte (MiB)."""
|
||||
|
||||
GB_bytes = 1_000_000_000
|
||||
"""The number of bytes in one gigabyte (GB)."""
|
||||
|
||||
|
@ -23,7 +23,7 @@ from vllm.logger import init_logger
|
||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import receiver_cache_from_config
|
||||
from vllm.multimodal.cache import engine_receiver_cache_from_config
|
||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import (
|
||||
maybe_register_config_serialize_by_value)
|
||||
@ -131,7 +131,7 @@ class EngineCore:
|
||||
self.use_spec_decode = vllm_config.speculative_config is not None
|
||||
|
||||
self.mm_registry = mm_registry = MULTIMODAL_REGISTRY
|
||||
self.mm_receiver_cache = receiver_cache_from_config(
|
||||
self.mm_receiver_cache = engine_receiver_cache_from_config(
|
||||
vllm_config, mm_registry)
|
||||
|
||||
# Setup batch queue for pipeline parallelism.
|
||||
|
@ -14,6 +14,7 @@ from enum import Enum, auto
|
||||
from functools import partial
|
||||
from multiprocessing.connection import Connection
|
||||
from multiprocessing.process import BaseProcess
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, Optional, Union, cast
|
||||
|
||||
@ -31,10 +32,13 @@ from vllm.distributed.parallel_state import (get_dp_group, get_ep_group,
|
||||
from vllm.executor.multiproc_worker_utils import (
|
||||
set_multiprocessing_worker_envs)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
from vllm.multimodal.cache import worker_receiver_cache_from_config
|
||||
from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||
get_loopback_ip, get_mp_context, get_open_port,
|
||||
set_process_title)
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
from vllm.v1.executor.utils import get_and_update_mm_cache
|
||||
from vllm.v1.outputs import (AsyncModelRunnerOutput, DraftTokenIds,
|
||||
ModelRunnerOutput)
|
||||
from vllm.worker.worker_base import WorkerWrapperBase
|
||||
@ -81,6 +85,8 @@ class MultiprocExecutor(Executor):
|
||||
scheduler_output_handle = self.rpc_broadcast_mq.export_handle()
|
||||
|
||||
# Create workers
|
||||
context = get_mp_context()
|
||||
shared_worker_lock = context.Lock()
|
||||
unready_workers: list[UnreadyWorkerProcHandle] = []
|
||||
success = False
|
||||
try:
|
||||
@ -92,6 +98,7 @@ class MultiprocExecutor(Executor):
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
input_shm_handle=scheduler_output_handle,
|
||||
shared_worker_lock=shared_worker_lock,
|
||||
))
|
||||
|
||||
# Workers must be created before wait_for_ready to avoid
|
||||
@ -380,6 +387,7 @@ class WorkerProc:
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle: Handle,
|
||||
shared_worker_lock: LockType,
|
||||
):
|
||||
self.rank = rank
|
||||
wrapper = WorkerWrapperBase(vllm_config=vllm_config, rpc_rank=rank)
|
||||
@ -416,6 +424,10 @@ class WorkerProc:
|
||||
name="WorkerAsyncOutputCopy")
|
||||
self.async_output_copy_thread.start()
|
||||
|
||||
# Initialize multimodal receiver cache if needed
|
||||
self.mm_receiver_cache = worker_receiver_cache_from_config(
|
||||
vllm_config, MULTIMODAL_REGISTRY, shared_worker_lock)
|
||||
|
||||
# Initialize device
|
||||
self.worker.init_device()
|
||||
|
||||
@ -428,11 +440,12 @@ class WorkerProc:
|
||||
|
||||
@staticmethod
|
||||
def make_worker_process(
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
vllm_config: VllmConfig,
|
||||
local_rank: int,
|
||||
rank: int,
|
||||
distributed_init_method: str,
|
||||
input_shm_handle, # Receive SchedulerOutput
|
||||
shared_worker_lock: LockType,
|
||||
) -> UnreadyWorkerProcHandle:
|
||||
context = get_mp_context()
|
||||
# (reader, writer)
|
||||
@ -449,6 +462,7 @@ class WorkerProc:
|
||||
"input_shm_handle": input_shm_handle,
|
||||
"ready_pipe": (reader, writer),
|
||||
"death_pipe": death_reader,
|
||||
"shared_worker_lock": shared_worker_lock,
|
||||
}
|
||||
# Run EngineCore busy loop in background process.
|
||||
proc = context.Process(target=WorkerProc.worker_main,
|
||||
@ -646,6 +660,10 @@ class WorkerProc:
|
||||
func = getattr(self.worker, method)
|
||||
elif isinstance(method, bytes):
|
||||
func = partial(cloudpickle.loads(method), self.worker)
|
||||
# retrieve from shm cache if available
|
||||
if self.mm_receiver_cache is not None \
|
||||
and func.__name__ == "execute_model":
|
||||
get_and_update_mm_cache(self.mm_receiver_cache, args)
|
||||
output = func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# Notes have been introduced in python 3.11
|
||||
|
25
vllm/v1/executor/utils.py
Normal file
25
vllm/v1/executor/utils.py
Normal file
@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from vllm.multimodal.cache import ShmObjectStoreReceiverCache
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
||||
def get_and_update_mm_cache(
|
||||
receiver_cache: ShmObjectStoreReceiverCache,
|
||||
args: tuple[SchedulerOutput],
|
||||
) -> None:
|
||||
"""
|
||||
For each MultiModalKwargsItem in SchedulerOutput, fetch from shared memory
|
||||
cache as needed.
|
||||
|
||||
Args:
|
||||
receiver_cache: The receiver cache to update.
|
||||
args: According to the collective_rpc call of execute_model method in
|
||||
executor, args is a tuple of only one SchedulerOutput element.
|
||||
"""
|
||||
scheduler_output = args[0]
|
||||
for request_data in scheduler_output.scheduled_new_reqs:
|
||||
for i in range(len(request_data.mm_kwargs)):
|
||||
mm_input = request_data.mm_kwargs[i]
|
||||
request_data.mm_kwargs[i] = \
|
||||
receiver_cache.get_and_update_item(mm_input, None)
|
Reference in New Issue
Block a user