[P/D] Move FakeNixlWrapper to test dir (#21328)

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao
2025-07-24 08:53:45 -07:00
committed by GitHub
parent d9f9a3fd96
commit 1e9ea8e69d
4 changed files with 139 additions and 114 deletions

View File

@ -1,10 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import inspect
import os
import tempfile
import textwrap
import time
import uuid
from collections import defaultdict
from typing import Optional
from unittest.mock import patch
import pytest
@ -16,30 +21,118 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
NixlConnectorWorker)
from vllm.forward_context import ForwardContext
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
from vllm.sampling_params import SamplingParams
from .utils import create_request, create_scheduler, create_vllm_config
def _make_stub_pkg() -> str:
"""Return a directory that makes
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
td = tempfile.mkdtemp()
pkg_root = os.path.join(td, "nixl", "_api")
os.makedirs(pkg_root, exist_ok=True)
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
stub = textwrap.dedent("""\
# Forward the real FakeNixlWrapper that the driver already defined.
print("In fake package")
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
""")
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
Note: The complete source of this class is also used in the
`_make_fake_nixl_pkg` function to create a fake nixl package
for Ray workers.
"""
# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
return td
AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int
def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA
def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME
def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
def release_xfer_handle(self, handle: int) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
return "PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles
@contextlib.contextmanager
def _make_fake_nixl_pkg():
"""Context manager that creates a temporary package making
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
Automatically cleans up the temporary directory when done.
"""
with tempfile.TemporaryDirectory() as td:
pkg_root = os.path.join(td, "nixl", "_api")
os.makedirs(pkg_root, exist_ok=True)
# Get the source code of FakeNixlWrapper class and dedent it
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
fake_nixl_source = textwrap.dedent(fake_nixl_source)
stub = f"""\
# Copy of FakeNixlWrapper implementation for Ray workers
import uuid
from collections import defaultdict
from typing import Optional
{fake_nixl_source}
# Export as nixl_agent
nixl_agent = FakeNixlWrapper
"""
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
f.write(stub)
# touch parent package
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
yield td
def test_basic_interface():
@ -351,27 +444,37 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
kv_connector="NixlConnector",
kv_role="kv_both",
)
llm_kwargs = {
"model": model_name,
"enforce_eager": True,
"gpu_memory_utilization": 0.5,
"kv_transfer_config": kv_transfer_config,
"distributed_executor_backend": distributed_executor_backend,
}
timeout = 6
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
# Build runtime_env only if were using Ray
# Build runtime_env only if we're using Ray
if distributed_executor_backend == "ray":
runtime_env = {
"working_dir": _make_stub_pkg(), # ship stub package
"env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
},
}
ray.init(runtime_env=runtime_env)
with _make_fake_nixl_pkg() as working_dir:
runtime_env = {
"working_dir": working_dir, # ship fake nixl package
"env_vars": {
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
},
}
ray.init(runtime_env=runtime_env)
llm = LLM(
model=model_name,
enforce_eager=True,
gpu_memory_utilization=0.5,
kv_transfer_config=kv_transfer_config,
distributed_executor_backend=distributed_executor_backend,
)
_run_abort_timeout_test(llm_kwargs, timeout)
else:
_run_abort_timeout_test(llm_kwargs, timeout)
def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
"""Helper function to run the abort timeout test logic."""
llm = LLM(**llm_kwargs)
remote_prefill_opts = {
"do_remote_decode": True,
"do_remote_prefill": False,

View File

@ -120,8 +120,8 @@ class KVOutputAggregator:
output corresponding to Rank 0 for scheduler."""
def __init__(self, world_size: int):
# Complete transfer tracker. Used by to track finished requests
# [req_id -> n_finished_workers]
# Complete transfer tracker. Used to track finished requests
# [req_id -> n_remaining_workers]
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
@ -134,12 +134,10 @@ class KVOutputAggregator:
remaining_count_dict: dict[str, int],
finished_set: set[str]) -> None:
for req_id in req_ids or ():
new_count = remaining_count_dict[req_id] - 1
if new_count == 0:
remaining_count_dict[req_id] -= 1
if remaining_count_dict[req_id] == 0:
finished_set.add(req_id)
del remaining_count_dict[req_id]
else:
remaining_count_dict[req_id] = new_count
finished_sending = set[str]()
finished_recving = set[str]()

View File

View File

@ -1,76 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import uuid
from collections import defaultdict
from typing import Optional
class FakeNixlWrapper:
"""Mock implementation of NixlWrapper for testing.
We don't inherit from nixl._api.nixl_agent because nixl may not be
installed.
"""
AGENT_METADATA = b"fake_agent_metadata"
REMOTE_AGENT_NAME = "remote_agent"
def __init__(self, agent_name: str, *args, **kwargs):
self._cycles_before_xfer_done = 0
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
lambda: 0)
def get_reg_descs(self, caches_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in caches_data]
def register_memory(self, descs) -> None:
pass
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
return [str(uuid.uuid4()) for _ in blocks_data]
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
return uuid.uuid4().int
def get_agent_metadata(self) -> bytes:
return self.AGENT_METADATA
def add_remote_agent(self, agent_metadata: bytes) -> str:
return self.REMOTE_AGENT_NAME
def get_new_notifs(self) -> dict[str, list[bytes]]:
# Used to collect done_sending, which we don't test yet.
return {}
def check_xfer_state(self, handle: int) -> str:
if self._check_xfer_state_cycles[
handle] >= self._cycles_before_xfer_done:
return "DONE"
self._check_xfer_state_cycles[handle] += 1
return "PROC"
def release_xfer_handle(self, handle: int) -> None:
pass
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
pass
def make_prepped_xfer(self,
xfer_type: str,
local_xfer_side_handle: int,
local_block_descs_ids: list[int],
remote_xfer_side_handle: int,
remote_block_descs_ids: list[int],
notif_msg: Optional[bytes] = None) -> int:
return uuid.uuid4().int
def transfer(self, handle: int) -> str:
return "PROC"
############################################################
# Follow are for changing the behavior during testing.
############################################################
def set_cycles_before_xfer_done(self, cycles: int):
"""Set the number of cycles before a transfer is considered done."""
self._cycles_before_xfer_done = cycles