mirror of
https://github.com/vllm-project/vllm-ascend.git
synced 2025-10-20 13:43:53 +08:00
[1/N][CI] Add multi node test (#3359)
### What this PR does / why we need it? This pr purpose to add multi-node test, on the first step, add `deepseek-v3` dp+tp+ep test ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
1
.github/actionlint.yaml
vendored
1
.github/actionlint.yaml
vendored
@ -18,3 +18,4 @@ self-hosted-runner:
|
||||
- linux-amd64-cpu-0
|
||||
- linux-amd64-cpu-8
|
||||
- linux-amd64-cpu-16
|
||||
- linux-aarch64-a3-0
|
||||
|
109
.github/workflows/multi_node_test.yaml
vendored
Normal file
109
.github/workflows/multi_node_test.yaml
vendored
Normal file
@ -0,0 +1,109 @@
|
||||
name: 'e2e test / multi-dp'
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: "0 */4 * * *"
|
||||
workflow_dispatch:
|
||||
|
||||
# Bash shells do not use ~/.profile or ~/.bashrc so these shells need to be explicitly
|
||||
# declared as "shell: bash -el {0}" on steps that need to be properly activated.
|
||||
# It's used to activate ascend-toolkit environment variables.
|
||||
defaults:
|
||||
run:
|
||||
shell: bash -el {0}
|
||||
|
||||
# only cancel in-progress runs of the same workflow
|
||||
# and ignore the lint / 8 cards test type
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
e2e:
|
||||
# This is a runner with no NPU for k8s controller
|
||||
runs-on: linux-aarch64-a3-0
|
||||
container:
|
||||
image: m.daocloud.io/quay.io/ascend/cann:8.2.rc1-a3-ubuntu22.04-py3.11
|
||||
env:
|
||||
KUBECONFIG: /tmp/kubeconfig
|
||||
KUBECTL: /root/.cache/.kube/kubectl
|
||||
NAMESPACE: vllm-project
|
||||
LEADER_POD: vllm-0
|
||||
steps:
|
||||
- name: Install system denpendencies
|
||||
run: |
|
||||
# configure apt and pip source
|
||||
sed -i 's|ports.ubuntu.com|mirrors.tuna.tsinghua.edu.cn|g' /etc/apt/sources.list
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
|
||||
apt-get update -y && apt-get install -y git curl
|
||||
|
||||
TOKEN=`echo -n "x-access-token:${{ secrets.ADMIN_PTA }}" | base64`
|
||||
git config --global http.https://gh-proxy.test.osinfra.cn/.extraheader "AUTHORIZATION: basic $TOKEN"
|
||||
|
||||
- name: Install kubectl
|
||||
run: |
|
||||
install -o root -g root -m 0755 $KUBECTL /usr/local/bin/kubectl
|
||||
|
||||
# get kubeconfig from secret
|
||||
echo "${{ secrets.KUBECONFIG_B64 }}" | base64 -d > $KUBECONFIG
|
||||
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Prepare scripts
|
||||
run: |
|
||||
# prepare for lws entrypoint scripts
|
||||
install -D tests/e2e/multi_node/scripts/run.sh /root/.cache/tests/run.sh
|
||||
|
||||
- name: Launch cluster
|
||||
run: |
|
||||
kubectl apply -f tests/e2e/multi_node/scripts/lws.yaml
|
||||
|
||||
- name: Waiting for pod ready
|
||||
run: |
|
||||
echo "waiting for Pod [$LEADER_POD] in namespace [$NAMESPACE] to Ready..."
|
||||
|
||||
while true; do
|
||||
# get pod status
|
||||
READY_STATUS=$(kubectl get pod "$LEADER_POD" -n "$NAMESPACE" -o jsonpath='{.status.containerStatuses[*].ready}')
|
||||
|
||||
if [[ "$READY_STATUS" == "true" ]]; then
|
||||
echo "✅ Pod [$LEADER_POD] is Ready!"
|
||||
break
|
||||
else
|
||||
echo "Pod [$LEADER_POD] not ready, waiting..."
|
||||
sleep 3
|
||||
fi
|
||||
done
|
||||
|
||||
- name: Stream logs and monitor pod health
|
||||
run: |
|
||||
set -euo pipefail
|
||||
|
||||
echo "🚀 Start streaming logs for Pod [$LEADER_POD] ..."
|
||||
kubectl logs -f "$LEADER_POD" -n "$NAMESPACE" &
|
||||
LOG_PID=$!
|
||||
|
||||
echo "Start monitoring Pod [$LEADER_POD] status ..."
|
||||
while true; do
|
||||
STATUS=$(kubectl get pod "$LEADER_POD" -n "$NAMESPACE" -o jsonpath='{.status.phase}')
|
||||
if [[ "$STATUS" != "Running" && "$STATUS" != "Succeeded" ]]; then
|
||||
echo "❌ Pod [$LEADER_POD] exited abnormally with status: $STATUS"
|
||||
kubectl describe pod "$LEADER_POD" -n "$NAMESPACE" || true
|
||||
kubectl logs "$LEADER_POD" -n "$NAMESPACE" --previous --all-containers || true
|
||||
kill $LOG_PID || true
|
||||
exit 1
|
||||
fi
|
||||
sleep 5
|
||||
done &
|
||||
|
||||
MONITOR_PID=$!
|
||||
wait $LOG_PID || true
|
||||
kill $MONITOR_PID || true
|
||||
|
||||
- name: Post process
|
||||
if: always()
|
||||
run: |
|
||||
kubectl get pods -n $NAMESPACE
|
||||
kubectl delete -f tests/e2e/multi_node/scripts/lws.yaml
|
@ -19,11 +19,18 @@
|
||||
|
||||
import contextlib
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import httpx
|
||||
import numpy as np
|
||||
import openai
|
||||
import pytest
|
||||
import requests
|
||||
import torch
|
||||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||||
from PIL import Image
|
||||
@ -33,9 +40,14 @@ from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
||||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config.model import TaskOption, _get_and_verify_dtype
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.entrypoints.cli.serve import ServeSubcommand
|
||||
from vllm.inputs import TextPrompt
|
||||
from vllm.model_executor.model_loader import get_model_loader
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
||||
|
||||
from tests.e2e.model_utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
@ -76,6 +88,181 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||||
torch.npu.reset_peak_memory_stats()
|
||||
|
||||
|
||||
class RemoteOpenAIServer:
|
||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||||
|
||||
def _start_server(self, model: str, vllm_serve_args: list[str],
|
||||
env_dict: Optional[dict[str, str]]) -> None:
|
||||
"""Subclasses override this method to customize server process launch
|
||||
"""
|
||||
env = os.environ.copy()
|
||||
# the current process might initialize npu,
|
||||
# to be safe, we should use spawn method
|
||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
||||
if env_dict is not None:
|
||||
env.update(env_dict)
|
||||
self.proc: subprocess.Popen = subprocess.Popen(
|
||||
["vllm", "serve", model, *vllm_serve_args],
|
||||
env=env,
|
||||
stdout=sys.stdout,
|
||||
stderr=sys.stderr,
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
server_host: str,
|
||||
server_port: int,
|
||||
vllm_serve_args: list[str],
|
||||
*,
|
||||
env_dict: Optional[dict[str, str]] = None,
|
||||
seed: Optional[int] = 0,
|
||||
auto_port: bool = True,
|
||||
max_wait_seconds: Optional[float] = None,
|
||||
override_hf_configs: Optional[dict[str, Any]] = None) -> None:
|
||||
if auto_port:
|
||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the port "
|
||||
"when `auto_port=True`.")
|
||||
|
||||
# No need for a port if using unix sockets
|
||||
if "--uds" not in vllm_serve_args:
|
||||
# Don't mutate the input args
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--port", str(get_open_port())
|
||||
]
|
||||
if seed is not None:
|
||||
if "--seed" in vllm_serve_args:
|
||||
raise ValueError("You have manually specified the seed "
|
||||
f"when `seed={seed}`.")
|
||||
|
||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||||
|
||||
if override_hf_configs is not None:
|
||||
vllm_serve_args = vllm_serve_args + [
|
||||
"--hf-overrides",
|
||||
json.dumps(override_hf_configs)
|
||||
]
|
||||
|
||||
parser = FlexibleArgumentParser(
|
||||
description="vLLM's remote OpenAI server.")
|
||||
subparsers = parser.add_subparsers(required=False, dest="subparser")
|
||||
parser = ServeSubcommand().subparser_init(subparsers)
|
||||
args = parser.parse_args([*vllm_serve_args])
|
||||
self.uds = args.uds
|
||||
if args.uds:
|
||||
self.host = None
|
||||
self.port = None
|
||||
else:
|
||||
self.host = str(server_host)
|
||||
self.port = int(server_port)
|
||||
|
||||
self.show_hidden_metrics = \
|
||||
args.show_hidden_metrics_for_version is not None
|
||||
|
||||
# download the model before starting the server to avoid timeout
|
||||
is_local = os.path.isdir(model)
|
||||
if not is_local:
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
model_config = engine_args.create_model_config()
|
||||
load_config = engine_args.create_load_config()
|
||||
|
||||
model_loader = get_model_loader(load_config)
|
||||
model_loader.download_model(model_config)
|
||||
|
||||
self._start_server(model, vllm_serve_args, env_dict)
|
||||
max_wait_seconds = max_wait_seconds or 7200
|
||||
self._wait_for_server(url=self.url_for("health"),
|
||||
timeout=max_wait_seconds)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
self.proc.terminate()
|
||||
try:
|
||||
self.proc.wait(8)
|
||||
except subprocess.TimeoutExpired:
|
||||
# force kill if needed
|
||||
self.proc.kill()
|
||||
|
||||
def _poll(self) -> Optional[int]:
|
||||
"""Subclasses override this method to customize process polling"""
|
||||
return self.proc.poll()
|
||||
|
||||
def hang_until_terminated(self) -> None:
|
||||
"""
|
||||
Wait until the server process terminates.
|
||||
This is for headless mode, where the api server
|
||||
process only exists in the leader node.
|
||||
"""
|
||||
if self.uds:
|
||||
client = httpx.Client(transport=httpx.HTTPTransport(uds=self.uds))
|
||||
else:
|
||||
client = requests
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
resp = client.get(self.url_for("health"), timeout=5)
|
||||
if resp.status_code != 200:
|
||||
break
|
||||
time.sleep(5)
|
||||
except Exception:
|
||||
break
|
||||
finally:
|
||||
if isinstance(client, httpx.Client):
|
||||
client.close()
|
||||
|
||||
def _wait_for_server(self, *, url: str, timeout: float):
|
||||
# run health check
|
||||
start = time.time()
|
||||
client = (httpx.Client(transport=httpx.HTTPTransport(
|
||||
uds=self.uds)) if self.uds else requests)
|
||||
while True:
|
||||
try:
|
||||
if client.get(url).status_code == 200:
|
||||
break
|
||||
except Exception:
|
||||
# this exception can only be raised by requests.get,
|
||||
# which means the server is not ready yet.
|
||||
# the stack trace is not useful, so we suppress it
|
||||
# by using `raise from None`.
|
||||
result = self._poll()
|
||||
if result is not None and result != 0:
|
||||
raise RuntimeError("Server exited unexpectedly.") from None
|
||||
|
||||
time.sleep(1)
|
||||
if time.time() - start > timeout:
|
||||
raise RuntimeError(
|
||||
"Server failed to start in time.") from None
|
||||
|
||||
@property
|
||||
def url_root(self) -> str:
|
||||
return (f"http://{self.uds.split('/')[-1]}"
|
||||
if self.uds else f"http://{self.host}:{self.port}")
|
||||
|
||||
def url_for(self, *parts: str) -> str:
|
||||
return self.url_root + "/" + "/".join(parts)
|
||||
|
||||
def get_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return openai.OpenAI(
|
||||
base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_async_client(self, **kwargs):
|
||||
if "timeout" not in kwargs:
|
||||
kwargs["timeout"] = 600
|
||||
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
||||
api_key=self.DUMMY_API_KEY,
|
||||
max_retries=0,
|
||||
**kwargs)
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
|
||||
def __init__(
|
||||
@ -289,7 +476,6 @@ class VllmRunner:
|
||||
class HfRunner:
|
||||
|
||||
def get_default_device(self):
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
return ("cpu"
|
||||
if current_platform.is_cpu() else current_platform.device_type)
|
||||
|
0
tests/e2e/multi_node/__init__.py
Normal file
0
tests/e2e/multi_node/__init__.py
Normal file
0
tests/e2e/multi_node/config/__init__.py
Normal file
0
tests/e2e/multi_node/config/__init__.py
Normal file
41
tests/e2e/multi_node/config/config.json
Normal file
41
tests/e2e/multi_node/config/config.json
Normal file
@ -0,0 +1,41 @@
|
||||
[
|
||||
{
|
||||
"test_name": "test_deepseek_v3",
|
||||
"disaggregate_prefill": false,
|
||||
"enable_multithread_load": false,
|
||||
"num_nodes": 2,
|
||||
"server_parameters": {
|
||||
"leader_config": {
|
||||
"model": "vllm-ascend/DeepSeek-V3-W8A8",
|
||||
"additional_config": {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": true
|
||||
},
|
||||
"torchair_graph_config": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"worker_config": {
|
||||
"model": "vllm-ascend/DeepSeek-V3-W8A8",
|
||||
"additional_config": {
|
||||
"ascend_scheduler_config": {
|
||||
"enabled": true
|
||||
},
|
||||
"torchair_graph_config": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "vllm-ascend/DeepSeek-V3-W8A8",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "/root/.cache/datasets/ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"request_rate": 1
|
||||
},
|
||||
"accuracy_parameters": {}
|
||||
}
|
||||
]
|
204
tests/e2e/multi_node/config/multi_node_config.py
Normal file
204
tests/e2e/multi_node/config/multi_node_config.py
Normal file
@ -0,0 +1,204 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field, fields
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from tests.e2e.multi_node.config.utils import (get_avaliable_port,
|
||||
get_leader_ip,
|
||||
get_net_interface)
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
CONFIG_PATH = Path("tests/e2e/multi_node/config/config.json")
|
||||
|
||||
T = TypeVar("T", bound="BaseConfig")
|
||||
|
||||
|
||||
# =========================
|
||||
# Base Config
|
||||
# =========================
|
||||
@dataclass
|
||||
class BaseConfig:
|
||||
model: str = "vllm-ascend/DeepSeek-V3-W8A8"
|
||||
_extra_fields: Optional[Dict[str, Any]] = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls: Type[T], data: dict[str, Any]) -> T:
|
||||
"""Create config instance from dict, keeping unknown fields."""
|
||||
field_names = {f.name for f in fields(cls)}
|
||||
valid_fields = {k: v for k, v in data.items() if k in field_names}
|
||||
extra_fields = {k: v for k, v in data.items() if k not in field_names}
|
||||
obj = cls(**valid_fields)
|
||||
obj._extra_fields = extra_fields or {}
|
||||
return obj
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
"""Convert all fields (including _extra_fields) to CLI arguments."""
|
||||
args: List[str] = []
|
||||
all_items = {**vars(self), **(self._extra_fields or {})}
|
||||
|
||||
for key, value in all_items.items():
|
||||
if key in ("model", "_extra_fields") or value in (None, "", [],
|
||||
{}):
|
||||
continue
|
||||
key = key.replace("_", "-")
|
||||
|
||||
if isinstance(value, bool):
|
||||
if value:
|
||||
args.append(f"--{key}")
|
||||
elif isinstance(value, dict):
|
||||
args += [f"--{key}", json.dumps(value, ensure_ascii=False)]
|
||||
else:
|
||||
args += [f"--{key}", str(value)]
|
||||
return args
|
||||
|
||||
|
||||
# =========================
|
||||
# Server Config
|
||||
# =========================
|
||||
@dataclass
|
||||
class ServerConfig(BaseConfig):
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8080
|
||||
trust_remote_code: bool = True
|
||||
enable_expert_parallel: bool = True
|
||||
gpu_memory_utilization: float = 0.9
|
||||
headless: bool = False
|
||||
quantization: Optional[str] = None
|
||||
tensor_parallel_size: int = 8
|
||||
max_model_len: int = 8192
|
||||
max_num_batched_token: int = 8192
|
||||
data_parallel_size: int = 4
|
||||
data_parallel_size_local: int = 2
|
||||
data_parallel_start_rank: int = 0
|
||||
data_parallel_rpc_port: int = 13389
|
||||
data_parallel_address: Optional[str] = None
|
||||
kv_transfer_config: Optional[Dict[str, Any]] = None
|
||||
additional_config: Optional[Dict[str, Any]] = None
|
||||
|
||||
def init_dp_param(
|
||||
self,
|
||||
is_leader: bool,
|
||||
is_disaggregate_prefill: bool,
|
||||
dp_size: int,
|
||||
world_size: int,
|
||||
) -> None:
|
||||
"""Initialize distributed parallel parameters."""
|
||||
iface = get_net_interface()
|
||||
if iface is None:
|
||||
raise RuntimeError("No available network interface found")
|
||||
self.data_parallel_address = iface[0]
|
||||
|
||||
if is_disaggregate_prefill:
|
||||
self.data_parallel_start_rank = 0
|
||||
return
|
||||
|
||||
if not is_leader:
|
||||
self.headless = True
|
||||
self.data_parallel_start_rank = dp_size // world_size
|
||||
self.data_parallel_address = get_leader_ip()
|
||||
|
||||
|
||||
@dataclass
|
||||
class PerfConfig(BaseConfig):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class AccuracyConfig:
|
||||
prompt: str
|
||||
expected_output: str
|
||||
|
||||
|
||||
# =========================
|
||||
# MultiNode Config
|
||||
# =========================
|
||||
@dataclass
|
||||
class MultiNodeConfig:
|
||||
test_name: str = "Unnamed Test"
|
||||
disaggregate_prefill: bool = False
|
||||
enable_multithread_load: bool = True
|
||||
world_size: int = 2
|
||||
server_host: str = "0.0.0.0"
|
||||
server_port: int = 8888
|
||||
server_config: ServerConfig = field(default_factory=ServerConfig)
|
||||
perf_config: Optional[PerfConfig] = None
|
||||
accuracy_config: Optional[AccuracyConfig] = None
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, cfg: Dict[str, Any]) -> "MultiNodeConfig":
|
||||
"""Create a MultiNodeConfig from raw dict."""
|
||||
num_nodes = cfg.get("num_nodes", 2)
|
||||
is_disaggregate_prefill = cfg.get("disaggregate_prefill", False)
|
||||
node_index = int(os.getenv("LWS_WORKER_INDEX", 0))
|
||||
is_leader = node_index == 0
|
||||
|
||||
# server config
|
||||
server_cfg_data = cfg.get("server_parameters", {})
|
||||
if not server_cfg_data:
|
||||
raise ValueError("Missing required key: 'server_parameters'")
|
||||
|
||||
role_key = "leader_config" if is_leader else "worker_config"
|
||||
server_cfg_dict = server_cfg_data.get(role_key, {})
|
||||
server_cfg: ServerConfig = ServerConfig.from_config(server_cfg_dict)
|
||||
|
||||
if cfg.get("enable_multithread_load"):
|
||||
server_cfg.model_loader_extra_config = { # type: ignore[attr-defined]
|
||||
"enable_multithread_load": True,
|
||||
"num_threads": 8,
|
||||
}
|
||||
|
||||
# distributed param init
|
||||
server_cfg.init_dp_param(
|
||||
is_leader=is_leader,
|
||||
is_disaggregate_prefill=is_disaggregate_prefill,
|
||||
dp_size=server_cfg.data_parallel_size,
|
||||
world_size=num_nodes,
|
||||
)
|
||||
|
||||
perf_cfg: Optional[PerfConfig] = (PerfConfig.from_config(
|
||||
cfg.get("client_parameters", {})) if cfg.get("client_parameters")
|
||||
else None)
|
||||
|
||||
# network info
|
||||
leader_cfg = server_cfg_data.get("leader_config", {})
|
||||
server_host = get_leader_ip()
|
||||
server_port = (get_avaliable_port() if is_disaggregate_prefill else
|
||||
leader_cfg.get("port", 8080))
|
||||
|
||||
return cls(
|
||||
test_name=str(cfg.get("test_name", "Unnamed Test")),
|
||||
disaggregate_prefill=is_disaggregate_prefill,
|
||||
enable_multithread_load=cfg.get("enable_multithread_load", False),
|
||||
world_size=num_nodes,
|
||||
server_config=server_cfg,
|
||||
perf_config=perf_cfg,
|
||||
server_host=server_host,
|
||||
server_port=server_port,
|
||||
)
|
||||
|
||||
|
||||
# =========================
|
||||
# Loader
|
||||
# =========================
|
||||
def load_configs(
|
||||
path: Union[str, Path] = CONFIG_PATH) -> List[MultiNodeConfig]:
|
||||
"""Load one or multiple configs from JSON file."""
|
||||
path = Path(path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"Configuration file not found: {path}")
|
||||
|
||||
raw = json.loads(path.read_text())
|
||||
configs_data = raw if isinstance(raw, list) else [raw]
|
||||
|
||||
configs = []
|
||||
for idx, item in enumerate(configs_data):
|
||||
try:
|
||||
configs.append(MultiNodeConfig.from_config(item))
|
||||
except Exception as e:
|
||||
LOG.exception(f"Failed to parse config #{idx}: {e}")
|
||||
raise
|
||||
return configs
|
68
tests/e2e/multi_node/config/utils.py
Normal file
68
tests/e2e/multi_node/config/utils.py
Normal file
@ -0,0 +1,68 @@
|
||||
import os
|
||||
import socket
|
||||
import subprocess
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import psutil
|
||||
|
||||
|
||||
def get_leader_ip():
|
||||
leader_dns = os.getenv("LWS_LEADER_ADDRESS")
|
||||
assert leader_dns is not None, "cannot find leader address"
|
||||
return socket.gethostbyname(leader_dns)
|
||||
|
||||
|
||||
def get_avaliable_port(start_port: int = 6000, end_port: int = 7000) -> int:
|
||||
import socket
|
||||
for port in range(start_port, end_port):
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
try:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
continue
|
||||
raise RuntimeError("No available port found")
|
||||
|
||||
|
||||
def get_net_interface(ip: Optional[str] = None) -> Optional[Tuple[str, str]]:
|
||||
"""
|
||||
Returns specified IP and its network interface.
|
||||
If no IP is provided, uses the first from hostname -I.
|
||||
"""
|
||||
if ip is None:
|
||||
ips = subprocess.check_output(["hostname",
|
||||
"-I"]).decode().strip().split()
|
||||
if not ips:
|
||||
return None
|
||||
ip = ips[0]
|
||||
|
||||
for iface, addrs in psutil.net_if_addrs().items():
|
||||
for addr in addrs:
|
||||
if addr.family == socket.AF_INET and addr.address == ip:
|
||||
return ip, iface
|
||||
return None
|
||||
|
||||
|
||||
def get_default_envs() -> dict[str, str]:
|
||||
"""Returns default network and system environment variables."""
|
||||
result = get_net_interface()
|
||||
if result is None:
|
||||
raise RuntimeError("Failed to get default network IP and interface")
|
||||
ip, nic_name = result
|
||||
|
||||
return {
|
||||
"HCCL_IF_IP": ip,
|
||||
"GLOO_SOCKET_IFNAME": nic_name,
|
||||
"TP_SOCKET_IFNAME": nic_name,
|
||||
"HCCL_SOCKET_IFNAME": nic_name,
|
||||
"OMP_PROC_BIND": "false",
|
||||
"OMP_NUM_THREADS": "100",
|
||||
"VLLM_USE_V1": "1",
|
||||
"HCCL_BUFFSIZE": "1024",
|
||||
"VLLM_USE_MODELSCOPE": "true",
|
||||
"NUMEXPR_MAX_THREADS": "100",
|
||||
}
|
||||
|
||||
|
||||
def generate_ranktable():
|
||||
pass
|
142
tests/e2e/multi_node/scripts/lws.yaml
Normal file
142
tests/e2e/multi_node/scripts/lws.yaml
Normal file
@ -0,0 +1,142 @@
|
||||
apiVersion: leaderworkerset.x-k8s.io/v1
|
||||
kind: LeaderWorkerSet
|
||||
metadata:
|
||||
name: vllm
|
||||
namespace: vllm-project
|
||||
spec:
|
||||
replicas: 1
|
||||
leaderWorkerTemplate:
|
||||
size: 2
|
||||
restartPolicy: RecreateGroupOnPodRestart
|
||||
leaderTemplate:
|
||||
metadata:
|
||||
labels:
|
||||
role: leader
|
||||
spec:
|
||||
containers:
|
||||
- name: vllm-leader
|
||||
image: m.daocloud.io/quay.io/ascend/cann:8.2.rc1-a3-ubuntu22.04-py3.11
|
||||
env:
|
||||
- name: VLLM_USE_MODELSCOPE
|
||||
value: "true"
|
||||
- name: WORKSPACE
|
||||
value: "/root/workspace"
|
||||
- name: WORLD_SIZE
|
||||
value: "2"
|
||||
- name: NPU_PER_NODE
|
||||
value: "16"
|
||||
# Set vLLM version and vLLM-Ascend version here, once there is a new release, update here.
|
||||
- name: VLLM_VERSION
|
||||
value: "v0.11.0"
|
||||
- name: VLLM_ASCEND_VERSION
|
||||
value: "main"
|
||||
- name: MOONCAKE_VERSION
|
||||
value: "06cc217504a6f1b0cdaa26b096b985651b262748"
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- |
|
||||
bash /root/.cache/tests/run.sh
|
||||
resources:
|
||||
limits:
|
||||
huawei.com/ascend-1980: "16"
|
||||
memory: 512Gi
|
||||
ephemeral-storage: 100Gi
|
||||
requests:
|
||||
huawei.com/ascend-1980: "16"
|
||||
ephemeral-storage: 100Gi
|
||||
cpu: 125
|
||||
ports:
|
||||
- containerPort: 8080
|
||||
# readinessProbe:
|
||||
# tcpSocket:
|
||||
# port: 8080
|
||||
# initialDelaySeconds: 15
|
||||
# periodSeconds: 10
|
||||
volumeMounts:
|
||||
- mountPath: /root/.cache
|
||||
name: shared-volume
|
||||
- mountPath: /usr/local/Ascend/driver/tools
|
||||
name: driver-tools
|
||||
- mountPath: /dev/shm
|
||||
name: dshm
|
||||
volumes:
|
||||
- name: dshm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
sizeLimit: 15Gi
|
||||
- name: shared-volume
|
||||
persistentVolumeClaim:
|
||||
claimName: nv-action-vllm-benchmarks-v2
|
||||
- name: driver-tools
|
||||
hostPath:
|
||||
path: /usr/local/Ascend/driver/tools
|
||||
workerTemplate:
|
||||
spec:
|
||||
containers:
|
||||
- name: vllm-worker
|
||||
image: m.daocloud.io/quay.io/ascend/cann:8.2.rc1-a3-ubuntu22.04-py3.11
|
||||
env:
|
||||
- name: VLLM_USE_MODELSCOPE
|
||||
value: "true"
|
||||
- name: WORKSPACE
|
||||
value: "/root/workspace"
|
||||
- name: WORLD_SIZE
|
||||
value: "2"
|
||||
- name: NPU_PER_NODE
|
||||
value: "16"
|
||||
# Set vLLM version and vLLM-Ascend version here, once there is a new release, update here.
|
||||
- name: VLLM_VERSION
|
||||
value: "v0.11.0"
|
||||
- name: VLLM_ASCEND_VERSION
|
||||
value: "main"
|
||||
- name: MOONCAKE_VERSION
|
||||
value: "06cc217504a6f1b0cdaa26b096b985651b262748"
|
||||
command:
|
||||
- sh
|
||||
- -c
|
||||
- |
|
||||
bash /root/.cache/tests/run.sh
|
||||
resources:
|
||||
limits:
|
||||
huawei.com/ascend-1980: "16"
|
||||
memory: 512Gi
|
||||
ephemeral-storage: 100Gi
|
||||
requests:
|
||||
huawei.com/ascend-1980: "16"
|
||||
ephemeral-storage: 100Gi
|
||||
cpu: 125
|
||||
volumeMounts:
|
||||
- mountPath: /root/.cache
|
||||
name: shared-volume
|
||||
- mountPath: /usr/local/Ascend/driver/tools
|
||||
name: driver-tools
|
||||
- mountPath: /dev/shm
|
||||
name: dshm
|
||||
volumes:
|
||||
- name: dshm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
sizeLimit: 15Gi
|
||||
- name: shared-volume
|
||||
persistentVolumeClaim:
|
||||
claimName: nv-action-vllm-benchmarks-v2
|
||||
- name: driver-tools
|
||||
hostPath:
|
||||
path: /usr/local/Ascend/driver/tools
|
||||
---
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: vllm-leader
|
||||
namespace: vllm-project
|
||||
spec:
|
||||
ports:
|
||||
- name: http
|
||||
port: 8080
|
||||
protocol: TCP
|
||||
targetPort: 8080
|
||||
selector:
|
||||
leaderworkerset.sigs.k8s.io/name: vllm
|
||||
role: leader
|
||||
type: ClusterIP
|
96
tests/e2e/multi_node/scripts/run.sh
Normal file
96
tests/e2e/multi_node/scripts/run.sh
Normal file
@ -0,0 +1,96 @@
|
||||
#!/bin/bash
|
||||
set -euo pipefail
|
||||
|
||||
export SRC_DIR="$WORKSPACE/source_code"
|
||||
|
||||
check_npu_info() {
|
||||
echo "====> Check NPU info"
|
||||
npu-smi info
|
||||
cat "/usr/local/Ascend/ascend-toolkit/latest/$(uname -i)-linux/ascend_toolkit_install.info"
|
||||
}
|
||||
|
||||
check_and_config() {
|
||||
echo "====> Configure mirrors and git proxy"
|
||||
git config --global url."https://gh-proxy.test.osinfra.cn/https://github.com/".insteadOf "https://github.com/"
|
||||
pip config set global.index-url https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
|
||||
export PIP_EXTRA_INDEX_URL=https://mirrors.huaweicloud.com/ascend/repos/pypi
|
||||
}
|
||||
|
||||
checkout_src() {
|
||||
echo "====> Checkout source code"
|
||||
mkdir -p "$SRC_DIR"
|
||||
|
||||
# vllm-ascend
|
||||
if [ ! -d "$SRC_DIR/vllm-ascend" ]; then
|
||||
git clone --depth 1 -b $VLLM_ASCEND_VERSION https://github.com/vllm-project/vllm-ascend.git "$SRC_DIR/vllm-ascend"
|
||||
fi
|
||||
|
||||
# vllm
|
||||
if [ ! -d "$SRC_DIR/vllm" ]; then
|
||||
git clone -b $VLLM_VERSION https://github.com/vllm-project/vllm.git "$SRC_DIR/vllm"
|
||||
fi
|
||||
|
||||
#mooncake
|
||||
if [ ! -d "$SRC_DIR/Mooncake" ]; then
|
||||
git clone https://github.com/kvcache-ai/Mooncake.git "$SRC_DIR/Mooncake"
|
||||
cd "$SRC_DIR/Mooncake"
|
||||
git checkout 06cc217504a6f1b0cdaa26b096b985651b262748
|
||||
cd -
|
||||
fi
|
||||
}
|
||||
|
||||
install_sys_dependencies() {
|
||||
echo "====> Install system dependencies"
|
||||
apt-get update -y
|
||||
|
||||
DEP_LIST=()
|
||||
while IFS= read -r line; do
|
||||
[[ -n "$line" && ! "$line" =~ ^# ]] && DEP_LIST+=("$line")
|
||||
done < "$SRC_DIR/vllm-ascend/packages.txt"
|
||||
|
||||
apt-get install -y "${DEP_LIST[@]}" gcc g++ cmake libnuma-dev iproute2
|
||||
}
|
||||
|
||||
install_vllm() {
|
||||
echo "====> Install vllm and vllm-ascend"
|
||||
VLLM_TARGET_DEVICE=empty pip install -e "$SRC_DIR/vllm"
|
||||
pip install -e "$SRC_DIR/vllm-ascend"
|
||||
pip install modelscope
|
||||
# Install for pytest
|
||||
pip install -r "$SRC_DIR/vllm-ascend/requirements-dev.txt"
|
||||
}
|
||||
|
||||
install_mooncake() {
|
||||
echo "====> Install mooncake"
|
||||
apt-get update
|
||||
apt install -y --allow-change-held-packages python3 python-is-python3
|
||||
apt-get install -y --no-install-recommends mpich libmpich-dev
|
||||
cd $SRC_DIR/Mooncake
|
||||
sed -i '/option(USE_ASCEND_DIRECT)/s/OFF/ON/' mooncake-common/common.cmake
|
||||
bash dependencies.sh --yes
|
||||
mkdir build
|
||||
cd -
|
||||
cd $SRC_DIR/Mooncake/build
|
||||
cmake ..
|
||||
make -j
|
||||
make install
|
||||
cd -
|
||||
}
|
||||
|
||||
run_tests() {
|
||||
echo "====> Run tests"
|
||||
cd "$SRC_DIR/vllm-ascend"
|
||||
pytest -sv tests/e2e/multi_node/test_multi_dp.py
|
||||
}
|
||||
|
||||
main() {
|
||||
check_npu_info
|
||||
check_and_config
|
||||
checkout_src
|
||||
install_sys_dependencies
|
||||
install_vllm
|
||||
#install_mooncake
|
||||
run_tests
|
||||
}
|
||||
|
||||
main "$@"
|
49
tests/e2e/multi_node/test_multi_dp.py
Normal file
49
tests/e2e/multi_node/test_multi_dp.py
Normal file
@ -0,0 +1,49 @@
|
||||
import subprocess
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.e2e.conftest import RemoteOpenAIServer
|
||||
from tests.e2e.multi_node.config.multi_node_config import (MultiNodeConfig,
|
||||
load_configs)
|
||||
from tests.e2e.multi_node.config.utils import get_default_envs
|
||||
|
||||
configs = load_configs()
|
||||
|
||||
|
||||
def get_benchmark_cmd(model: str, base_url: str, args: list[str]):
|
||||
"""vllm bench serve <model> --base-url <url> ..."""
|
||||
return [
|
||||
"vllm", "bench", "serve", "--model", model, "--base-url", base_url
|
||||
] + args
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config", configs)
|
||||
def test_multi_dp(config: MultiNodeConfig) -> None:
|
||||
env_dict = get_default_envs()
|
||||
|
||||
server_config = config.server_config
|
||||
perf_config = config.perf_config
|
||||
model_name = server_config.model
|
||||
assert model_name is not None, "Model name must be specified"
|
||||
|
||||
server_args = server_config.to_list()
|
||||
|
||||
with RemoteOpenAIServer(
|
||||
model_name,
|
||||
config.server_host,
|
||||
config.server_port,
|
||||
server_args,
|
||||
env_dict=env_dict,
|
||||
auto_port=False,
|
||||
seed=1024,
|
||||
max_wait_seconds=1000,
|
||||
) as remote_server:
|
||||
base_url = remote_server.url_root
|
||||
assert perf_config is not None, "Perf config must be specified for perf tests"
|
||||
perf_cmd = get_benchmark_cmd(server_config.model, base_url,
|
||||
perf_config.to_list())
|
||||
if server_config.headless:
|
||||
remote_server.hang_until_terminated()
|
||||
else:
|
||||
# run perf benchmark
|
||||
subprocess.run(perf_cmd, check=True)
|
Reference in New Issue
Block a user