mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Bugfix] Fix Ray Metrics API usage (#6354)
This commit is contained in:
@ -1,11 +1,13 @@
|
||||
from typing import List
|
||||
|
||||
import pytest
|
||||
import ray
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.metrics import RayPrometheusStatLogger
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
MODELS = [
|
||||
@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
|
||||
labels)
|
||||
assert (
|
||||
metric_value == num_requests), "Metrics should be collected"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half"])
|
||||
@pytest.mark.parametrize("max_tokens", [16])
|
||||
def test_engine_log_metrics_ray(
|
||||
example_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
# This test is quite weak - it only checks that we can use
|
||||
# RayPrometheusStatLogger without exceptions.
|
||||
# Checking whether the metrics are actually emitted is unfortunately
|
||||
# non-trivial.
|
||||
|
||||
# We have to run in a Ray task for Ray metrics to be emitted correctly
|
||||
@ray.remote(num_gpus=1)
|
||||
def _inner():
|
||||
|
||||
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._i = 0
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def log(self, *args, **kwargs):
|
||||
self._i += 1
|
||||
return super().log(*args, **kwargs)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model,
|
||||
dtype=dtype,
|
||||
disable_log_stats=False,
|
||||
)
|
||||
engine = LLMEngine.from_engine_args(engine_args)
|
||||
logger = _RayPrometheusStatLogger(
|
||||
local_interval=0.5,
|
||||
labels=dict(model_name=engine.model_config.served_model_name),
|
||||
max_model_len=engine.model_config.max_model_len)
|
||||
engine.add_logger("ray", logger)
|
||||
for i, prompt in enumerate(example_prompts):
|
||||
engine.add_request(
|
||||
f"request-id-{i}",
|
||||
prompt,
|
||||
SamplingParams(max_tokens=max_tokens),
|
||||
)
|
||||
while engine.has_unfinished_requests():
|
||||
engine.step()
|
||||
assert logger._i > 0, ".log must be called at least once"
|
||||
|
||||
ray.get(_inner.remote())
|
||||
|
@ -12,6 +12,7 @@ from vllm.core.scheduler import SchedulerOutputs
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_timeout import asyncio_timeout
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.metrics import StatLoggerBase
|
||||
from vllm.executor.ray_utils import initialize_ray_cluster, ray
|
||||
from vllm.inputs import LLMInputs, PromptInputs
|
||||
from vllm.logger import init_logger
|
||||
@ -389,6 +390,7 @@ class AsyncLLMEngine:
|
||||
engine_args: AsyncEngineArgs,
|
||||
start_engine_loop: bool = True,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "AsyncLLMEngine":
|
||||
"""Creates an async LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
@ -451,6 +453,7 @@ class AsyncLLMEngine:
|
||||
max_log_len=engine_args.max_log_len,
|
||||
start_engine_loop=start_engine_loop,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
return engine
|
||||
|
||||
@ -957,3 +960,19 @@ class AsyncLLMEngine:
|
||||
)
|
||||
else:
|
||||
return self.engine.is_tracing_enabled()
|
||||
|
||||
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
|
||||
if self.engine_use_ray:
|
||||
ray.get(
|
||||
self.engine.add_logger.remote( # type: ignore
|
||||
logger_name=logger_name, logger=logger))
|
||||
else:
|
||||
self.engine.add_logger(logger_name=logger_name, logger=logger)
|
||||
|
||||
def remove_logger(self, logger_name: str) -> None:
|
||||
if self.engine_use_ray:
|
||||
ray.get(
|
||||
self.engine.remove_logger.remote( # type: ignore
|
||||
logger_name=logger_name))
|
||||
else:
|
||||
self.engine.remove_logger(logger_name=logger_name)
|
||||
|
@ -379,6 +379,7 @@ class LLMEngine:
|
||||
cls,
|
||||
engine_args: EngineArgs,
|
||||
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
|
||||
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
|
||||
) -> "LLMEngine":
|
||||
"""Creates an LLM engine from the engine arguments."""
|
||||
# Create the engine configs.
|
||||
@ -423,6 +424,7 @@ class LLMEngine:
|
||||
executor_class=executor_class,
|
||||
log_stats=not engine_args.disable_log_stats,
|
||||
usage_context=usage_context,
|
||||
stat_loggers=stat_loggers,
|
||||
)
|
||||
return engine
|
||||
|
||||
|
@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
|
||||
# begin-metrics-definitions
|
||||
class Metrics:
|
||||
labelname_finish_reason = "finished_reason"
|
||||
_base_library = prometheus_client
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
# Unregister any existing vLLM collectors
|
||||
self._unregister_vllm_metrics()
|
||||
|
||||
# Config Information
|
||||
self.info_cache_config = prometheus_client.Info(
|
||||
name='vllm:cache_config',
|
||||
documentation='information of cache_config')
|
||||
self._create_info_cache_config()
|
||||
|
||||
# System stats
|
||||
# Scheduler State
|
||||
self.gauge_scheduler_running = self._base_library.Gauge(
|
||||
self.gauge_scheduler_running = self._gauge_cls(
|
||||
name="vllm:num_requests_running",
|
||||
documentation="Number of requests currently running on GPU.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_waiting = self._base_library.Gauge(
|
||||
self.gauge_scheduler_waiting = self._gauge_cls(
|
||||
name="vllm:num_requests_waiting",
|
||||
documentation="Number of requests waiting to be processed.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_scheduler_swapped = self._base_library.Gauge(
|
||||
self.gauge_scheduler_swapped = self._gauge_cls(
|
||||
name="vllm:num_requests_swapped",
|
||||
documentation="Number of requests swapped to CPU.",
|
||||
labelnames=labelnames)
|
||||
# KV Cache Usage in %
|
||||
self.gauge_gpu_cache_usage = self._base_library.Gauge(
|
||||
self.gauge_gpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:gpu_cache_usage_perc",
|
||||
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_cpu_cache_usage = self._base_library.Gauge(
|
||||
self.gauge_cpu_cache_usage = self._gauge_cls(
|
||||
name="vllm:cpu_cache_usage_perc",
|
||||
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
|
||||
labelnames=labelnames)
|
||||
|
||||
# Iteration stats
|
||||
self.counter_num_preemption = self._base_library.Counter(
|
||||
self.counter_num_preemption = self._counter_cls(
|
||||
name="vllm:num_preemptions_total",
|
||||
documentation="Cumulative number of preemption from the engine.",
|
||||
labelnames=labelnames)
|
||||
self.counter_prompt_tokens = self._base_library.Counter(
|
||||
self.counter_prompt_tokens = self._counter_cls(
|
||||
name="vllm:prompt_tokens_total",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.counter_generation_tokens = self._base_library.Counter(
|
||||
self.counter_generation_tokens = self._counter_cls(
|
||||
name="vllm:generation_tokens_total",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames)
|
||||
self.histogram_time_to_first_token = self._base_library.Histogram(
|
||||
self.histogram_time_to_first_token = self._histogram_cls(
|
||||
name="vllm:time_to_first_token_seconds",
|
||||
documentation="Histogram of time to first token in seconds.",
|
||||
labelnames=labelnames,
|
||||
@ -86,7 +86,7 @@ class Metrics:
|
||||
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
|
||||
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
|
||||
])
|
||||
self.histogram_time_per_output_token = self._base_library.Histogram(
|
||||
self.histogram_time_per_output_token = self._histogram_cls(
|
||||
name="vllm:time_per_output_token_seconds",
|
||||
documentation="Histogram of time per output token in seconds.",
|
||||
labelnames=labelnames,
|
||||
@ -97,83 +97,157 @@ class Metrics:
|
||||
|
||||
# Request stats
|
||||
# Latency
|
||||
self.histogram_e2e_time_request = self._base_library.Histogram(
|
||||
self.histogram_e2e_time_request = self._histogram_cls(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of end to end request latency in seconds.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
|
||||
# Metadata
|
||||
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
|
||||
self.histogram_num_prompt_tokens_request = self._histogram_cls(
|
||||
name="vllm:request_prompt_tokens",
|
||||
documentation="Number of prefill tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_num_generation_tokens_request = \
|
||||
self._base_library.Histogram(
|
||||
self._histogram_cls(
|
||||
name="vllm:request_generation_tokens",
|
||||
documentation="Number of generation tokens processed.",
|
||||
labelnames=labelnames,
|
||||
buckets=build_1_2_5_buckets(max_model_len),
|
||||
)
|
||||
self.histogram_best_of_request = self._base_library.Histogram(
|
||||
self.histogram_best_of_request = self._histogram_cls(
|
||||
name="vllm:request_params_best_of",
|
||||
documentation="Histogram of the best_of request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.histogram_n_request = self._base_library.Histogram(
|
||||
self.histogram_n_request = self._histogram_cls(
|
||||
name="vllm:request_params_n",
|
||||
documentation="Histogram of the n request parameter.",
|
||||
labelnames=labelnames,
|
||||
buckets=[1, 2, 5, 10, 20],
|
||||
)
|
||||
self.counter_request_success = self._base_library.Counter(
|
||||
self.counter_request_success = self._counter_cls(
|
||||
name="vllm:request_success_total",
|
||||
documentation="Count of successfully processed requests.",
|
||||
labelnames=labelnames + [Metrics.labelname_finish_reason])
|
||||
|
||||
# Speculatie decoding stats
|
||||
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
|
||||
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
|
||||
name="vllm:spec_decode_draft_acceptance_rate",
|
||||
documentation="Speulative token acceptance rate.",
|
||||
labelnames=labelnames)
|
||||
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
|
||||
self.gauge_spec_decode_efficiency = self._gauge_cls(
|
||||
name="vllm:spec_decode_efficiency",
|
||||
documentation="Speculative decoding system efficiency.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_accepted_tokens = (
|
||||
self._base_library.Counter(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames))
|
||||
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
|
||||
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
|
||||
name="vllm:spec_decode_num_accepted_tokens_total",
|
||||
documentation="Number of accepted tokens.",
|
||||
labelnames=labelnames))
|
||||
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
|
||||
name="vllm:spec_decode_num_draft_tokens_total",
|
||||
documentation="Number of draft tokens.",
|
||||
labelnames=labelnames)
|
||||
self.counter_spec_decode_num_emitted_tokens = (
|
||||
self._base_library.Counter(
|
||||
name="vllm:spec_decode_num_emitted_tokens_total",
|
||||
documentation="Number of emitted tokens.",
|
||||
labelnames=labelnames))
|
||||
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
|
||||
name="vllm:spec_decode_num_emitted_tokens_total",
|
||||
documentation="Number of emitted tokens.",
|
||||
labelnames=labelnames))
|
||||
|
||||
# Deprecated in favor of vllm:prompt_tokens_total
|
||||
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
|
||||
self.gauge_avg_prompt_throughput = self._gauge_cls(
|
||||
name="vllm:avg_prompt_throughput_toks_per_s",
|
||||
documentation="Average prefill throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
# Deprecated in favor of vllm:generation_tokens_total
|
||||
self.gauge_avg_generation_throughput = self._base_library.Gauge(
|
||||
self.gauge_avg_generation_throughput = self._gauge_cls(
|
||||
name="vllm:avg_generation_throughput_toks_per_s",
|
||||
documentation="Average generation throughput in tokens/s.",
|
||||
labelnames=labelnames,
|
||||
)
|
||||
|
||||
def _create_info_cache_config(self) -> None:
|
||||
# Config Information
|
||||
self.info_cache_config = prometheus_client.Info(
|
||||
name='vllm:cache_config',
|
||||
documentation='information of cache_config')
|
||||
|
||||
def _unregister_vllm_metrics(self) -> None:
|
||||
for collector in list(self._base_library.REGISTRY._collector_to_names):
|
||||
for collector in list(prometheus_client.REGISTRY._collector_to_names):
|
||||
if hasattr(collector, "_name") and "vllm" in collector._name:
|
||||
self._base_library.REGISTRY.unregister(collector)
|
||||
prometheus_client.REGISTRY.unregister(collector)
|
||||
|
||||
|
||||
# end-metrics-definitions
|
||||
|
||||
|
||||
class _RayGaugeWrapper:
|
||||
"""Wraps around ray.util.metrics.Gauge to provide same API as
|
||||
prometheus_client.Gauge"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: str = "",
|
||||
labelnames: Optional[List[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self._gauge = ray_metrics.Gauge(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._gauge.set_default_tags(labels)
|
||||
return self
|
||||
|
||||
def set(self, value: Union[int, float]):
|
||||
return self._gauge.set(value)
|
||||
|
||||
|
||||
class _RayCounterWrapper:
|
||||
"""Wraps around ray.util.metrics.Counter to provide same API as
|
||||
prometheus_client.Counter"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: str = "",
|
||||
labelnames: Optional[List[str]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self._counter = ray_metrics.Counter(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._counter.set_default_tags(labels)
|
||||
return self
|
||||
|
||||
def inc(self, value: Union[int, float] = 1.0):
|
||||
if value == 0:
|
||||
return
|
||||
return self._counter.inc(value)
|
||||
|
||||
|
||||
class _RayHistogramWrapper:
|
||||
"""Wraps around ray.util.metrics.Histogram to provide same API as
|
||||
prometheus_client.Histogram"""
|
||||
|
||||
def __init__(self,
|
||||
name: str,
|
||||
documentation: str = "",
|
||||
labelnames: Optional[List[str]] = None,
|
||||
buckets: Optional[List[float]] = None):
|
||||
labelnames_tuple = tuple(labelnames) if labelnames else None
|
||||
self._histogram = ray_metrics.Histogram(name=name,
|
||||
description=documentation,
|
||||
tag_keys=labelnames_tuple,
|
||||
boundaries=buckets)
|
||||
|
||||
def labels(self, **labels):
|
||||
self._histogram.set_default_tags(labels)
|
||||
return self
|
||||
|
||||
def observe(self, value: Union[int, float]):
|
||||
return self._histogram.observe(value)
|
||||
|
||||
|
||||
class RayMetrics(Metrics):
|
||||
@ -181,7 +255,9 @@ class RayMetrics(Metrics):
|
||||
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
|
||||
Provides the same metrics as Metrics but uses Ray's util.metrics library.
|
||||
"""
|
||||
_base_library = ray_metrics
|
||||
_gauge_cls = _RayGaugeWrapper
|
||||
_counter_cls = _RayCounterWrapper
|
||||
_histogram_cls = _RayHistogramWrapper
|
||||
|
||||
def __init__(self, labelnames: List[str], max_model_len: int):
|
||||
if ray_metrics is None:
|
||||
@ -192,8 +268,9 @@ class RayMetrics(Metrics):
|
||||
# No-op on purpose
|
||||
pass
|
||||
|
||||
|
||||
# end-metrics-definitions
|
||||
def _create_info_cache_config(self) -> None:
|
||||
# No-op on purpose
|
||||
pass
|
||||
|
||||
|
||||
def build_1_2_5_buckets(max_value: int) -> List[int]:
|
||||
@ -498,3 +575,6 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
class RayPrometheusStatLogger(PrometheusStatLogger):
|
||||
"""RayPrometheusStatLogger uses Ray metrics instead."""
|
||||
_metrics_cls = RayMetrics
|
||||
|
||||
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
|
||||
return None
|
||||
|
Reference in New Issue
Block a user