mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[V1][Metrics] Add several request timing histograms (#12644)
Signed-off-by: Mark McLoughlin <markmc@redhat.com>
This commit is contained in:
@ -85,6 +85,10 @@ EXPECTED_VALUES = {
|
||||
"vllm:time_per_output_token_seconds":
|
||||
[("_count", _NUM_REQUESTS * (_NUM_GENERATION_TOKENS_PER_REQUEST - 1))],
|
||||
"vllm:e2e_request_latency_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_queue_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_inference_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prefill_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_decode_time_seconds": [("_count", _NUM_REQUESTS)],
|
||||
"vllm:request_prompt_tokens":
|
||||
[("_sum", _NUM_REQUESTS * _NUM_PROMPT_TOKENS_PER_REQUEST),
|
||||
("_count", _NUM_REQUESTS)],
|
||||
@ -169,6 +173,18 @@ EXPECTED_METRICS = [
|
||||
"vllm:e2e_request_latency_seconds_sum",
|
||||
"vllm:e2e_request_latency_seconds_bucket",
|
||||
"vllm:e2e_request_latency_seconds_count",
|
||||
"vllm:request_queue_time_seconds_sum",
|
||||
"vllm:request_queue_time_seconds_bucket",
|
||||
"vllm:request_queue_time_seconds_count",
|
||||
"vllm:request_inference_time_seconds_sum",
|
||||
"vllm:request_inference_time_seconds_bucket",
|
||||
"vllm:request_inference_time_seconds_count",
|
||||
"vllm:request_prefill_time_seconds_sum",
|
||||
"vllm:request_prefill_time_seconds_bucket",
|
||||
"vllm:request_prefill_time_seconds_count",
|
||||
"vllm:request_decode_time_seconds_sum",
|
||||
"vllm:request_decode_time_seconds_bucket",
|
||||
"vllm:request_decode_time_seconds_count",
|
||||
"vllm:request_prompt_tokens_sum",
|
||||
"vllm:request_prompt_tokens_bucket",
|
||||
"vllm:request_prompt_tokens_count",
|
||||
@ -220,6 +236,21 @@ EXPECTED_METRICS_V1 = [
|
||||
"vllm:time_per_output_token_seconds_sum",
|
||||
"vllm:time_per_output_token_seconds_bucket",
|
||||
"vllm:time_per_output_token_seconds_count",
|
||||
"vllm:e2e_request_latency_seconds_sum",
|
||||
"vllm:e2e_request_latency_seconds_bucket",
|
||||
"vllm:e2e_request_latency_seconds_count",
|
||||
"vllm:request_queue_time_seconds_sum",
|
||||
"vllm:request_queue_time_seconds_bucket",
|
||||
"vllm:request_queue_time_seconds_count",
|
||||
"vllm:request_inference_time_seconds_sum",
|
||||
"vllm:request_inference_time_seconds_bucket",
|
||||
"vllm:request_inference_time_seconds_count",
|
||||
"vllm:request_prefill_time_seconds_sum",
|
||||
"vllm:request_prefill_time_seconds_bucket",
|
||||
"vllm:request_prefill_time_seconds_count",
|
||||
"vllm:request_decode_time_seconds_sum",
|
||||
"vllm:request_decode_time_seconds_bucket",
|
||||
"vllm:request_decode_time_seconds_count",
|
||||
]
|
||||
|
||||
|
||||
|
@ -38,7 +38,8 @@ def create_scheduler(
|
||||
return Scheduler(scheduler_config,
|
||||
model_config,
|
||||
cache_config,
|
||||
lora_config=None)
|
||||
lora_config=None,
|
||||
log_stats=True)
|
||||
|
||||
|
||||
def create_requests(
|
||||
|
@ -50,7 +50,8 @@ def test_engine_core(monkeypatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class)
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
|
||||
# First request.
|
||||
@ -157,7 +158,8 @@ def test_engine_core_advanced_sampling(monkeypatch):
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
|
||||
engine_core = EngineCore(vllm_config=vllm_config,
|
||||
executor_class=executor_class)
|
||||
executor_class=executor_class,
|
||||
log_stats=True)
|
||||
"""Test basic request lifecycle."""
|
||||
# First request.
|
||||
request: EngineCoreRequest = make_request()
|
||||
|
@ -94,6 +94,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool):
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
@ -163,6 +164,7 @@ async def test_engine_core_client_asyncio(monkeypatch):
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
)
|
||||
|
||||
MAX_TOKENS = 20
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
@ -15,6 +16,7 @@ from vllm.sequence import PromptLogprobs, SampleLogprobs
|
||||
from vllm.transformers_utils.tokenizer import AnyTokenizer
|
||||
from vllm.v1.engine import EngineCoreRequest
|
||||
from vllm.v1.engine.output_processor import OutputProcessor
|
||||
from vllm.v1.metrics.stats import IterationStats
|
||||
|
||||
|
||||
def _ref_convert_id_to_token(
|
||||
@ -603,6 +605,7 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor = OutputProcessor(dummy_test_vectors.tokenizer_group,
|
||||
log_stats=True)
|
||||
engine_core = MockEngineCore(dummy_test_vectors.generation_tokens)
|
||||
engine_core_timestamp = time.monotonic()
|
||||
|
||||
# Make N requests.
|
||||
requests = [
|
||||
@ -630,8 +633,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
# First iteration has 2 prefills.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
total_prompt_tokens = sum([
|
||||
len(prompt_tokens)
|
||||
for prompt_tokens in dummy_test_vectors.prompt_tokens[:num_active]
|
||||
@ -642,8 +646,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
@ -652,8 +657,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
output_processor.add_request(inactive_request)
|
||||
num_active += 1
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
total_prompt_tokens = len(dummy_test_vectors.prompt_tokens[num_active - 1])
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == total_prompt_tokens
|
||||
@ -661,8 +667,9 @@ def test_iteration_stats(dummy_test_vectors):
|
||||
|
||||
# Just decodes in this step.
|
||||
outputs = engine_core.get_outputs()[:num_active]
|
||||
processed_outputs = output_processor.process_outputs(outputs)
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
iteration_stats = IterationStats()
|
||||
output_processor.process_outputs(outputs, engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
assert iteration_stats.num_prompt_tokens == 0
|
||||
assert iteration_stats.num_generation_tokens == num_active
|
||||
|
@ -26,6 +26,7 @@ class KVCacheManager:
|
||||
sliding_window: Optional[int] = None,
|
||||
enable_caching: bool = True,
|
||||
num_preallocate_tokens: int = 64,
|
||||
log_stats: bool = False,
|
||||
) -> None:
|
||||
self.block_size = block_size
|
||||
self.num_gpu_blocks = num_gpu_blocks
|
||||
@ -33,6 +34,8 @@ class KVCacheManager:
|
||||
self.max_num_blocks_per_req = cdiv(max_model_len, block_size)
|
||||
self.sliding_window = sliding_window
|
||||
self.enable_caching = enable_caching
|
||||
# FIXME: make prefix cache stats conditional on log_stats
|
||||
self.log_stats = log_stats
|
||||
# NOTE(woosuk): To avoid frequent block allocation, we preallocate some
|
||||
# blocks for each request. For example, when a request reaches the end
|
||||
# of its block table, we preallocate N blocks in advance. This way, we
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import time
|
||||
from collections import deque
|
||||
from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union
|
||||
|
||||
@ -10,7 +11,8 @@ from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager,
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheManager
|
||||
from vllm.v1.core.scheduler_output import (CachedRequestData, NewRequestData,
|
||||
SchedulerOutput)
|
||||
from vllm.v1.engine import EngineCoreOutput, EngineCoreOutputs
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreOutput, EngineCoreOutputs)
|
||||
from vllm.v1.metrics.stats import SchedulerStats
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.request import Request, RequestStatus
|
||||
@ -26,10 +28,12 @@ class Scheduler:
|
||||
model_config: ModelConfig,
|
||||
cache_config: CacheConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
log_stats: bool,
|
||||
) -> None:
|
||||
self.scheduler_config = scheduler_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Scheduling constraints.
|
||||
self.max_num_running_reqs = self.scheduler_config.max_num_seqs
|
||||
@ -45,7 +49,8 @@ class Scheduler:
|
||||
num_gpu_blocks=num_gpu_blocks,
|
||||
max_model_len=self.max_model_len,
|
||||
sliding_window=self.cache_config.sliding_window,
|
||||
enable_caching=self.cache_config.enable_prefix_caching)
|
||||
enable_caching=self.cache_config.enable_prefix_caching,
|
||||
log_stats=self.log_stats)
|
||||
self.block_size = self.cache_config.block_size
|
||||
|
||||
# req_id -> Request
|
||||
@ -107,6 +112,8 @@ class Scheduler:
|
||||
scheduled_encoder_inputs: Dict[str, List[int]] = {}
|
||||
encoder_budget = self.max_num_encoder_input_tokens
|
||||
|
||||
scheduled_timestamp = time.monotonic()
|
||||
|
||||
# First, schedule the RUNNING requests.
|
||||
req_index = 0
|
||||
while req_index < len(self.running) and token_budget > 0:
|
||||
@ -246,6 +253,7 @@ class Scheduler:
|
||||
self.running.append(request)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
self.request_scheduled(request, scheduled_timestamp)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
@ -508,7 +516,8 @@ class Scheduler:
|
||||
finish_reason=request.get_finished_reason(),
|
||||
new_logprobs=new_logprobs,
|
||||
new_prompt_logprobs_tensors=prompt_logprobs_tensors,
|
||||
stop_reason=request.stop_reason))
|
||||
stop_reason=request.stop_reason,
|
||||
events=request.take_events()))
|
||||
|
||||
if not stopped:
|
||||
new_running.append(request)
|
||||
@ -541,6 +550,7 @@ class Scheduler:
|
||||
def add_request(self, request: Request) -> None:
|
||||
self.waiting.append(request)
|
||||
self.requests[request.request_id] = request
|
||||
self.request_queued(request)
|
||||
|
||||
def finish_requests(
|
||||
self,
|
||||
@ -588,7 +598,22 @@ class Scheduler:
|
||||
def reset_prefix_cache(self) -> bool:
|
||||
return self.kv_cache_manager.reset_prefix_cache()
|
||||
|
||||
def make_stats(self) -> SchedulerStats:
|
||||
def request_queued(self, request: Request):
|
||||
if not self.log_stats:
|
||||
return
|
||||
request.events.append(
|
||||
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED))
|
||||
|
||||
def request_scheduled(self, request: Request, timestamp: float):
|
||||
if not self.log_stats:
|
||||
return
|
||||
request.events.append(
|
||||
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
|
||||
timestamp))
|
||||
|
||||
def make_stats(self) -> Optional[SchedulerStats]:
|
||||
if not self.log_stats:
|
||||
return None
|
||||
return SchedulerStats(
|
||||
num_running_reqs=len(self.running),
|
||||
num_waiting_reqs=len(self.waiting),
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import enum
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import msgspec
|
||||
@ -60,6 +61,30 @@ class EngineCoreRequest(
|
||||
lora_request: Optional[LoRARequest]
|
||||
|
||||
|
||||
class EngineCoreEventType(enum.IntEnum):
|
||||
"""The type of engine core request event."""
|
||||
QUEUED = 1
|
||||
SCHEDULED = 2
|
||||
|
||||
|
||||
class EngineCoreEvent(msgspec.Struct):
|
||||
"""A timestamped engine core event associated with a request.
|
||||
|
||||
The timestamp is a monotonic timestamps and is used for by the engine
|
||||
frontend to calculate intervals between engine core events. These
|
||||
timestamps should not be compared with timestamps from other processes.
|
||||
"""
|
||||
type: EngineCoreEventType
|
||||
timestamp: float
|
||||
|
||||
@classmethod
|
||||
def new_event(cls,
|
||||
event_type: EngineCoreEventType,
|
||||
timestamp: Optional[float] = None) -> "EngineCoreEvent":
|
||||
timestamp = time.monotonic() if timestamp is None else timestamp
|
||||
return cls(event_type, timestamp)
|
||||
|
||||
|
||||
class EngineCoreOutput(
|
||||
msgspec.Struct,
|
||||
array_like=True, # type: ignore[call-arg]
|
||||
@ -74,6 +99,7 @@ class EngineCoreOutput(
|
||||
|
||||
finish_reason: Optional[FinishReason] = None
|
||||
stop_reason: Union[int, str, None] = None
|
||||
events: Optional[List[EngineCoreEvent]] = None
|
||||
|
||||
@property
|
||||
def finished(self) -> bool:
|
||||
@ -91,7 +117,12 @@ class EngineCoreOutputs(
|
||||
|
||||
# [num_reqs]
|
||||
outputs: List[EngineCoreOutput]
|
||||
scheduler_stats: SchedulerStats
|
||||
scheduler_stats: Optional[SchedulerStats]
|
||||
timestamp: float = 0.0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.timestamp == 0.0:
|
||||
self.timestamp = time.monotonic()
|
||||
|
||||
|
||||
class EngineCoreRequestType(enum.Enum):
|
||||
|
@ -53,10 +53,12 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
self.log_requests = log_requests
|
||||
self.log_stats = log_stats
|
||||
self.stat_loggers: List[StatLoggerBase] = [
|
||||
LoggingStatLogger(),
|
||||
PrometheusStatLogger(vllm_config.model_config),
|
||||
]
|
||||
self.stat_loggers: List[StatLoggerBase] = []
|
||||
if self.log_stats:
|
||||
self.stat_loggers.extend([
|
||||
LoggingStatLogger(),
|
||||
PrometheusStatLogger(vllm_config.model_config),
|
||||
])
|
||||
|
||||
# Tokenizer (+ ensure liveness if running in another process).
|
||||
self.tokenizer = init_tokenizer_from_configs(
|
||||
@ -85,6 +87,7 @@ class AsyncLLM(EngineClient):
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.output_handler: Optional[asyncio.Task] = None
|
||||
@ -246,6 +249,8 @@ class AsyncLLM(EngineClient):
|
||||
# 1) Pull EngineCoreOutputs from the EngineCore.
|
||||
outputs = await self.engine_core.get_output_async()
|
||||
|
||||
iteration_stats = IterationStats() if self.log_stats else None
|
||||
|
||||
# Split outputs into chunks of at most
|
||||
# VLLM_V1_OUTPUT_PROC_CHUNK_SIZE, so that we don't block the
|
||||
# event loop for too long.
|
||||
@ -257,14 +262,12 @@ class AsyncLLM(EngineClient):
|
||||
outputs.outputs,
|
||||
cdiv(num_outputs, VLLM_V1_OUTPUT_PROC_CHUNK_SIZE))
|
||||
|
||||
iteration_stats = None
|
||||
for i, outputs_slice in enumerate(slices):
|
||||
# 2) Process EngineCoreOutputs.
|
||||
processed_outputs = self.output_processor.process_outputs(
|
||||
outputs_slice, iteration_stats)
|
||||
outputs_slice, outputs.timestamp, iteration_stats)
|
||||
# NOTE: RequestOutputs are pushed to their queues.
|
||||
assert not processed_outputs.request_outputs
|
||||
iteration_stats = processed_outputs.iteration_stats
|
||||
|
||||
# Allow other asyncio tasks to run between chunks
|
||||
if i + 1 < len(slices):
|
||||
@ -277,7 +280,6 @@ class AsyncLLM(EngineClient):
|
||||
# 4) Logging.
|
||||
# TODO(rob): make into a coroutine and launch it in
|
||||
# background thread once Prometheus overhead is non-trivial.
|
||||
assert iteration_stats is not None
|
||||
self._log_stats(
|
||||
scheduler_stats=outputs.scheduler_stats,
|
||||
iteration_stats=iteration_stats,
|
||||
@ -299,12 +301,14 @@ class AsyncLLM(EngineClient):
|
||||
|
||||
def _log_stats(
|
||||
self,
|
||||
scheduler_stats: SchedulerStats,
|
||||
iteration_stats: IterationStats,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
):
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
assert scheduler_stats is not None
|
||||
assert iteration_stats is not None
|
||||
for logger in self.stat_loggers:
|
||||
logger.log(scheduler_stats=scheduler_stats,
|
||||
iteration_stats=iteration_stats)
|
||||
|
@ -38,12 +38,15 @@ class EngineCore:
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool,
|
||||
):
|
||||
assert vllm_config.model_config.runner_type != "pooling"
|
||||
|
||||
logger.info("Initializing a V1 LLM engine (v%s) with config: %s",
|
||||
VLLM_VERSION, vllm_config)
|
||||
|
||||
self.log_stats = log_stats
|
||||
|
||||
# Setup Model.
|
||||
self.model_executor = executor_class(vllm_config)
|
||||
|
||||
@ -59,6 +62,7 @@ class EngineCore:
|
||||
model_config=vllm_config.model_config,
|
||||
cache_config=vllm_config.cache_config,
|
||||
lora_config=vllm_config.lora_config,
|
||||
log_stats=self.log_stats,
|
||||
)
|
||||
|
||||
self.mm_input_mapper_server = MMInputMapperServer(
|
||||
@ -148,11 +152,9 @@ class EngineCoreProc(EngineCore):
|
||||
ready_pipe: Connection,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool = False,
|
||||
log_stats: bool,
|
||||
):
|
||||
super().__init__(vllm_config, executor_class)
|
||||
|
||||
self.log_stats = log_stats
|
||||
super().__init__(vllm_config, executor_class, log_stats)
|
||||
|
||||
# Background Threads and Queues for IO. These enable us to
|
||||
# overlap ZMQ socket IO with GPU since they release the GIL,
|
||||
|
@ -41,6 +41,7 @@ class EngineCoreClient(ABC):
|
||||
asyncio_mode: bool,
|
||||
vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor],
|
||||
log_stats: bool,
|
||||
) -> "EngineCoreClient":
|
||||
|
||||
# TODO: support this for debugging purposes.
|
||||
@ -50,12 +51,12 @@ class EngineCoreClient(ABC):
|
||||
"is not currently supported.")
|
||||
|
||||
if multiprocess_mode and asyncio_mode:
|
||||
return AsyncMPClient(vllm_config, executor_class)
|
||||
return AsyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
if multiprocess_mode and not asyncio_mode:
|
||||
return SyncMPClient(vllm_config, executor_class)
|
||||
return SyncMPClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
return InprocClient(vllm_config, executor_class)
|
||||
return InprocClient(vllm_config, executor_class, log_stats)
|
||||
|
||||
@abstractmethod
|
||||
def shutdown(self):
|
||||
@ -204,13 +205,13 @@ class MPClient(EngineCoreClient):
|
||||
class SyncMPClient(MPClient):
|
||||
"""Synchronous client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor]):
|
||||
def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor],
|
||||
log_stats: bool):
|
||||
super().__init__(
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
def get_output(self) -> EngineCoreOutputs:
|
||||
@ -245,13 +246,13 @@ class SyncMPClient(MPClient):
|
||||
class AsyncMPClient(MPClient):
|
||||
"""Asyncio-compatible client for multi-proc EngineCore."""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
executor_class: Type[Executor]):
|
||||
def __init__(self, vllm_config: VllmConfig, executor_class: Type[Executor],
|
||||
log_stats: bool):
|
||||
super().__init__(
|
||||
asyncio_mode=True,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=True,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
self.outputs_queue: Optional[asyncio.Queue[bytes]] = None
|
||||
|
@ -73,6 +73,7 @@ class LLMEngine:
|
||||
asyncio_mode=False,
|
||||
vllm_config=vllm_config,
|
||||
executor_class=executor_class,
|
||||
log_stats=False, # FIXME: implement
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
@ -19,7 +19,6 @@ class OutputProcessorOutput:
|
||||
|
||||
request_outputs: List[RequestOutput]
|
||||
reqs_to_abort: List[str]
|
||||
iteration_stats: IterationStats
|
||||
|
||||
|
||||
class RequestState:
|
||||
@ -34,6 +33,7 @@ class RequestState:
|
||||
detokenizer: IncrementalDetokenizer,
|
||||
arrival_time: float,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
log_stats: bool,
|
||||
):
|
||||
self.request_id = request_id
|
||||
self.output_kind = output_kind
|
||||
@ -45,14 +45,16 @@ class RequestState:
|
||||
self.is_prefilling = True
|
||||
self.queue = queue
|
||||
|
||||
self.stats = RequestStateStats(last_token_time=arrival_time)
|
||||
self.stats = RequestStateStats(
|
||||
arrival_time=arrival_time) if log_stats else None
|
||||
|
||||
@classmethod
|
||||
def from_new_request(
|
||||
cls,
|
||||
tokenizer: AnyTokenizer,
|
||||
request: EngineCoreRequest,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]] = None,
|
||||
queue: Optional[asyncio.Queue[RequestOutput]],
|
||||
log_stats: bool,
|
||||
) -> "RequestState":
|
||||
return cls(
|
||||
request_id=request.request_id,
|
||||
@ -69,6 +71,7 @@ class RequestState:
|
||||
),
|
||||
arrival_time=request.arrival_time,
|
||||
queue=queue,
|
||||
log_stats=log_stats,
|
||||
)
|
||||
|
||||
|
||||
@ -112,11 +115,13 @@ class OutputProcessor:
|
||||
self.request_states[request_id] = RequestState.from_new_request(
|
||||
tokenizer=self.tokenizer.get_lora_tokenizer(request.lora_request),
|
||||
request=request,
|
||||
queue=queue)
|
||||
queue=queue,
|
||||
log_stats=self.log_stats)
|
||||
|
||||
def process_outputs(
|
||||
self,
|
||||
engine_core_outputs: List[EngineCoreOutput],
|
||||
engine_core_timestamp: Optional[float] = None,
|
||||
iteration_stats: Optional[IterationStats] = None,
|
||||
) -> OutputProcessorOutput:
|
||||
"""
|
||||
@ -145,8 +150,6 @@ class OutputProcessor:
|
||||
|
||||
request_outputs: List[RequestOutput] = []
|
||||
reqs_to_abort: List[str] = []
|
||||
if not iteration_stats:
|
||||
iteration_stats = IterationStats(self.log_stats)
|
||||
for engine_core_output in engine_core_outputs:
|
||||
req_id = engine_core_output.request_id
|
||||
req_state = self.request_states.get(req_id)
|
||||
@ -155,10 +158,9 @@ class OutputProcessor:
|
||||
continue
|
||||
|
||||
# 1) Compute stats for this iteration.
|
||||
iteration_stats.update_from_output(engine_core_output,
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats)
|
||||
self._update_stats_from_output(req_state, engine_core_output,
|
||||
engine_core_timestamp,
|
||||
iteration_stats)
|
||||
|
||||
new_token_ids = engine_core_output.new_token_ids
|
||||
finish_reason = engine_core_output.finish_reason
|
||||
@ -205,17 +207,44 @@ class OutputProcessor:
|
||||
# detected stop string, abort needed in EngineCore.
|
||||
reqs_to_abort.append(req_id)
|
||||
|
||||
# Track per-request stats.
|
||||
assert finish_reason is not None
|
||||
iteration_stats.update_from_finished_request(
|
||||
finish_reason, request_output, req_state.stats)
|
||||
# Track per-request stats
|
||||
self._update_stats_from_finished(req_state, request_output,
|
||||
finish_reason,
|
||||
iteration_stats)
|
||||
|
||||
return OutputProcessorOutput(
|
||||
request_outputs=request_outputs,
|
||||
reqs_to_abort=reqs_to_abort,
|
||||
iteration_stats=iteration_stats,
|
||||
)
|
||||
|
||||
def _update_stats_from_output(self, req_state: RequestState,
|
||||
engine_core_output: EngineCoreOutput,
|
||||
engine_core_timestamp: Optional[float],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert engine_core_timestamp is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_output(engine_core_output,
|
||||
engine_core_timestamp,
|
||||
req_state.is_prefilling,
|
||||
req_state.prompt_len,
|
||||
req_state.stats)
|
||||
|
||||
def _update_stats_from_finished(self, req_state: RequestState,
|
||||
request_output: RequestOutput,
|
||||
finish_reason: Optional[FinishReason],
|
||||
iteration_stats: Optional[IterationStats]):
|
||||
if iteration_stats is None:
|
||||
return
|
||||
|
||||
assert finish_reason is not None
|
||||
assert req_state.stats is not None
|
||||
iteration_stats.update_from_finished_request(finish_reason,
|
||||
request_output,
|
||||
req_state.stats)
|
||||
|
||||
@staticmethod
|
||||
def _make_request_output(
|
||||
request_state: RequestState,
|
||||
|
@ -182,6 +182,45 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
],
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
request_latency_buckets = [
|
||||
0.3, 0.5, 0.8, 1.0, 1.5, 2.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0,
|
||||
40.0, 50.0, 60.0
|
||||
]
|
||||
self.histogram_e2e_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
name="vllm:e2e_request_latency_seconds",
|
||||
documentation="Histogram of e2e request latency in seconds.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_queue_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
name="vllm:request_queue_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in WAITING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_inference_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
name="vllm:request_inference_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in RUNNING phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_prefill_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
name="vllm:request_prefill_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in PREFILL phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
self.histogram_decode_time_request = \
|
||||
prometheus_client.Histogram(
|
||||
name="vllm:request_decode_time_seconds",
|
||||
documentation=
|
||||
"Histogram of time spent in DECODE phase for request.",
|
||||
buckets=request_latency_buckets,
|
||||
labelnames=labelnames).labels(*labelvalues)
|
||||
|
||||
def log(self, scheduler_stats: SchedulerStats,
|
||||
iteration_stats: IterationStats):
|
||||
"""Log to prometheus."""
|
||||
@ -201,6 +240,12 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
|
||||
for finished_request in iteration_stats.finished_requests:
|
||||
self.counter_request_success[finished_request.finish_reason].inc()
|
||||
self.histogram_e2e_time_request.observe(
|
||||
finished_request.e2e_latency)
|
||||
self.histogram_inference_time_request.observe(
|
||||
finished_request.inference_time)
|
||||
self.histogram_decode_time_request.observe(
|
||||
finished_request.decode_time)
|
||||
self.histogram_num_prompt_tokens_request.observe(
|
||||
finished_request.num_prompt_tokens)
|
||||
self.histogram_num_generation_tokens_request.observe(
|
||||
@ -210,6 +255,10 @@ class PrometheusStatLogger(StatLoggerBase):
|
||||
self.histogram_time_to_first_token.observe(ttft)
|
||||
for tpot in iteration_stats.time_per_output_tokens_iter:
|
||||
self.histogram_time_per_output_token.observe(tpot)
|
||||
for queue_time in iteration_stats.queue_times_iter:
|
||||
self.histogram_queue_time_request.observe(queue_time)
|
||||
for prefill_time in iteration_stats.prefill_times_iter:
|
||||
self.histogram_prefill_time_request.observe(prefill_time)
|
||||
|
||||
@staticmethod
|
||||
def _unregister_vllm_metrics():
|
||||
|
@ -6,7 +6,7 @@ from typing import TYPE_CHECKING, List
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.v1.engine import EngineCoreOutput, FinishReason
|
||||
from vllm.v1.engine import EngineCoreEvent, EngineCoreOutput, FinishReason
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -41,7 +41,15 @@ class RequestStateStats:
|
||||
"""Stats that need to be tracked across delta updates."""
|
||||
|
||||
num_generation_tokens: int = 0
|
||||
last_token_time: float = 0.0
|
||||
|
||||
# This is a engine frontend timestamp (wall-clock)
|
||||
arrival_time: float = 0.0
|
||||
|
||||
# These are engine core timestamps (monotonic)
|
||||
queued_ts: float = 0.0
|
||||
scheduled_ts: float = 0.0
|
||||
first_token_ts: float = 0.0
|
||||
last_token_ts: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -49,33 +57,37 @@ class FinishedRequestStats:
|
||||
"""Stats associated with a finished request."""
|
||||
|
||||
finish_reason: "FinishReason"
|
||||
e2e_latency: float = 0.0
|
||||
num_prompt_tokens: int = 0
|
||||
num_generation_tokens: int = 0
|
||||
inference_time: float = 0.0
|
||||
decode_time: float = 0.0
|
||||
|
||||
|
||||
class IterationStats:
|
||||
"""Stats associated with a single set of EngineCoreOutputs."""
|
||||
|
||||
def __init__(self, log_stats: bool):
|
||||
self.log_stats = log_stats
|
||||
def __init__(self):
|
||||
self.iteration_timestamp = time.time()
|
||||
self.num_generation_tokens = 0
|
||||
self.num_prompt_tokens = 0
|
||||
self.finished_requests: List[FinishedRequestStats] = []
|
||||
self.time_to_first_tokens_iter: List[float] = []
|
||||
self.time_per_output_tokens_iter: List[float] = []
|
||||
self.queue_times_iter: List[float] = []
|
||||
self.prefill_times_iter: List[float] = []
|
||||
|
||||
def _time_since(self, start: float) -> float:
|
||||
"""Calculate an interval relative to this iteration's timestamp."""
|
||||
return self.iteration_timestamp - start
|
||||
|
||||
def update_from_output(self, output: "EngineCoreOutput",
|
||||
is_prefilling: bool, prompt_len: int,
|
||||
request_state_stats: RequestStateStats):
|
||||
if not self.log_stats:
|
||||
return
|
||||
|
||||
engine_core_timestamp: float, is_prefilling: bool,
|
||||
prompt_len: int, req_stats: RequestStateStats):
|
||||
num_new_generation_tokens = len(output.new_token_ids)
|
||||
now = time.time()
|
||||
last_token_latency = now - request_state_stats.last_token_time
|
||||
|
||||
self.num_generation_tokens += num_new_generation_tokens
|
||||
if is_prefilling:
|
||||
if is_prefilling and num_new_generation_tokens > 0:
|
||||
# TODO(andy): we used to assert that num_new_generation_tokens
|
||||
# > 0 with an invariant that EngineCore does not stream outputs
|
||||
# for partially completed prefills (scheduler.update_from_output
|
||||
@ -84,19 +96,58 @@ class IterationStats:
|
||||
# partially completed prompt.
|
||||
# This will be reverted in a follow up PR and we should re-enable
|
||||
# this assertion / invariant.
|
||||
if num_new_generation_tokens > 0:
|
||||
self.num_prompt_tokens += prompt_len
|
||||
self.time_to_first_tokens_iter.append(last_token_latency)
|
||||
else:
|
||||
self.time_per_output_tokens_iter.append(last_token_latency)
|
||||
self.num_prompt_tokens += prompt_len
|
||||
|
||||
request_state_stats.num_generation_tokens += num_new_generation_tokens
|
||||
request_state_stats.last_token_time = now
|
||||
first_token_latency = self._time_since(req_stats.arrival_time)
|
||||
self.time_to_first_tokens_iter.append(first_token_latency)
|
||||
|
||||
req_stats.num_generation_tokens += num_new_generation_tokens
|
||||
|
||||
# Process request-level engine core events
|
||||
if output.events is not None:
|
||||
self.update_from_events(output.events, is_prefilling, req_stats)
|
||||
|
||||
# Process the batch-level "new tokens" engine core event
|
||||
if is_prefilling:
|
||||
# TODO: re-enable no-output-for-partial-prefills invariant as above
|
||||
if num_new_generation_tokens > 0:
|
||||
prefill_interval = \
|
||||
engine_core_timestamp - req_stats.scheduled_ts
|
||||
self.prefill_times_iter.append(prefill_interval)
|
||||
req_stats.first_token_ts = engine_core_timestamp
|
||||
else:
|
||||
tpot = engine_core_timestamp - req_stats.last_token_ts
|
||||
self.time_per_output_tokens_iter.append(tpot)
|
||||
|
||||
# TODO: re-enable no-output-for-partial-prefills invariant as above
|
||||
if num_new_generation_tokens > 0:
|
||||
req_stats.last_token_ts = engine_core_timestamp
|
||||
|
||||
def update_from_events(self, events: List["EngineCoreEvent"],
|
||||
is_prefilling: bool, req_stats: RequestStateStats):
|
||||
# Avoid circular dependency
|
||||
from vllm.v1.engine import EngineCoreEventType
|
||||
for event in events:
|
||||
if event.type == EngineCoreEventType.QUEUED:
|
||||
req_stats.queued_ts = event.timestamp
|
||||
elif event.type == EngineCoreEventType.SCHEDULED:
|
||||
queued_interval = event.timestamp - req_stats.queued_ts
|
||||
self.queue_times_iter.append(queued_interval)
|
||||
req_stats.scheduled_ts = event.timestamp
|
||||
|
||||
def update_from_finished_request(self, finish_reason: "FinishReason",
|
||||
request_output: "RequestOutput",
|
||||
request_state_stats: RequestStateStats):
|
||||
self.finished_requests.append(
|
||||
FinishedRequestStats(finish_reason,
|
||||
len(request_output.prompt_token_ids),
|
||||
request_state_stats.num_generation_tokens))
|
||||
req_stats: RequestStateStats):
|
||||
e2e_latency = self._time_since(req_stats.arrival_time)
|
||||
|
||||
inference_time = req_stats.last_token_ts - req_stats.scheduled_ts
|
||||
decode_time = req_stats.last_token_ts - req_stats.first_token_ts
|
||||
|
||||
finished_req = \
|
||||
FinishedRequestStats(finish_reason=finish_reason,
|
||||
e2e_latency=e2e_latency,
|
||||
num_prompt_tokens=len(request_output.prompt_token_ids),
|
||||
num_generation_tokens=req_stats.num_generation_tokens,
|
||||
inference_time=inference_time,
|
||||
decode_time=decode_time)
|
||||
self.finished_requests.append(finished_req)
|
||||
|
@ -5,8 +5,8 @@ from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import RequestMetrics
|
||||
from vllm.v1.engine import EngineCoreRequest, FinishReason
|
||||
from vllm.v1.engine import (EngineCoreEvent, EngineCoreEventType,
|
||||
EngineCoreRequest, FinishReason)
|
||||
from vllm.v1.utils import ConstantList
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -33,14 +33,10 @@ class Request:
|
||||
self.sampling_params = sampling_params
|
||||
# Because of LoRA, the eos token id can be different for each request.
|
||||
self.eos_token_id = eos_token_id
|
||||
self.metrics = RequestMetrics(arrival_time=arrival_time,
|
||||
last_token_time=arrival_time,
|
||||
first_scheduled_time=None,
|
||||
first_token_time=None,
|
||||
time_in_queue=None)
|
||||
self.lora_request = lora_request
|
||||
|
||||
self.status = RequestStatus.WAITING
|
||||
self.events: List[EngineCoreEvent] = []
|
||||
self.stop_reason: Union[int, str, None] = None
|
||||
assert sampling_params.max_tokens is not None
|
||||
self.max_tokens = sampling_params.max_tokens
|
||||
@ -83,6 +79,21 @@ class Request:
|
||||
lora_request=request.lora_request,
|
||||
)
|
||||
|
||||
def queued(self, timestamp: Optional[float] = None) -> None:
|
||||
self.events.append(
|
||||
EngineCoreEvent.new_event(EngineCoreEventType.QUEUED, timestamp))
|
||||
|
||||
def scheduled(self, timestamp: Optional[float] = None) -> None:
|
||||
self.events.append(
|
||||
EngineCoreEvent.new_event(EngineCoreEventType.SCHEDULED,
|
||||
timestamp))
|
||||
|
||||
def take_events(self) -> Optional[List[EngineCoreEvent]]:
|
||||
if not self.events:
|
||||
return None
|
||||
events, self.events = self.events, []
|
||||
return events
|
||||
|
||||
def append_output_token_ids(
|
||||
self,
|
||||
token_ids: Union[int, List[int]],
|
||||
|
Reference in New Issue
Block a user