mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
37 Commits
d31f7844f8
...
nixl-debug
Author | SHA1 | Date | |
---|---|---|---|
45c02abd72 | |||
d0bb3fa02c | |||
81fdcec214 | |||
f65450e3dc | |||
bd57841c7b | |||
f16bf63877 | |||
b835205d33 | |||
c22a6cb1cc | |||
7fbcbbfc45 | |||
ff5a0cfa6e | |||
56939c835d | |||
1172b70b79 | |||
15bc311d28 | |||
70b76554d1 | |||
128eca2ce3 | |||
6babd39366 | |||
491347cbc3 | |||
569de248cb | |||
f015919fc8 | |||
39e6bd19fd | |||
c4b9b2e682 | |||
17546dc79f | |||
5d8b665366 | |||
cda2f2c453 | |||
b9be6fd35a | |||
8283d7b85c | |||
c481d30c17 | |||
dedb1a5424 | |||
ee2a4b0889 | |||
f9617c75ad | |||
5d2eac70e7 | |||
fea0731cf4 | |||
9eaa81b9c9 | |||
852ee4b132 | |||
87bf6812b2 | |||
5b8c64dc77 | |||
489e5ba5ce |
387
benchmarks/benchmark_one_concurrent_req.py
Normal file
387
benchmarks/benchmark_one_concurrent_req.py
Normal file
@ -0,0 +1,387 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp # Import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from backend_request_func import RequestFuncInput, RequestFuncOutput
|
||||||
|
from benchmark_dataset import RandomDataset, SampleRequest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
except ImportError:
|
||||||
|
from backend_request_func import get_tokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkMetrics:
|
||||||
|
completed: int
|
||||||
|
total_input: int
|
||||||
|
total_output: int
|
||||||
|
mean_ttft_ms: float
|
||||||
|
median_ttft_ms: float
|
||||||
|
std_ttft_ms: float
|
||||||
|
percentiles_ttft_ms: list[tuple[float, float]]
|
||||||
|
mean_itl_ms: float
|
||||||
|
median_itl_ms: float
|
||||||
|
std_itl_ms: float
|
||||||
|
percentiles_itl_ms: list[tuple[float, float]]
|
||||||
|
mean_e2el_ms: float
|
||||||
|
median_e2el_ms: float
|
||||||
|
std_e2el_ms: float
|
||||||
|
percentiles_e2el_ms: list[tuple[float, float]]
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_cache(reset_url: str):
|
||||||
|
"""Sends a POST request to reset the prefix cache."""
|
||||||
|
logger.debug("Resetting prefix cache at %s", reset_url)
|
||||||
|
try:
|
||||||
|
async with (
|
||||||
|
aiohttp.ClientSession() as session,
|
||||||
|
session.post(reset_url) as response,
|
||||||
|
):
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||||
|
logger.debug("Prefix cache reset successful: %s", response.status)
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
logger.error(
|
||||||
|
"Cache reset request failed with status %s: %s", e.status, e.message
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("An unexpected error occurred during cache reset: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
async def sequential_benchmark(
|
||||||
|
backend: str,
|
||||||
|
api_url: str,
|
||||||
|
model_id: str,
|
||||||
|
tokenizer,
|
||||||
|
input_requests: list[SampleRequest],
|
||||||
|
request_func,
|
||||||
|
selected_percentiles: list[float],
|
||||||
|
cache_reset_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Benchmark that processes requests sequentially, waiting for each to complete
|
||||||
|
before starting the next one. Resets prefix cache between requests.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
pbar = tqdm(total=len(input_requests))
|
||||||
|
|
||||||
|
# Small request to force a forward pass.
|
||||||
|
# Used for resetting the prefix cache.
|
||||||
|
# dummy_req_input = RequestFuncInput(
|
||||||
|
# model=model_id,
|
||||||
|
# prompt="0",
|
||||||
|
# api_url=api_url,
|
||||||
|
# prompt_len=1,
|
||||||
|
# output_len=2,
|
||||||
|
# )
|
||||||
|
|
||||||
|
# print("Starting initial single prompt test run...")
|
||||||
|
# test_output = await request_func(request_func_input=dummy_req_input)
|
||||||
|
# if not test_output.success:
|
||||||
|
# raise ValueError(
|
||||||
|
# "Initial test run failed - Please check your configuration. "
|
||||||
|
# "Error: %s", test_output.error)
|
||||||
|
# else:
|
||||||
|
# print("Initial test run completed. Starting sequential benchmark...")
|
||||||
|
|
||||||
|
benchmark_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Process requests sequentially
|
||||||
|
for request in input_requests:
|
||||||
|
prompt, prompt_len, output_len = (
|
||||||
|
request.prompt,
|
||||||
|
request.prompt_len,
|
||||||
|
request.expected_output_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Sending request with len %s", request.prompt_len)
|
||||||
|
logger.debug('Request str: "%s"', request.prompt[:50])
|
||||||
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# print(f"{prompt=}")
|
||||||
|
request_func_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
output_len=output_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = await request_func(request_func_input=request_func_input)
|
||||||
|
|
||||||
|
request_end_time = time.perf_counter()
|
||||||
|
# Add timing information
|
||||||
|
if output.success and not hasattr(output, "latency"):
|
||||||
|
output.latency = request_end_time - request_start_time
|
||||||
|
logger.info("Finished request with latency %.4f s", output.latency)
|
||||||
|
|
||||||
|
outputs.append(output)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# Reset prefix cache if configured, except after the very last request
|
||||||
|
if cache_reset_url and False:
|
||||||
|
await request_func(request_func_input=dummy_req_input)
|
||||||
|
await reset_cache(cache_reset_url)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
metrics = calculate_metrics(
|
||||||
|
input_requests=input_requests,
|
||||||
|
outputs=outputs,
|
||||||
|
dur_s=benchmark_duration,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
selected_percentiles=selected_percentiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_results(metrics, benchmark_duration)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"duration": benchmark_duration,
|
||||||
|
"completed": metrics.completed,
|
||||||
|
"total_input_tokens": metrics.total_input,
|
||||||
|
"total_output_tokens": metrics.total_output,
|
||||||
|
"input_lens": [request.prompt_len for request in input_requests],
|
||||||
|
"output_lens": [
|
||||||
|
output.output_tokens if output.success else 0 for output in outputs
|
||||||
|
],
|
||||||
|
"ttfts": [output.ttft for output in outputs if output.success],
|
||||||
|
"itls": [output.itl for output in outputs if output.success],
|
||||||
|
"generated_texts": [
|
||||||
|
output.generated_text for output in outputs if output.success
|
||||||
|
],
|
||||||
|
"errors": [output.error for output in outputs if not output.success],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add summary statistics
|
||||||
|
for stat_name in ["ttft", "itl", "e2el"]:
|
||||||
|
for metric_name in ["mean", "median", "std"]:
|
||||||
|
result[f"{metric_name}_{stat_name}_ms"] = getattr(
|
||||||
|
metrics, f"{metric_name}_{stat_name}_ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
|
||||||
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
|
result[f"p{p_word}_{stat_name}_ms"] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_metrics(
|
||||||
|
input_requests: list[SampleRequest],
|
||||||
|
outputs: list[RequestFuncOutput],
|
||||||
|
dur_s: float,
|
||||||
|
tokenizer,
|
||||||
|
selected_percentiles: list[float],
|
||||||
|
) -> BenchmarkMetrics:
|
||||||
|
"""Calculate benchmark metrics from results."""
|
||||||
|
total_input = 0
|
||||||
|
completed = 0
|
||||||
|
total_output = 0
|
||||||
|
ttfts = []
|
||||||
|
itls = []
|
||||||
|
e2els = []
|
||||||
|
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
if output.success:
|
||||||
|
output_len = output.output_tokens
|
||||||
|
|
||||||
|
if not output_len:
|
||||||
|
# Use tokenizer to count output tokens if not provided
|
||||||
|
output_len = len(
|
||||||
|
tokenizer(output.generated_text, add_special_tokens=False).input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
total_output += output_len
|
||||||
|
total_input += input_requests[i].prompt_len
|
||||||
|
|
||||||
|
if hasattr(output, "ttft") and output.ttft is not None:
|
||||||
|
ttfts.append(output.ttft)
|
||||||
|
|
||||||
|
if hasattr(output, "itl") and output.itl:
|
||||||
|
# Ensure itl is a list of floats
|
||||||
|
if isinstance(output.itl, list):
|
||||||
|
itls.extend(output.itl)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Expected list for ITL but got %s. Appending as is.",
|
||||||
|
type(output.itl),
|
||||||
|
)
|
||||||
|
itls.append(output.itl)
|
||||||
|
|
||||||
|
if hasattr(output, "latency") and output.latency is not None:
|
||||||
|
e2els.append(output.latency)
|
||||||
|
|
||||||
|
completed += 1
|
||||||
|
|
||||||
|
return BenchmarkMetrics(
|
||||||
|
completed=completed,
|
||||||
|
total_input=total_input,
|
||||||
|
total_output=total_output,
|
||||||
|
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
|
||||||
|
median_ttft_ms=np.median(ttfts or [0]) * 1000,
|
||||||
|
std_ttft_ms=np.std(ttfts or [0]) * 1000,
|
||||||
|
percentiles_ttft_ms=[
|
||||||
|
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
|
||||||
|
],
|
||||||
|
mean_itl_ms=np.mean(itls or [0]) * 1000,
|
||||||
|
median_itl_ms=np.median(itls or [0]) * 1000,
|
||||||
|
std_itl_ms=np.std(itls or [0]) * 1000,
|
||||||
|
percentiles_itl_ms=[
|
||||||
|
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
|
||||||
|
],
|
||||||
|
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
|
||||||
|
median_e2el_ms=np.median(e2els or [0]) * 1000,
|
||||||
|
std_e2el_ms=np.std(e2els or [0]) * 1000,
|
||||||
|
percentiles_e2el_ms=[
|
||||||
|
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
|
||||||
|
"""Print benchmark results in a formatted way."""
|
||||||
|
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
|
||||||
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||||
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||||
|
|
||||||
|
def print_metric_stats(metric_name, header):
|
||||||
|
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
f"Mean {metric_name} (ms):",
|
||||||
|
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
f"Median {metric_name} (ms):",
|
||||||
|
getattr(metrics, f"median_{metric_name.lower()}_ms"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
|
||||||
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||||
|
|
||||||
|
print_metric_stats("TTFT", "Time to First Token")
|
||||||
|
print_metric_stats("ITL", "Inter-token Latency")
|
||||||
|
print_metric_stats("E2EL", "End-to-end Latency")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async(args):
|
||||||
|
# Import needed functions based on your setup
|
||||||
|
from backend_request_func import ASYNC_REQUEST_FUNCS
|
||||||
|
|
||||||
|
backend = args.backend
|
||||||
|
model_id = args.model
|
||||||
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||||
|
|
||||||
|
# Set up API URL
|
||||||
|
if args.base_url is not None:
|
||||||
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
|
else:
|
||||||
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
|
|
||||||
|
# Set up Cache Reset URL
|
||||||
|
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
|
||||||
|
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
|
||||||
|
|
||||||
|
# Get tokenizer
|
||||||
|
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
|
||||||
|
|
||||||
|
# Get request function
|
||||||
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
input_requests = RandomDataset().sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
prefix_len=0,
|
||||||
|
input_len=args.input_len,
|
||||||
|
output_len=args.output_len,
|
||||||
|
range_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
result = await sequential_benchmark(
|
||||||
|
backend=backend,
|
||||||
|
api_url=api_url,
|
||||||
|
model_id=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_func=request_func,
|
||||||
|
selected_percentiles=[50, 90, 95, 99],
|
||||||
|
cache_reset_url=cache_reset_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
asyncio.run(main_async(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", type=str, default="vllm", help="Backend to use for requests"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Server base URL (overrides --host and --port)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-requests", type=int, default=100, help="Number of requests to process"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len", type=int, default=128, help="Input len for generated prompts"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-len", type=int, default=None, help="Override output len for requests"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Trust remote code from HuggingFace",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
79
tools/Justfile
Normal file
79
tools/Justfile
Normal file
@ -0,0 +1,79 @@
|
|||||||
|
# Needed for the proxy server
|
||||||
|
vllm-directory := "/home/rshaw/vllm/"
|
||||||
|
|
||||||
|
# MODEL := "Qwen/Qwen3-0.6B"
|
||||||
|
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
PROXY_PORT := "8192"
|
||||||
|
PREFILL_PORT := "8100"
|
||||||
|
DECODE_PORT := "8200"
|
||||||
|
|
||||||
|
prefill:
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_PORT=5557 \
|
||||||
|
CUDA_VISIBLE_DEVICES=0,7 \
|
||||||
|
vllm serve {{MODEL}} \
|
||||||
|
--port {{PREFILL_PORT}} \
|
||||||
|
--tensor-parallel-size 2 \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--block-size 128 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||||
|
|
||||||
|
decode:
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_PORT=5567 \
|
||||||
|
CUDA_VISIBLE_DEVICES=4,5 \
|
||||||
|
vllm serve {{MODEL}} \
|
||||||
|
--port {{DECODE_PORT}} \
|
||||||
|
--tensor-parallel-size 2 \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--block-size 128 \
|
||||||
|
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||||
|
|
||||||
|
proxy:
|
||||||
|
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||||
|
--port {{PROXY_PORT}} \
|
||||||
|
--prefiller-port {{PREFILL_PORT}} \
|
||||||
|
--decoder-port {{DECODE_PORT}}
|
||||||
|
|
||||||
|
send_request:
|
||||||
|
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{ \
|
||||||
|
"model": "{{MODEL}}", \
|
||||||
|
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
|
||||||
|
"max_tokens": 150, \
|
||||||
|
"temperature": 0.7 \
|
||||||
|
}'
|
||||||
|
|
||||||
|
benchmark NUM_PROMPTS:
|
||||||
|
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
|
||||||
|
--port {{PROXY_PORT}} \
|
||||||
|
--model {{MODEL}} \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input-len 30000 \
|
||||||
|
--random-output-len 10 \
|
||||||
|
--num-prompts {{NUM_PROMPTS}} \
|
||||||
|
--seed $(date +%s) \
|
||||||
|
|
||||||
|
benchmark_one INPUT_LEN:
|
||||||
|
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||||
|
--port {{PROXY_PORT}} \
|
||||||
|
--model {{MODEL}} \
|
||||||
|
--input-len {{INPUT_LEN}} \
|
||||||
|
--output-len 1 \
|
||||||
|
--num-requests 10 \
|
||||||
|
--seed $(date +%s)
|
||||||
|
|
||||||
|
benchmark_one_no_pd INPUT_LEN:
|
||||||
|
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||||
|
--port {{DECODE_PORT}} \
|
||||||
|
--model {{MODEL}} \
|
||||||
|
--input-len {{INPUT_LEN}} \
|
||||||
|
--output-len 1 \
|
||||||
|
--num-requests 10 \
|
||||||
|
--seed $(date +%s)
|
||||||
|
|
||||||
|
eval:
|
||||||
|
lm_eval --model local-completions --tasks gsm8k \
|
||||||
|
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
|
||||||
|
--limit 1000
|
@ -37,6 +37,9 @@ if TYPE_CHECKING:
|
|||||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||||
GET_META_MSG = b"get_meta_msg"
|
GET_META_MSG = b"get_meta_msg"
|
||||||
|
|
||||||
|
import os
|
||||||
|
LOG_XFER_TIME = os.getenv("VLLM_LOG_XFER_TIME", "0") == "1"
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||||
@ -329,7 +332,17 @@ class NixlConnectorWorker:
|
|||||||
self.block_size = vllm_config.cache_config.block_size
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
|
||||||
# Agent.
|
# Agent.
|
||||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
import os
|
||||||
|
num_workers = 32
|
||||||
|
# setting num_workers on the prefiller causes no notifs to be recved???
|
||||||
|
# this is a hack to make sure we set num workers on the prefiller to 1.
|
||||||
|
if os.getenv("VLLM_IS_PREFILL", "0") == "1":
|
||||||
|
num_workers = None
|
||||||
|
print(f"NUM_WORKERS: {num_workers=}")
|
||||||
|
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()),
|
||||||
|
None,
|
||||||
|
num_workers=None,
|
||||||
|
num_shared_workers=num_workers)
|
||||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||||
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
||||||
|
|
||||||
@ -371,8 +384,8 @@ class NixlConnectorWorker:
|
|||||||
self._registered_descs: list[Any] = []
|
self._registered_descs: list[Any] = []
|
||||||
|
|
||||||
# In progress transfers.
|
# In progress transfers.
|
||||||
# [req_id -> list[handle]]
|
# [req_id -> list[handles], agent_name, notif_id]
|
||||||
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
|
||||||
|
|
||||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||||
# transactions on ranks 1 to N-1.
|
# transactions on ranks 1 to N-1.
|
||||||
@ -754,8 +767,11 @@ class NixlConnectorWorker:
|
|||||||
to Rank 0 once their transaction is done + Rank 0 returns
|
to Rank 0 once their transaction is done + Rank 0 returns
|
||||||
finished sets to Scheduler only once all ranks are done.
|
finished sets to Scheduler only once all ranks are done.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
done_sending = self._get_new_notifs()
|
done_sending = self._get_new_notifs()
|
||||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||||
|
|
||||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||||
logger.debug(
|
logger.debug(
|
||||||
"Rank %s, get_finished: %s requests done sending "
|
"Rank %s, get_finished: %s requests done sending "
|
||||||
@ -796,6 +812,10 @@ class NixlConnectorWorker:
|
|||||||
if self._done_sending_count[req_id] == self.world_size:
|
if self._done_sending_count[req_id] == self.world_size:
|
||||||
del self._done_sending_count[req_id]
|
del self._done_sending_count[req_id]
|
||||||
all_done_sending.add(req_id)
|
all_done_sending.add(req_id)
|
||||||
|
|
||||||
|
end = time.perf_counter()
|
||||||
|
if LOG_XFER_TIME:
|
||||||
|
logger.info("========== .get_finished(): %s ==========", end - start)
|
||||||
|
|
||||||
return all_done_sending, all_done_recving
|
return all_done_sending, all_done_recving
|
||||||
|
|
||||||
@ -805,6 +825,10 @@ class NixlConnectorWorker:
|
|||||||
self.tp_group.send_object(finished_req_ids, dst=0)
|
self.tp_group.send_object(finished_req_ids, dst=0)
|
||||||
|
|
||||||
# Unused as only Rank 0 results are sent to scheduler.
|
# Unused as only Rank 0 results are sent to scheduler.
|
||||||
|
end = time.perf_counter()
|
||||||
|
if LOG_XFER_TIME:
|
||||||
|
logger.info("========== .get_finished(): %s ==========", end - start)
|
||||||
|
|
||||||
return done_sending, done_recving
|
return done_sending, done_recving
|
||||||
|
|
||||||
def _get_new_notifs(self) -> set[str]:
|
def _get_new_notifs(self) -> set[str]:
|
||||||
@ -826,7 +850,8 @@ class NixlConnectorWorker:
|
|||||||
return notified_req_ids
|
return notified_req_ids
|
||||||
|
|
||||||
def _pop_done_transfers(
|
def _pop_done_transfers(
|
||||||
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
|
self, transfers: dict[str, tuple[list[int], str,
|
||||||
|
str]]) -> set[str]:
|
||||||
"""
|
"""
|
||||||
Pop completed xfers by checking for DONE state.
|
Pop completed xfers by checking for DONE state.
|
||||||
Args:
|
Args:
|
||||||
@ -834,19 +859,30 @@ class NixlConnectorWorker:
|
|||||||
Returns:
|
Returns:
|
||||||
set of req_ids that have all done xfers
|
set of req_ids that have all done xfers
|
||||||
"""
|
"""
|
||||||
done_req_ids: set[str] = set()
|
done_req_ids: set[str, float] = set()
|
||||||
for req_id, handles in list(transfers.items()):
|
for req_id, (handles, agent_name, notif_id, start_time) in list(transfers.items()):
|
||||||
for handle, xfer_stime in handles:
|
new_handles = []
|
||||||
|
for handle in handles:
|
||||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||||
if xfer_state == "DONE":
|
if xfer_state == "DONE":
|
||||||
self.nixl_wrapper.release_xfer_handle(handle)
|
self.nixl_wrapper.release_xfer_handle(handle)
|
||||||
done_req_ids.add(req_id)
|
|
||||||
del transfers[req_id]
|
|
||||||
elif xfer_state == "PROC":
|
elif xfer_state == "PROC":
|
||||||
continue
|
new_handles.append(handle)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Transfer failed with state %s",
|
raise RuntimeError("Transfer failed with state %s",
|
||||||
xfer_state)
|
xfer_state)
|
||||||
|
|
||||||
|
# Done.
|
||||||
|
if len(new_handles) == 0:
|
||||||
|
self.nixl_wrapper.send_notif(agent_name, notif_id)
|
||||||
|
del transfers[req_id]
|
||||||
|
done_req_ids.add(req_id)
|
||||||
|
if LOG_XFER_TIME:
|
||||||
|
logger.info("========== transmission time: %s ==========", time.perf_counter() - start_time)
|
||||||
|
|
||||||
|
else:
|
||||||
|
transfers[req_id] = (new_handles, agent_name, notif_id, start_time)
|
||||||
|
|
||||||
return done_req_ids
|
return done_req_ids
|
||||||
|
|
||||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||||
@ -958,22 +994,35 @@ class NixlConnectorWorker:
|
|||||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||||
|
|
||||||
# Prepare transfer with Nixl.
|
# Prepare transfer with Nixl.
|
||||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
CHUNK_SIZE = 500
|
||||||
"READ",
|
handles = []
|
||||||
local_xfer_side_handle,
|
# NOTE: this is a hack to make make_prepped_xfer into threads so that
|
||||||
local_block_descs_ids,
|
# different workers are allocated for each chuck. Without this change,
|
||||||
remote_xfer_side_handle,
|
# nixl was allocating the same worker (0) for all the chunks and the
|
||||||
remote_block_descs_ids,
|
# overall launch time was >300 ms.
|
||||||
notif_msg=notif_id,
|
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
|
||||||
)
|
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||||
|
"READ",
|
||||||
|
local_xfer_side_handle,
|
||||||
|
local_block_descs_ids[i:i + CHUNK_SIZE],
|
||||||
|
remote_xfer_side_handle,
|
||||||
|
remote_block_descs_ids[i:i + CHUNK_SIZE],
|
||||||
|
skip_desc_merge=True,
|
||||||
|
)
|
||||||
|
handles.append(handle)
|
||||||
|
|
||||||
# Begin async xfer.
|
# Begin async xfer.
|
||||||
self.nixl_wrapper.transfer(handle)
|
start = time.perf_counter()
|
||||||
|
self.nixl_wrapper.transfer_batched(handles)
|
||||||
|
end = time.perf_counter()
|
||||||
|
if LOG_XFER_TIME:
|
||||||
|
logger.info("========== .transfer_batched(): %s ==========", end - start)
|
||||||
|
|
||||||
# Use handle to check completion in future step().
|
# Keep track of ongoing transfers.
|
||||||
# TODO (NickLucche) surface xfer elapsed time
|
remote_rank = self.tp_rank // tp_ratio
|
||||||
self._recving_transfers[request_id].append(
|
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||||
(handle, time.perf_counter()))
|
assert request_id not in self._recving_transfers
|
||||||
|
self._recving_transfers[request_id] = (handles, agent_name, notif_id, time.perf_counter())
|
||||||
|
|
||||||
def _get_block_descs_ids(self,
|
def _get_block_descs_ids(self,
|
||||||
engine_id: str,
|
engine_id: str,
|
||||||
|
@ -75,7 +75,7 @@ enable_hf_transfer()
|
|||||||
class DisabledTqdm(tqdm):
|
class DisabledTqdm(tqdm):
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs, disable=True)
|
super().__init__(*args, **kwargs, disable=False)
|
||||||
|
|
||||||
|
|
||||||
def get_lock(model_name_or_path: Union[str, Path],
|
def get_lock(model_name_or_path: Union[str, Path],
|
||||||
|
Reference in New Issue
Block a user