mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
4 Commits
6fad29b11b
...
nixl-upstr
Author | SHA1 | Date | |
---|---|---|---|
8a8b40d417 | |||
c3f7afa6a8 | |||
6cd8dec23f | |||
723263fa23 |
@ -10,6 +10,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from importlib import metadata
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import msgspec
|
||||
@ -42,16 +43,19 @@ EngineId = str
|
||||
ReqId = str
|
||||
GET_META_MSG = b"get_meta_msg"
|
||||
|
||||
import os
|
||||
VLLM_DEBUG_NIXL_XFER_TIME = os.getenv("VLLM_DEBUG_NIXL_XFER_TIME", "1") == "1"
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Lazy import nixl_wrapper to avoid loading nixl_bindings if nixl is not used
|
||||
try:
|
||||
from nixl._api import nixl_agent as NixlWrapper
|
||||
logger.info("NIXL is available")
|
||||
from nixl._api import nixl_agent as NixlWrapper, nixl_agent_config
|
||||
NIXL_VERSION = metadata.version("nixl")
|
||||
except ImportError:
|
||||
logger.warning("NIXL is not available")
|
||||
NixlWrapper = None
|
||||
|
||||
NIXL_VERSION = None
|
||||
|
||||
class NixlAgentMetadata(
|
||||
msgspec.Struct,
|
||||
@ -352,16 +356,20 @@ class NixlConnectorWorker:
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
if NixlWrapper is None:
|
||||
logger.error("NIXL is not available")
|
||||
raise RuntimeError("NIXL is not available")
|
||||
logger.info("Initializing NIXL wrapper")
|
||||
logger.info("Initializing NIXL worker %s", engine_id)
|
||||
raise RuntimeError("NIXL is not available.")
|
||||
logger.info("Initializing NIXL v%s: worker %s", NIXL_VERSION, engine_id)
|
||||
|
||||
# Config.
|
||||
self.vllm_config = vllm_config
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
|
||||
# Agent.
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
|
||||
import os
|
||||
NIXL_NUM_WORKERS = int(os.getenv("VLLM_NIXL_NUM_WORKERS", "8"))
|
||||
logger.info(f"Using NIXL_NUM_WORKERS={NIXL_NUM_WORKERS} for NIXL agent.")
|
||||
|
||||
config = nixl_agent_config(enable_prog_thread=False, num_threads=NIXL_NUM_WORKERS)
|
||||
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), config)
|
||||
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
|
||||
self._remote_agents: dict[EngineId, dict[int, str]] = defaultdict(dict)
|
||||
|
||||
@ -449,7 +457,8 @@ class NixlConnectorWorker:
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
self._handshake_initiation_executor.shutdown(wait=False)
|
||||
if t_ := getattr(self, "_handshake_initiation_executor", None):
|
||||
t_.shutdown(wait=False)
|
||||
if self._nixl_handshake_listener_t:
|
||||
self._nixl_handshake_listener_t.join(timeout=0)
|
||||
|
||||
@ -1019,10 +1028,16 @@ class NixlConnectorWorker:
|
||||
remote_xfer_side_handle,
|
||||
remote_block_descs_ids,
|
||||
notif_msg=notif_id,
|
||||
skip_desc_merge=True,
|
||||
)
|
||||
|
||||
# Begin async xfer.
|
||||
start = time.perf_counter()
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
end = time.perf_counter()
|
||||
if VLLM_DEBUG_NIXL_XFER_TIME:
|
||||
# Log the time taken for the transfer.
|
||||
logger.info(f"TIME: {end - start}")
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
# TODO (NickLucche) surface xfer elapsed time
|
||||
|
Reference in New Issue
Block a user