mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[P/D][BugFix]Mooncake timeout release bug fix (#2899)
### What this PR does / why we need it?
In the P node timeout release mechanism during PD separation, the req_id
that requires timeout release is transmitted from the scheduler to the
worker. If the KV cache between PDs is transferred too quickly, the P
node's req_id may be released twice. The first release is when the D
node notifies the P node that the KV cache has been pulled, and the
second release is when the scheduler transmits the timeout release to
the worker.
To address this bug, an intermediate component is introduced to manage
the release of req_ids.
Pull kv and forward2 may occur one after the other in timing. The
previous timeout defaulted to forward2 being before pull_kv.
### How was this patch tested?
- vLLM version: v0.10.2
- vLLM main:
f225ea7dd9
---------
Signed-off-by: baxingpiaochong <771405853@qq.com>
This commit is contained in:
@ -7,6 +7,7 @@ import time
|
||||
import types
|
||||
import unittest
|
||||
from collections import defaultdict, deque
|
||||
from typing import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import msgspec
|
||||
@ -34,7 +35,7 @@ class TestKVCacheTaskTrackerInit(unittest.TestCase):
|
||||
tracker = KVCacheTaskTracker()
|
||||
self.assertIsInstance(tracker.done_task_lock, type(threading.Lock()))
|
||||
self.assertIsInstance(tracker.finished_requests, set)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, deque)
|
||||
self.assertIsInstance(tracker.delayed_free_requests, OrderedDict)
|
||||
|
||||
|
||||
class TestGetAndClearFinishedSingleRequests(unittest.TestCase):
|
||||
@ -495,18 +496,42 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
def test_update_done_task_count(self):
|
||||
self.assertEqual(len(self.tracker.finished_requests), 0)
|
||||
self.assertEqual(len(self.tracker.delayed_free_requests), 0)
|
||||
self.assertEqual(len(self.tracker.record_finished_requests), 0)
|
||||
|
||||
current_time = time.time()
|
||||
self.tracker.add_delayed_request("req_1", current_time)
|
||||
result = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result), 1)
|
||||
self.assertEqual(result[0], ("req_1", current_time))
|
||||
self.assertEqual(result["req_1"], current_time)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_1")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 0)
|
||||
|
||||
self.tracker.update_done_task_count("req_2")
|
||||
result_finished = self.tracker.finished_requests
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(result_finished, {"req_1", "req_2"})
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_record), 1)
|
||||
self.assertEqual(result_record, {"req_2"})
|
||||
|
||||
def test_updtate_add_delayed_request(self) -> None:
|
||||
self.tracker.update_done_task_count("req2")
|
||||
result_start_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_start_record), 1)
|
||||
self.tracker.add_delayed_request("req2", time.time())
|
||||
result_delayed = self.tracker.delayed_free_requests
|
||||
result_end_record = self.tracker.record_finished_requests
|
||||
self.assertEqual(len(result_delayed), 0)
|
||||
self.assertEqual(len(result_end_record), 0)
|
||||
|
||||
def test_retrieve_expired_requests(self):
|
||||
current_time = time.time()
|
||||
@ -518,7 +543,7 @@ class TestKVCacheTaskTracker(unittest.TestCase):
|
||||
})
|
||||
result_delay = self.tracker.delayed_free_requests
|
||||
self.assertEqual(len(result_delay), 1)
|
||||
self.assertEqual(result_delay[0], ("req_2", current_time))
|
||||
self.assertIn("req_2", result_delay)
|
||||
|
||||
def test_duplicate_task_update(self):
|
||||
self.tracker.update_done_task_count("req1")
|
||||
|
@ -11,7 +11,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, List, Optional, OrderedDict, Tuple
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@ -68,12 +68,16 @@ class KVCacheTaskTracker:
|
||||
# intentionally delayed. Each entry is a tuple of (request_id,
|
||||
# timestamp). If a request remains in this queue for too long, it will
|
||||
# be force-freed.
|
||||
self.delayed_free_requests: deque[Tuple[str, float]] = deque()
|
||||
self.record_finished_requests: set[str] = set()
|
||||
self.delayed_free_requests: OrderedDict[str, float] = OrderedDict()
|
||||
|
||||
def update_done_task_count(self, request_id: str):
|
||||
with self.done_task_lock:
|
||||
self.finished_requests.add(request_id)
|
||||
self._remove_delayed_requests(request_id)
|
||||
if request_id in self.delayed_free_requests:
|
||||
self._remove_delayed_requests(request_id)
|
||||
else:
|
||||
self.record_finished_requests.add(request_id)
|
||||
|
||||
def get_and_clear_finished_requests(self) -> set[str]:
|
||||
"""
|
||||
@ -91,7 +95,10 @@ class KVCacheTaskTracker:
|
||||
def add_delayed_request(self, request_id: str, delay_start_time: float):
|
||||
"""Add a delayed free request."""
|
||||
with self.done_task_lock:
|
||||
self.delayed_free_requests.append((request_id, delay_start_time))
|
||||
if request_id not in self.record_finished_requests:
|
||||
self.delayed_free_requests[request_id] = delay_start_time
|
||||
else:
|
||||
self.record_finished_requests.discard(request_id)
|
||||
|
||||
def _retrieve_expired_requests(self):
|
||||
"""Retrieve all expired delayed requests."""
|
||||
@ -99,10 +106,11 @@ class KVCacheTaskTracker:
|
||||
# Free delayed requests if they exceed the timeout
|
||||
current_time = time.time()
|
||||
while self.delayed_free_requests:
|
||||
request_id, delay_start_time = self.delayed_free_requests[0]
|
||||
request_id = next(iter(self.delayed_free_requests))
|
||||
delay_start_time = self.delayed_free_requests[request_id]
|
||||
if (current_time - delay_start_time
|
||||
> envs.VLLM_NIXL_ABORT_REQUEST_TIMEOUT):
|
||||
self.delayed_free_requests.popleft()
|
||||
self.delayed_free_requests.popitem(last=False)
|
||||
expired_requests.add(request_id)
|
||||
logger.info("Force freed request: %s", request_id)
|
||||
else:
|
||||
@ -111,8 +119,7 @@ class KVCacheTaskTracker:
|
||||
|
||||
def _remove_delayed_requests(self, request_id: str):
|
||||
"""Remove all delayed free requests matching the given request_id."""
|
||||
self.delayed_free_requests = deque(
|
||||
(r, t) for r, t in self.delayed_free_requests if r != request_id)
|
||||
self.delayed_free_requests.pop(request_id)
|
||||
|
||||
|
||||
class KVCacheSendingThread(threading.Thread):
|
||||
|
Reference in New Issue
Block a user