[Tests] Update online DP tests to verify that requests are balanced (#20157)

Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
Nick Hill
2025-07-03 07:49:13 +01:00
committed by GitHub
parent 363528de27
commit 67d25eca05
3 changed files with 170 additions and 9 deletions

View File

@ -38,7 +38,7 @@ def default_server_args():
]])
def server(default_server_args, request):
if request.param:
default_server_args.extend(request.param)
default_server_args = default_server_args + request.param
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
yield remote_server

View File

@ -2,10 +2,12 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import os
import re
import openai # use the official client for correctness check
import pytest
import pytest_asyncio
import requests
from tests.utils import RemoteOpenAIServer
@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b"
DP_SIZE = os.getenv("DP_SIZE", "1")
def get_prometheus_metrics(
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
Returns:
Dict mapping metric names to their values grouped by labels.
For example: {"vllm:request_success": {
"engine=0": 5.0, "engine=1": 3.0}
}
"""
try:
response = requests.get(server.url_for("metrics"), timeout=10)
response.raise_for_status()
metrics: dict[str, dict[str, float]] = {}
# Regex patterns for Prometheus metrics
metric_with_labels = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
metric_simple = re.compile(
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
for line in response.text.split('\n'):
line = line.strip()
# Skip comments and empty lines
if not line or line.startswith('#'):
continue
# Try to match metric with labels first
match = metric_with_labels.match(line)
if match:
metric_name, labels_part, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][f'{{{labels_part}}}'] = value
except ValueError:
continue
else:
# Try simple metric without labels
match = metric_simple.match(line)
if match:
metric_name, value_str = match.groups()
try:
value = float(value_str)
if metric_name not in metrics:
metrics[metric_name] = {}
metrics[metric_name][''] = value
except ValueError:
continue
return metrics
except Exception as e:
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
return {}
def get_engine_request_counts(
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
"""Extract request counts per engine from Prometheus metrics.
Returns:
Dict mapping engine indices to request counts.
For example: {"0": 15.0, "1": 12.0}
"""
engine_counts = {}
# Look for request success metrics with engine labels
success_metrics = metrics.get("vllm:request_success_total", {})
engine_pattern = re.compile(r'engine="([^"]*)"')
for labels, count in success_metrics.items():
# Extract engine ID from labels using regex
match = engine_pattern.search(labels)
if match:
engine_id = match.group(1)
if engine_id not in engine_counts:
engine_counts[engine_id] = 0.0
engine_counts[engine_id] += count
return engine_counts
def check_request_balancing(server: RemoteOpenAIServer):
"""Check request balancing via Prometheus metrics if DP_SIZE > 1.
Args:
server: The RemoteOpenAIServer instance
"""
dp_size = int(DP_SIZE)
if dp_size <= 1:
return
# Get metrics after all requests are completed
metrics = get_prometheus_metrics(server)
engine_counts = get_engine_request_counts(metrics)
# Check that multiple engines received requests
engines_with_requests = [
engine for engine, count in engine_counts.items() if count > 0
]
assert len(engines_with_requests) == dp_size, (
f"Expected requests to be distributed across multiple engines,"
f" but only engine(s) {engines_with_requests} received "
f"requests. Engine counts: {engine_counts}")
# Verify that the load is reasonably balanced
# (no engine should handle all requests)
total_requests = sum(engine_counts.values())
for count in engine_counts.values():
assert count > total_requests // (dp_size + 1), (
f"requests are imbalanced: {engine_counts}")
@pytest.fixture(scope="module")
def default_server_args():
return [
@ -50,6 +168,7 @@ async def client(server):
[MODEL_NAME],
)
async def test_single_completion(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None:
async def make_request():
@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI,
assert len(results) == num_requests
assert all(completion is not None for completion in results)
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)
@pytest.mark.asyncio
@pytest.mark.parametrize(
@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI,
[MODEL_NAME],
)
async def test_completion_streaming(client: openai.AsyncOpenAI,
server: RemoteOpenAIServer,
model_name: str) -> None:
prompt = "What is an LLM?"
@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
results
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
assert all(results), "Not all streaming requests completed successfully."
# Check request balancing via Prometheus metrics if DP_SIZE > 1
check_request_balancing(server)

View File

@ -4,24 +4,30 @@
import asyncio
import os
from contextlib import ExitStack
from dataclasses import dataclass
from typing import Optional
import pytest
from vllm import SamplingParams
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.inputs import PromptType
from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.engine.core_client import DPAsyncMPClient
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
DP_SIZE = int(os.getenv("DP_SIZE", 2))
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
enforce_eager=True,
disable_log_requests=True,
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
data_parallel_size=DP_SIZE,
)
if not current_platform.supports_v1(engine_args.create_model_config()):
@ -74,12 +80,32 @@ async def generate(
async def test_load(output_kind: RequestOutputKind,
data_parallel_backend: str):
stats_loggers = {}
@dataclass
class SimpleStatsLogger(StatLoggerBase):
init_count: int = 0
finished_req_count: int = 0
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
stats_loggers[engine_index] = self
def record(self, scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats]):
if iteration_stats:
self.finished_req_count += len(
iteration_stats.finished_requests)
def log_engine_initialized(self):
self.init_count += 1
with ExitStack() as after:
prompt = "This is a test of data parallel"
engine_args.data_parallel_backend = data_parallel_backend
engine = AsyncLLM.from_engine_args(engine_args)
engine = AsyncLLM.from_engine_args(engine_args,
stat_loggers=[SimpleStatsLogger])
after.callback(engine.shutdown)
NUM_REQUESTS = 100
@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
for request_id in request_ids:
tasks.append(
asyncio.create_task(
generate(engine,
request_id,
prompt,
output_kind,
NUM_EXPECTED_TOKENS,
data_parallel_rank=0)))
generate(engine, request_id, prompt, output_kind,
NUM_EXPECTED_TOKENS)))
# Short sleep to ensure that requests are distributed.
await asyncio.sleep(0.01)
# Confirm that we got all the EXPECTED tokens from the requests.
done, pending = await asyncio.wait(tasks,
return_when=asyncio.FIRST_EXCEPTION)
@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
assert not core_client.engines_running
assert not core_client.reqs_in_flight
# Check that requests were distributed between the engines
print(f"Stats loggers after test: {stats_loggers}")
assert len(stats_loggers) == DP_SIZE
assert stats_loggers[0].init_count == 1
for sl in stats_loggers.values():
slogger: SimpleStatsLogger = sl
assert slogger.finished_req_count > NUM_REQUESTS // (
DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"