mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
37 Commits
v0.11.0rc1
...
nixl-debug
Author | SHA1 | Date | |
---|---|---|---|
45c02abd72 | |||
d0bb3fa02c | |||
81fdcec214 | |||
f65450e3dc | |||
bd57841c7b | |||
f16bf63877 | |||
b835205d33 | |||
c22a6cb1cc | |||
7fbcbbfc45 | |||
ff5a0cfa6e | |||
56939c835d | |||
1172b70b79 | |||
15bc311d28 | |||
70b76554d1 | |||
128eca2ce3 | |||
6babd39366 | |||
491347cbc3 | |||
569de248cb | |||
f015919fc8 | |||
39e6bd19fd | |||
c4b9b2e682 | |||
17546dc79f | |||
5d8b665366 | |||
cda2f2c453 | |||
b9be6fd35a | |||
8283d7b85c | |||
c481d30c17 | |||
dedb1a5424 | |||
ee2a4b0889 | |||
f9617c75ad | |||
5d2eac70e7 | |||
fea0731cf4 | |||
9eaa81b9c9 | |||
852ee4b132 | |||
87bf6812b2 | |||
5b8c64dc77 | |||
489e5ba5ce |
387
benchmarks/benchmark_one_concurrent_req.py
Normal file
387
benchmarks/benchmark_one_concurrent_req.py
Normal file
@ -0,0 +1,387 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp # Import aiohttp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from backend_request_func import RequestFuncInput, RequestFuncOutput
|
||||
from benchmark_dataset import RandomDataset, SampleRequest
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
except ImportError:
|
||||
from backend_request_func import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
total_output: int
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
percentiles_ttft_ms: list[tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
percentiles_itl_ms: list[tuple[float, float]]
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
async def reset_cache(reset_url: str):
|
||||
"""Sends a POST request to reset the prefix cache."""
|
||||
logger.debug("Resetting prefix cache at %s", reset_url)
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(reset_url) as response,
|
||||
):
|
||||
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||
logger.debug("Prefix cache reset successful: %s", response.status)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
|
||||
except aiohttp.ClientResponseError as e:
|
||||
logger.error(
|
||||
"Cache reset request failed with status %s: %s", e.status, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("An unexpected error occurred during cache reset: %s", e)
|
||||
|
||||
|
||||
async def sequential_benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer,
|
||||
input_requests: list[SampleRequest],
|
||||
request_func,
|
||||
selected_percentiles: list[float],
|
||||
cache_reset_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Benchmark that processes requests sequentially, waiting for each to complete
|
||||
before starting the next one. Resets prefix cache between requests.
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
pbar = tqdm(total=len(input_requests))
|
||||
|
||||
# Small request to force a forward pass.
|
||||
# Used for resetting the prefix cache.
|
||||
# dummy_req_input = RequestFuncInput(
|
||||
# model=model_id,
|
||||
# prompt="0",
|
||||
# api_url=api_url,
|
||||
# prompt_len=1,
|
||||
# output_len=2,
|
||||
# )
|
||||
|
||||
# print("Starting initial single prompt test run...")
|
||||
# test_output = await request_func(request_func_input=dummy_req_input)
|
||||
# if not test_output.success:
|
||||
# raise ValueError(
|
||||
# "Initial test run failed - Please check your configuration. "
|
||||
# "Error: %s", test_output.error)
|
||||
# else:
|
||||
# print("Initial test run completed. Starting sequential benchmark...")
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
# Process requests sequentially
|
||||
for request in input_requests:
|
||||
prompt, prompt_len, output_len = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
)
|
||||
|
||||
logger.info("Sending request with len %s", request.prompt_len)
|
||||
logger.debug('Request str: "%s"', request.prompt[:50])
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
# print(f"{prompt=}")
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
|
||||
output = await request_func(request_func_input=request_func_input)
|
||||
|
||||
request_end_time = time.perf_counter()
|
||||
# Add timing information
|
||||
if output.success and not hasattr(output, "latency"):
|
||||
output.latency = request_end_time - request_start_time
|
||||
logger.info("Finished request with latency %.4f s", output.latency)
|
||||
|
||||
outputs.append(output)
|
||||
pbar.update(1)
|
||||
|
||||
# Reset prefix cache if configured, except after the very last request
|
||||
if cache_reset_url and False:
|
||||
await request_func(request_func_input=dummy_req_input)
|
||||
await reset_cache(cache_reset_url)
|
||||
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
# Calculate metrics
|
||||
metrics = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
|
||||
print_results(metrics, benchmark_duration)
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"input_lens": [request.prompt_len for request in input_requests],
|
||||
"output_lens": [
|
||||
output.output_tokens if output.success else 0 for output in outputs
|
||||
],
|
||||
"ttfts": [output.ttft for output in outputs if output.success],
|
||||
"itls": [output.itl for output in outputs if output.success],
|
||||
"generated_texts": [
|
||||
output.generated_text for output in outputs if output.success
|
||||
],
|
||||
"errors": [output.error for output in outputs if not output.success],
|
||||
}
|
||||
|
||||
# Add summary statistics
|
||||
for stat_name in ["ttft", "itl", "e2el"]:
|
||||
for metric_name in ["mean", "median", "std"]:
|
||||
result[f"{metric_name}_{stat_name}_ms"] = getattr(
|
||||
metrics, f"{metric_name}_{stat_name}_ms"
|
||||
)
|
||||
|
||||
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
result[f"p{p_word}_{stat_name}_ms"] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer,
|
||||
selected_percentiles: list[float],
|
||||
) -> BenchmarkMetrics:
|
||||
"""Calculate benchmark metrics from results."""
|
||||
total_input = 0
|
||||
completed = 0
|
||||
total_output = 0
|
||||
ttfts = []
|
||||
itls = []
|
||||
e2els = []
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
if output.success:
|
||||
output_len = output.output_tokens
|
||||
|
||||
if not output_len:
|
||||
# Use tokenizer to count output tokens if not provided
|
||||
output_len = len(
|
||||
tokenizer(output.generated_text, add_special_tokens=False).input_ids
|
||||
)
|
||||
|
||||
total_output += output_len
|
||||
total_input += input_requests[i].prompt_len
|
||||
|
||||
if hasattr(output, "ttft") and output.ttft is not None:
|
||||
ttfts.append(output.ttft)
|
||||
|
||||
if hasattr(output, "itl") and output.itl:
|
||||
# Ensure itl is a list of floats
|
||||
if isinstance(output.itl, list):
|
||||
itls.extend(output.itl)
|
||||
else:
|
||||
logger.warning(
|
||||
"Expected list for ITL but got %s. Appending as is.",
|
||||
type(output.itl),
|
||||
)
|
||||
itls.append(output.itl)
|
||||
|
||||
if hasattr(output, "latency") and output.latency is not None:
|
||||
e2els.append(output.latency)
|
||||
|
||||
completed += 1
|
||||
|
||||
return BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=total_output,
|
||||
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or [0]) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or [0]) * 1000,
|
||||
percentiles_ttft_ms=[
|
||||
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_itl_ms=np.mean(itls or [0]) * 1000,
|
||||
median_itl_ms=np.median(itls or [0]) * 1000,
|
||||
std_itl_ms=np.std(itls or [0]) * 1000,
|
||||
percentiles_itl_ms=[
|
||||
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
|
||||
median_e2el_ms=np.median(e2els or [0]) * 1000,
|
||||
std_e2el_ms=np.std(e2els or [0]) * 1000,
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
|
||||
"""Print benchmark results in a formatted way."""
|
||||
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||
|
||||
def print_metric_stats(metric_name, header):
|
||||
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_name.lower()}_ms"),
|
||||
)
|
||||
)
|
||||
|
||||
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||
|
||||
print_metric_stats("TTFT", "Time to First Token")
|
||||
print_metric_stats("ITL", "Inter-token Latency")
|
||||
print_metric_stats("E2EL", "End-to-end Latency")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async def main_async(args):
|
||||
# Import needed functions based on your setup
|
||||
from backend_request_func import ASYNC_REQUEST_FUNCS
|
||||
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
# Set up API URL
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
|
||||
# Set up Cache Reset URL
|
||||
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
|
||||
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
|
||||
|
||||
# Get tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
# Get request function
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
input_requests = RandomDataset().sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_requests,
|
||||
prefix_len=0,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
range_ratio=0.0,
|
||||
)
|
||||
|
||||
# Run benchmark
|
||||
result = await sequential_benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_func=request_func,
|
||||
selected_percentiles=[50, 90, 95, 99],
|
||||
cache_reset_url=cache_reset_url,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
asyncio.run(main_async(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
|
||||
parser.add_argument(
|
||||
"--backend", type=str, default="vllm", help="Backend to use for requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server base URL (overrides --host and --port)",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
|
||||
)
|
||||
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
||||
parser.add_argument(
|
||||
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-requests", type=int, default=100, help="Number of requests to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, default=128, help="Input len for generated prompts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, default=None, help="Override output len for requests"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code from HuggingFace",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
79
tools/Justfile
Normal file
79
tools/Justfile
Normal file
@ -0,0 +1,79 @@
|
||||
# Needed for the proxy server
|
||||
vllm-directory := "/home/rshaw/vllm/"
|
||||
|
||||
# MODEL := "Qwen/Qwen3-0.6B"
|
||||
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
|
||||
PROXY_PORT := "8192"
|
||||
PREFILL_PORT := "8100"
|
||||
DECODE_PORT := "8200"
|
||||
|
||||
prefill:
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5557 \
|
||||
CUDA_VISIBLE_DEVICES=0,7 \
|
||||
vllm serve {{MODEL}} \
|
||||
--port {{PREFILL_PORT}} \
|
||||
--tensor-parallel-size 2 \
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--block-size 128 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
|
||||
decode:
|
||||
VLLM_NIXL_SIDE_CHANNEL_PORT=5567 \
|
||||
CUDA_VISIBLE_DEVICES=4,5 \
|
||||
vllm serve {{MODEL}} \
|
||||
--port {{DECODE_PORT}} \
|
||||
--tensor-parallel-size 2 \
|
||||
--enforce-eager \
|
||||
--disable-log-requests \
|
||||
--block-size 128 \
|
||||
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||
|
||||
proxy:
|
||||
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||
--port {{PROXY_PORT}} \
|
||||
--prefiller-port {{PREFILL_PORT}} \
|
||||
--decoder-port {{DECODE_PORT}}
|
||||
|
||||
send_request:
|
||||
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{ \
|
||||
"model": "{{MODEL}}", \
|
||||
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
|
||||
"max_tokens": 150, \
|
||||
"temperature": 0.7 \
|
||||
}'
|
||||
|
||||
benchmark NUM_PROMPTS:
|
||||
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
|
||||
--port {{PROXY_PORT}} \
|
||||
--model {{MODEL}} \
|
||||
--dataset-name random \
|
||||
--random-input-len 30000 \
|
||||
--random-output-len 10 \
|
||||
--num-prompts {{NUM_PROMPTS}} \
|
||||
--seed $(date +%s) \
|
||||
|
||||
benchmark_one INPUT_LEN:
|
||||
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||
--port {{PROXY_PORT}} \
|
||||
--model {{MODEL}} \
|
||||
--input-len {{INPUT_LEN}} \
|
||||
--output-len 1 \
|
||||
--num-requests 10 \
|
||||
--seed $(date +%s)
|
||||
|
||||
benchmark_one_no_pd INPUT_LEN:
|
||||
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||
--port {{DECODE_PORT}} \
|
||||
--model {{MODEL}} \
|
||||
--input-len {{INPUT_LEN}} \
|
||||
--output-len 1 \
|
||||
--num-requests 10 \
|
||||
--seed $(date +%s)
|
||||
|
||||
eval:
|
||||
lm_eval --model local-completions --tasks gsm8k \
|
||||
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
|
||||
--limit 1000
|
@ -37,6 +37,9 @@ if TYPE_CHECKING:
|
||||
Transfer = tuple[int, float] # (xfer_handle, start_time)
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
import os
|
||||
LOG_XFER_TIME = os.getenv("VLLM_LOG_XFER_TIME", "0") == "1"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
@ -329,7 +332,17 @@ class NixlConnectorWorker:
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Agent.
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
import os
|
||||
num_workers = 32
|
||||
# setting num_workers on the prefiller causes no notifs to be recved???
|
||||
# this is a hack to make sure we set num workers on the prefiller to 1.
|
||||
if os.getenv("VLLM_IS_PREFILL", "0") == "1":
|
||||
num_workers = None
|
||||
print(f"NUM_WORKERS: {num_workers=}")
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()),
|
||||
None,
|
||||
num_workers=None,
|
||||
num_shared_workers=num_workers)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
@ -371,8 +384,8 @@ class NixlConnectorWorker:
|
||||
self._registered_descs: list[Any] = []
|
||||
|
||||
# In progress transfers.
|
||||
# [req_id -> list[handle]]
|
||||
self._recving_transfers = defaultdict[str, list[Transfer]](list)
|
||||
# [req_id -> list[handles], agent_name, notif_id]
|
||||
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
|
||||
|
||||
# Complete transfer tracker. Used by the rank 0 to track finished
|
||||
# transactions on ranks 1 to N-1.
|
||||
@ -754,8 +767,11 @@ class NixlConnectorWorker:
|
||||
to Rank 0 once their transaction is done + Rank 0 returns
|
||||
finished sets to Scheduler only once all ranks are done.
|
||||
"""
|
||||
|
||||
start = time.perf_counter()
|
||||
done_sending = self._get_new_notifs()
|
||||
done_recving = self._pop_done_transfers(self._recving_transfers)
|
||||
|
||||
if len(done_sending) > 0 or len(done_recving) > 0:
|
||||
logger.debug(
|
||||
"Rank %s, get_finished: %s requests done sending "
|
||||
@ -796,6 +812,10 @@ class NixlConnectorWorker:
|
||||
if self._done_sending_count[req_id] == self.world_size:
|
||||
del self._done_sending_count[req_id]
|
||||
all_done_sending.add(req_id)
|
||||
|
||||
end = time.perf_counter()
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== .get_finished(): %s ==========", end - start)
|
||||
|
||||
return all_done_sending, all_done_recving
|
||||
|
||||
@ -805,6 +825,10 @@ class NixlConnectorWorker:
|
||||
self.tp_group.send_object(finished_req_ids, dst=0)
|
||||
|
||||
# Unused as only Rank 0 results are sent to scheduler.
|
||||
end = time.perf_counter()
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== .get_finished(): %s ==========", end - start)
|
||||
|
||||
return done_sending, done_recving
|
||||
|
||||
def _get_new_notifs(self) -> set[str]:
|
||||
@ -826,7 +850,8 @@ class NixlConnectorWorker:
|
||||
return notified_req_ids
|
||||
|
||||
def _pop_done_transfers(
|
||||
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
|
||||
self, transfers: dict[str, tuple[list[int], str,
|
||||
str]]) -> set[str]:
|
||||
"""
|
||||
Pop completed xfers by checking for DONE state.
|
||||
Args:
|
||||
@ -834,19 +859,30 @@ class NixlConnectorWorker:
|
||||
Returns:
|
||||
set of req_ids that have all done xfers
|
||||
"""
|
||||
done_req_ids: set[str] = set()
|
||||
for req_id, handles in list(transfers.items()):
|
||||
for handle, xfer_stime in handles:
|
||||
done_req_ids: set[str, float] = set()
|
||||
for req_id, (handles, agent_name, notif_id, start_time) in list(transfers.items()):
|
||||
new_handles = []
|
||||
for handle in handles:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
done_req_ids.add(req_id)
|
||||
del transfers[req_id]
|
||||
elif xfer_state == "PROC":
|
||||
continue
|
||||
new_handles.append(handle)
|
||||
else:
|
||||
raise RuntimeError("Transfer failed with state %s",
|
||||
xfer_state)
|
||||
|
||||
# Done.
|
||||
if len(new_handles) == 0:
|
||||
self.nixl_wrapper.send_notif(agent_name, notif_id)
|
||||
del transfers[req_id]
|
||||
done_req_ids.add(req_id)
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== transmission time: %s ==========", time.perf_counter() - start_time)
|
||||
|
||||
else:
|
||||
transfers[req_id] = (new_handles, agent_name, notif_id, start_time)
|
||||
|
||||
return done_req_ids
|
||||
|
||||
def start_load_kv(self, metadata: NixlConnectorMetadata):
|
||||
@ -958,22 +994,35 @@ class NixlConnectorWorker:
|
||||
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
|
||||
|
||||
# Prepare transfer with Nixl.
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids,
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=notif_id,
|
||||
)
|
||||
CHUNK_SIZE = 500
|
||||
handles = []
|
||||
# NOTE: this is a hack to make make_prepped_xfer into threads so that
|
||||
# different workers are allocated for each chuck. Without this change,
|
||||
# nixl was allocating the same worker (0) for all the chunks and the
|
||||
# overall launch time was >300 ms.
|
||||
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
|
||||
handle = self.nixl_wrapper.make_prepped_xfer(
|
||||
"READ",
|
||||
local_xfer_side_handle,
|
||||
local_block_descs_ids[i:i + CHUNK_SIZE],
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids[i:i + CHUNK_SIZE],
|
||||
skip_desc_merge=True,
|
||||
)
|
||||
handles.append(handle)
|
||||
|
||||
# Begin async xfer.
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
start = time.perf_counter()
|
||||
self.nixl_wrapper.transfer_batched(handles)
|
||||
end = time.perf_counter()
|
||||
if LOG_XFER_TIME:
|
||||
logger.info("========== .transfer_batched(): %s ==========", end - start)
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
# TODO (NickLucche) surface xfer elapsed time
|
||||
self._recving_transfers[request_id].append(
|
||||
(handle, time.perf_counter()))
|
||||
# Keep track of ongoing transfers.
|
||||
remote_rank = self.tp_rank // tp_ratio
|
||||
agent_name = self._remote_agents[dst_engine_id][remote_rank]
|
||||
assert request_id not in self._recving_transfers
|
||||
self._recving_transfers[request_id] = (handles, agent_name, notif_id, time.perf_counter())
|
||||
|
||||
def _get_block_descs_ids(self,
|
||||
engine_id: str,
|
||||
|
@ -75,7 +75,7 @@ enable_hf_transfer()
|
||||
class DisabledTqdm(tqdm):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs, disable=True)
|
||||
super().__init__(*args, **kwargs, disable=False)
|
||||
|
||||
|
||||
def get_lock(model_name_or_path: Union[str, Path],
|
||||
|
Reference in New Issue
Block a user