[Bugfix] Fix Ray Metrics API usage (#6354)

This commit is contained in:
Antoni Baum
2024-07-17 12:40:10 -07:00
committed by GitHub
parent a38524f338
commit 5f0b9933e6
4 changed files with 195 additions and 40 deletions

View File

@ -1,11 +1,13 @@
from typing import List
import pytest
import ray
from prometheus_client import REGISTRY
from vllm import EngineArgs, LLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.metrics import RayPrometheusStatLogger
from vllm.sampling_params import SamplingParams
MODELS = [
@ -241,3 +243,55 @@ def assert_metrics(engine: LLMEngine, disable_log_stats: bool,
labels)
assert (
metric_value == num_requests), "Metrics should be collected"
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [16])
def test_engine_log_metrics_ray(
example_prompts,
model: str,
dtype: str,
max_tokens: int,
) -> None:
# This test is quite weak - it only checks that we can use
# RayPrometheusStatLogger without exceptions.
# Checking whether the metrics are actually emitted is unfortunately
# non-trivial.
# We have to run in a Ray task for Ray metrics to be emitted correctly
@ray.remote(num_gpus=1)
def _inner():
class _RayPrometheusStatLogger(RayPrometheusStatLogger):
def __init__(self, *args, **kwargs):
self._i = 0
super().__init__(*args, **kwargs)
def log(self, *args, **kwargs):
self._i += 1
return super().log(*args, **kwargs)
engine_args = EngineArgs(
model=model,
dtype=dtype,
disable_log_stats=False,
)
engine = LLMEngine.from_engine_args(engine_args)
logger = _RayPrometheusStatLogger(
local_interval=0.5,
labels=dict(model_name=engine.model_config.served_model_name),
max_model_len=engine.model_config.max_model_len)
engine.add_logger("ray", logger)
for i, prompt in enumerate(example_prompts):
engine.add_request(
f"request-id-{i}",
prompt,
SamplingParams(max_tokens=max_tokens),
)
while engine.has_unfinished_requests():
engine.step()
assert logger._i > 0, ".log must be called at least once"
ray.get(_inner.remote())

View File

@ -12,6 +12,7 @@ from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
from vllm.engine.llm_engine import LLMEngine
from vllm.engine.metrics import StatLoggerBase
from vllm.executor.ray_utils import initialize_ray_cluster, ray
from vllm.inputs import LLMInputs, PromptInputs
from vllm.logger import init_logger
@ -389,6 +390,7 @@ class AsyncLLMEngine:
engine_args: AsyncEngineArgs,
start_engine_loop: bool = True,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
@ -451,6 +453,7 @@ class AsyncLLMEngine:
max_log_len=engine_args.max_log_len,
start_engine_loop=start_engine_loop,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine
@ -957,3 +960,19 @@ class AsyncLLMEngine:
)
else:
return self.engine.is_tracing_enabled()
def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None:
if self.engine_use_ray:
ray.get(
self.engine.add_logger.remote( # type: ignore
logger_name=logger_name, logger=logger))
else:
self.engine.add_logger(logger_name=logger_name, logger=logger)
def remove_logger(self, logger_name: str) -> None:
if self.engine_use_ray:
ray.get(
self.engine.remove_logger.remote( # type: ignore
logger_name=logger_name))
else:
self.engine.remove_logger(logger_name=logger_name)

View File

@ -379,6 +379,7 @@ class LLMEngine:
cls,
engine_args: EngineArgs,
usage_context: UsageContext = UsageContext.ENGINE_CONTEXT,
stat_loggers: Optional[Dict[str, StatLoggerBase]] = None,
) -> "LLMEngine":
"""Creates an LLM engine from the engine arguments."""
# Create the engine configs.
@ -423,6 +424,7 @@ class LLMEngine:
executor_class=executor_class,
log_stats=not engine_args.disable_log_stats,
usage_context=usage_context,
stat_loggers=stat_loggers,
)
return engine

View File

