Compare commits

...

2 Commits

Author SHA1 Message Date
98d535eb4f add aggregator interface and abstract common logic
Signed-off-by: Lu Fang <fanglu@fb.com>
2025-09-22 13:00:53 -07:00
a46e279909 [Misc][DP] support customized global logger for dp
Signed-off-by: Lu Fang <fanglu@fb.com>

fix the test

Signed-off-by: Lu Fang <fanglu@fb.com>

address comments

Signed-off-by: Lu Fang <fanglu@fb.com>
2025-09-22 11:27:53 -07:00
3 changed files with 195 additions and 47 deletions

View File

@ -1,12 +1,15 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import asyncio
import threading
from typing import Optional
from vllm.config import VllmConfig
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
"""
To run this example, run the following commands simultaneously with
@ -22,37 +25,67 @@ send a request to the instance with DP rank 1.
"""
def _do_background_logging(engine, interval, stop_event):
try:
while not stop_event.is_set():
asyncio.run(engine.do_log_stats())
stop_event.wait(interval)
except Exception as e:
print(f"vLLM background logging shutdown: {e}")
pass
async def main():
engine_args = AsyncEngineArgs(
model="ibm-research/PowerMoE-3b",
data_parallel_size=2,
tensor_parallel_size=1,
dtype="auto",
max_model_len=2048,
data_parallel_address="127.0.0.1",
data_parallel_rpc_port=62300,
data_parallel_size_local=1,
enforce_eager=True,
enable_log_requests=True,
disable_custom_all_reduce=True,
)
engine_client = AsyncLLMEngine.from_engine_args(engine_args)
def per_engine_logger_factory(config: VllmConfig, rank: int) -> LoggingStatLogger:
return LoggingStatLogger(config, rank)
engine_client = AsyncLLMEngine.from_engine_args(
engine_args,
# Example: Using both regular loggers and aggregated logger
stat_loggers=[per_engine_logger_factory, AggregatedStatLogger],
)
stop_logging_event = threading.Event()
logging_thread = threading.Thread(
target=_do_background_logging,
args=(engine_client, 5, stop_logging_event),
daemon=True,
)
logging_thread.start()
sampling_params = SamplingParams(
temperature=0.7,
top_p=0.9,
max_tokens=100,
)
num_prompts = 10
for i in range(num_prompts):
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id=f"abcdef-{i}",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
prompt = "Who won the 2004 World Series?"
final_output: Optional[RequestOutput] = None
async for output in engine_client.generate(
prompt=prompt,
sampling_params=sampling_params,
request_id="abcdef",
data_parallel_rank=1,
):
final_output = output
if final_output:
print(final_output.outputs[0].text)
stop_logging_event.set()
logging_thread.join()
if __name__ == "__main__":

View File

