mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -32,6 +32,7 @@ class MyType:
|
||||
large_f_contig_tensor: torch.Tensor
|
||||
small_non_contig_tensor: torch.Tensor
|
||||
large_non_contig_tensor: torch.Tensor
|
||||
empty_tensor: torch.Tensor
|
||||
|
||||
|
||||
def test_encode_decode():
|
||||
@ -58,6 +59,7 @@ def test_encode_decode():
|
||||
large_f_contig_tensor=torch.rand(1024, 4).t(),
|
||||
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
|
||||
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
|
||||
empty_tensor=torch.empty(0),
|
||||
)
|
||||
|
||||
encoder = MsgpackEncoder(size_threshold=256)
|
||||
@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
|
||||
obj2.small_non_contig_tensor)
|
||||
assert torch.equal(obj1.large_non_contig_tensor,
|
||||
obj2.large_non_contig_tensor)
|
||||
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)
|
||||
|
@ -5,6 +5,7 @@ import signal
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
from concurrent.futures import Future
|
||||
from inspect import isclass, signature
|
||||
from logging import DEBUG
|
||||
@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
|
||||
|
||||
# Msgpack serialization encoding.
|
||||
encoder = MsgpackEncoder()
|
||||
# Reuse send buffer.
|
||||
buffer = bytearray()
|
||||
# Send buffers to reuse.
|
||||
reuse_buffers: list[bytearray] = []
|
||||
# Keep references to outputs and buffers until zmq is finished
|
||||
# with them (outputs may contain tensors/np arrays whose
|
||||
# backing buffers were extracted for zero-copy send).
|
||||
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
|
||||
|
||||
# We must set linger to ensure the ENGINE_CORE_DEAD
|
||||
# message is sent prior to closing the socket.
|
||||
@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
|
||||
break
|
||||
assert not isinstance(outputs, bytes)
|
||||
outputs.engine_index = engine_index
|
||||
|
||||
# Reclaim buffers that zmq is finished with.
|
||||
while pending and pending[-1][0].done:
|
||||
reuse_buffers.append(pending.pop()[2])
|
||||
|
||||
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
|
||||
buffers = encoder.encode_into(outputs, buffer)
|
||||
socket.send_multipart(buffers, copy=False)
|
||||
tracker = socket.send_multipart(buffers,
|
||||
copy=False,
|
||||
track=True)
|
||||
if not tracker.done:
|
||||
ref = outputs if len(buffers) > 1 else None
|
||||
pending.appendleft((tracker, ref, buffer))
|
||||
elif len(reuse_buffers) < 2:
|
||||
# Keep at most 2 buffers to reuse.
|
||||
reuse_buffers.append(buffer)
|
||||
|
||||
|
||||
class DPEngineCoreProc(EngineCoreProc):
|
||||
|
@ -1,9 +1,11 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import asyncio
|
||||
import contextlib
|
||||
import queue
|
||||
import uuid
|
||||
import weakref
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import deque
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from concurrent.futures import Future
|
||||
from dataclasses import dataclass, field
|
||||
@ -396,6 +398,12 @@ class MPClient(EngineCoreClient):
|
||||
self._wait_for_engine_startup()
|
||||
|
||||
self.utility_results: dict[int, AnyFuture] = {}
|
||||
|
||||
# Request objects which may contain pytorch-allocated tensors
|
||||
# that we need to keep references to until zmq is done with the
|
||||
# underlying data.
|
||||
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
|
||||
|
||||
success = True
|
||||
finally:
|
||||
if not success:
|
||||
@ -459,6 +467,14 @@ class MPClient(EngineCoreClient):
|
||||
if self.resources.engine_dead:
|
||||
raise EngineDeadError()
|
||||
|
||||
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
|
||||
if not tracker.done:
|
||||
self.pending_messages.appendleft((tracker, msg))
|
||||
|
||||
def free_pending_messages(self):
|
||||
while self.pending_messages and self.pending_messages[-1][0].done:
|
||||
self.pending_messages.pop()
|
||||
|
||||
|
||||
def _process_utility_output(output: UtilityOutput,
|
||||
utility_results: dict[int, AnyFuture]):
|
||||
@ -544,10 +560,18 @@ class SyncMPClient(MPClient):
|
||||
|
||||
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
|
||||
self.ensure_alive()
|
||||
self.free_pending_messages()
|
||||
# (Identity, RequestType, SerializedRequest)
|
||||
msg = (self.core_engine.identity, request_type.value,
|
||||
*self.encoder.encode(request))
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
if len(msg) <= 3:
|
||||
# No auxiliary buffers => no tensor backing buffers in request.
|
||||
self.input_socket.send_multipart(msg, copy=False)
|
||||
return
|
||||
|
||||
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
|
||||
self.add_pending_message(tracker, request)
|
||||
|
||||
def call_utility(self, method: str, *args) -> Any:
|
||||
call_id = uuid.uuid1().int >> 64
|
||||
@ -698,19 +722,38 @@ class AsyncMPClient(MPClient):
|
||||
def _send_input(self,
|
||||
request_type: EngineCoreRequestType,
|
||||
request: Any,
|
||||
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
|
||||
engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
|
||||
self.ensure_alive()
|
||||
if engine is None:
|
||||
engine = self.core_engine
|
||||
|
||||
message = (request_type.value, *self.encoder.encode(request))
|
||||
return self._send_input_message(message, engine)
|
||||
return self._send_input_message(message, engine, request)
|
||||
|
||||
def _send_input_message(self, message: tuple[bytestr, ...],
|
||||
engine: CoreEngine) -> Awaitable[None]:
|
||||
def _send_input_message(self, message: tuple[bytestr,
|
||||
...], engine: CoreEngine,
|
||||
objects: Any) -> Awaitable[Any]:
|
||||
"""
|
||||
objects is a reference to retain until zmq is finished with the
|
||||
buffers, in case they were extracted from tensors in the request.
|
||||
"""
|
||||
self.ensure_alive()
|
||||
message = (engine.identity, ) + message
|
||||
return self.input_socket.send_multipart(message, copy=False)
|
||||
self.free_pending_messages()
|
||||
|
||||
msg = (engine.identity, ) + message
|
||||
if not objects or len(msg) <= 3:
|
||||
# No auxiliary buffers => no tensor backing buffers in request.
|
||||
return self.input_socket.send_multipart(msg, copy=False)
|
||||
|
||||
future: asyncio.Future[zmq.MessageTracker]
|
||||
future = self.input_socket.send_multipart(msg, copy=False, track=True)
|
||||
|
||||
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
|
||||
with contextlib.suppress(BaseException):
|
||||
self.add_pending_message(f.result(), objects)
|
||||
|
||||
future.add_done_callback(add_pending)
|
||||
return future
|
||||
|
||||
async def call_utility_async(self, method: str, *args) -> Any:
|
||||
return await self._call_utility_async(method,
|
||||
@ -724,7 +767,7 @@ class AsyncMPClient(MPClient):
|
||||
self.utility_results[call_id] = future
|
||||
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
|
||||
(call_id, method, args)))
|
||||
await self._send_input_message(message, engine)
|
||||
await self._send_input_message(message, engine, args)
|
||||
self._ensure_output_queue_task()
|
||||
return await future
|
||||
|
||||
|
Reference in New Issue
Block a user