Compare commits

...

37 Commits

Author SHA1 Message Date
45c02abd72 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-11 00:57:50 +00:00
d0bb3fa02c fix
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-10 13:54:33 +00:00
81fdcec214 added logging
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-10 13:32:28 +00:00
f65450e3dc updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-07 01:27:40 +00:00
bd57841c7b updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-07 01:14:10 +00:00
f16bf63877 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-07 01:13:20 +00:00
b835205d33 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-07 00:32:42 +00:00
c22a6cb1cc cleanup
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-07 00:30:51 +00:00
7fbcbbfc45 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-07-01 03:15:16 +00:00
ff5a0cfa6e updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-07-01 02:49:54 +00:00
56939c835d updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-07-01 01:34:46 +00:00
1172b70b79 updated vllm
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-07-01 00:16:07 +00:00
15bc311d28 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 20:09:12 +00:00
70b76554d1 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 20:01:56 +00:00
128eca2ce3 update for use batched
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 19:48:33 +00:00
6babd39366 print out
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 19:30:14 +00:00
491347cbc3 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 13:42:36 +00:00
569de248cb cleanup
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 12:36:19 +00:00
f015919fc8 add comment about hack
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 12:25:48 +00:00
39e6bd19fd Merge pull request #17 from praveingk/batching
Load balance across multiple workers
2025-06-30 08:21:03 -04:00
c4b9b2e682 Increase chunk size to reduce no. of threads 2025-06-30 15:03:52 +05:30
17546dc79f Add threading for load-balancing to different workers 2025-06-30 14:40:18 +05:30
5d8b665366 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:59:02 +00:00
cda2f2c453 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:54:43 +00:00
b9be6fd35a updated to make send_notif work
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:51:37 +00:00
8283d7b85c updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:45:03 +00:00
c481d30c17 update
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:39:15 +00:00
dedb1a5424 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:30:06 +00:00
ee2a4b0889 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-30 01:11:22 +00:00
f9617c75ad updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-27 18:48:05 +00:00
5d2eac70e7 update
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-27 15:12:03 +00:00
fea0731cf4 update
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-27 15:11:23 +00:00
9eaa81b9c9 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-19 13:18:39 +00:00
852ee4b132 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-19 13:16:50 +00:00
87bf6812b2 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-19 13:15:50 +00:00
5b8c64dc77 updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-19 13:12:43 +00:00
489e5ba5ce updated
Signed-off-by: rshaw@neuralmagic.com <robertgshaw2@gmail.com>
2025-06-19 13:10:52 +00:00
4 changed files with 539 additions and 24 deletions

View File

