Files
vllm-ascend/tests/e2e/multi_node/config/multi_node_config.py
Li Wang 9eb103607f [1/N][CI] Add multi node test (#3359)
### What this PR does / why we need it?
This pr purpose to add multi-node test, on the first step, add
`deepseek-v3` dp+tp+ep test
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

---------

Signed-off-by: wangli <wangli858794774@gmail.com>
2025-10-11 14:50:46 +08:00

205 lines
6.7 KiB
Python

import json
import logging
import os
from dataclasses import dataclass, field, fields
from pathlib import Path
from typing import Any, Dict, List, Optional, Type, TypeVar, Union
from tests.e2e.multi_node.config.utils import (get_avaliable_port,
get_leader_ip,
get_net_interface)
LOG = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
CONFIG_PATH = Path("tests/e2e/multi_node/config/config.json")
T = TypeVar("T", bound="BaseConfig")
# =========================
# Base Config
# =========================
@dataclass
class BaseConfig:
model: str = "vllm-ascend/DeepSeek-V3-W8A8"
_extra_fields: Optional[Dict[str, Any]] = None
@classmethod
def from_config(cls: Type[T], data: dict[str, Any]) -> T:
"""Create config instance from dict, keeping unknown fields."""
field_names = {f.name for f in fields(cls)}
valid_fields = {k: v for k, v in data.items() if k in field_names}
extra_fields = {k: v for k, v in data.items() if k not in field_names}
obj = cls(**valid_fields)
obj._extra_fields = extra_fields or {}
return obj
def to_list(self) -> List[str]:
"""Convert all fields (including _extra_fields) to CLI arguments."""
args: List[str] = []
all_items = {**vars(self), **(self._extra_fields or {})}
for key, value in all_items.items():
if key in ("model", "_extra_fields") or value in (None, "", [],
{}):
continue
key = key.replace("_", "-")
if isinstance(value, bool):
if value:
args.append(f"--{key}")
elif isinstance(value, dict):
args += [f"--{key}", json.dumps(value, ensure_ascii=False)]
else:
args += [f"--{key}", str(value)]
return args
# =========================
# Server Config
# =========================
@dataclass
class ServerConfig(BaseConfig):
host: str = "0.0.0.0"
port: int = 8080
trust_remote_code: bool = True
enable_expert_parallel: bool = True
gpu_memory_utilization: float = 0.9
headless: bool = False
quantization: Optional[str] = None
tensor_parallel_size: int = 8
max_model_len: int = 8192
max_num_batched_token: int = 8192
data_parallel_size: int = 4
data_parallel_size_local: int = 2
data_parallel_start_rank: int = 0
data_parallel_rpc_port: int = 13389
data_parallel_address: Optional[str] = None
kv_transfer_config: Optional[Dict[str, Any]] = None
additional_config: Optional[Dict[str, Any]] = None
def init_dp_param(
self,
is_leader: bool,
is_disaggregate_prefill: bool,
dp_size: int,
world_size: int,
) -> None:
"""Initialize distributed parallel parameters."""
iface = get_net_interface()
if iface is None:
raise RuntimeError("No available network interface found")
self.data_parallel_address = iface[0]
if is_disaggregate_prefill:
self.data_parallel_start_rank = 0
return
if not is_leader:
self.headless = True
self.data_parallel_start_rank = dp_size // world_size
self.data_parallel_address = get_leader_ip()
@dataclass
class PerfConfig(BaseConfig):
pass
@dataclass
class AccuracyConfig:
prompt: str
expected_output: str
# =========================
# MultiNode Config
# =========================
@dataclass
class MultiNodeConfig:
test_name: str = "Unnamed Test"
disaggregate_prefill: bool = False
enable_multithread_load: bool = True
world_size: int = 2
server_host: str = "0.0.0.0"
server_port: int = 8888
server_config: ServerConfig = field(default_factory=ServerConfig)
perf_config: Optional[PerfConfig] = None
accuracy_config: Optional[AccuracyConfig] = None
@classmethod
def from_config(cls, cfg: Dict[str, Any]) -> "MultiNodeConfig":
"""Create a MultiNodeConfig from raw dict."""
num_nodes = cfg.get("num_nodes", 2)
is_disaggregate_prefill = cfg.get("disaggregate_prefill", False)
node_index = int(os.getenv("LWS_WORKER_INDEX", 0))
is_leader = node_index == 0
# server config
server_cfg_data = cfg.get("server_parameters", {})
if not server_cfg_data:
raise ValueError("Missing required key: 'server_parameters'")
role_key = "leader_config" if is_leader else "worker_config"
server_cfg_dict = server_cfg_data.get(role_key, {})
server_cfg: ServerConfig = ServerConfig.from_config(server_cfg_dict)
if cfg.get("enable_multithread_load"):
server_cfg.model_loader_extra_config = { # type: ignore[attr-defined]
"enable_multithread_load": True,
"num_threads": 8,
}
# distributed param init
server_cfg.init_dp_param(
is_leader=is_leader,
is_disaggregate_prefill=is_disaggregate_prefill,
dp_size=server_cfg.data_parallel_size,
world_size=num_nodes,
)
perf_cfg: Optional[PerfConfig] = (PerfConfig.from_config(
cfg.get("client_parameters", {})) if cfg.get("client_parameters")
else None)
# network info
leader_cfg = server_cfg_data.get("leader_config", {})
server_host = get_leader_ip()
server_port = (get_avaliable_port() if is_disaggregate_prefill else
leader_cfg.get("port", 8080))
return cls(
test_name=str(cfg.get("test_name", "Unnamed Test")),
disaggregate_prefill=is_disaggregate_prefill,
enable_multithread_load=cfg.get("enable_multithread_load", False),
world_size=num_nodes,
server_config=server_cfg,
perf_config=perf_cfg,
server_host=server_host,
server_port=server_port,
)
# =========================
# Loader
# =========================
def load_configs(
path: Union[str, Path] = CONFIG_PATH) -> List[MultiNodeConfig]:
"""Load one or multiple configs from JSON file."""
path = Path(path)
if not path.exists():
raise FileNotFoundError(f"Configuration file not found: {path}")
raw = json.loads(path.read_text())
configs_data = raw if isinstance(raw, list) else [raw]
configs = []
for idx, item in enumerate(configs_data):
try:
configs.append(MultiNodeConfig.from_config(item))
except Exception as e:
LOG.exception(f"Failed to parse config #{idx}: {e}")
raise
return configs