[Bugfix] Route requests requiring KVC recomputation from the decode instance to the P instance (#3448)

### What this PR does / why we need it?
This PR is aimed to fix the recomputing out of memory bug in decode
instance. When recomputing happens in decode, kv cache usage may exceed
the pre-allocated memory, and it will cause OOM.

So we propose a new scheduling strategy, when decode instance cannot
allocate new block for running requests, we will stop the request that
will be preempted. These stopped request will be recognied by proxy, and
they will be send to prefill instance again to calculate kvc and then
direct to decode instance.

This is a temporary plan to fix the bug. The long-term stratege is to
use CPU offload in decode instance.

### Does this PR introduce _any_ user-facing change?
An extra ascend configuration option **-- recompute_scheduler_enable =
True** is added to enable this strategy. The default value is False
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
This commit is contained in:
Shirley125
2025-10-18 15:56:44 +08:00
committed by GitHub
parent 4750d45d86
commit b4233a2ec3
6 changed files with 1761 additions and 114 deletions

View File

@ -84,17 +84,18 @@
#
# For more details, see the code and comments in this file.
import argparse
import asyncio
import functools
import heapq
import json
import os
import sys
import uuid
import threading
import uuid
from contextlib import asynccontextmanager
from typing import List
from dataclasses import dataclass
from typing import Any, List
import httpx
from fastapi import FastAPI, Request
@ -106,6 +107,7 @@ logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
@ -324,7 +326,7 @@ async def listen_for_disconnect(request: Request) -> None:
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
@ -337,9 +339,9 @@ def with_cancellation(handler_func):
if handler_task in done:
return handler_task.result()
return None
return wrapper
app = FastAPI(lifespan=lifespan)
@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_requests),
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
"metaserver":
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
req_data["stream"] = False
req_data["max_tokens"] = 1
@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
return "chatcmpl-" + req_id
async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
result_future = asyncio.Future() # type: ignore
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_id_future[request_id_api] = result_future
# Send request to prefiller
asyncio.get_running_loop().create_task(
send_request_to_service(prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay))
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response = await result_future
del proxy_state.req_id_future[request_id_api]
req_data["kv_transfer_params"] = response
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
@dataclass
class InstanceInfo:
request_id: str
prefiller_idx: int
prefiller_score: float
prefiller: ServerState
decoder_idx: int
decoder_score: float
decoder: ServerState
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
result_future = asyncio.Future() # type: ignore
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_id_future[request_id_api] = result_future
# Send request to prefiller
asyncio.get_running_loop().create_task(send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay))
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response = await result_future
del proxy_state.req_id_future[request_id_api]
req_data["kv_transfer_params"] = response
instance_info = await _handle_select_instance(api, req_data,
request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False
async def generate_stream():
nonlocal released_kv
nonlocal instance_info
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
prefiller_idx, prefiller_score)
released_kv = True
yield chunk
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
released_kv = True
chunk_str = chunk.decode("utf-8").strip()
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.warning(
f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
generated_token += content
stop_reason = choice.get(
"stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choices[0]["message"][
"content"] = generated_token
else:
choices[0]["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
proxy_state.release_prefiller_kv(prefiller_idx,
prefiller_score)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
@ -564,13 +669,12 @@ async def metaserver(request: Request):
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(req_data)
except Exception as e:
logger.error(
f"Post metaserver failed with: {str(e)}"
)
logger.error(f"Post metaserver failed with: {str(e)}")
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -84,16 +84,17 @@
#
# For more details, see the code and comments in this file.
import argparse
import asyncio
import functools
import heapq
import json
import os
import sys
import uuid
from contextlib import asynccontextmanager
from typing import List
from dataclasses import dataclass
from typing import Any, List
import httpx
from fastapi import FastAPI, Request
@ -105,6 +106,7 @@ logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
@ -443,69 +445,170 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
raise e
async def _handle_select_instance(api: str, req_data: Any,
request_length: int):
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
# Send request to prefiller
response = await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
return InstanceInfo(request_id=request_id,
prefiller_idx=prefiller_idx,
prefiller_score=prefiller_score,
prefiller=prefiller,
decoder=decoder,
decoder_idx=decoder_idx,
decoder_score=decoder_score)
@dataclass
class InstanceInfo:
request_id: str
prefiller_idx: int
prefiller_score: float
prefiller: ServerState
decoder_idx: int
decoder_score: float
decoder: ServerState
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
# Send request to prefiller
response = await send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay)
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response_json = response.json()
kv_transfer_params = response_json.get('kv_transfer_params', {})
if kv_transfer_params:
req_data["kv_transfer_params"] = kv_transfer_params
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False
instance_info = await _handle_select_instance(api, req_data,
request_length)
stream_flag = bool(req_data.get("stream", False))
chat_flag = "messages" in req_data
if "prompt" in req_data:
origin_prompt = req_data["prompt"]
elif chat_flag:
messages = req_data["messages"]
origin_prompt = messages[0].get("content", "")
else:
origin_prompt = ""
# refer to vLLM sampling_params: max_token default value
origin_max_tokens = req_data.get("max_tokens", 16)
async def generate_stream():
nonlocal released_kv
nonlocal instance_info
generated_token = ""
released_kv = False
retry_count = 0
retry = True
completion_tokens = 0
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
prefiller_idx, prefiller_score)
released_kv = True
yield chunk
while retry:
retry = False
async for chunk in stream_service_response_with_retry(
instance_info.decoder.client,
api,
req_data,
request_id=instance_info.request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
instance_info.prefiller_idx,
instance_info.prefiller_score)
released_kv = True
chunk_str = chunk.decode("utf-8").strip()
if not chunk_str:
continue
if chunk_str.startswith("data: "):
chunk_str = chunk_str[len("data: "):]
try:
chunk_json = json.loads(chunk_str)
except json.JSONDecodeError:
# if chunk is [done], skip it.
logger.warning(
f"Skipping chunk: {chunk_str}")
yield chunk
continue
choices = chunk_json.get("choices", [])
if not choices:
yield chunk
continue
choice = choices[0]
delta = choice.get("delta") or {}
message = choice.get("message") or {}
content = (
delta.get("content")
or message.get("content")
or choice.get("text")
or ""
)
generated_token += content
stop_reason = choice.get(
"stop_reason")
usage = chunk_json.get("usage", {})
completion_tokens = (completion_tokens + 1) if stream_flag else \
(completion_tokens + usage.get("completion_tokens"))
if stop_reason == "recomputed":
retry = True
retry_count += 1
if chat_flag:
messages[0][
"content"] = origin_prompt + generated_token
else:
req_data[
"prompt"] = origin_prompt + generated_token
req_data[
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
tmp_request_length = len(
json.dumps(req_data).encode("utf-8"))
instance_info = await _handle_select_instance(
api, req_data, tmp_request_length)
break
if retry_count > 0 and not stream_flag:
if chat_flag:
choice["message"][
"content"] = generated_token
else:
choice["text"] = generated_token
chunk = json.dumps(chunk_json).encode("utf-8")
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
proxy_state.release_prefiller_kv(prefiller_idx,
prefiller_score)
proxy_state.abort_prefiller_request(
instance_info.prefiller_idx, instance_info.request_id)
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
instance_info.prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
proxy_state.release_decoder(instance_info.decoder_idx,
instance_info.decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
@ -544,4 +647,5 @@ if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -70,6 +70,8 @@ class AscendConfig:
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.recompute_scheduler_enable = additional_config.get(
"recompute_scheduler_enable", False)
self.lmhead_tensor_parallel_size = additional_config.get(
"lmhead_tensor_parallel_size", None)
if self.lmhead_tensor_parallel_size is not None:

View File

@ -0,0 +1,39 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
from dataclasses import dataclass, fields
from typing import Type, Union
from vllm.config import SchedulerConfig
MAX_INT = 2147483647
@dataclass
class RecomputeSchedulerConfig(SchedulerConfig):
scheduler_cls: Union[str, Type[object]] = (
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
@classmethod
def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig):
scheduler_config = {
field.name: getattr(vllm_scheduler_config, field.name)
for field in fields(vllm_scheduler_config) if field.init
}
scheduler_config["scheduler_cls"] = (
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
return cls(**scheduler_config)

File diff suppressed because it is too large Load Diff

View File

@ -300,6 +300,12 @@ class NPUPlatform(Platform):
vllm_config.scheduler_config,
ascend_config.ascend_scheduler_config)
vllm_config.scheduler_config = ascend_scheduler_config
elif ascend_config.recompute_scheduler_enable:
from vllm_ascend.core.recompute_schedule_config import \
RecomputeSchedulerConfig
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
vllm_config.scheduler_config)
vllm_config.scheduler_config = recompute_scheduler_config
@classmethod
def get_attn_backend_cls(