Merge pull request #17 from praveingk/batching

Load balance across multiple workers
This commit is contained in:
Robert Shaw
2025-06-30 08:21:03 -04:00
committed by GitHub

View File

@ -13,6 +13,7 @@ from typing import TYPE_CHECKING, Any, Optional
import msgspec
import torch
import zmq
from concurrent.futures import ThreadPoolExecutor, as_completed
from vllm import envs
from vllm.attention.selector import backend_name_to_enum, get_attn_backend
@ -984,18 +985,25 @@ class NixlConnectorWorker:
assert len(local_block_descs_ids) == len(remote_block_descs_ids)
# Prepare transfer with Nixl.
CHUNK_SIZE = 100
CHUNK_SIZE = 1000
handles = []
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
handles.append(
self.nixl_wrapper.make_prepped_xfer(
futures = []
with ThreadPoolExecutor() as executor:
for i in range(0, len(local_block_descs_ids), CHUNK_SIZE):
future = executor.submit(
self.nixl_wrapper.make_prepped_xfer,
"READ",
local_xfer_side_handle,
local_block_descs_ids[i:i + CHUNK_SIZE],
remote_xfer_side_handle,
remote_block_descs_ids[i:i + CHUNK_SIZE],
skip_desc_merge=True,
))
)
futures.append(future)
for future in futures:
handles.append(future.result())
# Begin async xfer.
start = time.perf_counter()