@ -18,7 +18,7 @@ from vllm.platforms import current_platform
from vllm.sampling_params import RequestOutputKind
from vllm.utils import set_default_torch_num_threads
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import LoggingStatLogger
from vllm.v1.metrics.loggers import AggregatedStatLogger, LoggingStatLogger
if not current_platform.is_cuda():
pytest.skip(reason="V1 currently only supported on CUDA.",
@ -389,6 +389,15 @@ class MockLoggingStatLogger(LoggingStatLogger):
self.log = MagicMock()
class MockAggregatedStatLogger(AggregatedStatLogger):
def __init__(self,
vllm_config: VllmConfig,
engine_indexes: Optional[list[int]] = None):
super().__init__(vllm_config, engine_indexes)
self.log = MagicMock()
@pytest.mark.asyncio
async def test_customize_loggers(monkeypatch):
"""Test that we can customize the loggers.
@ -415,6 +424,35 @@ async def test_customize_loggers(monkeypatch):
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio
async def test_customize_aggregated_loggers(monkeypatch):
"""Test that we can customize the aggregated loggers.
If a customized logger is provided at the init, it should
be added to the default loggers.
"""
with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1")
with set_default_torch_num_threads(1):
engine = AsyncLLM.from_engine_args(
TEXT_ENGINE_ARGS,
stat_loggers=[MockLoggingStatLogger, MockAggregatedStatLogger],
)
after.callback(engine.shutdown)
await engine.do_log_stats()
stat_loggers = engine.logger_manager.per_engine_logger_dict
assert len(stat_loggers) == 1
assert len(
stat_loggers[0]) == 2 # LoggingStatLogger + MockLoggingStatLogger
aggregated_loggers = engine.logger_manager.aggregated_loggers
assert len(aggregated_loggers) == 1
aggregated_loggers[0].log.assert_called_once()
stat_loggers[0][0].log.assert_called_once()
@pytest.mark.asyncio(scope="module")
async def test_dp_rank_argument(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m, ExitStack() as after:

View File

@ -18,7 +18,9 @@ from vllm.v1.spec_decode.metrics import SpecDecodingLogging, SpecDecodingProm
logger = init_logger(__name__)
StatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
PerEngineStatLoggerFactory = Callable[[VllmConfig, int], "StatLoggerBase"]
StatLoggerFactory = Union[PerEngineStatLoggerFactory,
type["AggregatedStatLoggerBase"]]
class StatLoggerBase(ABC):
@ -48,6 +50,16 @@ class StatLoggerBase(ABC):
pass
class AggregatedStatLoggerBase(StatLoggerBase):
"""Abstract base class for loggers that
aggregates statistics across multiple engines."""
@abstractmethod
def __init__(self, vllm_config: VllmConfig,
engine_indexes: Optional[list[int]]):
...
class LoggingStatLogger(StatLoggerBase):
def __init__(self, vllm_config: VllmConfig, engine_index: int = 0):
@ -61,6 +73,7 @@ class LoggingStatLogger(StatLoggerBase):
self.spec_decoding_logging = SpecDecodingLogging()
self.last_prompt_throughput: float = 0.0
self.last_generation_throughput: float = 0.0
self.engine_is_idle = False
def _reset(self, now):
self.last_log_time = now
@ -100,25 +113,25 @@ class LoggingStatLogger(StatLoggerBase):
self.last_scheduler_stats = scheduler_stats
def log(self):
def get_log_stats(self):
now = time.monotonic()
prompt_throughput = self._get_throughput(self.num_prompt_tokens, now)
generation_throughput = self._get_throughput(
self.num_generation_tokens, now)
self._reset(now)
scheduler_stats = self.last_scheduler_stats
log_fn = logger.info
if not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput)):
# Avoid log noise on an idle production system
log_fn = logger.debug
self.last_generation_throughput = generation_throughput
self.last_prompt_throughput = prompt_throughput
self.engine_is_idle = not any(
(prompt_throughput, generation_throughput,
self.last_prompt_throughput, self.last_generation_throughput))
def log(self):
self.get_log_stats()
log_fn = logger.info
if self.engine_is_idle:
# Avoid log noise on an idle production system
log_fn = logger.debug
# Format and print output.
log_fn(
"Engine %03d: "
@ -128,11 +141,11 @@ class LoggingStatLogger(StatLoggerBase):
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
self.engine_index,
prompt_throughput,
generation_throughput,
scheduler_stats.num_running_reqs,
scheduler_stats.num_waiting_reqs,
scheduler_stats.kv_cache_usage * 100,
self.last_prompt_throughput,
self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs,
self.last_scheduler_stats.num_waiting_reqs,
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
@ -145,7 +158,61 @@ class LoggingStatLogger(StatLoggerBase):
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(StatLoggerBase):
class AggregatedStatLogger(LoggingStatLogger, AggregatedStatLoggerBase):
def __init__(self,
vllm_config: VllmConfig,
engine_idxs: Optional[list[int]] = None):
if engine_idxs is None:
engine_idxs = [0]
self.engine_idxs = engine_idxs
LoggingStatLogger.__init__(self, vllm_config, engine_index=-1)
def record(
self,
scheduler_stats: Optional[SchedulerStats],
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
if engine_idx not in self.engine_idxs:
logger.warning("Unexpected engine_idx: %d", engine_idx)
return
LoggingStatLogger.record(self, scheduler_stats, iteration_stats,
engine_idx)
def log(self):
self.get_log_stats()
log_fn = logger.info
if self.engine_is_idle:
# Avoid log noise on an idle production system
log_fn = logger.debug
# Format and print output.
log_fn(
"%s Engines Aggregated: "
"Avg prompt throughput: %.1f tokens/s, "
"Avg generation throughput: %.1f tokens/s, "
"Running: %d reqs, Waiting: %d reqs, "
"GPU KV cache usage: %.1f%%, "
"Prefix cache hit rate: %.1f%%",
len(self.engine_idxs),
self.last_prompt_throughput,
self.last_generation_throughput,
self.last_scheduler_stats.num_running_reqs,
self.last_scheduler_stats.num_waiting_reqs,
self.last_scheduler_stats.kv_cache_usage * 100,
self.prefix_caching_metrics.hit_rate * 100,
)
self.spec_decoding_logging.log(log_fn=log_fn)
def log_engine_initialized(self):
if self.vllm_config.cache_config.num_gpu_blocks:
logger.info(
"%d Engines: vllm cache_config_info with initialization "
"after num_gpu_blocks is: %d", len(self.engine_idxs),
self.vllm_config.cache_config.num_gpu_blocks)
class PrometheusStatLogger(AggregatedStatLoggerBase):
_gauge_cls = prometheus_client.Gauge
_counter_cls = prometheus_client.Counter
_histogram_cls = prometheus_client.Histogram
@ -674,23 +741,32 @@ class StatLoggerManager:
# engine_idx: StatLogger
self.per_engine_logger_dict: dict[int, list[StatLoggerBase]] = {}
prometheus_factory = PrometheusStatLogger
self.aggregated_loggers: list[AggregatedStatLoggerBase] = []
aggregated_loggers_factories = set()
for engine_idx in self.engine_idxs:
loggers: list[StatLoggerBase] = []
for logger_factory in factories:
# If we get a custom prometheus logger, use that
# instead. This is typically used for the ray case.
if (isinstance(logger_factory, type)
and issubclass(logger_factory, PrometheusStatLogger)):
prometheus_factory = logger_factory
continue
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
# If we get a custom prometheus logger or aggregated logger,
# We initialize it separately with all engine idxs.
# A custom prometheus logger is typically used for the ray.
if (isinstance(logger_factory, type) and issubclass(
logger_factory, AggregatedStatLoggerBase)):
aggregated_loggers_factories.add(logger_factory)
else:
loggers.append(logger_factory(vllm_config,
engine_idx)) # type: ignore
self.per_engine_logger_dict[engine_idx] = loggers
# For Prometheus, need to share the metrics between EngineCores.
# If no custom aggregated logger is provide,
# we by default use PrometheusStatLogger
if not aggregated_loggers_factories:
aggregated_loggers_factories.add(PrometheusStatLogger)
# For custom aggregated logger(or default Prometheus Logger)
# need to share the metrics between EngineCores.
# Each EngineCore's metrics are expressed as a unique label.
self.prometheus_logger = prometheus_factory(vllm_config, engine_idxs)
for aggregated_loggers_factory in aggregated_loggers_factories:
self.aggregated_loggers.append(
aggregated_loggers_factory(vllm_config, engine_idxs))
def record(
self,
@ -704,18 +780,19 @@ class StatLoggerManager:
per_engine_loggers = self.per_engine_logger_dict[engine_idx]
for logger in per_engine_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
self.prometheus_logger.record(scheduler_stats, iteration_stats,
engine_idx)
for logger in self.aggregated_loggers:
logger.record(scheduler_stats, iteration_stats, engine_idx)
def log(self):
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log()
for logger in self.aggregated_loggers:
logger.log()
def log_engine_initialized(self):
self.prometheus_logger.log_engine_initialized()
for agg_logger in self.aggregated_loggers:
agg_logger.log_engine_initialized()
for per_engine_loggers in self.per_engine_logger_dict.values():
for logger in per_engine_loggers:
logger.log_engine_initialized()