[BugFix] Avoid race conditions in zero-copy tensor transmission (#17203)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-04-25 23:00:07 -07:00
committed by GitHub
parent 53e8cf53a4
commit b07bf83c7d
3 changed files with 76 additions and 11 deletions

View File

@ -32,6 +32,7 @@ class MyType:
large_f_contig_tensor: torch.Tensor
small_non_contig_tensor: torch.Tensor
large_non_contig_tensor: torch.Tensor
empty_tensor: torch.Tensor
def test_encode_decode():
@ -58,6 +59,7 @@ def test_encode_decode():
large_f_contig_tensor=torch.rand(1024, 4).t(),
small_non_contig_tensor=torch.rand(2, 4)[:, 1:3],
large_non_contig_tensor=torch.rand(1024, 512)[:, 10:20],
empty_tensor=torch.empty(0),
)
encoder = MsgpackEncoder(size_threshold=256)
@ -193,3 +195,4 @@ def assert_equal(obj1: MyType, obj2: MyType):
obj2.small_non_contig_tensor)
assert torch.equal(obj1.large_non_contig_tensor,
obj2.large_non_contig_tensor)
assert torch.equal(obj1.empty_tensor, obj2.empty_tensor)

View File

@ -5,6 +5,7 @@ import signal
import sys
import threading
import time
from collections import deque
from concurrent.futures import Future
from inspect import isclass, signature
from logging import DEBUG
@ -527,8 +528,12 @@ class EngineCoreProc(EngineCore):
# Msgpack serialization encoding.
encoder = MsgpackEncoder()
# Reuse send buffer.
buffer = bytearray()
# Send buffers to reuse.
reuse_buffers: list[bytearray] = []
# Keep references to outputs and buffers until zmq is finished
# with them (outputs may contain tensors/np arrays whose
# backing buffers were extracted for zero-copy send).
pending = deque[tuple[zmq.MessageTracker, Any, bytearray]]()
# We must set linger to ensure the ENGINE_CORE_DEAD
# message is sent prior to closing the socket.
@ -541,8 +546,22 @@ class EngineCoreProc(EngineCore):
break
assert not isinstance(outputs, bytes)
outputs.engine_index = engine_index
# Reclaim buffers that zmq is finished with.
while pending and pending[-1][0].done:
reuse_buffers.append(pending.pop()[2])
buffer = reuse_buffers.pop() if reuse_buffers else bytearray()
buffers = encoder.encode_into(outputs, buffer)
socket.send_multipart(buffers, copy=False)
tracker = socket.send_multipart(buffers,
copy=False,
track=True)
if not tracker.done:
ref = outputs if len(buffers) > 1 else None
pending.appendleft((tracker, ref, buffer))
elif len(reuse_buffers) < 2:
# Keep at most 2 buffers to reuse.
reuse_buffers.append(buffer)
class DPEngineCoreProc(EngineCoreProc):

View File

@ -1,9 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
import asyncio
import contextlib
import queue
import uuid
import weakref
from abc import ABC, abstractmethod
from collections import deque
from collections.abc import Awaitable, Sequence
from concurrent.futures import Future
from dataclasses import dataclass, field
@ -396,6 +398,12 @@ class MPClient(EngineCoreClient):
self._wait_for_engine_startup()
self.utility_results: dict[int, AnyFuture] = {}
# Request objects which may contain pytorch-allocated tensors
# that we need to keep references to until zmq is done with the
# underlying data.
self.pending_messages = deque[tuple[zmq.MessageTracker, Any]]()
success = True
finally:
if not success:
@ -459,6 +467,14 @@ class MPClient(EngineCoreClient):
if self.resources.engine_dead:
raise EngineDeadError()
def add_pending_message(self, tracker: zmq.MessageTracker, msg: Any):
if not tracker.done:
self.pending_messages.appendleft((tracker, msg))
def free_pending_messages(self):
while self.pending_messages and self.pending_messages[-1][0].done:
self.pending_messages.pop()
def _process_utility_output(output: UtilityOutput,
utility_results: dict[int, AnyFuture]):
@ -544,10 +560,18 @@ class SyncMPClient(MPClient):
def _send_input(self, request_type: EngineCoreRequestType, request: Any):
self.ensure_alive()
self.free_pending_messages()
# (Identity, RequestType, SerializedRequest)
msg = (self.core_engine.identity, request_type.value,
*self.encoder.encode(request))
self.input_socket.send_multipart(msg, copy=False)
if len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
self.input_socket.send_multipart(msg, copy=False)
return
tracker = self.input_socket.send_multipart(msg, copy=False, track=True)
self.add_pending_message(tracker, request)
def call_utility(self, method: str, *args) -> Any:
call_id = uuid.uuid1().int >> 64
@ -698,19 +722,38 @@ class AsyncMPClient(MPClient):
def _send_input(self,
request_type: EngineCoreRequestType,
request: Any,
engine: Optional[CoreEngine] = None) -> Awaitable[None]:
engine: Optional[CoreEngine] = None) -> Awaitable[Any]:
self.ensure_alive()
if engine is None:
engine = self.core_engine
message = (request_type.value, *self.encoder.encode(request))
return self._send_input_message(message, engine)
return self._send_input_message(message, engine, request)
def _send_input_message(self, message: tuple[bytestr, ...],
engine: CoreEngine) -> Awaitable[None]:
def _send_input_message(self, message: tuple[bytestr,
...], engine: CoreEngine,
objects: Any) -> Awaitable[Any]:
"""
objects is a reference to retain until zmq is finished with the
buffers, in case they were extracted from tensors in the request.
"""
self.ensure_alive()
message = (engine.identity, ) + message
return self.input_socket.send_multipart(message, copy=False)
self.free_pending_messages()
msg = (engine.identity, ) + message
if not objects or len(msg) <= 3:
# No auxiliary buffers => no tensor backing buffers in request.
return self.input_socket.send_multipart(msg, copy=False)
future: asyncio.Future[zmq.MessageTracker]
future = self.input_socket.send_multipart(msg, copy=False, track=True)
def add_pending(f: asyncio.Future[zmq.MessageTracker]):
with contextlib.suppress(BaseException):
self.add_pending_message(f.result(), objects)
future.add_done_callback(add_pending)
return future
async def call_utility_async(self, method: str, *args) -> Any:
return await self._call_utility_async(method,
@ -724,7 +767,7 @@ class AsyncMPClient(MPClient):
self.utility_results[call_id] = future
message = (EngineCoreRequestType.UTILITY.value, *self.encoder.encode(
(call_id, method, args)))
await self._send_input_message(message, engine)
await self._send_input_message(message, engine, args)
self._ensure_output_queue_task()
return await future