mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[P/D] Move FakeNixlWrapper to test dir (#21328)
Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
@ -1,10 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import os
|
||||
import tempfile
|
||||
import textwrap
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@ -16,30 +21,118 @@ from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
|
||||
def _make_stub_pkg() -> str:
|
||||
"""Return a directory that makes
|
||||
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper."""
|
||||
td = tempfile.mkdtemp()
|
||||
pkg_root = os.path.join(td, "nixl", "_api")
|
||||
os.makedirs(pkg_root, exist_ok=True)
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
stub = textwrap.dedent("""\
|
||||
# Forward the real FakeNixlWrapper that the driver already defined.
|
||||
print("In fake package")
|
||||
from vllm.mocks.mock_nixl_connector import FakeNixlWrapper as nixl_agent
|
||||
""")
|
||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||
f.write(stub)
|
||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||
installed.
|
||||
|
||||
Note: The complete source of this class is also used in the
|
||||
`_make_fake_nixl_pkg` function to create a fake nixl package
|
||||
for Ray workers.
|
||||
"""
|
||||
|
||||
# touch parent package
|
||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||
return td
|
||||
AGENT_METADATA = b"fake_agent_metadata"
|
||||
REMOTE_AGENT_NAME = "remote_agent"
|
||||
|
||||
def __init__(self, agent_name: str, *args, **kwargs):
|
||||
self._cycles_before_xfer_done = 0
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||
lambda: 0)
|
||||
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||
|
||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def get_agent_metadata(self) -> bytes:
|
||||
return self.AGENT_METADATA
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
return self.REMOTE_AGENT_NAME
|
||||
|
||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||
# Used to collect done_sending, which we don't test yet.
|
||||
return {}
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self._check_xfer_state_cycles[
|
||||
handle] >= self._cycles_before_xfer_done:
|
||||
return "DONE"
|
||||
self._check_xfer_state_cycles[handle] += 1
|
||||
return "PROC"
|
||||
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
def make_prepped_xfer(self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def transfer(self, handle: int) -> str:
|
||||
return "PROC"
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _make_fake_nixl_pkg():
|
||||
"""Context manager that creates a temporary package making
|
||||
`from nixl._api import nixl_agent` resolve to our FakeNixlWrapper.
|
||||
|
||||
Automatically cleans up the temporary directory when done.
|
||||
"""
|
||||
with tempfile.TemporaryDirectory() as td:
|
||||
pkg_root = os.path.join(td, "nixl", "_api")
|
||||
os.makedirs(pkg_root, exist_ok=True)
|
||||
|
||||
# Get the source code of FakeNixlWrapper class and dedent it
|
||||
fake_nixl_source = inspect.getsource(FakeNixlWrapper)
|
||||
fake_nixl_source = textwrap.dedent(fake_nixl_source)
|
||||
|
||||
stub = f"""\
|
||||
# Copy of FakeNixlWrapper implementation for Ray workers
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
{fake_nixl_source}
|
||||
|
||||
# Export as nixl_agent
|
||||
nixl_agent = FakeNixlWrapper
|
||||
"""
|
||||
with open(os.path.join(pkg_root, "__init__.py"), "w") as f:
|
||||
f.write(stub)
|
||||
|
||||
# touch parent package
|
||||
open(os.path.join(td, "nixl", "__init__.py"), "w").close()
|
||||
yield td
|
||||
|
||||
|
||||
def test_basic_interface():
|
||||
@ -351,27 +444,37 @@ def test_abort_timeout_on_prefiller(monkeypatch, distributed_executor_backend):
|
||||
kv_connector="NixlConnector",
|
||||
kv_role="kv_both",
|
||||
)
|
||||
llm_kwargs = {
|
||||
"model": model_name,
|
||||
"enforce_eager": True,
|
||||
"gpu_memory_utilization": 0.5,
|
||||
"kv_transfer_config": kv_transfer_config,
|
||||
"distributed_executor_backend": distributed_executor_backend,
|
||||
}
|
||||
|
||||
timeout = 6
|
||||
monkeypatch.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
|
||||
monkeypatch.setenv("VLLM_NIXL_ABORT_REQUEST_TIMEOUT", str(timeout))
|
||||
|
||||
# Build runtime_env only if we’re using Ray
|
||||
# Build runtime_env only if we're using Ray
|
||||
if distributed_executor_backend == "ray":
|
||||
runtime_env = {
|
||||
"working_dir": _make_stub_pkg(), # ship stub package
|
||||
"env_vars": {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||
},
|
||||
}
|
||||
ray.init(runtime_env=runtime_env)
|
||||
with _make_fake_nixl_pkg() as working_dir:
|
||||
runtime_env = {
|
||||
"working_dir": working_dir, # ship fake nixl package
|
||||
"env_vars": {
|
||||
"VLLM_NIXL_ABORT_REQUEST_TIMEOUT": str(timeout),
|
||||
},
|
||||
}
|
||||
ray.init(runtime_env=runtime_env)
|
||||
|
||||
llm = LLM(
|
||||
model=model_name,
|
||||
enforce_eager=True,
|
||||
gpu_memory_utilization=0.5,
|
||||
kv_transfer_config=kv_transfer_config,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
)
|
||||
_run_abort_timeout_test(llm_kwargs, timeout)
|
||||
else:
|
||||
_run_abort_timeout_test(llm_kwargs, timeout)
|
||||
|
||||
|
||||
def _run_abort_timeout_test(llm_kwargs: dict, timeout: int):
|
||||
"""Helper function to run the abort timeout test logic."""
|
||||
llm = LLM(**llm_kwargs)
|
||||
remote_prefill_opts = {
|
||||
"do_remote_decode": True,
|
||||
"do_remote_prefill": False,
|
||||
|
@ -120,8 +120,8 @@ class KVOutputAggregator:
|
||||
output corresponding to Rank 0 for scheduler."""
|
||||
|
||||
def __init__(self, world_size: int):
|
||||
# Complete transfer tracker. Used by to track finished requests
|
||||
# [req_id -> n_finished_workers]
|
||||
# Complete transfer tracker. Used to track finished requests
|
||||
# [req_id -> n_remaining_workers]
|
||||
self._recv_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||
self._send_remaining_count = defaultdict[str, int](lambda: world_size)
|
||||
|
||||
@ -134,12 +134,10 @@ class KVOutputAggregator:
|
||||
remaining_count_dict: dict[str, int],
|
||||
finished_set: set[str]) -> None:
|
||||
for req_id in req_ids or ():
|
||||
new_count = remaining_count_dict[req_id] - 1
|
||||
if new_count == 0:
|
||||
remaining_count_dict[req_id] -= 1
|
||||
if remaining_count_dict[req_id] == 0:
|
||||
finished_set.add(req_id)
|
||||
del remaining_count_dict[req_id]
|
||||
else:
|
||||
remaining_count_dict[req_id] = new_count
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
|
@ -1,76 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class FakeNixlWrapper:
|
||||
"""Mock implementation of NixlWrapper for testing.
|
||||
|
||||
We don't inherit from nixl._api.nixl_agent because nixl may not be
|
||||
installed.
|
||||
"""
|
||||
|
||||
AGENT_METADATA = b"fake_agent_metadata"
|
||||
REMOTE_AGENT_NAME = "remote_agent"
|
||||
|
||||
def __init__(self, agent_name: str, *args, **kwargs):
|
||||
self._cycles_before_xfer_done = 0
|
||||
self._check_xfer_state_cycles: defaultdict[int, int] = defaultdict(
|
||||
lambda: 0)
|
||||
|
||||
def get_reg_descs(self, caches_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in caches_data]
|
||||
|
||||
def register_memory(self, descs) -> None:
|
||||
pass
|
||||
|
||||
def get_xfer_descs(self, blocks_data, memory_type: str) -> list:
|
||||
return [str(uuid.uuid4()) for _ in blocks_data]
|
||||
|
||||
def prep_xfer_dlist(self, agent_name: str, descs: list) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def get_agent_metadata(self) -> bytes:
|
||||
return self.AGENT_METADATA
|
||||
|
||||
def add_remote_agent(self, agent_metadata: bytes) -> str:
|
||||
return self.REMOTE_AGENT_NAME
|
||||
|
||||
def get_new_notifs(self) -> dict[str, list[bytes]]:
|
||||
# Used to collect done_sending, which we don't test yet.
|
||||
return {}
|
||||
|
||||
def check_xfer_state(self, handle: int) -> str:
|
||||
if self._check_xfer_state_cycles[
|
||||
handle] >= self._cycles_before_xfer_done:
|
||||
return "DONE"
|
||||
self._check_xfer_state_cycles[handle] += 1
|
||||
return "PROC"
|
||||
|
||||
def release_xfer_handle(self, handle: int) -> None:
|
||||
pass
|
||||
|
||||
def send_notif(self, agent_name: str, notif_msg: bytes) -> None:
|
||||
pass
|
||||
|
||||
def make_prepped_xfer(self,
|
||||
xfer_type: str,
|
||||
local_xfer_side_handle: int,
|
||||
local_block_descs_ids: list[int],
|
||||
remote_xfer_side_handle: int,
|
||||
remote_block_descs_ids: list[int],
|
||||
notif_msg: Optional[bytes] = None) -> int:
|
||||
return uuid.uuid4().int
|
||||
|
||||
def transfer(self, handle: int) -> str:
|
||||
return "PROC"
|
||||
|
||||
############################################################
|
||||
# Follow are for changing the behavior during testing.
|
||||
############################################################
|
||||
|
||||
def set_cycles_before_xfer_done(self, cycles: int):
|
||||
"""Set the number of cycles before a transfer is considered done."""
|
||||
self._cycles_before_xfer_done = cycles
|
Reference in New Issue
Block a user