mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
[Bugfix] Add log prefix in non-dp mode engine core (#21889)
Signed-off-by: wuhang <wuhang6@huawei.com>
This commit is contained in:
@ -2,9 +2,7 @@
|
|||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import signal
|
import signal
|
||||||
import sys
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import uvloop
|
import uvloop
|
||||||
@ -18,10 +16,9 @@ from vllm.entrypoints.openai.cli_args import (make_arg_parser,
|
|||||||
validate_parsed_serve_args)
|
validate_parsed_serve_args)
|
||||||
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
|
from vllm.entrypoints.utils import (VLLM_SUBCMD_PARSER_EPILOG,
|
||||||
show_filtered_argument_or_group_from_help)
|
show_filtered_argument_or_group_from_help)
|
||||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import FlexibleArgumentParser, get_tcp_uri
|
from vllm.utils import FlexibleArgumentParser, decorate_logs, get_tcp_uri
|
||||||
from vllm.v1.engine.core import EngineCoreProc
|
from vllm.v1.engine.core import EngineCoreProc
|
||||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||||
from vllm.v1.executor.abstract import Executor
|
from vllm.v1.executor.abstract import Executor
|
||||||
@ -229,11 +226,7 @@ def run_api_server_worker_proc(listen_address,
|
|||||||
"""Entrypoint for individual API server worker processes."""
|
"""Entrypoint for individual API server worker processes."""
|
||||||
|
|
||||||
# Add process-specific prefix to stdout and stderr.
|
# Add process-specific prefix to stdout and stderr.
|
||||||
from multiprocessing import current_process
|
decorate_logs()
|
||||||
process_name = current_process().name
|
|
||||||
pid = os.getpid()
|
|
||||||
_add_prefix(sys.stdout, process_name, pid)
|
|
||||||
_add_prefix(sys.stderr, process_name, pid)
|
|
||||||
|
|
||||||
uvloop.run(
|
uvloop.run(
|
||||||
run_server_worker(listen_address, sock, args, client_config,
|
run_server_worker(listen_address, sock, args, client_config,
|
||||||
|
@ -11,7 +11,6 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import socket
|
import socket
|
||||||
import sys
|
|
||||||
import tempfile
|
import tempfile
|
||||||
import uuid
|
import uuid
|
||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
@ -95,15 +94,15 @@ from vllm.entrypoints.openai.serving_transcription import (
|
|||||||
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
from vllm.entrypoints.openai.tool_parsers import ToolParserManager
|
||||||
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
from vllm.entrypoints.utils import (cli_env_setup, load_aware_call,
|
||||||
log_non_default_args, with_cancellation)
|
log_non_default_args, with_cancellation)
|
||||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.reasoning import ReasoningParserManager
|
from vllm.reasoning import ReasoningParserManager
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
from vllm.transformers_utils.tokenizer import MistralTokenizer
|
||||||
from vllm.usage.usage_lib import UsageContext
|
from vllm.usage.usage_lib import UsageContext
|
||||||
from vllm.utils import (Device, FlexibleArgumentParser, get_open_zmq_ipc_path,
|
from vllm.utils import (Device, FlexibleArgumentParser, decorate_logs,
|
||||||
is_valid_ipv6_address, set_process_title, set_ulimit)
|
get_open_zmq_ipc_path, is_valid_ipv6_address,
|
||||||
|
set_process_title, set_ulimit)
|
||||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||||
from vllm.version import __version__ as VLLM_VERSION
|
from vllm.version import __version__ as VLLM_VERSION
|
||||||
|
|
||||||
@ -1808,10 +1807,7 @@ async def run_server(args, **uvicorn_kwargs) -> None:
|
|||||||
"""Run a single-worker API server."""
|
"""Run a single-worker API server."""
|
||||||
|
|
||||||
# Add process-specific prefix to stdout and stderr.
|
# Add process-specific prefix to stdout and stderr.
|
||||||
process_name = "APIServer"
|
decorate_logs("APIServer")
|
||||||
pid = os.getpid()
|
|
||||||
_add_prefix(sys.stdout, process_name, pid)
|
|
||||||
_add_prefix(sys.stderr, process_name, pid)
|
|
||||||
|
|
||||||
listen_address, sock = setup_server(args)
|
listen_address, sock = setup_server(args)
|
||||||
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
await run_server_worker(listen_address, sock, args, **uvicorn_kwargs)
|
||||||
|
@ -3,21 +3,20 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from multiprocessing import Queue
|
from multiprocessing import Queue
|
||||||
from multiprocessing.connection import wait
|
from multiprocessing.connection import wait
|
||||||
from multiprocessing.process import BaseProcess
|
from multiprocessing.process import BaseProcess
|
||||||
from typing import (Any, Callable, Dict, Generic, List, Optional, TextIO,
|
from typing import Any, Callable, Dict, Generic, List, Optional, TypeVar, Union
|
||||||
TypeVar, Union)
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import _maybe_force_spawn, get_mp_context, run_method
|
from vllm.utils import (_maybe_force_spawn, decorate_logs, get_mp_context,
|
||||||
|
run_method)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@ -25,10 +24,6 @@ T = TypeVar('T')
|
|||||||
|
|
||||||
_TERMINATE = "TERMINATE" # sentinel
|
_TERMINATE = "TERMINATE" # sentinel
|
||||||
|
|
||||||
# ANSI color codes
|
|
||||||
CYAN = '\033[1;36m'
|
|
||||||
RESET = '\033[0;0m'
|
|
||||||
|
|
||||||
JOIN_TIMEOUT_S = 2
|
JOIN_TIMEOUT_S = 2
|
||||||
|
|
||||||
|
|
||||||
@ -213,9 +208,7 @@ def _run_worker_process(
|
|||||||
|
|
||||||
# Add process-specific prefix to stdout and stderr
|
# Add process-specific prefix to stdout and stderr
|
||||||
process_name = get_mp_context().current_process().name
|
process_name = get_mp_context().current_process().name
|
||||||
pid = os.getpid()
|
decorate_logs(process_name)
|
||||||
_add_prefix(sys.stdout, process_name, pid)
|
|
||||||
_add_prefix(sys.stderr, process_name, pid)
|
|
||||||
|
|
||||||
# Initialize worker
|
# Initialize worker
|
||||||
worker = worker_factory(vllm_config, rank)
|
worker = worker_factory(vllm_config, rank)
|
||||||
@ -260,33 +253,6 @@ def _run_worker_process(
|
|||||||
logger.info("Worker exiting")
|
logger.info("Worker exiting")
|
||||||
|
|
||||||
|
|
||||||
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
|
||||||
"""Prepend each output line with process-specific prefix"""
|
|
||||||
|
|
||||||
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
|
||||||
file_write = file.write
|
|
||||||
|
|
||||||
def write_with_prefix(s: str):
|
|
||||||
if not s:
|
|
||||||
return
|
|
||||||
if file.start_new_line: # type: ignore[attr-defined]
|
|
||||||
file_write(prefix)
|
|
||||||
idx = 0
|
|
||||||
while (next_idx := s.find('\n', idx)) != -1:
|
|
||||||
next_idx += 1
|
|
||||||
file_write(s[idx:next_idx])
|
|
||||||
if next_idx == len(s):
|
|
||||||
file.start_new_line = True # type: ignore[attr-defined]
|
|
||||||
return
|
|
||||||
file_write(prefix)
|
|
||||||
idx = next_idx
|
|
||||||
file_write(s[idx:])
|
|
||||||
file.start_new_line = False # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
file.start_new_line = True # type: ignore[attr-defined]
|
|
||||||
file.write = write_with_prefix # type: ignore[method-assign]
|
|
||||||
|
|
||||||
|
|
||||||
def set_multiprocessing_worker_envs(parallel_config):
|
def set_multiprocessing_worker_envs(parallel_config):
|
||||||
""" Set up environment variables that should be used when there are workers
|
""" Set up environment variables that should be used when there are workers
|
||||||
in a multiprocessing environment. This should be called by the parent
|
in a multiprocessing environment. This should be called by the parent
|
||||||
|
@ -47,7 +47,7 @@ from dataclasses import dataclass, field
|
|||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||||
Optional, Tuple, TypeVar, Union, cast, overload)
|
Optional, TextIO, Tuple, TypeVar, Union, cast, overload)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -167,6 +167,10 @@ GB_bytes = 1_000_000_000
|
|||||||
GiB_bytes = 1 << 30
|
GiB_bytes = 1 << 30
|
||||||
"""The number of bytes in one gibibyte (GiB)."""
|
"""The number of bytes in one gibibyte (GiB)."""
|
||||||
|
|
||||||
|
# ANSI color codes
|
||||||
|
CYAN = '\033[1;36m'
|
||||||
|
RESET = '\033[0;0m'
|
||||||
|
|
||||||
STR_DTYPE_TO_TORCH_DTYPE = {
|
STR_DTYPE_TO_TORCH_DTYPE = {
|
||||||
"half": torch.half,
|
"half": torch.half,
|
||||||
"bfloat16": torch.bfloat16,
|
"bfloat16": torch.bfloat16,
|
||||||
@ -3258,3 +3262,52 @@ def set_process_title(name: str,
|
|||||||
else:
|
else:
|
||||||
name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}"
|
name = f"{envs.VLLM_PROCESS_NAME_PREFIX}::{name}"
|
||||||
setproctitle.setproctitle(name)
|
setproctitle.setproctitle(name)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
||||||
|
"""Prepend each output line with process-specific prefix"""
|
||||||
|
|
||||||
|
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
||||||
|
file_write = file.write
|
||||||
|
|
||||||
|
def write_with_prefix(s: str):
|
||||||
|
if not s:
|
||||||
|
return
|
||||||
|
if file.start_new_line: # type: ignore[attr-defined]
|
||||||
|
file_write(prefix)
|
||||||
|
idx = 0
|
||||||
|
while (next_idx := s.find('\n', idx)) != -1:
|
||||||
|
next_idx += 1
|
||||||
|
file_write(s[idx:next_idx])
|
||||||
|
if next_idx == len(s):
|
||||||
|
file.start_new_line = True # type: ignore[attr-defined]
|
||||||
|
return
|
||||||
|
file_write(prefix)
|
||||||
|
idx = next_idx
|
||||||
|
file_write(s[idx:])
|
||||||
|
file.start_new_line = False # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
file.start_new_line = True # type: ignore[attr-defined]
|
||||||
|
file.write = write_with_prefix # type: ignore[method-assign]
|
||||||
|
|
||||||
|
|
||||||
|
def decorate_logs(process_name: Optional[str] = None) -> None:
|
||||||
|
"""
|
||||||
|
Adds a process-specific prefix to each line of output written to stdout and
|
||||||
|
stderr.
|
||||||
|
|
||||||
|
This function is intended to be called before initializing the api_server,
|
||||||
|
engine_core, or worker classes, so that all subsequent output from the
|
||||||
|
process is prefixed with the process name and PID. This helps distinguish
|
||||||
|
log output from different processes in multi-process environments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
process_name: Optional; the name of the process to use in the prefix.
|
||||||
|
If not provided, the current process name from the multiprocessing
|
||||||
|
context is used.
|
||||||
|
"""
|
||||||
|
if process_name is None:
|
||||||
|
process_name = get_mp_context().current_process().name
|
||||||
|
pid = os.getpid()
|
||||||
|
_add_prefix(sys.stdout, process_name, pid)
|
||||||
|
_add_prefix(sys.stderr, process_name, pid)
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
import os
|
import os
|
||||||
import queue
|
import queue
|
||||||
import signal
|
import signal
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from collections import deque
|
from collections import deque
|
||||||
@ -19,15 +18,14 @@ import zmq
|
|||||||
|
|
||||||
from vllm.config import ParallelConfig, VllmConfig
|
from vllm.config import ParallelConfig, VllmConfig
|
||||||
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
from vllm.distributed import stateless_destroy_torch_distributed_process_group
|
||||||
from vllm.executor.multiproc_worker_utils import _add_prefix
|
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.logging_utils.dump_input import dump_engine_exception
|
from vllm.logging_utils.dump_input import dump_engine_exception
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
from vllm.tasks import POOLING_TASKS, SupportedTask
|
from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||||
from vllm.transformers_utils.config import (
|
from vllm.transformers_utils.config import (
|
||||||
maybe_register_config_serialize_by_value)
|
maybe_register_config_serialize_by_value)
|
||||||
from vllm.utils import (make_zmq_socket, resolve_obj_by_qualname,
|
from vllm.utils import (decorate_logs, make_zmq_socket,
|
||||||
set_process_title)
|
resolve_obj_by_qualname, set_process_title)
|
||||||
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
from vllm.v1.core.kv_cache_utils import (get_kv_cache_config,
|
||||||
unify_kv_cache_configs)
|
unify_kv_cache_configs)
|
||||||
from vllm.v1.core.sched.interface import SchedulerInterface
|
from vllm.v1.core.sched.interface import SchedulerInterface
|
||||||
@ -649,12 +647,14 @@ class EngineCoreProc(EngineCore):
|
|||||||
"vllm_config"].parallel_config
|
"vllm_config"].parallel_config
|
||||||
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
if parallel_config.data_parallel_size > 1 or dp_rank > 0:
|
||||||
set_process_title("DPEngineCore", str(dp_rank))
|
set_process_title("DPEngineCore", str(dp_rank))
|
||||||
|
decorate_logs()
|
||||||
# Set data parallel rank for this engine process.
|
# Set data parallel rank for this engine process.
|
||||||
parallel_config.data_parallel_rank = dp_rank
|
parallel_config.data_parallel_rank = dp_rank
|
||||||
parallel_config.data_parallel_rank_local = local_dp_rank
|
parallel_config.data_parallel_rank_local = local_dp_rank
|
||||||
engine_core = DPEngineCoreProc(*args, **kwargs)
|
engine_core = DPEngineCoreProc(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
set_process_title("EngineCore")
|
set_process_title("EngineCore")
|
||||||
|
decorate_logs()
|
||||||
engine_core = EngineCoreProc(*args, **kwargs)
|
engine_core = EngineCoreProc(*args, **kwargs)
|
||||||
|
|
||||||
engine_core.run_busy_loop()
|
engine_core.run_busy_loop()
|
||||||
@ -905,8 +905,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
log_stats: bool,
|
log_stats: bool,
|
||||||
client_handshake_address: Optional[str] = None,
|
client_handshake_address: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self._decorate_logs()
|
|
||||||
|
|
||||||
# Counts forward-passes of the model so that we can synchronize
|
# Counts forward-passes of the model so that we can synchronize
|
||||||
# finished with DP peers every N steps.
|
# finished with DP peers every N steps.
|
||||||
self.counter = 0
|
self.counter = 0
|
||||||
@ -919,15 +917,6 @@ class DPEngineCoreProc(EngineCoreProc):
|
|||||||
executor_class, log_stats, client_handshake_address,
|
executor_class, log_stats, client_handshake_address,
|
||||||
dp_rank)
|
dp_rank)
|
||||||
|
|
||||||
def _decorate_logs(self):
|
|
||||||
# Add process-specific prefix to stdout and stderr before
|
|
||||||
# we initialize the engine.
|
|
||||||
from multiprocessing import current_process
|
|
||||||
process_name = current_process().name
|
|
||||||
pid = os.getpid()
|
|
||||||
_add_prefix(sys.stdout, process_name, pid)
|
|
||||||
_add_prefix(sys.stderr, process_name, pid)
|
|
||||||
|
|
||||||
def _init_data_parallel(self, vllm_config: VllmConfig):
|
def _init_data_parallel(self, vllm_config: VllmConfig):
|
||||||
|
|
||||||
# Configure GPUs and stateless process group for data parallel.
|
# Configure GPUs and stateless process group for data parallel.
|
||||||
@ -1149,9 +1138,6 @@ class DPEngineCoreActor(DPEngineCoreProc):
|
|||||||
f"{(local_dp_rank + 1) * world_size}) "
|
f"{(local_dp_rank + 1) * world_size}) "
|
||||||
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
f"base value: \"{os.getenv(device_control_env_var)}\"") from e
|
||||||
|
|
||||||
def _decorate_logs(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def _perform_handshakes(self, handshake_address: str, identity: bytes,
|
def _perform_handshakes(self, handshake_address: str, identity: bytes,
|
||||||
local_client: bool, vllm_config: VllmConfig,
|
local_client: bool, vllm_config: VllmConfig,
|
||||||
|
@ -4,7 +4,6 @@ import multiprocessing
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
import signal
|
import signal
|
||||||
import sys
|
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
@ -28,10 +27,11 @@ from vllm.distributed.device_communicators.shm_broadcast import (Handle,
|
|||||||
MessageQueue)
|
MessageQueue)
|
||||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||||
from vllm.executor.multiproc_worker_utils import (
|
from vllm.executor.multiproc_worker_utils import (
|
||||||
_add_prefix, set_multiprocessing_worker_envs)
|
set_multiprocessing_worker_envs)
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import (get_distributed_init_method, get_loopback_ip,
|
from vllm.utils import (decorate_logs, get_distributed_init_method,
|
||||||
get_mp_context, get_open_port, set_process_title)
|
get_loopback_ip, get_mp_context, get_open_port,
|
||||||
|
set_process_title)
|
||||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||||
from vllm.v1.outputs import ModelRunnerOutput
|
from vllm.v1.outputs import ModelRunnerOutput
|
||||||
from vllm.worker.worker_base import WorkerWrapperBase
|
from vllm.worker.worker_base import WorkerWrapperBase
|
||||||
@ -382,11 +382,11 @@ class WorkerProc:
|
|||||||
pp_str = f"PP{rank // tp_size}" if pp_size > 1 else ""
|
pp_str = f"PP{rank // tp_size}" if pp_size > 1 else ""
|
||||||
tp_str = f"TP{rank % tp_size}" if tp_size > 1 else ""
|
tp_str = f"TP{rank % tp_size}" if tp_size > 1 else ""
|
||||||
suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}"
|
suffix = f"{pp_str}{'_' if pp_str and tp_str else ''}{tp_str}"
|
||||||
|
process_name = "VllmWorker"
|
||||||
if suffix:
|
if suffix:
|
||||||
set_process_title(suffix, append=True)
|
set_process_title(suffix, append=True)
|
||||||
pid = os.getpid()
|
process_name = f"{process_name} {suffix}"
|
||||||
_add_prefix(sys.stdout, f"VllmWorker rank={rank}", pid)
|
decorate_logs(process_name)
|
||||||
_add_prefix(sys.stderr, f"VllmWorker rank={rank}", pid)
|
|
||||||
|
|
||||||
# Initialize MessageQueue for receiving SchedulerOutput
|
# Initialize MessageQueue for receiving SchedulerOutput
|
||||||
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
self.rpc_broadcast_mq = MessageQueue.create_from_handle(
|
||||||
|
Reference in New Issue
Block a user