mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
KVCache Transfer via Layer-wise Strategy in Disaggregation (#2602)
### What this PR does / why we need it? See RFC: https://github.com/vllm-project/vllm-ascend/issues/2470 This PR add a new kv connector for layer-wised kv transfer ### Does this PR introduce _any_ user-facing change? yes, a new kv connector is added. User can use layer wised feature now. ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/releases/v0.11.0 --------- Signed-off-by: leichao.lc <leichao139636@163.com> Signed-off-by: CaveNightingale <2859066733@qq.com> Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Signed-off-by: hanxinlong <50882499@qq.com> Signed-off-by: liziyu <liziyu16@huawei.com> Co-authored-by: CaveNightingale <2859066733@qq.com> Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com> Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: hanxinlong <50882499@qq.com>
This commit is contained in:
@ -0,0 +1,576 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
|
||||
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Tutorial: Using the Load Balance Proxy Server Example
|
||||
#
|
||||
# This proxy server is designed to distribute requests between multiple
|
||||
# "prefiller" and "decoder" backend servers for large language model inference.
|
||||
# It is useful for scaling out inference workloads and balancing load across
|
||||
# multiple backend instances.
|
||||
#
|
||||
# Features:
|
||||
# - Load balances requests to multiple prefiller and decoder servers.
|
||||
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
|
||||
# - Streams responses from backend servers to clients.
|
||||
#
|
||||
# Prerequisites:
|
||||
# - Python 3.8+
|
||||
# - Install dependencies:
|
||||
# pip install fastapi httpx uvicorn vllm
|
||||
#
|
||||
# Step 1: Start Your Backend Servers
|
||||
# ----------------------------------
|
||||
# You need to have at least one prefiller and one decoder backend running.
|
||||
# These can be mock servers or actual vLLM servers.
|
||||
#
|
||||
# For testing, you can use the provided mock server:
|
||||
#
|
||||
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
|
||||
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
|
||||
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
|
||||
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
|
||||
#
|
||||
# Step 2: Start the Proxy Server
|
||||
# ------------------------------
|
||||
# Run the proxy server, specifying the host/port for each prefiller and decoder:
|
||||
#
|
||||
# python load_balance_proxy_server_example.py \
|
||||
# --host 0.0.0.0 --port 9000 \
|
||||
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
|
||||
# --prefiller-ports 8100 8101 \
|
||||
# --decoder-hosts 127.0.0.1 127.0.0.1 \
|
||||
# --decoder-ports 8200 8201
|
||||
#
|
||||
# This will start the proxy on port 9000, load balancing between two prefiller
|
||||
# and two decoder servers.
|
||||
#
|
||||
# Step 3: Send a Request to the Proxy
|
||||
# -----------------------------------
|
||||
# You can now send OpenAI-compatible requests to the proxy. For example:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/v1/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "your-model",
|
||||
# "prompt": "The quick brown fox jumps over the lazy dog",
|
||||
# "max_tokens": 16
|
||||
# }'
|
||||
#
|
||||
# Or for chat completions:
|
||||
#
|
||||
# curl -X POST http://localhost:9000/v1/chat/completions \
|
||||
# -H "Content-Type: application/json" \
|
||||
# -d '{
|
||||
# "model": "your-model",
|
||||
# "messages": [{"role": "user", "content": "Hello!"}],
|
||||
# "max_tokens": 16
|
||||
# }'
|
||||
#
|
||||
# Step 4: Health Check
|
||||
# --------------------
|
||||
# To check if the proxy is running and see how many backend instances are
|
||||
# connected, use:
|
||||
#
|
||||
# curl http://localhost:9000/healthcheck
|
||||
#
|
||||
# This will return a JSON object with the status and the number of prefiller
|
||||
# and decoder instances.
|
||||
#
|
||||
# Notes:
|
||||
# - You can scale the number of prefiller and decoder servers as needed.
|
||||
# - The proxy will round-robin requests to balance load.
|
||||
# - For production, ensure your backend servers are robust and secure.
|
||||
#
|
||||
# For more details, see the code and comments in this file.
|
||||
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import heapq
|
||||
import os
|
||||
import sys
|
||||
import uuid
|
||||
import threading
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import List
|
||||
|
||||
import httpx
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Add uvloop for faster event loop if available
|
||||
try:
|
||||
import uvloop
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ServerState:
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.url = f'http://{host}:{port}/v1'
|
||||
self.client = httpx.AsyncClient(timeout=None,
|
||||
base_url=self.url,
|
||||
limits=httpx.Limits(
|
||||
max_connections=100000,
|
||||
max_keepalive_connections=100000))
|
||||
self.active_tokens = 0
|
||||
self.active_kv_cache = 0 # Only for prefiller
|
||||
self.active_requests = 0 # Number of active requests
|
||||
self.aborted_requests = set() # Track aborted requests
|
||||
# Removed individual server lock - will use global locks instead
|
||||
|
||||
|
||||
class ProxyState:
|
||||
|
||||
def __init__(self, prefiller_instances, decoder_instances):
|
||||
self.prefillers: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in prefiller_instances
|
||||
]
|
||||
self.decoders: List[ServerState] = [
|
||||
ServerState(h, p) for h, p in decoder_instances
|
||||
]
|
||||
self.req_to_prefiller = {}
|
||||
self.req_id_lock = asyncio.Lock()
|
||||
# Removed selection locks - no longer needed for synchronous methods
|
||||
|
||||
# Initialize priority queues for efficient server selection
|
||||
# Each entry is (priority_score, server_index, server_reference)
|
||||
# Lower priority score = higher priority (less loaded)
|
||||
self.prefiller_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.prefillers)]
|
||||
self.decoder_heap = [(0, i, server)
|
||||
for i, server in enumerate(self.decoders)]
|
||||
heapq.heapify(self.prefiller_heap)
|
||||
heapq.heapify(self.decoder_heap)
|
||||
self.req_id_future = {}
|
||||
|
||||
def _update_prefiller_priority(self, server_idx: int):
|
||||
"""Update the priority of a prefiller server in the heap."""
|
||||
server = self.prefillers[server_idx]
|
||||
# Priority based on active_tokens and active_kv_cache
|
||||
priority = server.active_tokens + server.active_kv_cache * 0.3
|
||||
# Remove old entry and add new one
|
||||
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.prefiller_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
|
||||
def _update_decoder_priority(self, server_idx: int):
|
||||
"""Update the priority of a decoder server in the heap."""
|
||||
server = self.decoders[server_idx]
|
||||
priority = server.active_tokens
|
||||
# Remove old entry and add new one
|
||||
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
|
||||
if i != server_idx]
|
||||
heapq.heappush(self.decoder_heap,
|
||||
(priority, server_idx, server)) # type: ignore
|
||||
|
||||
def abort_prefiller_request(self, server_idx: int,
|
||||
request_id): # Changed to synchronous
|
||||
"""
|
||||
Mark a request as aborted. This will helps to release kv cache in
|
||||
prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[server_idx].aborted_requests.add(request_id)
|
||||
|
||||
def aquire_aborted_prefiller_requests(
|
||||
self, server_idx: int): # Changed to synchronous
|
||||
"""
|
||||
Get the set of aborted requests and clear it.
|
||||
This is used to release kv cache in prefiller node.
|
||||
"""
|
||||
# No lock needed - atomic operation
|
||||
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
|
||||
self.prefillers[server_idx].aborted_requests.clear()
|
||||
return aborted_requests
|
||||
|
||||
async def next_req_id(self):
|
||||
async with self.req_id_lock:
|
||||
return str(uuid.uuid4())
|
||||
|
||||
def select_prefiller(self, token_count): # Changed to synchronous
|
||||
# No lock needed - entire function is atomic
|
||||
if not self.prefiller_heap:
|
||||
raise RuntimeError("No prefiller servers available")
|
||||
|
||||
priority, chosen, server = heapq.heappop(self.prefiller_heap)
|
||||
|
||||
# Update the chosen server atomically
|
||||
self.prefillers[chosen].active_tokens += token_count
|
||||
self.prefillers[chosen].active_kv_cache += token_count
|
||||
|
||||
# Update priority and re-add to heap
|
||||
self._update_prefiller_priority(chosen)
|
||||
|
||||
return chosen
|
||||
|
||||
def release_prefiller(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
self.prefillers[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
if self.prefillers[idx].active_kv_cache > 0:
|
||||
self.prefillers[idx].active_kv_cache -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_prefiller_priority(idx)
|
||||
|
||||
def select_decoder(self, token_count): # Changed to synchronous
|
||||
# No lock needed - entire function is atomic
|
||||
if not self.decoder_heap:
|
||||
raise RuntimeError("No decoder servers available")
|
||||
|
||||
priority, chosen, server = heapq.heappop(self.decoder_heap)
|
||||
|
||||
# Update the chosen server atomically
|
||||
self.decoders[chosen].active_tokens += token_count
|
||||
|
||||
# Update priority and re-add to heap
|
||||
self._update_decoder_priority(chosen)
|
||||
|
||||
return chosen
|
||||
|
||||
def release_decoder(self, idx, token_count): # Changed to synchronous
|
||||
# No lock needed - atomic operation
|
||||
self.decoders[idx].active_tokens -= token_count
|
||||
# Update priority queue after releasing
|
||||
self._update_decoder_priority(idx)
|
||||
|
||||
# Omni_infer's calculate_input_scores function
|
||||
def calculate_prefill_scores(self, request_length: int) -> float:
|
||||
length_score = request_length / 4.0
|
||||
input_score = length_score * 0.0345 + 120.0745
|
||||
return input_score
|
||||
|
||||
def calculate_decode_scores(self, request_length: int) -> float:
|
||||
return request_length
|
||||
|
||||
|
||||
proxy_state = None
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--prefiller-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--prefiller-ports",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=[8001])
|
||||
parser.add_argument("--decoder-hosts",
|
||||
type=str,
|
||||
nargs="+",
|
||||
default=["localhost"])
|
||||
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
|
||||
parser.add_argument("--max-retries",
|
||||
type=int,
|
||||
default=3,
|
||||
help="Maximum number of retries for HTTP requests")
|
||||
parser.add_argument(
|
||||
"--retry-delay",
|
||||
type=float,
|
||||
default=0.001,
|
||||
help="Base delay (seconds) for exponential backoff retries")
|
||||
args = parser.parse_args()
|
||||
if len(args.prefiller_hosts) != len(args.prefiller_ports):
|
||||
raise ValueError(
|
||||
"Number of prefiller hosts must match number of prefiller ports")
|
||||
if len(args.decoder_hosts) != len(args.decoder_ports):
|
||||
raise ValueError(
|
||||
"Number of decoder hosts must match number of decoder ports")
|
||||
args.prefiller_instances = list(
|
||||
zip(args.prefiller_hosts, args.prefiller_ports))
|
||||
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
|
||||
return args
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
global proxy_state
|
||||
proxy_state = ProxyState(global_args.prefiller_instances,
|
||||
global_args.decoder_instances)
|
||||
print(
|
||||
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
|
||||
)
|
||||
yield
|
||||
for p in proxy_state.prefillers:
|
||||
await p.client.aclose()
|
||||
for d in proxy_state.decoders:
|
||||
await d.client.aclose()
|
||||
|
||||
|
||||
async def listen_for_disconnect(request: Request) -> None:
|
||||
"""Return if a disconnect message is received"""
|
||||
while True:
|
||||
message = await request.receive()
|
||||
if message["type"] == "http.disconnect":
|
||||
break
|
||||
|
||||
|
||||
def with_cancellation(handler_func):
|
||||
|
||||
@functools.wraps(handler_func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
request = kwargs["request"]
|
||||
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
|
||||
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
|
||||
done, pending = await asyncio.wait([handler_task, cancellation_task],
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if handler_task in done:
|
||||
return handler_task.result()
|
||||
return None
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
app = FastAPI(lifespan=lifespan)
|
||||
|
||||
|
||||
async def send_request_to_service(client: httpx.AsyncClient,
|
||||
prefiller_id: int,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
|
||||
prefiller_id)
|
||||
req_data = req_data.copy()
|
||||
req_data['kv_transfer_params'] = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
"remote_engine_id": None,
|
||||
"remote_block_ids": None,
|
||||
"remote_host": None,
|
||||
"remote_port": None,
|
||||
"aborted_request": list(aborted_requests),
|
||||
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
|
||||
}
|
||||
req_data["stream"] = False
|
||||
req_data["max_tokens"] = 1
|
||||
if "stream_options" in req_data:
|
||||
del req_data["stream_options"]
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
last_exc = None
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = await client.post(endpoint,
|
||||
json=req_data,
|
||||
headers=headers)
|
||||
response.raise_for_status()
|
||||
if request_id in proxy_state.req_id_future:
|
||||
result_future = proxy_state.req_id_future[request_id]
|
||||
result_future.set_result(response.json()["kv_transfer_params"])
|
||||
return
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
|
||||
last_exc = e
|
||||
if attempt < max_retries:
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for {endpoint}.")
|
||||
raise last_exc
|
||||
|
||||
|
||||
async def stream_service_response_with_retry(client: httpx.AsyncClient,
|
||||
endpoint: str,
|
||||
req_data: dict,
|
||||
request_id: str,
|
||||
max_retries: int = 3,
|
||||
base_delay: float = 0.2):
|
||||
headers = {
|
||||
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
|
||||
"X-Request-Id": request_id
|
||||
}
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
async with client.stream("POST",
|
||||
endpoint,
|
||||
json=req_data,
|
||||
headers=headers) as response:
|
||||
response.raise_for_status()
|
||||
first_chunk_sent = False
|
||||
async for chunk in response.aiter_bytes():
|
||||
first_chunk_sent = True
|
||||
yield chunk
|
||||
return # Success, exit after streaming
|
||||
except (httpx.RequestError, httpx.HTTPStatusError) as e:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
# If any chunk has been sent, do not retry, just log and drop
|
||||
if 'first_chunk_sent' in locals() and first_chunk_sent:
|
||||
logger.error(
|
||||
f"Streaming to client interrupted after response started: {str(e)}"
|
||||
)
|
||||
return
|
||||
else:
|
||||
if attempt < max_retries:
|
||||
logger.warning(
|
||||
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
|
||||
)
|
||||
await asyncio.sleep(base_delay * (2**(attempt - 1)))
|
||||
else:
|
||||
logger.error(
|
||||
f"All {max_retries} attempts failed for streaming {endpoint}."
|
||||
)
|
||||
raise e
|
||||
|
||||
|
||||
def get_api_request_id(api, req_id):
|
||||
if api == "/completions":
|
||||
return "cmpl-" + req_id + "-0"
|
||||
elif api == "/chat/completions":
|
||||
return "chatcmpl-" + req_id
|
||||
|
||||
|
||||
async def _handle_completions(api: str, request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
req_body = await request.body()
|
||||
request_length = len(req_body)
|
||||
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
|
||||
logger.debug(
|
||||
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
|
||||
)
|
||||
request_id = await proxy_state.next_req_id()
|
||||
# Select prefiller
|
||||
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
|
||||
prefiller = proxy_state.prefillers[prefiller_idx]
|
||||
result_future = asyncio.Future() # type: ignore
|
||||
request_id_api = get_api_request_id(api, request_id)
|
||||
proxy_state.req_id_future[request_id_api] = result_future
|
||||
# Send request to prefiller
|
||||
asyncio.get_running_loop().create_task(send_request_to_service(
|
||||
prefiller.client,
|
||||
prefiller_idx,
|
||||
api,
|
||||
req_data,
|
||||
request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay))
|
||||
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
|
||||
|
||||
response = await result_future
|
||||
del proxy_state.req_id_future[request_id_api]
|
||||
req_data["kv_transfer_params"] = response
|
||||
|
||||
# Select decoder
|
||||
decoder_score = proxy_state.calculate_decode_scores(request_length)
|
||||
logger.debug("Decoder score: %f", decoder_score)
|
||||
# Use the prefiller's kv_transfer_params to select decoder
|
||||
decoder_idx = proxy_state.select_decoder(decoder_score)
|
||||
decoder = proxy_state.decoders[decoder_idx]
|
||||
logger.debug("Using %s %s", prefiller.url, decoder.url)
|
||||
# Stream response from decoder
|
||||
released_kv = False
|
||||
async def generate_stream():
|
||||
nonlocal released_kv
|
||||
# Only one await per chunk, minimal logic in loop
|
||||
try:
|
||||
async for chunk in stream_service_response_with_retry(
|
||||
decoder.client,
|
||||
api,
|
||||
req_data,
|
||||
request_id=request_id,
|
||||
max_retries=global_args.max_retries,
|
||||
base_delay=global_args.retry_delay):
|
||||
if not released_kv and chunk:
|
||||
proxy_state.release_prefiller_kv(
|
||||
prefiller_idx, prefiller_score)
|
||||
released_kv = True
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
|
||||
)
|
||||
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
|
||||
proxy_state.release_prefiller_kv(prefiller_idx,
|
||||
prefiller_score)
|
||||
|
||||
# After streaming done, release tokens
|
||||
proxy_state.release_decoder(decoder_idx, decoder_score)
|
||||
|
||||
return StreamingResponse(generate_stream(),
|
||||
media_type="application/json")
|
||||
except Exception as e:
|
||||
import traceback
|
||||
exc_info = sys.exc_info()
|
||||
print("Error occurred in disagg prefill proxy server"
|
||||
f" - {api} endpoint")
|
||||
print(e)
|
||||
print("".join(traceback.format_exception(*exc_info)))
|
||||
raise
|
||||
|
||||
|
||||
@app.post("/v1/completions")
|
||||
@with_cancellation
|
||||
async def handle_completions(request: Request):
|
||||
return await _handle_completions("/completions", request)
|
||||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
@with_cancellation
|
||||
async def handle_chat_completions(request: Request):
|
||||
return await _handle_completions("/chat/completions", request)
|
||||
|
||||
|
||||
@app.get("/healthcheck")
|
||||
async def healthcheck():
|
||||
return {
|
||||
"status": "ok",
|
||||
"prefill_instances": len(proxy_state.prefillers),
|
||||
"decode_instances": len(proxy_state.decoders)
|
||||
}
|
||||
|
||||
|
||||
@app.post("/v1/metaserver")
|
||||
async def metaserver(request: Request):
|
||||
try:
|
||||
req_data = await request.json()
|
||||
request_id = req_data.pop("request_id", None)
|
||||
if request_id in proxy_state.req_id_future:
|
||||
result_future = proxy_state.req_id_future[request_id]
|
||||
result_future.set_result(req_data)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Post metaserver failed with: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
@ -544,4 +544,4 @@ if __name__ == '__main__':
|
||||
global global_args
|
||||
global_args = parse_args()
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
||||
uvicorn.run(app, host=global_args.host, port=global_args.port)
|
@ -4,8 +4,9 @@ import pytest
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import (
|
||||
_LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group,
|
||||
get_mc2_group, get_otp_group, init_ascend_model_parallel)
|
||||
_LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel,
|
||||
get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group,
|
||||
init_ascend_model_parallel)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -30,6 +31,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mock_ascend_config = MagicMock()
|
||||
mock_ascend_config.lmhead_tensor_parallel_size = 2
|
||||
mock_ascend_config.oproj_tensor_parallel_size = 2
|
||||
mock_ascend_config.pd_tp_ratio = 2
|
||||
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
|
||||
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
|
||||
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
|
||||
@ -38,11 +40,14 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
|
||||
mc2_group = get_mc2_group()
|
||||
lmheadtp_group = get_lmhead_tp_group()
|
||||
otp_group = get_otp_group()
|
||||
p_tp_group = get_p_tp_group()
|
||||
assert mc2_group is not None
|
||||
assert otp_group is not None
|
||||
assert lmheadtp_group is not None
|
||||
assert p_tp_group is not None
|
||||
|
||||
destroy_ascend_model_parallel()
|
||||
assert _MC2 is None
|
||||
assert _LMTP is None
|
||||
assert _OTP is None
|
||||
assert _P_TP is None
|
||||
|
1001
tests/ut/kv_connector/test_mooncake_layerwise_connector.py
Normal file
1001
tests/ut/kv_connector/test_mooncake_layerwise_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -94,6 +94,17 @@ class AscendConfig:
|
||||
raise AssertionError(
|
||||
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
|
||||
)
|
||||
self.pd_tp_ratio = 1
|
||||
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
|
||||
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"prefill", {"tp_size": 1})["tp_size"]
|
||||
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||
"decode", {"tp_size": 1})["tp_size"]
|
||||
pd_tp_ratio: int = prefill_tp_size // decode_tp_size
|
||||
self.pd_tp_ratio = pd_tp_ratio
|
||||
if self.pd_tp_ratio == 0:
|
||||
raise AssertionError(
|
||||
"Only support P node tp size lagger then D node tp size")
|
||||
|
||||
|
||||
class TorchairGraphConfig:
|
||||
|
@ -31,3 +31,8 @@ KVConnectorFactory.register_connector(
|
||||
"MooncakeConnectorStoreV1",
|
||||
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
|
||||
"MooncakeConnectorV1")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"MooncakeLayerwiseConnector",
|
||||
"vllm_ascend.distributed.mooncake_layerwise_connector",
|
||||
"MooncakeLayerwiseConnector")
|
||||
|
@ -1109,4 +1109,4 @@ def ensure_zmq_recv(
|
||||
logger.error(f"Receive failed after all retries: {e}")
|
||||
raise RuntimeError(
|
||||
f"Failed to receive data after {max_retries} "
|
||||
f"retries: {e}")
|
||||
f"retries: {e}")
|
1335
vllm_ascend/distributed/mooncake_layerwise_connector.py
Normal file
1335
vllm_ascend/distributed/mooncake_layerwise_connector.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -13,6 +13,7 @@ _MC2: Optional[GroupCoordinator] = None
|
||||
_MLP_TP: Optional[GroupCoordinator] = None
|
||||
_OTP: Optional[GroupCoordinator] = None
|
||||
_LMTP: Optional[GroupCoordinator] = None
|
||||
_P_TP: Optional[GroupCoordinator] = None
|
||||
|
||||
|
||||
def get_mc2_group() -> GroupCoordinator:
|
||||
@ -37,6 +38,12 @@ def get_mlp_tp_group() -> GroupCoordinator:
|
||||
return _MLP_TP
|
||||
|
||||
|
||||
def get_p_tp_group() -> GroupCoordinator:
|
||||
assert _P_TP is not None, (
|
||||
"distributed prefill tensor parallel group is not initialized")
|
||||
return _P_TP
|
||||
|
||||
|
||||
def model_parallel_initialized():
|
||||
return (_MC2 is not None)
|
||||
|
||||
@ -54,6 +61,22 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
|
||||
all_ranks = torch.arange(world_size).reshape(
|
||||
-1, parallel_config.data_parallel_size *
|
||||
parallel_config.tensor_parallel_size)
|
||||
|
||||
pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
global _P_TP
|
||||
assert _P_TP is None, (
|
||||
"distributed prefill tensor parallel group is already initialized")
|
||||
prefill_tensor_model_parallel_size = pd_tp_ratio if \
|
||||
pd_tp_ratio > 0 and pd_tp_ratio < parallel_config.tensor_parallel_size else parallel_config.tensor_parallel_size
|
||||
group_ranks = all_ranks.view(-1,
|
||||
prefill_tensor_model_parallel_size).unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
num = get_world_group().local_rank // pd_tp_ratio
|
||||
_P_TP = init_model_parallel_group(group_ranks,
|
||||
get_world_group().local_rank,
|
||||
backend,
|
||||
group_name=f"p_tp_{num}")
|
||||
|
||||
global _MC2
|
||||
group_ranks = all_ranks.unbind(0)
|
||||
group_ranks = [x.tolist() for x in group_ranks]
|
||||
@ -142,3 +165,8 @@ def destroy_ascend_model_parallel():
|
||||
if _OTP:
|
||||
_OTP.destroy()
|
||||
_OTP = None
|
||||
|
||||
global _P_TP
|
||||
if _P_TP:
|
||||
_P_TP.destroy()
|
||||
_P_TP = None
|
||||
|
47
vllm_ascend/distributed/utils.py
Normal file
47
vllm_ascend/distributed/utils.py
Normal file
@ -0,0 +1,47 @@
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_p_tp_group
|
||||
|
||||
|
||||
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
|
||||
value: torch.TensorType):
|
||||
if pd_tp_ratio <= 1:
|
||||
return None, None
|
||||
elif key is None or value is None:
|
||||
raise ValueError("key or value is None")
|
||||
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
|
||||
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
|
||||
return k_output, v_output
|
||||
|
||||
|
||||
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
|
||||
num_kv_heads = input_tensor.size(1)
|
||||
output_tensor = torch.zeros_like(input_tensor)
|
||||
dist.all_to_all_single(output_tensor,
|
||||
input_tensor,
|
||||
group=get_p_tp_group().device_group)
|
||||
input_tensor = 0
|
||||
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
|
||||
output_tensor = 0
|
||||
return result
|
||||
|
||||
|
||||
def rearrange_output(base_output: torch.Tensor, cut_num: int,
|
||||
num_kv_heads: int):
|
||||
size_0 = base_output.size(0)
|
||||
if size_0 % cut_num != 0:
|
||||
raise ValueError(
|
||||
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
|
||||
)
|
||||
chunk_size = size_0 // cut_num
|
||||
reshaped = base_output.view(cut_num, chunk_size, -1)
|
||||
transposed = reshaped.transpose(0, 1)
|
||||
return transposed.contiguous().view(size_0, num_kv_heads, -1)
|
||||
|
||||
|
||||
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
|
||||
data_ptr = tensor.data_ptr()
|
||||
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
|
||||
offset = (aligned_addr - data_ptr) // tensor.element_size()
|
||||
return tensor[int(offset):]
|
Reference in New Issue
Block a user