[V1][Metrics] Implement vllm:lora_requests_info metric (#13504)

This commit is contained in:
Mark McLoughlin
2025-02-25 04:01:33 +00:00
committed by GitHub
parent ab1091d5f2
commit bc32bc73aa
3 changed files with 121 additions and 10 deletions

View File

@ -11,7 +11,8 @@ from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason from vllm.v1.engine import EngineCoreOutput, EngineCoreRequest, FinishReason
from vllm.v1.engine.detokenizer import IncrementalDetokenizer from vllm.v1.engine.detokenizer import IncrementalDetokenizer
from vllm.v1.engine.logprobs import LogprobsProcessor from vllm.v1.engine.logprobs import LogprobsProcessor
from vllm.v1.metrics.stats import IterationStats, RequestStateStats from vllm.v1.metrics.stats import (IterationStats, LoRARequestStates,
RequestStateStats)
@dataclass @dataclass
@ -26,6 +27,7 @@ class RequestState:
def __init__( def __init__(
self, self,
request_id: str, request_id: str,
lora_name: Optional[str],
output_kind: RequestOutputKind, output_kind: RequestOutputKind,
prompt: Optional[str], prompt: Optional[str],
prompt_token_ids: List[int], prompt_token_ids: List[int],
@ -36,6 +38,7 @@ class RequestState:
log_stats: bool, log_stats: bool,
): ):
self.request_id = request_id self.request_id = request_id
self.lora_name = lora_name
self.output_kind = output_kind self.output_kind = output_kind
self.prompt = prompt self.prompt = prompt
self.prompt_token_ids = prompt_token_ids self.prompt_token_ids = prompt_token_ids
@ -58,6 +61,8 @@ class RequestState:
) -> "RequestState": ) -> "RequestState":
return cls( return cls(
request_id=request.request_id, request_id=request.request_id,
lora_name=(request.lora_request.name
if request.lora_request is not None else None),
output_kind=request.sampling_params.output_kind, output_kind=request.sampling_params.output_kind,
prompt=request.prompt, prompt=request.prompt,
prompt_token_ids=request.prompt_token_ids, prompt_token_ids=request.prompt_token_ids,
@ -86,6 +91,7 @@ class OutputProcessor:
self.log_stats = log_stats self.log_stats = log_stats
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.request_states: Dict[str, RequestState] = {} self.request_states: Dict[str, RequestState] = {}
self.lora_states = LoRARequestStates()
def is_request_active(self, request_id: str) -> bool: def is_request_active(self, request_id: str) -> bool:
return request_id in self.request_states return request_id in self.request_states
@ -101,7 +107,9 @@ class OutputProcessor:
request_ids: List[str], request_ids: List[str],
) -> None: ) -> None:
for request_id in request_ids: for request_id in request_ids:
self.request_states.pop(request_id, None) req_state = self.request_states.pop(request_id, None)
if req_state is not None:
self.lora_states.abort_request(req_state)
def add_request( def add_request(
self, self,
@ -112,11 +120,13 @@ class OutputProcessor:
if request_id in self.request_states: if request_id in self.request_states:
raise ValueError(f"Request id {request_id} already running.") raise ValueError(f"Request id {request_id} already running.")
self.request_states[request_id] = RequestState.from_new_request( req_state = RequestState.from_new_request(
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request), tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
request=request, request=request,
queue=queue, queue=queue,
log_stats=self.log_stats) log_stats=self.log_stats)
self.request_states[request_id] = req_state
self.lora_states.add_request(req_state)
def process_outputs( def process_outputs(
self, self,
@ -214,6 +224,8 @@ class OutputProcessor:
finish_reason, finish_reason,
iteration_stats) iteration_stats)
self.lora_states.update_iteration_stats(iteration_stats)
return OutputProcessorOutput( return OutputProcessorOutput(
request_outputs=request_outputs, request_outputs=request_outputs,
reqs_to_abort=reqs_to_abort, reqs_to_abort=reqs_to_abort,
@ -226,13 +238,15 @@ class OutputProcessor:
if iteration_stats is None: if iteration_stats is None:
return return
lora_stats = self.lora_states.get_stats(req_state)
assert engine_core_timestamp is not None assert engine_core_timestamp is not None
assert req_state.stats is not None assert req_state.stats is not None
iteration_stats.update_from_output(engine_core_output, iteration_stats.update_from_output(engine_core_output,
engine_core_timestamp, engine_core_timestamp,
req_state.is_prefilling, req_state.is_prefilling,
req_state.prompt_len, req_state.prompt_len,
req_state.stats) req_state.stats, lora_stats)
def _update_stats_from_finished(self, req_state: RequestState, def _update_stats_from_finished(self, req_state: RequestState,
request_output: RequestOutput, request_output: RequestOutput,
@ -246,6 +260,7 @@ class OutputProcessor:
iteration_stats.update_from_finished_request(finish_reason, iteration_stats.update_from_finished_request(finish_reason,
request_output, request_output,
req_state.stats) req_state.stats)
self.lora_states.finish_request(req_state)
@staticmethod @staticmethod
def _make_request_output( def _make_request_output(

View File

@ -2,7 +2,7 @@
import time import time
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List from typing import Dict, List, Optional
import numpy as np import numpy as np
import prometheus_client import prometheus_client
@ -233,6 +233,22 @@ class PrometheusStatLogger(StatLoggerBase):
buckets=request_latency_buckets, buckets=request_latency_buckets,
labelnames=labelnames).labels(*labelvalues) labelnames=labelnames).labels(*labelvalues)
self.gauge_lora_info: Optional[prometheus_client.Gauge] = None
if vllm_config.lora_config is not None:
self.labelname_max_lora = "max_lora"
self.labelname_waiting_lora_adapters = "waiting_lora_adapters"
self.labelname_running_lora_adapters = "running_lora_adapters"
self.max_lora = vllm_config.lora_config.max_loras
self.gauge_lora_info = \
prometheus_client.Gauge(
name="vllm:lora_requests_info",
documentation="Running stats on lora requests.",
labelnames=[
self.labelname_max_lora,
self.labelname_waiting_lora_adapters,
self.labelname_running_lora_adapters,
])
self.log_metrics_info("cache_config", vllm_config.cache_config) self.log_metrics_info("cache_config", vllm_config.cache_config)
def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo): def log_metrics_info(self, type: str, config_obj: SupportsMetricsInfo):
@ -295,6 +311,19 @@ class PrometheusStatLogger(StatLoggerBase):
for prefill_time in iteration_stats.prefill_times_iter: for prefill_time in iteration_stats.prefill_times_iter:
self.histogram_prefill_time_request.observe(prefill_time) self.histogram_prefill_time_request.observe(prefill_time)
if self.gauge_lora_info is not None:
running_lora_adapters = \
",".join(iteration_stats.running_lora_adapters.keys())
waiting_lora_adapters = \
",".join(iteration_stats.waiting_lora_adapters.keys())
lora_info_labels = {
self.labelname_running_lora_adapters: running_lora_adapters,
self.labelname_waiting_lora_adapters: waiting_lora_adapters,
self.labelname_max_lora: self.max_lora,
}
self.gauge_lora_info.labels(**lora_info_labels)\
.set_to_current_time()
@staticmethod @staticmethod
def _unregister_vllm_metrics(): def _unregister_vllm_metrics():
# Unregister any existing vLLM collectors (for CI/CD # Unregister any existing vLLM collectors (for CI/CD

View File

@ -2,11 +2,12 @@
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, Dict, List, Optional, Set
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.outputs import RequestOutput from vllm.outputs import RequestOutput
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
from vllm.v1.output_processor import RequestState
@dataclass @dataclass
@ -36,6 +37,12 @@ class SchedulerStats:
default_factory=PrefixCacheStats) default_factory=PrefixCacheStats)
@dataclass
class LoRAStats:
waiting_requests: Set[str] = field(default_factory=set)
running_requests: Set[str] = field(default_factory=set)
@dataclass @dataclass
class RequestStateStats: class RequestStateStats:
"""Stats that need to be tracked across delta updates.""" """Stats that need to be tracked across delta updates."""
@ -76,6 +83,8 @@ class IterationStats:
self.time_per_output_tokens_iter: List[float] = [] self.time_per_output_tokens_iter: List[float] = []
self.queue_times_iter: List[float] = [] self.queue_times_iter: List[float] = []
self.prefill_times_iter: List[float] = [] self.prefill_times_iter: List[float] = []
self.waiting_lora_adapters: Dict[str, int] = {}
self.running_lora_adapters: Dict[str, int] = {}
def _time_since(self, start: float) -> float: def _time_since(self, start: float) -> float:
"""Calculate an interval relative to this iteration's timestamp.""" """Calculate an interval relative to this iteration's timestamp."""
@ -83,7 +92,8 @@ class IterationStats:
def update_from_output(self, output: "EngineCoreOutput", def update_from_output(self, output: "EngineCoreOutput",
engine_core_timestamp: float, is_prefilling: bool, engine_core_timestamp: float, is_prefilling: bool,
prompt_len: int, req_stats: RequestStateStats): prompt_len: int, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
num_new_generation_tokens = len(output.new_token_ids) num_new_generation_tokens = len(output.new_token_ids)
self.num_generation_tokens += num_new_generation_tokens self.num_generation_tokens += num_new_generation_tokens
@ -105,7 +115,8 @@ class IterationStats:
# Process request-level engine core events # Process request-level engine core events
if output.events is not None: if output.events is not None:
self.update_from_events(output.events, is_prefilling, req_stats) self.update_from_events(output.request_id, output.events,
is_prefilling, req_stats, lora_stats)
# Process the batch-level "new tokens" engine core event # Process the batch-level "new tokens" engine core event
if is_prefilling: if is_prefilling:
@ -123,17 +134,21 @@ class IterationStats:
if num_new_generation_tokens > 0: if num_new_generation_tokens > 0:
req_stats.last_token_ts = engine_core_timestamp req_stats.last_token_ts = engine_core_timestamp
def update_from_events(self, events: List["EngineCoreEvent"], def update_from_events(self, req_id: str, events: List["EngineCoreEvent"],
is_prefilling: bool, req_stats: RequestStateStats): is_prefilling: bool, req_stats: RequestStateStats,
lora_stats: Optional[LoRAStats]):
# Avoid circular dependency # Avoid circular dependency
from vllm.v1.engine import EngineCoreEventType from vllm.v1.engine import EngineCoreEventType
for event in events: for event in events:
if event.type == EngineCoreEventType.QUEUED: if event.type == EngineCoreEventType.QUEUED:
req_stats.queued_ts = event.timestamp req_stats.queued_ts = event.timestamp
if lora_stats is not None:
lora_stats.waiting_requests.add(req_id)
elif event.type == EngineCoreEventType.SCHEDULED: elif event.type == EngineCoreEventType.SCHEDULED:
queued_interval = event.timestamp - req_stats.queued_ts queued_interval = event.timestamp - req_stats.queued_ts
self.queue_times_iter.append(queued_interval) self.queue_times_iter.append(queued_interval)
req_stats.scheduled_ts = event.timestamp req_stats.scheduled_ts = event.timestamp
LoRARequestStates.scheduled_request(lora_stats, req_id)
def update_from_finished_request(self, finish_reason: "FinishReason", def update_from_finished_request(self, finish_reason: "FinishReason",
request_output: "RequestOutput", request_output: "RequestOutput",
@ -151,3 +166,55 @@ class IterationStats:
inference_time=inference_time, inference_time=inference_time,
decode_time=decode_time) decode_time=decode_time)
self.finished_requests.append(finished_req) self.finished_requests.append(finished_req)
class LoRARequestStates:
"""Per-LoRA request state stats."""
def __init__(self):
self.lora_name_to_stats: Dict[str, LoRAStats] = {}
def get_stats(self, req_state: 'RequestState') -> Optional[LoRAStats]:
if req_state.lora_name is None:
return None
if req_state.lora_name not in self.lora_name_to_stats:
self.lora_name_to_stats[req_state.lora_name] = LoRAStats()
return self.lora_name_to_stats[req_state.lora_name]
def add_request(self, req_state: 'RequestState'):
if (lora_stats := self.get_stats(req_state)) is not None:
lora_stats.waiting_requests.add(req_state.request_id)
def finish_request(self, req_state: 'RequestState'):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.running_requests.remove(req_state.request_id)
def abort_request(self, req_state: 'RequestState'):
if req_state.lora_name is None:
return
lora_stats = self.lora_name_to_stats[req_state.lora_name]
lora_stats.waiting_requests.discard(req_state.request_id)
lora_stats.running_requests.discard(req_state.request_id)
# Break the pattern for this lifecycle methods so we can
# call this from IterationStats.update_from_events()
@staticmethod
def scheduled_request(lora_stats: Optional[LoRAStats], request_id: str):
if lora_stats is None:
return
lora_stats.waiting_requests.remove(request_id)
lora_stats.running_requests.add(request_id)
def update_iteration_stats(self,
iteration_stats: Optional[IterationStats]):
if iteration_stats is None:
return
for lora_name, stats in self.lora_name_to_stats.items():
if stats.waiting_requests:
iteration_stats.waiting_lora_adapters[lora_name] = \
len(stats.waiting_requests)
if stats.running_requests:
iteration_stats.running_lora_adapters[lora_name] = \
len(stats.running_requests)