KVCache Transfer via Layer-wise Strategy in Disaggregation (#2602)

### What this PR does / why we need it?
See RFC: https://github.com/vllm-project/vllm-ascend/issues/2470 This PR
add a new kv connector for layer-wised kv transfer

### Does this PR introduce _any_ user-facing change?
yes, a new kv connector is added. User can use layer wised feature now.
### How was this patch tested?


- vLLM version: v0.11.0rc3
- vLLM main:
https://github.com/vllm-project/vllm/commit/releases/v0.11.0

---------

Signed-off-by: leichao.lc <leichao139636@163.com>
Signed-off-by: CaveNightingale <2859066733@qq.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: hanxinlong <50882499@qq.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: CaveNightingale <2859066733@qq.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: hanxinlong <50882499@qq.com>
This commit is contained in:
Chao Lei
2025-09-30 15:10:29 +08:00
committed by GitHub
parent f8c93d8d24
commit a486ff8c11
10 changed files with 3012 additions and 4 deletions

View File

@ -0,0 +1,576 @@
# Adapted from https://github.com/vllm-project/vllm/tests/v1/kv_connector/nixl_integration/toy_proxy_server.py
# SPDX-License-Identifier: Apache-2.0
#
# Tutorial: Using the Load Balance Proxy Server Example
#
# This proxy server is designed to distribute requests between multiple
# "prefiller" and "decoder" backend servers for large language model inference.
# It is useful for scaling out inference workloads and balancing load across
# multiple backend instances.
#
# Features:
# - Load balances requests to multiple prefiller and decoder servers.
# - Supports OpenAI-compatible /v1/completions and /v1/chat/completions endpoints.
# - Streams responses from backend servers to clients.
#
# Prerequisites:
# - Python 3.8+
# - Install dependencies:
# pip install fastapi httpx uvicorn vllm
#
# Step 1: Start Your Backend Servers
# ----------------------------------
# You need to have at least one prefiller and one decoder backend running.
# These can be mock servers or actual vLLM servers.
#
# For testing, you can use the provided mock server:
#
# vllm serve --host 0.0.0.0 --port 8100 ... # Prefiller 1
# vllm serve --host 0.0.0.0 --port 8101 ... # Prefiller 2
# vllm serve --host 0.0.0.0 --port 8200 ... # Decoder 1
# vllm serve --host 0.0.0.0 --port 8201 ... # Decoder 2
#
# Step 2: Start the Proxy Server
# ------------------------------
# Run the proxy server, specifying the host/port for each prefiller and decoder:
#
# python load_balance_proxy_server_example.py \
# --host 0.0.0.0 --port 9000 \
# --prefiller-hosts 127.0.0.1 127.0.0.1 \
# --prefiller-ports 8100 8101 \
# --decoder-hosts 127.0.0.1 127.0.0.1 \
# --decoder-ports 8200 8201
#
# This will start the proxy on port 9000, load balancing between two prefiller
# and two decoder servers.
#
# Step 3: Send a Request to the Proxy
# -----------------------------------
# You can now send OpenAI-compatible requests to the proxy. For example:
#
# curl -X POST http://localhost:9000/v1/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "your-model",
# "prompt": "The quick brown fox jumps over the lazy dog",
# "max_tokens": 16
# }'
#
# Or for chat completions:
#
# curl -X POST http://localhost:9000/v1/chat/completions \
# -H "Content-Type: application/json" \
# -d '{
# "model": "your-model",
# "messages": [{"role": "user", "content": "Hello!"}],
# "max_tokens": 16
# }'
#
# Step 4: Health Check
# --------------------
# To check if the proxy is running and see how many backend instances are
# connected, use:
#
# curl http://localhost:9000/healthcheck
#
# This will return a JSON object with the status and the number of prefiller
# and decoder instances.
#
# Notes:
# - You can scale the number of prefiller and decoder servers as needed.
# - The proxy will round-robin requests to balance load.
# - For production, ensure your backend servers are robust and secure.
#
# For more details, see the code and comments in this file.
import argparse
import asyncio
import functools
import heapq
import os
import sys
import uuid
import threading
from contextlib import asynccontextmanager
from typing import List
import httpx
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from vllm.logger import init_logger
logger = init_logger(__name__)
# Add uvloop for faster event loop if available
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
except ImportError:
pass
class ServerState:
def __init__(self, host, port):
self.host = host
self.port = port
self.url = f'http://{host}:{port}/v1'
self.client = httpx.AsyncClient(timeout=None,
base_url=self.url,
limits=httpx.Limits(
max_connections=100000,
max_keepalive_connections=100000))
self.active_tokens = 0
self.active_kv_cache = 0 # Only for prefiller
self.active_requests = 0 # Number of active requests
self.aborted_requests = set() # Track aborted requests
# Removed individual server lock - will use global locks instead
class ProxyState:
def __init__(self, prefiller_instances, decoder_instances):
self.prefillers: List[ServerState] = [
ServerState(h, p) for h, p in prefiller_instances
]
self.decoders: List[ServerState] = [
ServerState(h, p) for h, p in decoder_instances
]
self.req_to_prefiller = {}
self.req_id_lock = asyncio.Lock()
# Removed selection locks - no longer needed for synchronous methods
# Initialize priority queues for efficient server selection
# Each entry is (priority_score, server_index, server_reference)
# Lower priority score = higher priority (less loaded)
self.prefiller_heap = [(0, i, server)
for i, server in enumerate(self.prefillers)]
self.decoder_heap = [(0, i, server)
for i, server in enumerate(self.decoders)]
heapq.heapify(self.prefiller_heap)
heapq.heapify(self.decoder_heap)
self.req_id_future = {}
def _update_prefiller_priority(self, server_idx: int):
"""Update the priority of a prefiller server in the heap."""
server = self.prefillers[server_idx]
# Priority based on active_tokens and active_kv_cache
priority = server.active_tokens + server.active_kv_cache * 0.3
# Remove old entry and add new one
self.prefiller_heap = [(p, i, s) for p, i, s in self.prefiller_heap
if i != server_idx]
heapq.heappush(self.prefiller_heap,
(priority, server_idx, server)) # type: ignore
def _update_decoder_priority(self, server_idx: int):
"""Update the priority of a decoder server in the heap."""
server = self.decoders[server_idx]
priority = server.active_tokens
# Remove old entry and add new one
self.decoder_heap = [(p, i, s) for p, i, s in self.decoder_heap
if i != server_idx]
heapq.heappush(self.decoder_heap,
(priority, server_idx, server)) # type: ignore
def abort_prefiller_request(self, server_idx: int,
request_id): # Changed to synchronous
"""
Mark a request as aborted. This will helps to release kv cache in
prefiller node.
"""
# No lock needed - atomic operation
self.prefillers[server_idx].aborted_requests.add(request_id)
def aquire_aborted_prefiller_requests(
self, server_idx: int): # Changed to synchronous
"""
Get the set of aborted requests and clear it.
This is used to release kv cache in prefiller node.
"""
# No lock needed - atomic operation
aborted_requests = self.prefillers[server_idx].aborted_requests.copy()
self.prefillers[server_idx].aborted_requests.clear()
return aborted_requests
async def next_req_id(self):
async with self.req_id_lock:
return str(uuid.uuid4())
def select_prefiller(self, token_count): # Changed to synchronous
# No lock needed - entire function is atomic
if not self.prefiller_heap:
raise RuntimeError("No prefiller servers available")
priority, chosen, server = heapq.heappop(self.prefiller_heap)
# Update the chosen server atomically
self.prefillers[chosen].active_tokens += token_count
self.prefillers[chosen].active_kv_cache += token_count
# Update priority and re-add to heap
self._update_prefiller_priority(chosen)
return chosen
def release_prefiller(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
self.prefillers[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_prefiller_priority(idx)
def release_prefiller_kv(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
if self.prefillers[idx].active_kv_cache > 0:
self.prefillers[idx].active_kv_cache -= token_count
# Update priority queue after releasing
self._update_prefiller_priority(idx)
def select_decoder(self, token_count): # Changed to synchronous
# No lock needed - entire function is atomic
if not self.decoder_heap:
raise RuntimeError("No decoder servers available")
priority, chosen, server = heapq.heappop(self.decoder_heap)
# Update the chosen server atomically
self.decoders[chosen].active_tokens += token_count
# Update priority and re-add to heap
self._update_decoder_priority(chosen)
return chosen
def release_decoder(self, idx, token_count): # Changed to synchronous
# No lock needed - atomic operation
self.decoders[idx].active_tokens -= token_count
# Update priority queue after releasing
self._update_decoder_priority(idx)
# Omni_infer's calculate_input_scores function
def calculate_prefill_scores(self, request_length: int) -> float:
length_score = request_length / 4.0
input_score = length_score * 0.0345 + 120.0745
return input_score
def calculate_decode_scores(self, request_length: int) -> float:
return request_length
proxy_state = None
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--port", type=int, default=8000)
parser.add_argument("--host", type=str, default="localhost")
parser.add_argument("--prefiller-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--prefiller-ports",
type=int,
nargs="+",
default=[8001])
parser.add_argument("--decoder-hosts",
type=str,
nargs="+",
default=["localhost"])
parser.add_argument("--decoder-ports", type=int, nargs="+", default=[8002])
parser.add_argument("--max-retries",
type=int,
default=3,
help="Maximum number of retries for HTTP requests")
parser.add_argument(
"--retry-delay",
type=float,
default=0.001,
help="Base delay (seconds) for exponential backoff retries")
args = parser.parse_args()
if len(args.prefiller_hosts) != len(args.prefiller_ports):
raise ValueError(
"Number of prefiller hosts must match number of prefiller ports")
if len(args.decoder_hosts) != len(args.decoder_ports):
raise ValueError(
"Number of decoder hosts must match number of decoder ports")
args.prefiller_instances = list(
zip(args.prefiller_hosts, args.prefiller_ports))
args.decoder_instances = list(zip(args.decoder_hosts, args.decoder_ports))
return args
@asynccontextmanager
async def lifespan(app: FastAPI):
global proxy_state
proxy_state = ProxyState(global_args.prefiller_instances,
global_args.decoder_instances)
print(
f"Initialized {len(proxy_state.prefillers)} prefill clients and {len(proxy_state.decoders)} decode clients."
)
yield
for p in proxy_state.prefillers:
await p.client.aclose()
for d in proxy_state.decoders:
await d.client.aclose()
async def listen_for_disconnect(request: Request) -> None:
"""Return if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break
def with_cancellation(handler_func):
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
request = kwargs["request"]
handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))
done, pending = await asyncio.wait([handler_task, cancellation_task],
return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()
if handler_task in done:
return handler_task.result()
return None
return wrapper
app = FastAPI(lifespan=lifespan)
async def send_request_to_service(client: httpx.AsyncClient,
prefiller_id: int,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
aborted_requests = proxy_state.aquire_aborted_prefiller_requests(
prefiller_id)
req_data = req_data.copy()
req_data['kv_transfer_params'] = {
"do_remote_decode": True,
"do_remote_prefill": False,
"remote_engine_id": None,
"remote_block_ids": None,
"remote_host": None,
"remote_port": None,
"aborted_request": list(aborted_requests),
"metaserver": f"http://{global_args.host}:{global_args.port}/v1/metaserver"
}
req_data["stream"] = False
req_data["max_tokens"] = 1
if "stream_options" in req_data:
del req_data["stream_options"]
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
last_exc = None
for attempt in range(1, max_retries + 1):
try:
response = await client.post(endpoint,
json=req_data,
headers=headers)
response.raise_for_status()
if request_id in proxy_state.req_id_future:
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(response.json()["kv_transfer_params"])
return
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.warning(
f"Attempt {attempt} failed for {endpoint}: {str(e)}")
last_exc = e
if attempt < max_retries:
await asyncio.sleep(base_delay * (2**(attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for {endpoint}.")
raise last_exc
async def stream_service_response_with_retry(client: httpx.AsyncClient,
endpoint: str,
req_data: dict,
request_id: str,
max_retries: int = 3,
base_delay: float = 0.2):
headers = {
"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}",
"X-Request-Id": request_id
}
for attempt in range(1, max_retries + 1):
try:
async with client.stream("POST",
endpoint,
json=req_data,
headers=headers) as response:
response.raise_for_status()
first_chunk_sent = False
async for chunk in response.aiter_bytes():
first_chunk_sent = True
yield chunk
return # Success, exit after streaming
except (httpx.RequestError, httpx.HTTPStatusError) as e:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
raise e
except Exception as e:
# If any chunk has been sent, do not retry, just log and drop
if 'first_chunk_sent' in locals() and first_chunk_sent:
logger.error(
f"Streaming to client interrupted after response started: {str(e)}"
)
return
else:
if attempt < max_retries:
logger.warning(
f"Attempt {attempt} failed for streaming {endpoint}: {str(e)}"
)
await asyncio.sleep(base_delay * (2**(attempt - 1)))
else:
logger.error(
f"All {max_retries} attempts failed for streaming {endpoint}."
)
raise e
def get_api_request_id(api, req_id):
if api == "/completions":
return "cmpl-" + req_id + "-0"
elif api == "/chat/completions":
return "chatcmpl-" + req_id
async def _handle_completions(api: str, request: Request):
try:
req_data = await request.json()
req_body = await request.body()
request_length = len(req_body)
prefiller_score = proxy_state.calculate_prefill_scores(request_length)
logger.debug(
f"Request length: {request_length}, Prefiller score: {prefiller_score}"
)
request_id = await proxy_state.next_req_id()
# Select prefiller
prefiller_idx = proxy_state.select_prefiller(prefiller_score)
prefiller = proxy_state.prefillers[prefiller_idx]
result_future = asyncio.Future() # type: ignore
request_id_api = get_api_request_id(api, request_id)
proxy_state.req_id_future[request_id_api] = result_future
# Send request to prefiller
asyncio.get_running_loop().create_task(send_request_to_service(
prefiller.client,
prefiller_idx,
api,
req_data,
request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay))
proxy_state.release_prefiller(prefiller_idx, prefiller_score)
response = await result_future
del proxy_state.req_id_future[request_id_api]
req_data["kv_transfer_params"] = response
# Select decoder
decoder_score = proxy_state.calculate_decode_scores(request_length)
logger.debug("Decoder score: %f", decoder_score)
# Use the prefiller's kv_transfer_params to select decoder
decoder_idx = proxy_state.select_decoder(decoder_score)
decoder = proxy_state.decoders[decoder_idx]
logger.debug("Using %s %s", prefiller.url, decoder.url)
# Stream response from decoder
released_kv = False
async def generate_stream():
nonlocal released_kv
# Only one await per chunk, minimal logic in loop
try:
async for chunk in stream_service_response_with_retry(
decoder.client,
api,
req_data,
request_id=request_id,
max_retries=global_args.max_retries,
base_delay=global_args.retry_delay):
if not released_kv and chunk:
proxy_state.release_prefiller_kv(
prefiller_idx, prefiller_score)
released_kv = True
yield chunk
except Exception as e:
logger.error(
f"Error during streaming from decoder {decoder.url}: {str(e)} the aborted request {request_id} will be routing to the target prefiller when new request is ready to dispatch to it"
)
proxy_state.abort_prefiller_request(prefiller_idx, request_id)
proxy_state.release_prefiller_kv(prefiller_idx,
prefiller_score)
# After streaming done, release tokens
proxy_state.release_decoder(decoder_idx, decoder_score)
return StreamingResponse(generate_stream(),
media_type="application/json")
except Exception as e:
import traceback
exc_info = sys.exc_info()
print("Error occurred in disagg prefill proxy server"
f" - {api} endpoint")
print(e)
print("".join(traceback.format_exception(*exc_info)))
raise
@app.post("/v1/completions")
@with_cancellation
async def handle_completions(request: Request):
return await _handle_completions("/completions", request)
@app.post("/v1/chat/completions")
@with_cancellation
async def handle_chat_completions(request: Request):
return await _handle_completions("/chat/completions", request)
@app.get("/healthcheck")
async def healthcheck():
return {
"status": "ok",
"prefill_instances": len(proxy_state.prefillers),
"decode_instances": len(proxy_state.decoders)
}
@app.post("/v1/metaserver")
async def metaserver(request: Request):
try:
req_data = await request.json()
request_id = req_data.pop("request_id", None)
if request_id in proxy_state.req_id_future:
result_future = proxy_state.req_id_future[request_id]
result_future.set_result(req_data)
except Exception as e:
logger.error(
f"Post metaserver failed with: {str(e)}"
)
if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -544,4 +544,4 @@ if __name__ == '__main__':
global global_args
global_args = parse_args()
import uvicorn
uvicorn.run(app, host=global_args.host, port=global_args.port)
uvicorn.run(app, host=global_args.host, port=global_args.port)

View File

@ -4,8 +4,9 @@ import pytest
from vllm.config import ParallelConfig
from vllm_ascend.distributed.parallel_state import (
_LMTP, _MC2, _OTP, destroy_ascend_model_parallel, get_lmhead_tp_group,
get_mc2_group, get_otp_group, init_ascend_model_parallel)
_LMTP, _MC2, _OTP, _P_TP, destroy_ascend_model_parallel,
get_lmhead_tp_group, get_mc2_group, get_otp_group, get_p_tp_group,
init_ascend_model_parallel)
@pytest.fixture
@ -30,6 +31,7 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mock_ascend_config = MagicMock()
mock_ascend_config.lmhead_tensor_parallel_size = 2
mock_ascend_config.oproj_tensor_parallel_size = 2
mock_ascend_config.pd_tp_ratio = 2
with patch('vllm_ascend.distributed.parallel_state.model_parallel_initialized', return_value=False), \
patch('vllm_ascend.distributed.parallel_state.init_model_parallel_group'), \
patch('vllm_ascend.distributed.parallel_state.get_ascend_config', return_value=mock_ascend_config):
@ -38,11 +40,14 @@ def test_init_ascend_model_parallel(mock_distributed, parallel_config):
mc2_group = get_mc2_group()
lmheadtp_group = get_lmhead_tp_group()
otp_group = get_otp_group()
p_tp_group = get_p_tp_group()
assert mc2_group is not None
assert otp_group is not None
assert lmheadtp_group is not None
assert p_tp_group is not None
destroy_ascend_model_parallel()
assert _MC2 is None
assert _LMTP is None
assert _OTP is None
assert _P_TP is None

File diff suppressed because it is too large Load Diff

View File

@ -94,6 +94,17 @@ class AscendConfig:
raise AssertionError(
"oproj_tensor_parallel_size is only supported in pd scenario and can only be used in D node."
)
self.pd_tp_ratio = 1
if vllm_config.kv_transfer_config is not None and not vllm_config.model_config.is_deepseek_mla:
prefill_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"prefill", {"tp_size": 1})["tp_size"]
decode_tp_size = vllm_config.kv_transfer_config.get_from_extra_config(
"decode", {"tp_size": 1})["tp_size"]
pd_tp_ratio: int = prefill_tp_size // decode_tp_size
self.pd_tp_ratio = pd_tp_ratio
if self.pd_tp_ratio == 0:
raise AssertionError(
"Only support P node tp size lagger then D node tp size")
class TorchairGraphConfig:

View File

@ -31,3 +31,8 @@ KVConnectorFactory.register_connector(
"MooncakeConnectorStoreV1",
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
"MooncakeConnectorV1")
KVConnectorFactory.register_connector(
"MooncakeLayerwiseConnector",
"vllm_ascend.distributed.mooncake_layerwise_connector",
"MooncakeLayerwiseConnector")

View File

@ -1109,4 +1109,4 @@ def ensure_zmq_recv(
logger.error(f"Receive failed after all retries: {e}")
raise RuntimeError(
f"Failed to receive data after {max_retries} "
f"retries: {e}")
f"retries: {e}")

File diff suppressed because it is too large Load Diff

View File

@ -13,6 +13,7 @@ _MC2: Optional[GroupCoordinator] = None
_MLP_TP: Optional[GroupCoordinator] = None
_OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_P_TP: Optional[GroupCoordinator] = None
def get_mc2_group() -> GroupCoordinator:
@ -37,6 +38,12 @@ def get_mlp_tp_group() -> GroupCoordinator:
return _MLP_TP
def get_p_tp_group() -> GroupCoordinator:
assert _P_TP is not None, (
"distributed prefill tensor parallel group is not initialized")
return _P_TP
def model_parallel_initialized():
return (_MC2 is not None)
@ -54,6 +61,22 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
all_ranks = torch.arange(world_size).reshape(
-1, parallel_config.data_parallel_size *
parallel_config.tensor_parallel_size)
pd_tp_ratio = get_ascend_config().pd_tp_ratio
global _P_TP
assert _P_TP is None, (
"distributed prefill tensor parallel group is already initialized")
prefill_tensor_model_parallel_size = pd_tp_ratio if \
pd_tp_ratio > 0 and pd_tp_ratio < parallel_config.tensor_parallel_size else parallel_config.tensor_parallel_size
group_ranks = all_ranks.view(-1,
prefill_tensor_model_parallel_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
num = get_world_group().local_rank // pd_tp_ratio
_P_TP = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name=f"p_tp_{num}")
global _MC2
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
@ -142,3 +165,8 @@ def destroy_ascend_model_parallel():
if _OTP:
_OTP.destroy()
_OTP = None
global _P_TP
if _P_TP:
_P_TP.destroy()
_P_TP = None

View File

@ -0,0 +1,47 @@
import torch
import torch.distributed as dist
from vllm_ascend.distributed.parallel_state import get_p_tp_group
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
value: torch.TensorType):
if pd_tp_ratio <= 1:
return None, None
elif key is None or value is None:
raise ValueError("key or value is None")
k_output = alltoall_and_rearrange(pd_tp_ratio, key)
v_output = alltoall_and_rearrange(pd_tp_ratio, value)
return k_output, v_output
def alltoall_and_rearrange(tp_ratio: int, input_tensor: torch.Tensor):
num_kv_heads = input_tensor.size(1)
output_tensor = torch.zeros_like(input_tensor)
dist.all_to_all_single(output_tensor,
input_tensor,
group=get_p_tp_group().device_group)
input_tensor = 0
result = rearrange_output(output_tensor, tp_ratio, num_kv_heads)
output_tensor = 0
return result
def rearrange_output(base_output: torch.Tensor, cut_num: int,
num_kv_heads: int):
size_0 = base_output.size(0)
if size_0 % cut_num != 0:
raise ValueError(
f"The size of dim 0 [{size_0}] must be divisible by the cut_num [{cut_num}]"
)
chunk_size = size_0 // cut_num
reshaped = base_output.view(cut_num, chunk_size, -1)
transposed = reshaped.transpose(0, 1)
return transposed.contiguous().view(size_0, num_kv_heads, -1)
def align_memory(tensor: torch.Tensor, alignment: int) -> torch.Tensor:
data_ptr = tensor.data_ptr()
aligned_addr = (data_ptr + alignment - 1) // alignment * alignment
offset = (aligned_addr - data_ptr) // tensor.element_size()
return tensor[int(offset):]