mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 21:53:54 +08:00
[Bugfix] Route requests requiring KVC recomputation from the decode instance to the P instance (#3448)
### What this PR does / why we need it? This PR is aimed to fix the recomputing out of memory bug in decode instance. When recomputing happens in decode, kv cache usage may exceed the pre-allocated memory, and it will cause OOM. So we propose a new scheduling strategy, when decode instance cannot allocate new block for running requests, we will stop the request that will be preempted. These stopped request will be recognied by proxy, and they will be send to prefill instance again to calculate kvc and then direct to decode instance. This is a temporary plan to fix the bug. The long-term stratege is to use CPU offload in decode instance. ### Does this PR introduce _any_ user-facing change? An extra ascend configuration option **-- recompute_scheduler_enable = True** is added to enable this strategy. The default value is False ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: CHEN <116010019@link.cuhk.edu.cn>
This commit is contained in:
@ -84,17 +84,18 @@
|
||||
#
|
||||
# For more details, see the code and comments in this file.
|
||||
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import threading
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
@ -106,6 +107,7 @@ logger = init_logger(__name__)
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
@ -362,7 +364,8 @@ async def send_request_to_service(client: httpx.AsyncClient,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
"aborted_request": list(aborted_requests),
|
||||
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
||||
"metaserver":
|
||||
f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
@ -455,72 +458,174 @@ def get_api_request_id(api, req_id):
|
||||
return "chatcmpl-" + req_id
|
||||
|
||||
|
||||
async def _handle_select_instance(api: str, req_data: Any,
|
||||
request_length: int):
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
result_future = asyncio.Future() # type: ignore
|
||||
request_id_api = get_api_request_id(api, request_id)
|
||||
proxy_state.req_id_future[request_id_api] = result_future
|
||||
# Send request to prefiller
|
||||
asyncio.get_running_loop().create_task(
|
||||
send_request_to_service(prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay))
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
|
||||
response = await result_future
|
||||
del proxy_state.req_id_future[request_id_api]
|
||||
req_data["kv_transfer_params"] = response
|
||||
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
return InstanceInfo(request_id=request_id,
|
||||
prefiller_idx=prefiller_idx,
|
||||
prefiller_score=prefiller_score,
|
||||
prefiller=prefiller,
|
||||
decoder=decoder,
|
||||
decoder_idx=decoder_idx,
|
||||
decoder_score=decoder_score)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceInfo:
|
||||
request_id: str
|
||||
prefiller_idx: int
|
||||
prefiller_score: float
|
||||
prefiller: ServerState
|
||||
decoder_idx: int
|
||||
decoder_score: float
|
||||
decoder: ServerState
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
result_future = asyncio.Future() # type: ignore
|
||||
request_id_api = get_api_request_id(api, request_id)
|
||||
proxy_state.req_id_future[request_id_api] = result_future
|
||||
# Send request to prefiller
|
||||
asyncio.get_running_loop().create_task(send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay))
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
instance_info = await _handle_select_instance(api, req_data,
|
||||
request_length)
|
||||
stream_flag = bool(req_data.get("stream", False))
|
||||
chat_flag = "messages" in req_data
|
||||
|
||||
response = await result_future
|
||||
del proxy_state.req_id_future[request_id_api]
|
||||
req_data["kv_transfer_params"] = response
|
||||
if "prompt" in req_data:
|
||||
origin_prompt = req_data["prompt"]
|
||||
elif chat_flag:
|
||||
messages = req_data["messages"]
|
||||
origin_prompt = messages[0].get("content", "")
|
||||
else:
|
||||
origin_prompt = ""
|
||||
# refer to vLLM sampling_params: max_token default value
|
||||
origin_max_tokens = req_data.get("max_tokens", 16)
|
||||
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
async def generate_stream():
|
||||
nonlocal released_kv
|
||||
nonlocal instance_info
|
||||
generated_token = ""
|
||||
released_kv = False
|
||||
retry_count = 0
|
||||
retry = True
|
||||
completion_tokens = 0
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
prefiller_idx, prefiller_score)
|
||||
released_kv = True
|
||||
yield chunk
|
||||
while retry:
|
||||
retry = False
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
instance_info.decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
released_kv = True
|
||||
chunk_str = chunk.decode("utf-8").strip()
|
||||
if not chunk_str:
|
||||
continue
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[len("data: "):]
|
||||
try:
|
||||
chunk_json = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
# if chunk is [done], skip it.
|
||||
logger.warning(
|
||||
f"Skipping chunk: {chunk_str}")
|
||||
yield chunk
|
||||
continue
|
||||
choices = chunk_json.get("choices", [])
|
||||
if not choices:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta") or {}
|
||||
message = choice.get("message") or {}
|
||||
content = (
|
||||
delta.get("content")
|
||||
or message.get("content")
|
||||
or choice.get("text")
|
||||
or ""
|
||||
)
|
||||
generated_token += content
|
||||
|
||||
stop_reason = choice.get(
|
||||
"stop_reason")
|
||||
usage = chunk_json.get("usage", {})
|
||||
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
||||
(completion_tokens + usage.get("completion_tokens"))
|
||||
if stop_reason == "recomputed":
|
||||
retry = True
|
||||
retry_count += 1
|
||||
if chat_flag:
|
||||
messages[0][
|
||||
"content"] = origin_prompt + generated_token
|
||||
else:
|
||||
req_data[
|
||||
"prompt"] = origin_prompt + generated_token
|
||||
req_data[
|
||||
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||
tmp_request_length = len(
|
||||
json.dumps(req_data).encode("utf-8"))
|
||||
instance_info = await _handle_select_instance(
|
||||
api, req_data, tmp_request_length)
|
||||
break
|
||||
if retry_count > 0 and not stream_flag:
|
||||
if chat_flag:
|
||||
choices[0]["message"][
|
||||
"content"] = generated_token
|
||||
else:
|
||||
choices[0]["text"] = generated_token
|
||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
|
||||
proxy_state.release_prefiller_kv(prefiller_idx,
|
||||
prefiller_score)
|
||||
proxy_state.abort_prefiller_request(
|
||||
instance_info.prefiller_idx, instance_info.request_id)
|
||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
proxy_state.release_decoder(instance_info.decoder_idx,
|
||||
instance_info.decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
@ -564,13 +669,12 @@ async def metaserver(request: Request):
|
||||
result_future = proxy_state.req_id_future[request_id]
|
||||
result_future.set_result(req_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Post metaserver failed with: {str(e)}"
|
||||
)
|
||||
logger.error(f"Post metaserver failed with: {str(e)}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
|
@ -84,16 +84,17 @@
|
||||
#
|
||||
# For more details, see the code and comments in this file.
|
||||
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
@ -105,6 +106,7 @@ logger = init_logger(__name__)
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
import uvloop
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
@ -443,69 +445,170 @@ async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
raise e
|
||||
|
||||
|
||||
async def _handle_select_instance(api: str, req_data: Any,
|
||||
request_length: int):
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
# Send request to prefiller
|
||||
response = await send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay)
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
return InstanceInfo(request_id=request_id,
|
||||
prefiller_idx=prefiller_idx,
|
||||
prefiller_score=prefiller_score,
|
||||
prefiller=prefiller,
|
||||
decoder=decoder,
|
||||
decoder_idx=decoder_idx,
|
||||
decoder_score=decoder_score)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InstanceInfo:
|
||||
request_id: str
|
||||
prefiller_idx: int
|
||||
prefiller_score: float
|
||||
prefiller: ServerState
|
||||
decoder_idx: int
|
||||
decoder_score: float
|
||||
decoder: ServerState
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
# Send request to prefiller
|
||||
response = await send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay)
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
response_json = response.json()
|
||||
kv_transfer_params = response_json.get('kv_transfer_params', {})
|
||||
if kv_transfer_params:
|
||||
req_data["kv_transfer_params"] = kv_transfer_params
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
instance_info = await _handle_select_instance(api, req_data,
|
||||
request_length)
|
||||
stream_flag = bool(req_data.get("stream", False))
|
||||
chat_flag = "messages" in req_data
|
||||
|
||||
if "prompt" in req_data:
|
||||
origin_prompt = req_data["prompt"]
|
||||
elif chat_flag:
|
||||
messages = req_data["messages"]
|
||||
origin_prompt = messages[0].get("content", "")
|
||||
else:
|
||||
origin_prompt = ""
|
||||
# refer to vLLM sampling_params: max_token default value
|
||||
origin_max_tokens = req_data.get("max_tokens", 16)
|
||||
|
||||
async def generate_stream():
|
||||
nonlocal released_kv
|
||||
nonlocal instance_info
|
||||
generated_token = ""
|
||||
released_kv = False
|
||||
retry_count = 0
|
||||
retry = True
|
||||
completion_tokens = 0
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
prefiller_idx, prefiller_score)
|
||||
released_kv = True
|
||||
yield chunk
|
||||
while retry:
|
||||
retry = False
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
instance_info.decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=instance_info.request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
released_kv = True
|
||||
chunk_str = chunk.decode("utf-8").strip()
|
||||
if not chunk_str:
|
||||
continue
|
||||
if chunk_str.startswith("data: "):
|
||||
chunk_str = chunk_str[len("data: "):]
|
||||
try:
|
||||
chunk_json = json.loads(chunk_str)
|
||||
except json.JSONDecodeError:
|
||||
# if chunk is [done], skip it.
|
||||
logger.warning(
|
||||
f"Skipping chunk: {chunk_str}")
|
||||
yield chunk
|
||||
continue
|
||||
choices = chunk_json.get("choices", [])
|
||||
if not choices:
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta") or {}
|
||||
message = choice.get("message") or {}
|
||||
content = (
|
||||
delta.get("content")
|
||||
or message.get("content")
|
||||
or choice.get("text")
|
||||
or ""
|
||||
)
|
||||
generated_token += content
|
||||
|
||||
stop_reason = choice.get(
|
||||
"stop_reason")
|
||||
usage = chunk_json.get("usage", {})
|
||||
completion_tokens = (completion_tokens + 1) if stream_flag else \
|
||||
(completion_tokens + usage.get("completion_tokens"))
|
||||
if stop_reason == "recomputed":
|
||||
retry = True
|
||||
retry_count += 1
|
||||
if chat_flag:
|
||||
messages[0][
|
||||
"content"] = origin_prompt + generated_token
|
||||
else:
|
||||
req_data[
|
||||
"prompt"] = origin_prompt + generated_token
|
||||
req_data[
|
||||
"max_tokens"] = origin_max_tokens - completion_tokens + retry_count
|
||||
tmp_request_length = len(
|
||||
json.dumps(req_data).encode("utf-8"))
|
||||
instance_info = await _handle_select_instance(
|
||||
api, req_data, tmp_request_length)
|
||||
break
|
||||
if retry_count > 0 and not stream_flag:
|
||||
if chat_flag:
|
||||
choice["message"][
|
||||
"content"] = generated_token
|
||||
else:
|
||||
choice["text"] = generated_token
|
||||
chunk = json.dumps(chunk_json).encode("utf-8")
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
f"Error during streaming from decoder {instance_info.decoder.url}: {str(e)} the aborted request {instance_info.request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
|
||||
proxy_state.release_prefiller_kv(prefiller_idx,
|
||||
prefiller_score)
|
||||
proxy_state.abort_prefiller_request(
|
||||
instance_info.prefiller_idx, instance_info.request_id)
|
||||
proxy_state.release_prefiller_kv(instance_info.prefiller_idx,
|
||||
instance_info.prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
proxy_state.release_decoder(instance_info.decoder_idx,
|
||||
instance_info.decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
@ -544,4 +647,5 @@ if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
@ -70,6 +70,8 @@ class AscendConfig:
|
||||
) and not self.torchair_graph_config.enabled and vllm_config.parallel_config.enable_expert_parallel
|
||||
self.multistream_overlap_shared_expert = additional_config.get(
|
||||
"multistream_overlap_shared_expert", False)
|
||||
self.recompute_scheduler_enable = additional_config.get(
|
||||
"recompute_scheduler_enable", False)
|
||||
self.lmhead_tensor_parallel_size = additional_config.get(
|
||||
"lmhead_tensor_parallel_size", None)
|
||||
if self.lmhead_tensor_parallel_size is not None:
|
||||
|
39
vllm_ascend/core/recompute_schedule_config.py
Normal file
39
vllm_ascend/core/recompute_schedule_config.py
Normal file
@ -0,0 +1,39 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Type, Union
|
||||
|
||||
from vllm.config import SchedulerConfig
|
||||
|
||||
MAX_INT = 2147483647
|
||||
|
||||
|
||||
@dataclass
|
||||
class RecomputeSchedulerConfig(SchedulerConfig):
|
||||
scheduler_cls: Union[str, Type[object]] = (
|
||||
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
|
||||
|
||||
@classmethod
|
||||
def initialize_from_config(cls, vllm_scheduler_config: SchedulerConfig):
|
||||
scheduler_config = {
|
||||
field.name: getattr(vllm_scheduler_config, field.name)
|
||||
for field in fields(vllm_scheduler_config) if field.init
|
||||
}
|
||||
scheduler_config["scheduler_cls"] = (
|
||||
"vllm_ascend.core.recompute_scheduler.RecomputeScheduler")
|
||||
return cls(**scheduler_config)
|
1392
vllm_ascend/core/recompute_scheduler.py
Normal file
1392
vllm_ascend/core/recompute_scheduler.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -300,6 +300,12 @@ class NPUPlatform(Platform):
|
||||
vllm_config.scheduler_config,
|
||||
ascend_config.ascend_scheduler_config)
|
||||
vllm_config.scheduler_config = ascend_scheduler_config
|
||||
elif ascend_config.recompute_scheduler_enable:
|
||||
from vllm_ascend.core.recompute_schedule_config import \
|
||||
RecomputeSchedulerConfig
|
||||
recompute_scheduler_config = RecomputeSchedulerConfig.initialize_from_config(
|
||||
vllm_config.scheduler_config)
|
||||
vllm_config.scheduler_config = recompute_scheduler_config
|
||||
|
||||
@classmethod
|
||||
def get_attn_backend_cls(
|
||||
|
Reference in New Issue
Block a user