mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
2 Commits
v0.11.0
...
support_gl
Author | SHA1 | Date | |
---|---|---|---|
98d535eb4f | |||
a46e279909 |
@ -1,12 +1,15 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Optional
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||
|
||||
"""
|
||||
To run this example, run the following commands simultaneously with
|
||||
@ -22,37 +25,67 @@ send a request to the instance with DP rank 1.
|
||||
"""
|
||||
|
||||
|
||||
def _do_background_logging(engine, interval, stop_event):
|
||||
try:
|
||||
while not stop_event.is_set():
|
||||
asyncio.run(engine.do_log_stats())
|
||||
stop_event.wait(interval)
|
||||
except Exception as e:
|
||||
print(f"vLLM background logging shutdown: {e}")
|
||||
pass
|
||||
|
||||
|
||||
async def main():
|
||||
engine_args = AsyncEngineArgs(
|
||||
model="ibm-research/PowerMoE-3b",
|
||||
data_parallel_size=2,
|
||||
tensor_parallel_size=1,
|
||||
dtype="auto",
|
||||
max_model_len=2048,
|
||||
data_parallel_address="127.0.0.1",
|
||||
data_parallel_rpc_port=62300,
|
||||
data_parallel_size_local=1,
|
||||
enforce_eager=True,
|
||||
enable_log_requests=True,
|
||||
disable_custom_all_reduce=True,
|
||||
)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
|
||||
return LoggingStatLogger(config, rank)
|
||||
|
||||
engine_client = AsyncLLMEngine.from_engine_args(
|
||||
engine_args,
|
||||
# Example: Using both regular loggers and aggregated logger
|
||||
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
|
||||
)
|
||||
stop_logging_event = threading.Event()
|
||||
logging_thread = threading.Thread(
|
||||
target=_do_background_logging,
|
||||
args=(engine_client, 5, stop_logging_event),
|
||||
daemon=True,
|
||||
)
|
||||
logging_thread.start()
|
||||
sampling_params = SamplingParams(
|
||||
temperature=0.7,
|
||||
top_p=0.9,
|
||||
max_tokens=100,
|
||||
)
|
||||
num_prompts = 10
|
||||
for i in range(num_prompts):
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id=f"abcdef-{i}",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
|
||||
prompt = "Who won the 2004 World Series?"
|
||||
final_output: Optional[RequestOutput] = None
|
||||
async for output in engine_client.generate(
|
||||
prompt=prompt,
|
||||
sampling_params=sampling_params,
|
||||
request_id="abcdef",
|
||||
data_parallel_rank=1,
|
||||
):
|
||||
final_output = output
|
||||
if final_output:
|
||||
print(final_output.outputs[0].text)
|
||||
stop_logging_event.set()
|
||||
logging_thread.join()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -18,7 +18,7 @@ from vllm.platforms import current_platform
|
||||
from vllm.sampling_params import RequestOutputKind
|
||||
from vllm.utils import set_default_torch_num_threads
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
from vllm.v1.metrics.loggers import LoggingStatLogger
|
||||
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
|
||||
|
||||
if not current_platform.is_cuda():
|
||||
pytest.skip(reason="V1 currently only supported on CUDA.",
|
||||
@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger):
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
class MockAggregatedStatLogger(AggregatedStatLogger):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]] = None):
|
||||
super().__init__(vllm_config, engine_indexes)
|
||||
self.log = MagicMock()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_loggers(monkeypatch):
|
||||
"""Test that we can customize the loggers.
|
||||
@ -415,6 +424,35 @@ async def test_customize_loggers(monkeypatch):
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_customize_aggregated_loggers(monkeypatch):
|
||||
"""Test that we can customize the aggregated loggers.
|
||||
If a customized logger is provided at the init, it should
|
||||
be added to the default loggers.
|
||||
"""
|
||||
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
m.setenv("VLLM_USE_V1", "1")
|
||||
|
||||
with set_default_torch_num_threads(1):
|
||||
engine = AsyncLLM.from_engine_args(
|
||||
TEXT_ENGINE_ARGS,
|
||||
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
|
||||
)
|
||||
after.callback(engine.shutdown)
|
||||
|
||||
await engine.do_log_stats()
|
||||
|
||||
stat_loggers = engine.logger_manager.per_engine_logger_dict
|
||||
assert len(stat_loggers) == 1
|
||||
assert len(
|
||||
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
|
||||
aggregated_loggers = engine.logger_manager.aggregated_loggers
|
||||
assert len(aggregated_loggers) == 1
|
||||
aggregated_loggers[0].log.assert_called_once()
|
||||
stat_loggers[0][0].log.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio(scope="module")
|
||||
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m, ExitStack() as after:
|
||||
|
@ -18,7 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
|
||||
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
|
||||
type["AggregatedStatLoggerBase"]]
|
||||
|
||||
|
||||
class StatLoggerBase(ABC):
|
||||
@ -48,6 +50,16 @@ class StatLoggerBase(ABC):
|
||||
pass
|
||||
|
||||
|
||||
class AggregatedStatLoggerBase(StatLoggerBase):
|
||||
"""Abstract base class for loggers that
|
||||
aggregates statistics across multiple engines."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, vllm_config: VllmConfig,
|
||||
engine_indexes: Optional[list[int]]):
|
||||
...
|
||||
|
||||
|
||||
class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
|
||||
@ -61,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.spec_decoding_logging = SpecDecodingLogging()
|
||||
self.last_prompt_throughput: float = 0.0
|
||||
self.last_generation_throughput: float = 0.0
|
||||
self.engine_is_idle = False
|
||||
|
||||
def _reset(self, now):
|
||||
self.last_log_time = now
|
||||
@ -100,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
|
||||
self.last_scheduler_stats = scheduler_stats
|
||||
|
||||
def log(self):
|
||||
def get_log_stats(self):
|
||||
now = time.monotonic()
|
||||
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
|
||||
generation_throughput = self._get_throughput(
|
||||
self.num_generation_tokens, now)
|
||||
|
||||
self._reset(now)
|
||||
|
||||
scheduler_stats = self.last_scheduler_stats
|
||||
|
||||
log_fn = logger.info
|
||||
if not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput)):
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
self.last_generation_throughput = generation_throughput
|
||||
self.last_prompt_throughput = prompt_throughput
|
||||
self.engine_is_idle = not any(
|
||||
(prompt_throughput, generation_throughput,
|
||||
self.last_prompt_throughput, self.last_generation_throughput))
|
||||
|
||||
def log(self):
|
||||
self.get_log_stats()
|
||||
log_fn = logger.info
|
||||
if self.engine_is_idle:
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"Engine %03d: "
|
||||
@ -128,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
self.engine_index,
|
||||
prompt_throughput,
|
||||
generation_throughput,
|
||||
scheduler_stats.num_running_reqs,
|
||||
scheduler_stats.num_waiting_reqs,
|
||||
scheduler_stats.kv_cache_usage * 100,
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput,
|
||||
self.last_scheduler_stats.num_running_reqs,
|
||||
self.last_scheduler_stats.num_waiting_reqs,
|
||||
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
@ -145,7 +158,61 @@ class LoggingStatLogger(StatLoggerBase):
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(StatLoggerBase):
|
||||
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
engine_idxs: Optional[list[int]] = None):
|
||||
if engine_idxs is None:
|
||||
engine_idxs = [0]
|
||||
self.engine_idxs = engine_idxs
|
||||
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
|
||||
|
||||
def record(
|
||||
self,
|
||||
scheduler_stats: Optional[SchedulerStats],
|
||||
iteration_stats: Optional[IterationStats],
|
||||
engine_idx: int = 0,
|
||||
):
|
||||
if engine_idx not in self.engine_idxs:
|
||||
logger.warning("Unexpected engine_idx: %d", engine_idx)
|
||||
return
|
||||
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
|
||||
def log(self):
|
||||
self.get_log_stats()
|
||||
log_fn = logger.info
|
||||
if self.engine_is_idle:
|
||||
# Avoid log noise on an idle production system
|
||||
log_fn = logger.debug
|
||||
# Format and print output.
|
||||
log_fn(
|
||||
"%s Engines Aggregated: "
|
||||
"Avg prompt throughput: %.1f tokens/s, "
|
||||
"Avg generation throughput: %.1f tokens/s, "
|
||||
"Running: %d reqs, Waiting: %d reqs, "
|
||||
"GPU KV cache usage: %.1f%%, "
|
||||
"Prefix cache hit rate: %.1f%%",
|
||||
len(self.engine_idxs),
|
||||
self.last_prompt_throughput,
|
||||
self.last_generation_throughput,
|
||||
self.last_scheduler_stats.num_running_reqs,
|
||||
self.last_scheduler_stats.num_waiting_reqs,
|
||||
self.last_scheduler_stats.kv_cache_usage * 100,
|
||||
self.prefix_caching_metrics.hit_rate * 100,
|
||||
)
|
||||
self.spec_decoding_logging.log(log_fn=log_fn)
|
||||
|
||||
def log_engine_initialized(self):
|
||||
if self.vllm_config.cache_config.num_gpu_blocks:
|
||||
logger.info(
|
||||
"%d Engines: vllm cache_config_info with initialization "
|
||||
"after num_gpu_blocks is: %d", len(self.engine_idxs),
|
||||
self.vllm_config.cache_config.num_gpu_blocks)
|
||||
|
||||
|
||||
class PrometheusStatLogger(AggregatedStatLoggerBase):
|
||||
_gauge_cls = prometheus_client.Gauge
|
||||
_counter_cls = prometheus_client.Counter
|
||||
_histogram_cls = prometheus_client.Histogram
|
||||
@ -674,23 +741,32 @@ class StatLoggerManager:
|
||||
|
||||
# engine_idx: StatLogger
|
||||
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
|
||||
prometheus_factory = PrometheusStatLogger
|
||||
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
|
||||
|
||||
aggregated_loggers_factories = set()
|
||||
for engine_idx in self.engine_idxs:
|
||||
loggers: list[StatLoggerBase] = []
|
||||
for logger_factory in factories:
|
||||
# If we get a custom prometheus logger, use that
|
||||
# instead. This is typically used for the ray case.
|
||||
if (isinstance(logger_factory, type)
|
||||
and issubclass(logger_factory, PrometheusStatLogger)):
|
||||
prometheus_factory = logger_factory
|
||||
continue
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
# If we get a custom prometheus logger or aggregated logger,
|
||||
# We initialize it separately with all engine idxs.
|
||||
# A custom prometheus logger is typically used for the ray.
|
||||
if (isinstance(logger_factory, type) and issubclass(
|
||||
logger_factory, AggregatedStatLoggerBase)):
|
||||
aggregated_loggers_factories.add(logger_factory)
|
||||
else:
|
||||
loggers.append(logger_factory(vllm_config,
|
||||
engine_idx)) # type: ignore
|
||||
self.per_engine_logger_dict[engine_idx] = loggers
|
||||
|
||||
# For Prometheus, need to share the metrics between EngineCores.
|
||||
# If no custom aggregated logger is provide,
|
||||
# we by default use PrometheusStatLogger
|
||||
if not aggregated_loggers_factories:
|
||||
aggregated_loggers_factories.add(PrometheusStatLogger)
|
||||
# For custom aggregated logger(or default Prometheus Logger)
|
||||
# need to share the metrics between EngineCores.
|
||||
# Each EngineCore's metrics are expressed as a unique label.
|
||||
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
|
||||
for aggregated_loggers_factory in aggregated_loggers_factories:
|
||||
self.aggregated_loggers.append(
|
||||
aggregated_loggers_factory(vllm_config, engine_idxs))
|
||||
|
||||
def record(
|
||||
self,
|
||||
@ -704,18 +780,19 @@ class StatLoggerManager:
|
||||
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
|
||||
for logger in per_engine_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
self.prometheus_logger.record(scheduler_stats, iteration_stats,
|
||||
engine_idx)
|
||||
for logger in self.aggregated_loggers:
|
||||
logger.record(scheduler_stats, iteration_stats, engine_idx)
|
||||
|
||||
def log(self):
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log()
|
||||
for logger in self.aggregated_loggers:
|
||||
logger.log()
|
||||
|
||||
def log_engine_initialized(self):
|
||||
self.prometheus_logger.log_engine_initialized()
|
||||
|
||||
for agg_logger in self.aggregated_loggers:
|
||||
agg_logger.log_engine_initialized()
|
||||
for per_engine_loggers in self.per_engine_logger_dict.values():
|
||||
for logger in per_engine_loggers:
|
||||
logger.log_engine_initialized()
|
||||
|
Reference in New Issue
Block a user