Compare commits

...

4 Commits

Author SHA1 Message Date
8a8b40d417 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-16 20:40:35 +00:00
c3f7afa6a8 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-16 20:39:45 +00:00
6cd8dec23f updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-16 20:29:24 +00:00
723263fa23 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-15 22:06:34 +00:00

View File

@ -10,6 +10,7 @@ from collections import defaultdict
from collections.abc import Iterator
from concurrent.futures import Future, ThreadPoolExecutor
from dataclasses import dataclass
from importlib import metadata
from typing import TYPE_CHECKING, Any, Optional
import msgspec
@ -42,16 +43,19 @@ EngineId = str
ReqId = str
GET_META_MSG = b"get_meta_msg"
import os
VLLM_DEBUG_NIXL_XFER_TIME = os.getenv("VLLM_DEBUG_NIXL_XFER_TIME", "1") == "1"
logger = init_logger(__name__)
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
try:
from nixl._api import nixl_agent as NixlWrapper
logger.info("NIXL is available")
from nixl._api import nixl_agent as NixlWrapper, nixl_agent_config
NIXL_VERSION = metadata.version("nixl")
except ImportError:
logger.warning("NIXL is not available")
NixlWrapper = None
NIXL_VERSION = None
class NixlAgentMetadata(
msgspec.Struct,
@ -352,16 +356,20 @@ class NixlConnectorWorker:
def __init__(self, vllm_config: VllmConfig, engine_id: str):
if NixlWrapper is None:
logger.error("NIXL is not available")
raise RuntimeError("NIXL is not available")
logger.info("Initializing NIXL wrapper")
logger.info("Initializing NIXL worker %s", engine_id)
raise RuntimeError("NIXL is not available.")
logger.info("Initializing NIXL v%s: worker %s", NIXL_VERSION, engine_id)
# Config.
self.vllm_config = vllm_config
self.block_size = vllm_config.cache_config.block_size
# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
import os
NIXL_NUM_WORKERS = int(os.getenv("VLLM_NIXL_NUM_WORKERS", "8"))
logger.info(f"Using NIXL_NUM_WORKERS={NIXL_NUM_WORKERS} for NIXL agent.")
config = nixl_agent_config(enable_prog_thread=False, num_threads=NIXL_NUM_WORKERS)
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
@ -449,7 +457,8 @@ class NixlConnectorWorker:
def __del__(self):
"""Cleanup background threads on destruction."""
self._handshake_initiation_executor.shutdown(wait=False)
if t_ := getattr(self, "_handshake_initiation_executor", None):
t_.shutdown(wait=False)
if self._nixl_handshake_listener_t:
self._nixl_handshake_listener_t.join(timeout=0)
@ -1019,10 +1028,16 @@ class NixlConnectorWorker:
remote_xfer_side_handle,
remote_block_descs_ids,
notif_msg=notif_id,
skip_desc_merge=True,
)
# Begin async xfer.
start = time.perf_counter()
self.nixl_wrapper.transfer(handle)
end = time.perf_counter()
if VLLM_DEBUG_NIXL_XFER_TIME:
# Log the time taken for the transfer.
logger.info(f"TIME: {end - start}")
# Use handle to check completion in future step().
# TODO (NickLucche) surface xfer elapsed time