@ -30,55 +30,55 @@ prometheus_client.disable_created_metrics()
# begin-metrics-definitions
class Metrics:
labelname_finish_reason = "finished_reason"
_base_library = prometheus_client
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
def __init__(self, labelnames: List[str], max_model_len: int):
# Unregister any existing vLLM collectors
self._unregister_vllm_metrics()
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
self._create_info_cache_config()
# System stats
# Scheduler State
self.gauge_scheduler_running = self._base_library.Gauge(
self.gauge_scheduler_running = self._gauge_cls(
name="vllm:num_requests_running",
documentation="Number of requests currently running on GPU.",
labelnames=labelnames)
self.gauge_scheduler_waiting = self._base_library.Gauge(
self.gauge_scheduler_waiting = self._gauge_cls(
name="vllm:num_requests_waiting",
documentation="Number of requests waiting to be processed.",
labelnames=labelnames)
self.gauge_scheduler_swapped = self._base_library.Gauge(
self.gauge_scheduler_swapped = self._gauge_cls(
name="vllm:num_requests_swapped",
documentation="Number of requests swapped to CPU.",
labelnames=labelnames)
# KV Cache Usage in %
self.gauge_gpu_cache_usage = self._base_library.Gauge(
self.gauge_gpu_cache_usage = self._gauge_cls(
name="vllm:gpu_cache_usage_perc",
documentation="GPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
self.gauge_cpu_cache_usage = self._base_library.Gauge(
self.gauge_cpu_cache_usage = self._gauge_cls(
name="vllm:cpu_cache_usage_perc",
documentation="CPU KV-cache usage. 1 means 100 percent usage.",
labelnames=labelnames)
# Iteration stats
self.counter_num_preemption = self._base_library.Counter(
self.counter_num_preemption = self._counter_cls(
name="vllm:num_preemptions_total",
documentation="Cumulative number of preemption from the engine.",
labelnames=labelnames)
self.counter_prompt_tokens = self._base_library.Counter(
self.counter_prompt_tokens = self._counter_cls(
name="vllm:prompt_tokens_total",
documentation="Number of prefill tokens processed.",
labelnames=labelnames)
self.counter_generation_tokens = self._base_library.Counter(
self.counter_generation_tokens = self._counter_cls(
name="vllm:generation_tokens_total",
documentation="Number of generation tokens processed.",
labelnames=labelnames)
self.histogram_time_to_first_token = self._base_library.Histogram(
self.histogram_time_to_first_token = self._histogram_cls(
name="vllm:time_to_first_token_seconds",
documentation="Histogram of time to first token in seconds.",
labelnames=labelnames,
@ -86,7 +86,7 @@ class Metrics:
0.001, 0.005, 0.01, 0.02, 0.04, 0.06, 0.08, 0.1, 0.25, 0.5,
0.75, 1.0, 2.5, 5.0, 7.5, 10.0
])
self.histogram_time_per_output_token = self._base_library.Histogram(
self.histogram_time_per_output_token = self._histogram_cls(
name="vllm:time_per_output_token_seconds",
documentation="Histogram of time per output token in seconds.",
labelnames=labelnames,
@ -97,83 +97,157 @@ class Metrics:
# Request stats
# Latency
self.histogram_e2e_time_request = self._base_library.Histogram(
self.histogram_e2e_time_request = self._histogram_cls(
name="vllm:e2e_request_latency_seconds",
documentation="Histogram of end to end request latency in seconds.",
labelnames=labelnames,
buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0])
# Metadata
self.histogram_num_prompt_tokens_request = self._base_library.Histogram(
self.histogram_num_prompt_tokens_request = self._histogram_cls(
name="vllm:request_prompt_tokens",
documentation="Number of prefill tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_num_generation_tokens_request = \
self._base_library.Histogram(
self._histogram_cls(
name="vllm:request_generation_tokens",
documentation="Number of generation tokens processed.",
labelnames=labelnames,
buckets=build_1_2_5_buckets(max_model_len),
)
self.histogram_best_of_request = self._base_library.Histogram(
self.histogram_best_of_request = self._histogram_cls(
name="vllm:request_params_best_of",
documentation="Histogram of the best_of request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.histogram_n_request = self._base_library.Histogram(
self.histogram_n_request = self._histogram_cls(
name="vllm:request_params_n",
documentation="Histogram of the n request parameter.",
labelnames=labelnames,
buckets=[1, 2, 5, 10, 20],
)
self.counter_request_success = self._base_library.Counter(
self.counter_request_success = self._counter_cls(
name="vllm:request_success_total",
documentation="Count of successfully processed requests.",
labelnames=labelnames + [Metrics.labelname_finish_reason])
# Speculatie decoding stats
self.gauge_spec_decode_draft_acceptance_rate = self._base_library.Gauge(
self.gauge_spec_decode_draft_acceptance_rate = self._gauge_cls(
name="vllm:spec_decode_draft_acceptance_rate",
documentation="Speulative token acceptance rate.",
labelnames=labelnames)
self.gauge_spec_decode_efficiency = self._base_library.Gauge(
self.gauge_spec_decode_efficiency = self._gauge_cls(
name="vllm:spec_decode_efficiency",
documentation="Speculative decoding system efficiency.",
labelnames=labelnames)
self.counter_spec_decode_num_accepted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._base_library.Counter(
self.counter_spec_decode_num_accepted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_accepted_tokens_total",
documentation="Number of accepted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_draft_tokens = self._counter_cls(
name="vllm:spec_decode_num_draft_tokens_total",
documentation="Number of draft tokens.",
labelnames=labelnames)
self.counter_spec_decode_num_emitted_tokens = (
self._base_library.Counter(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
self.counter_spec_decode_num_emitted_tokens = (self._counter_cls(
name="vllm:spec_decode_num_emitted_tokens_total",
documentation="Number of emitted tokens.",
labelnames=labelnames))
# Deprecated in favor of vllm:prompt_tokens_total
self.gauge_avg_prompt_throughput = self._base_library.Gauge(
self.gauge_avg_prompt_throughput = self._gauge_cls(
name="vllm:avg_prompt_throughput_toks_per_s",
documentation="Average prefill throughput in tokens/s.",
labelnames=labelnames,
)
# Deprecated in favor of vllm:generation_tokens_total
self.gauge_avg_generation_throughput = self._base_library.Gauge(
self.gauge_avg_generation_throughput = self._gauge_cls(
name="vllm:avg_generation_throughput_toks_per_s",
documentation="Average generation throughput in tokens/s.",
labelnames=labelnames,
)
def _create_info_cache_config(self) -> None:
# Config Information
self.info_cache_config = prometheus_client.Info(
name='vllm:cache_config',
documentation='information of cache_config')
def _unregister_vllm_metrics(self) -> None:
for collector in list(self._base_library.REGISTRY._collector_to_names):
for collector in list(prometheus_client.REGISTRY._collector_to_names):
if hasattr(collector, "_name") and "vllm" in collector._name:
self._base_library.REGISTRY.unregister(collector)
prometheus_client.REGISTRY.unregister(collector)
# end-metrics-definitions
class _RayGaugeWrapper:
"""Wraps around ray.util.metrics.Gauge to provide same API as
prometheus_client.Gauge"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._gauge = ray_metrics.Gauge(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._gauge.set_default_tags(labels)
return self
def set(self, value: Union[int, float]):
return self._gauge.set(value)
class _RayCounterWrapper:
"""Wraps around ray.util.metrics.Counter to provide same API as
prometheus_client.Counter"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._counter = ray_metrics.Counter(name=name,
description=documentation,
tag_keys=labelnames_tuple)
def labels(self, **labels):
self._counter.set_default_tags(labels)
return self
def inc(self, value: Union[int, float] = 1.0):
if value == 0:
return
return self._counter.inc(value)
class _RayHistogramWrapper:
"""Wraps around ray.util.metrics.Histogram to provide same API as
prometheus_client.Histogram"""
def __init__(self,
name: str,
documentation: str = "",
labelnames: Optional[List[str]] = None,
buckets: Optional[List[float]] = None):
labelnames_tuple = tuple(labelnames) if labelnames else None
self._histogram = ray_metrics.Histogram(name=name,
description=documentation,
tag_keys=labelnames_tuple,
boundaries=buckets)
def labels(self, **labels):
self._histogram.set_default_tags(labels)
return self
def observe(self, value: Union[int, float]):
return self._histogram.observe(value)
class RayMetrics(Metrics):
@ -181,7 +255,9 @@ class RayMetrics(Metrics):
RayMetrics is used by RayPrometheusStatLogger to log to Ray metrics.
Provides the same metrics as Metrics but uses Ray's util.metrics library.
"""
_base_library = ray_metrics
_gauge_cls = _RayGaugeWrapper
_counter_cls = _RayCounterWrapper
_histogram_cls = _RayHistogramWrapper
def __init__(self, labelnames: List[str], max_model_len: int):
if ray_metrics is None:
@ -192,8 +268,9 @@ class RayMetrics(Metrics):
# No-op on purpose
pass
# end-metrics-definitions
def _create_info_cache_config(self) -> None:
# No-op on purpose
pass
def build_1_2_5_buckets(max_value: int) -> List[int]:
@ -498,3 +575,6 @@ class PrometheusStatLogger(StatLoggerBase):
class RayPrometheusStatLogger(PrometheusStatLogger):
"""RayPrometheusStatLogger uses Ray metrics instead."""
_metrics_cls = RayMetrics
def info(self, type: str, obj: SupportsMetricsInfo) -> None:
return None