mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Tests] Update online DP tests to verify that requests are balanced (#20157)
Signed-off-by: Nick Hill <nhill@redhat.com>
This commit is contained in:
@ -38,7 +38,7 @@ def default_server_args():
|
||||
]])
|
||||
def server(default_server_args, request):
|
||||
if request.param:
|
||||
default_server_args.extend(request.param)
|
||||
default_server_args = default_server_args + request.param
|
||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
||||
yield remote_server
|
||||
|
||||
|
@ -2,10 +2,12 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import os
|
||||
import re
|
||||
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
|
||||
@ -14,6 +16,122 @@ MODEL_NAME = "ibm-research/PowerMoE-3b"
|
||||
DP_SIZE = os.getenv("DP_SIZE", "1")
|
||||
|
||||
|
||||
def get_prometheus_metrics(
|
||||
server: RemoteOpenAIServer) -> dict[str, dict[str, float]]:
|
||||
"""Fetch and parse Prometheus metrics from the /metrics endpoint.
|
||||
|
||||
Returns:
|
||||
Dict mapping metric names to their values grouped by labels.
|
||||
For example: {"vllm:request_success": {
|
||||
"engine=0": 5.0, "engine=1": 3.0}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
response = requests.get(server.url_for("metrics"), timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
metrics: dict[str, dict[str, float]] = {}
|
||||
|
||||
# Regex patterns for Prometheus metrics
|
||||
metric_with_labels = re.compile(
|
||||
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\{([^}]*)\}\s+([\d\.\-\+e]+)$')
|
||||
metric_simple = re.compile(
|
||||
r'^([a-zA-Z_:][a-zA-Z0-9_:]*)\s+([\d\.\-\+e]+)$')
|
||||
|
||||
for line in response.text.split('\n'):
|
||||
line = line.strip()
|
||||
# Skip comments and empty lines
|
||||
if not line or line.startswith('#'):
|
||||
continue
|
||||
|
||||
# Try to match metric with labels first
|
||||
match = metric_with_labels.match(line)
|
||||
if match:
|
||||
metric_name, labels_part, value_str = match.groups()
|
||||
try:
|
||||
value = float(value_str)
|
||||
if metric_name not in metrics:
|
||||
metrics[metric_name] = {}
|
||||
metrics[metric_name][f'{{{labels_part}}}'] = value
|
||||
except ValueError:
|
||||
continue
|
||||
else:
|
||||
# Try simple metric without labels
|
||||
match = metric_simple.match(line)
|
||||
if match:
|
||||
metric_name, value_str = match.groups()
|
||||
try:
|
||||
value = float(value_str)
|
||||
if metric_name not in metrics:
|
||||
metrics[metric_name] = {}
|
||||
metrics[metric_name][''] = value
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return metrics
|
||||
except Exception as e:
|
||||
pytest.fail(f"Failed to fetch Prometheus metrics: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def get_engine_request_counts(
|
||||
metrics: dict[str, dict[str, float]]) -> dict[str, float]:
|
||||
"""Extract request counts per engine from Prometheus metrics.
|
||||
|
||||
Returns:
|
||||
Dict mapping engine indices to request counts.
|
||||
For example: {"0": 15.0, "1": 12.0}
|
||||
"""
|
||||
engine_counts = {}
|
||||
|
||||
# Look for request success metrics with engine labels
|
||||
success_metrics = metrics.get("vllm:request_success_total", {})
|
||||
engine_pattern = re.compile(r'engine="([^"]*)"')
|
||||
|
||||
for labels, count in success_metrics.items():
|
||||
# Extract engine ID from labels using regex
|
||||
match = engine_pattern.search(labels)
|
||||
if match:
|
||||
engine_id = match.group(1)
|
||||
if engine_id not in engine_counts:
|
||||
engine_counts[engine_id] = 0.0
|
||||
engine_counts[engine_id] += count
|
||||
|
||||
return engine_counts
|
||||
|
||||
|
||||
def check_request_balancing(server: RemoteOpenAIServer):
|
||||
"""Check request balancing via Prometheus metrics if DP_SIZE > 1.
|
||||
|
||||
Args:
|
||||
server: The RemoteOpenAIServer instance
|
||||
"""
|
||||
dp_size = int(DP_SIZE)
|
||||
if dp_size <= 1:
|
||||
return
|
||||
|
||||
# Get metrics after all requests are completed
|
||||
metrics = get_prometheus_metrics(server)
|
||||
engine_counts = get_engine_request_counts(metrics)
|
||||
|
||||
# Check that multiple engines received requests
|
||||
engines_with_requests = [
|
||||
engine for engine, count in engine_counts.items() if count > 0
|
||||
]
|
||||
assert len(engines_with_requests) == dp_size, (
|
||||
f"Expected requests to be distributed across multiple engines,"
|
||||
f" but only engine(s) {engines_with_requests} received "
|
||||
f"requests. Engine counts: {engine_counts}")
|
||||
|
||||
# Verify that the load is reasonably balanced
|
||||
# (no engine should handle all requests)
|
||||
total_requests = sum(engine_counts.values())
|
||||
|
||||
for count in engine_counts.values():
|
||||
assert count > total_requests // (dp_size + 1), (
|
||||
f"requests are imbalanced: {engine_counts}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def default_server_args():
|
||||
return [
|
||||
@ -50,6 +168,7 @@ async def client(server):
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str) -> None:
|
||||
|
||||
async def make_request():
|
||||
@ -97,6 +216,9 @@ async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
assert len(results) == num_requests
|
||||
assert all(completion is not None for completion in results)
|
||||
|
||||
# Check request balancing via Prometheus metrics if DP_SIZE > 1
|
||||
check_request_balancing(server)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
@ -104,6 +226,7 @@ async def test_single_completion(client: openai.AsyncOpenAI,
|
||||
[MODEL_NAME],
|
||||
)
|
||||
async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str) -> None:
|
||||
prompt = "What is an LLM?"
|
||||
|
||||
@ -170,3 +293,6 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
|
||||
results
|
||||
) == num_requests, f"Expected {num_requests} results, got {len(results)}"
|
||||
assert all(results), "Not all streaming requests completed successfully."
|
||||
|
||||
# Check request balancing via Prometheus metrics if DP_SIZE > 1
|
||||
check_request_balancing(server)
|
||||
|
@ -4,24 +4,30 @@
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import ExitStack
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.inputs import PromptType
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.engine.core_client import DPAsyncMPClient
|
||||
from vllm.v1.metrics.loggers import StatLoggerBase
|
||||
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
|
||||
|
||||
DP_SIZE = int(os.getenv("DP_SIZE", 2))
|
||||
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
enforce_eager=True,
|
||||
disable_log_requests=True,
|
||||
tensor_parallel_size=int(os.getenv("TP_SIZE", 1)),
|
||||
data_parallel_size=int(os.getenv("DP_SIZE", 2)),
|
||||
data_parallel_size=DP_SIZE,
|
||||
)
|
||||
|
||||
if not current_platform.supports_v1(engine_args.create_model_config()):
|
||||
@ -74,12 +80,32 @@ async def generate(
|
||||
async def test_load(output_kind: RequestOutputKind,
|
||||
data_parallel_backend: str):
|
||||
|
||||
stats_loggers = {}
|
||||
|
||||
@dataclass
|
||||
class SimpleStatsLogger(StatLoggerBase):
|
||||
init_count: int = 0
|
||||
finished_req_count: int = 0
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
stats_loggers[engine_index] = self
|
||||
|
||||
def record(self, scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats:
|
||||
self.finished_req_count += len(
|
||||
iteration_stats.finished_requests)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.init_count += 1
|
||||
|
||||
with ExitStack() as after:
|
||||
|
||||
prompt = "This is a test of data parallel"
|
||||
|
||||
engine_args.data_parallel_backend = data_parallel_backend
|
||||
engine = AsyncLLM.from_engine_args(engine_args)
|
||||
engine = AsyncLLM.from_engine_args(engine_args,
|
||||
stat_loggers=[SimpleStatsLogger])
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
NUM_REQUESTS = 100
|
||||
@ -92,12 +118,10 @@ async def test_load(output_kind: RequestOutputKind,
|
||||
for request_id in request_ids:
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
generate(engine,
|
||||
request_id,
|
||||
prompt,
|
||||
output_kind,
|
||||
NUM_EXPECTED_TOKENS,
|
||||
data_parallel_rank=0)))
|
||||
generate(engine, request_id, prompt, output_kind,
|
||||
NUM_EXPECTED_TOKENS)))
|
||||
# Short sleep to ensure that requests are distributed.
|
||||
await asyncio.sleep(0.01)
|
||||
# Confirm that we got all the EXPECTED tokens from the requests.
|
||||
done, pending = await asyncio.wait(tasks,
|
||||
return_when=asyncio.FIRST_EXCEPTION)
|
||||
@ -122,3 +146,14 @@ async def test_load(output_kind: RequestOutputKind,
|
||||
|
||||
assert not core_client.engines_running
|
||||
assert not core_client.reqs_in_flight
|
||||
|
||||
# Check that requests were distributed between the engines
|
||||
print(f"Stats loggers after test: {stats_loggers}")
|
||||
assert len(stats_loggers) == DP_SIZE
|
||||
assert stats_loggers[0].init_count == 1
|
||||
|
||||
for sl in stats_loggers.values():
|
||||
slogger: SimpleStatsLogger = sl
|
||||
|
||||
assert slogger.finished_req_count > NUM_REQUESTS // (
|
||||
DP_SIZE + 1), f"requests are imbalanced: {stats_loggers}"
|
||||
|
Reference in New Issue
Block a user