@ -0,0 +1,387 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import random
import time
from dataclasses import dataclass
from typing import Optional
import aiohttp # Import aiohttp
import numpy as np
from tqdm import tqdm
from backend_request_func import RequestFuncInput, RequestFuncOutput
from benchmark_dataset import RandomDataset, SampleRequest
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
logger = logging.getLogger(__name__)
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
percentiles_ttft_ms: list[tuple[float, float]]
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
percentiles_itl_ms: list[tuple[float, float]]
mean_e2el_ms: float
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
async def reset_cache(reset_url: str):
"""Sends a POST request to reset the prefix cache."""
logger.debug("Resetting prefix cache at %s", reset_url)
try:
async with (
aiohttp.ClientSession() as session,
session.post(reset_url) as response,
):
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
logger.debug("Prefix cache reset successful: %s", response.status)
except aiohttp.ClientConnectorError as e:
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
except aiohttp.ClientResponseError as e:
logger.error(
"Cache reset request failed with status %s: %s", e.status, e.message
)
except Exception as e:
logger.error("An unexpected error occurred during cache reset: %s", e)
async def sequential_benchmark(
backend: str,
api_url: str,
model_id: str,
tokenizer,
input_requests: list[SampleRequest],
request_func,
selected_percentiles: list[float],
cache_reset_url: Optional[str] = None,
):
"""
Benchmark that processes requests sequentially, waiting for each to complete
before starting the next one. Resets prefix cache between requests.
"""
outputs = []
pbar = tqdm(total=len(input_requests))
# Small request to force a forward pass.
# Used for resetting the prefix cache.
# dummy_req_input = RequestFuncInput(
# model=model_id,
# prompt="0",
# api_url=api_url,
# prompt_len=1,
# output_len=2,
# )
# print("Starting initial single prompt test run...")
# test_output = await request_func(request_func_input=dummy_req_input)
# if not test_output.success:
# raise ValueError(
# "Initial test run failed - Please check your configuration. "
# "Error: %s", test_output.error)
# else:
# print("Initial test run completed. Starting sequential benchmark...")
benchmark_start_time = time.perf_counter()
# Process requests sequentially
for request in input_requests:
prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.expected_output_len,
)
logger.info("Sending request with len %s", request.prompt_len)
logger.debug('Request str: "%s"', request.prompt[:50])
request_start_time = time.perf_counter()
# print(f"{prompt=}")
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
)
output = await request_func(request_func_input=request_func_input)
request_end_time = time.perf_counter()
# Add timing information
if output.success and not hasattr(output, "latency"):
output.latency = request_end_time - request_start_time
logger.info("Finished request with latency %.4f s", output.latency)
outputs.append(output)
pbar.update(1)
# Reset prefix cache if configured, except after the very last request
if cache_reset_url and False:
await request_func(request_func_input=dummy_req_input)
await reset_cache(cache_reset_url)
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
# Calculate metrics
metrics = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
)
print_results(metrics, benchmark_duration)
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"input_lens": [request.prompt_len for request in input_requests],
"output_lens": [
output.output_tokens if output.success else 0 for output in outputs
],
"ttfts": [output.ttft for output in outputs if output.success],
"itls": [output.itl for output in outputs if output.success],
"generated_texts": [
output.generated_text for output in outputs if output.success
],
"errors": [output.error for output in outputs if not output.success],
}
# Add summary statistics
for stat_name in ["ttft", "itl", "e2el"]:
for metric_name in ["mean", "median", "std"]:
result[f"{metric_name}_{stat_name}_ms"] = getattr(
metrics, f"{metric_name}_{stat_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
result[f"p{p_word}_{stat_name}_ms"] = value
return result
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer,
selected_percentiles: list[float],
) -> BenchmarkMetrics:
"""Calculate benchmark metrics from results."""
total_input = 0
completed = 0
total_output = 0
ttfts = []
itls = []
e2els = []
for i, output in enumerate(outputs):
if output.success:
output_len = output.output_tokens
if not output_len:
# Use tokenizer to count output tokens if not provided
output_len = len(
tokenizer(output.generated_text, add_special_tokens=False).input_ids
)
total_output += output_len
total_input += input_requests[i].prompt_len
if hasattr(output, "ttft") and output.ttft is not None:
ttfts.append(output.ttft)
if hasattr(output, "itl") and output.itl:
# Ensure itl is a list of floats
if isinstance(output.itl, list):
itls.extend(output.itl)
else:
logger.warning(
"Expected list for ITL but got %s. Appending as is.",
type(output.itl),
)
itls.append(output.itl)
if hasattr(output, "latency") and output.latency is not None:
e2els.append(output.latency)
completed += 1
return BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=total_output,
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
median_ttft_ms=np.median(ttfts or [0]) * 1000,
std_ttft_ms=np.std(ttfts or [0]) * 1000,
percentiles_ttft_ms=[
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or [0]) * 1000,
median_itl_ms=np.median(itls or [0]) * 1000,
std_itl_ms=np.std(itls or [0]) * 1000,
percentiles_itl_ms=[
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
median_e2el_ms=np.median(e2els or [0]) * 1000,
std_e2el_ms=np.std(e2els or [0]) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
],
)
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
"""Print benchmark results in a formatted way."""
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
def print_metric_stats(metric_name, header):
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_name.lower()}_ms"),
)
)
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
print_metric_stats("TTFT", "Time to First Token")
print_metric_stats("ITL", "Inter-token Latency")
print_metric_stats("E2EL", "End-to-end Latency")
print("=" * 60)
async def main_async(args):
# Import needed functions based on your setup
from backend_request_func import ASYNC_REQUEST_FUNCS
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
# Set up API URL
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
# Set up Cache Reset URL
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
# Get tokenizer
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
# Get request function
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
input_requests = RandomDataset().sample(
tokenizer=tokenizer,
num_requests=args.num_requests,
prefix_len=0,
input_len=args.input_len,
output_len=args.output_len,
range_ratio=0.0,
)
# Run benchmark
result = await sequential_benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_func=request_func,
selected_percentiles=[50, 90, 95, 99],
cache_reset_url=cache_reset_url,
)
return result
def main(args):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
asyncio.run(main_async(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
parser.add_argument(
"--backend", type=str, default="vllm", help="Backend to use for requests"
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server base URL (overrides --host and --port)",
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
)
parser.add_argument("--model", type=str, required=True, help="Name of the model")
parser.add_argument(
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
)
parser.add_argument(
"--num-requests", type=int, default=100, help="Number of requests to process"
)
parser.add_argument(
"--input-len", type=int, default=128, help="Input len for generated prompts"
)
parser.add_argument(
"--output-len", type=int, default=None, help="Override output len for requests"
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from HuggingFace",
)
args = parser.parse_args()
main(args)

79
tools/Justfile Normal file
View File

@ -0,0 +1,79 @@
# Needed for the proxy server
vllm-directory := "/home/rshaw/vllm/"
# MODEL := "Qwen/Qwen3-0.6B"
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
PROXY_PORT := "8192"
PREFILL_PORT := "8100"
DECODE_PORT := "8200"
prefill:
VLLM_NIXL_SIDE_CHANNEL_PORT=5557 \
CUDA_VISIBLE_DEVICES=0,7 \
vllm serve {{MODEL}} \
--port {{PREFILL_PORT}} \
--tensor-parallel-size 2 \
--enforce-eager \
--disable-log-requests \
--block-size 128 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
decode:
VLLM_NIXL_SIDE_CHANNEL_PORT=5567 \
CUDA_VISIBLE_DEVICES=4,5 \
vllm serve {{MODEL}} \
--port {{DECODE_PORT}} \
--tensor-parallel-size 2 \
--enforce-eager \
--disable-log-requests \
--block-size 128 \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
proxy:
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
--port {{PROXY_PORT}} \
--prefiller-port {{PREFILL_PORT}} \
--decoder-port {{DECODE_PORT}}
send_request:
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
-H "Content-Type: application/json" \
-d '{ \
"model": "{{MODEL}}", \
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
"max_tokens": 150, \
"temperature": 0.7 \
}'
benchmark NUM_PROMPTS:
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--dataset-name random \
--random-input-len 30000 \
--random-output-len 10 \
--num-prompts {{NUM_PROMPTS}} \
--seed $(date +%s) \
benchmark_one INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
benchmark_one_no_pd INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
--port {{DECODE_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
eval:
lm_eval --model local-completions --tasks gsm8k \
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
--limit 1000

View File

@ -37,6 +37,9 @@ if TYPE_CHECKING:
Transfer = tuple[int, float] # (xfer_handle, start_time)
GET_META_MSG = b"get_meta_msg"
import os
LOG_XFER_TIME = os.getenv("VLLM_LOG_XFER_TIME", "0") == "1"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
@ -329,7 +332,17 @@ class NixlConnectorWorker:
self.block_size = vllm_config.cache_config.block_size
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
import os
num_workers = 32
# setting num_workers on the prefiller causes no notifs to be recved???
# this is a hack to make sure we set num workers on the prefiller to 1.
if os.getenv("VLLM_IS_PREFILL", "0") == "1":
num_workers = None
print(f"NUM_WORKERS: {num_workers=}")
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()),
None,
num_workers=None,
num_shared_workers=num_workers)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)
@ -371,8 +384,8 @@ class NixlConnectorWorker:
self._registered_descs: list[Any] = []
# In progress transfers.
# [req_id -> list[handle]]
self._recving_transfers = defaultdict[str, list[Transfer]](list)
# [req_id -> list[handles], agent_name, notif_id]
self._recving_transfers: dict[str, tuple[list[int], str, str]] = {}
# Complete transfer tracker. Used by the rank 0 to track finished
# transactions on ranks 1 to N-1.
@ -754,8 +767,11 @@ class NixlConnectorWorker:
to Rank 0 once their transaction is done + Rank 0 returns
finished sets to Scheduler only once all ranks are done.
"""
start = time.perf_counter()
done_sending = self._get_new_notifs()
done_recving = self._pop_done_transfers(self._recving_transfers)
if len(done_sending) > 0 or len(done_recving) > 0:
logger.debug(
"Rank %s, get_finished: %s requests done sending "
@ -796,6 +812,10 @@ class NixlConnectorWorker:
if self._done_sending_count[req_id] == self.world_size:
del self._done_sending_count[req_id]
all_done_sending.add(req_id)
end = time.perf_counter()
if LOG_XFER_TIME:
logger.info("========== .get_finished(): %s ==========", end - start)
return all_done_sending, all_done_recving
@ -805,6 +825,10 @@ class NixlConnectorWorker:
self.tp_group.send_object(finished_req_ids, dst=0)
# Unused as only Rank 0 results are sent to scheduler.
end = time.perf_counter()
if LOG_XFER_TIME:
logger.info("========== .get_finished(): %s ==========", end - start)
return done_sending, done_recving
def _get_new_notifs(self) -> set[str]:
@ -826,7 +850,8 @@ class NixlConnectorWorker:
return notified_req_ids
def _pop_done_transfers(
self, transfers: dict[str, list[tuple[int, float]]]) -> set[str]:
self, transfers: dict[str, tuple[list[int], str,
str]]) -> set[str]:
"""
Pop completed xfers by checking for DONE state.
Args:
@ -834,19 +859,30 @@ class NixlConnectorWorker:
Returns:
set of req_ids that have all done xfers
"""
done_req_ids: set[str] = set()
for req_id, handles in list(transfers.items()):
for handle, xfer_stime in handles:
done_req_ids: set[str, float] = set()
for req_id, (handles, agent_name, notif_id, start_time) in list(transfers.items()):
new_handles = []
for handle in handles:
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
if xfer_state == "DONE":
self.nixl_wrapper.release_xfer_handle(handle)
done_req_ids.add(req_id)
del transfers[req_id]
elif xfer_state == "PROC":
continue
new_handles.append(handle)
else:
raise RuntimeError("Transfer failed with state %s",
xfer_state)
# Done.
if len(new_handles) == 0:
self.nixl_wrapper.send_notif(agent_name, notif_id)
del transfers[req_id]
done_req_ids.add(req_id)
if LOG_XFER_TIME:
logger.info("========== transmission time: %s ==========", time.perf_counter() - start_time)
else:
transfers[req_id] = (new_handles, agent_name, notif_id, start_time)
return done_req_ids
def start_load_kv(self, metadata: NixlConnectorMetadata):
@ -958,22 +994,35 @@ class NixlConnectorWorker:
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids,
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
)
CHUNK_SIZE = 500
handles = []
# NOTE: this is a hack to make make_prepped_xfer into threads so that
# different workers are allocated for each chuck. Without this change,
# nixl was allocating the same worker (0) for all the chunks and the
# overall launch time was >300 ms.
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
handle = self.nixl_wrapper.make_prepped_xfer(
"READ",
local_xfer_side_handle,
local_block_descs_ids[i:i + CHUNK_SIZE],
remote_xfer_side_handle,
remote_block_descs_ids[i:i + CHUNK_SIZE],
skip_desc_merge=True,
)
handles.append(handle)
# Begin async xfer.
self.nixl_wrapper.transfer(handle)
start = time.perf_counter()
self.nixl_wrapper.transfer_batched(handles)
end = time.perf_counter()
if LOG_XFER_TIME:
logger.info("========== .transfer_batched(): %s ==========", end - start)
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time
self._recving_transfers[request_id].append(
(handle, time.perf_counter()))
# Keep track of ongoing transfers.
remote_rank = self.tp_rank // tp_ratio
agent_name = self._remote_agents[dst_engine_id][remote_rank]
assert request_id not in self._recving_transfers
self._recving_transfers[request_id] = (handles, agent_name, notif_id, time.perf_counter())
def _get_block_descs_ids(self,
engine_id: str,

View File

@ -75,7 +75,7 @@ enable_hf_transfer()
class DisabledTqdm(tqdm):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs, disable=True)
super().__init__(*args, **kwargs, disable=False)
def get_lock(model_name_or_path: Union[str, Path],