[P/D][BugFix]Mooncake timeout release bug fix (#2899)

### What this PR does / why we need it?
In the P node timeout release mechanism during PD separation, the req_id
that requires timeout release is transmitted from the scheduler to the
worker. If the KV cache between PDs is transferred too quickly, the P
node's req_id may be released twice. The first release is when the D
node notifies the P node that the KV cache has been pulled, and the
second release is when the scheduler transmits the timeout release to
the worker.

To address this bug, an intermediate component is introduced to manage
the release of req_ids.

Pull kv and forward2 may occur one after the other in timing. The
previous timeout defaulted to forward2 being before pull_kv.


### How was this patch tested?

- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9

---------

Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
baxingpiaochong
2025-09-24 11:22:46 +08:00
committed by GitHub
parent 6995a7bc5b
commit eb205d9f35
2 changed files with 43 additions and 11 deletions

View File

@ -7,6 +7,7 @@ import time
import types
import unittest
from collections import defaultdict, deque
from typing import OrderedDict
from unittest.mock import MagicMock, patch
import msgspec
@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase):
tracker = KVCacheTaskTracker()
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
self.assertIsInstance(tracker.finished_requests, set)
self.assertIsInstance(tracker.delayed_free_requests, deque)
self.assertIsInstance(tracker.delayed_free_requests, OrderedDict)
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase):
def test_update_done_task_count(self):
self.assertEqual(len(self.tracker.finished_requests), 0)
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
self.assertEqual(len(self.tracker.record_finished_requests), 0)
current_time = time.time()
self.tracker.add_delayed_request("req_1", current_time)
result = self.tracker.delayed_free_requests
result_record = self.tracker.record_finished_requests
self.assertEqual(len(result), 1)
self.assertEqual(result[0], ("req_1", current_time))
self.assertEqual(result["req_1"], current_time)
self.assertEqual(len(result_record), 0)
self.tracker.update_done_task_count("req_1")
result_finished = self.tracker.finished_requests
result_delayed = self.tracker.delayed_free_requests
result_record = self.tracker.record_finished_requests
self.assertEqual(result_finished, {"req_1"})
self.assertEqual(len(result_delayed), 0)
self.assertEqual(len(result_record), 0)
self.tracker.update_done_task_count("req_2")
result_finished = self.tracker.finished_requests
result_delayed = self.tracker.delayed_free_requests
result_record = self.tracker.record_finished_requests
self.assertEqual(result_finished, {"req_1", "req_2"})
self.assertEqual(len(result_delayed), 0)
self.assertEqual(len(result_record), 1)
self.assertEqual(result_record, {"req_2"})
def test_updtate_add_delayed_request(self) -> None:
self.tracker.update_done_task_count("req2")
result_start_record = self.tracker.record_finished_requests
self.assertEqual(len(result_start_record), 1)
self.tracker.add_delayed_request("req2", time.time())
result_delayed = self.tracker.delayed_free_requests
result_end_record = self.tracker.record_finished_requests
self.assertEqual(len(result_delayed), 0)
self.assertEqual(len(result_end_record), 0)
def test_retrieve_expired_requests(self):
current_time = time.time()
@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase):
})
result_delay = self.tracker.delayed_free_requests
self.assertEqual(len(result_delay), 1)
self.assertEqual(result_delay[0], ("req_2", current_time))
self.assertIn("req_2", result_delay)
def test_duplicate_task_update(self):
self.tracker.update_done_task_count("req1")

View File

@ -11,7 +11,7 @@ from collections import defaultdict, deque
from collections.abc import Iterator
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
import msgspec
import numpy as np
@ -68,12 +68,16 @@ class KVCacheTaskTracker:
# intentionally delayed. Each entry is a tuple of (request_id,
# timestamp). If a request remains in this queue for too long, it will
# be force-freed.
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
self.record_finished_requests: set[str] = set()
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
def update_done_task_count(self, request_id: str):
with self.done_task_lock:
self.finished_requests.add(request_id)
self._remove_delayed_requests(request_id)
if request_id in self.delayed_free_requests:
self._remove_delayed_requests(request_id)
else:
self.record_finished_requests.add(request_id)
def get_and_clear_finished_requests(self) -> set[str]:
"""
@ -91,7 +95,10 @@ class KVCacheTaskTracker:
def add_delayed_request(self, request_id: str, delay_start_time: float):
"""Add a delayed free request."""
with self.done_task_lock:
self.delayed_free_requests.append((request_id, delay_start_time))
if request_id not in self.record_finished_requests:
self.delayed_free_requests[request_id] = delay_start_time
else:
self.record_finished_requests.discard(request_id)
def _retrieve_expired_requests(self):
"""Retrieve all expired delayed requests."""
@ -99,10 +106,11 @@ class KVCacheTaskTracker:
# Free delayed requests if they exceed the timeout
current_time = time.time()
while self.delayed_free_requests:
request_id, delay_start_time = self.delayed_free_requests[0]
request_id = next(iter(self.delayed_free_requests))
delay_start_time = self.delayed_free_requests[request_id]
if (current_time - delay_start_time
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
self.delayed_free_requests.popleft()
self.delayed_free_requests.popitem(last=False)
expired_requests.add(request_id)
logger.info("Force freed request: %s", request_id)
else:
@ -111,8 +119,7 @@ class KVCacheTaskTracker:
def _remove_delayed_requests(self, request_id: str):
"""Remove all delayed free requests matching the given request_id."""
self.delayed_free_requests = deque(
(r, t) for r, t in self.delayed_free_requests if r != request_id)
self.delayed_free_requests.pop(request_id)
class KVCacheSendingThread(threading.Thread):