mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Chore] Separate out vllm.utils.network_utils
(#27164)
Signed-off-by: iAmir97 <Amir.balwel@embeddedllm.com> Co-authored-by: iAmir97 <Amir.balwel@embeddedllm.com>
This commit is contained in:
@ -33,7 +33,7 @@ import os
|
||||
from time import sleep
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -38,7 +38,7 @@ from rlhf_utils import stateless_init_process_group
|
||||
from transformers import AutoModelForCausalLM
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
|
@ -19,7 +19,7 @@ from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
from vllm import initialize_ray_cluster
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.executor.ray_utils import _wait_until_pg_removed
|
||||
from vllm.utils import get_ip
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
|
||||
|
||||
|
@ -7,7 +7,7 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import _node_count
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
if __name__ == "__main__":
|
||||
dist.init_process_group(backend="gloo")
|
||||
|
@ -7,7 +7,7 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.parallel_state import in_the_same_node_as
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import get_ip, get_open_port
|
||||
from vllm.utils.network_utils import get_ip, get_open_port
|
||||
|
||||
if __name__ == "__main__":
|
||||
dist.init_process_group(backend="gloo")
|
||||
|
@ -10,7 +10,8 @@ import torch.distributed as dist
|
||||
|
||||
from vllm.distributed.device_communicators.shm_broadcast import MessageQueue
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
from vllm.utils import update_environment_variables
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
|
||||
def get_arrays(n: int, seed: int = 0) -> list[np.ndarray]:
|
||||
|
@ -10,10 +10,8 @@ import torch
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.utils import (
|
||||
get_open_port,
|
||||
update_environment_variables,
|
||||
)
|
||||
from vllm.utils import update_environment_variables
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
|
@ -9,7 +9,7 @@ import time
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from ...utils import get_open_port
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM"
|
||||
|
||||
|
@ -12,7 +12,7 @@ from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import init_distributed_environment, initialize_model_parallel
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
|
@ -15,7 +15,8 @@ from torch.distributed import ProcessGroup
|
||||
from torch.multiprocessing import spawn # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port, has_deep_ep
|
||||
from vllm.utils import has_deep_ep
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import (
|
||||
|
@ -8,7 +8,7 @@ from vllm import LLM, EngineArgs
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.model_executor.model_loader import tensorizer as tensorizer_mod
|
||||
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.executor.abstract import UniProcExecutor
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
@ -19,7 +19,8 @@ from vllm.model_executor.models.vision import (
|
||||
run_dp_sharded_vision_model,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_port, update_environment_variables
|
||||
from vllm.utils import update_environment_variables
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
|
||||
pytestmark = pytest.mark.cpu_test
|
||||
|
||||
|
@ -46,9 +46,9 @@ from vllm.platforms import current_platform
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.utils.mem_constants import GB_bytes
|
||||
from vllm.utils.network_utils import get_open_port
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if current_platform.is_rocm():
|
||||
|
126
tests/utils_/test_network_utils.py
Normal file
126
tests/utils_/test_network_utils.py
Normal file
@ -0,0 +1,126 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import socket
|
||||
|
||||
import pytest
|
||||
import zmq
|
||||
|
||||
from vllm.utils.network_utils import (
|
||||
get_open_port,
|
||||
get_tcp_uri,
|
||||
join_host_port,
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
)
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
# make sure we can get multiple ports, even if the env var is set
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
|
||||
s1.bind(("localhost", get_open_port()))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
|
||||
s2.bind(("localhost", get_open_port()))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
|
||||
s3.bind(("localhost", get_open_port()))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,expected",
|
||||
[
|
||||
("ipc://some_path", ("ipc", "some_path", "")),
|
||||
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
|
||||
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
|
||||
("inproc://some_identifier", ("inproc", "some_identifier", "")),
|
||||
],
|
||||
)
|
||||
def test_split_zmq_path(path, expected):
|
||||
assert split_zmq_path(path) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_path",
|
||||
[
|
||||
"invalid_path", # Missing scheme
|
||||
"tcp://127.0.0.1", # Missing port
|
||||
"tcp://[::1]", # Missing port for IPv6
|
||||
"tcp://:5555", # Missing host
|
||||
],
|
||||
)
|
||||
def test_split_zmq_path_invalid(invalid_path):
|
||||
with pytest.raises(ValueError):
|
||||
split_zmq_path(invalid_path)
|
||||
|
||||
|
||||
def test_make_zmq_socket_ipv6():
|
||||
# Check if IPv6 is supported by trying to create an IPv6 socket
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
sock.close()
|
||||
except OSError:
|
||||
pytest.skip("IPv6 is not supported on this system")
|
||||
|
||||
ctx = zmq.Context()
|
||||
ipv6_path = "tcp://[::]:5555" # IPv6 loopback address
|
||||
socket_type = zmq.REP # Example socket type
|
||||
|
||||
# Create the socket
|
||||
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
|
||||
|
||||
# Verify that the IPV6 option is set
|
||||
assert zsock.getsockopt(zmq.IPV6) == 1, (
|
||||
"IPV6 option should be enabled for IPv6 addresses"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
zsock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def test_make_zmq_path():
|
||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_get_tcp_uri():
|
||||
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
||||
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_split_host_port():
|
||||
# valid ipv4
|
||||
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
||||
# invalid ipv4
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("127.0.0.1::5555")
|
||||
with pytest.raises(ValueError):
|
||||
# tailing colon
|
||||
assert split_host_port("127.0.0.1:5555:")
|
||||
with pytest.raises(ValueError):
|
||||
# no colon
|
||||
assert split_host_port("127.0.0.15555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("127.0.0.1:5555a")
|
||||
|
||||
# valid ipv6
|
||||
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
||||
# invalid ipv6
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("[::1]::5555")
|
||||
with pytest.raises(IndexError):
|
||||
# no colon
|
||||
assert split_host_port("[::1]5555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("[::1]:5555a")
|
||||
|
||||
|
||||
def test_join_host_port():
|
||||
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
@ -6,7 +6,6 @@ import hashlib
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import socket
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@ -14,7 +13,6 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
import torch
|
||||
import yaml
|
||||
import zmq
|
||||
from transformers import AutoTokenizer
|
||||
from vllm_test_utils.monitor import monitor
|
||||
|
||||
@ -24,13 +22,6 @@ from vllm.transformers_utils.detokenizer_utils import convert_ids_list_to_tokens
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
bind_kv_cache,
|
||||
get_open_port,
|
||||
get_tcp_uri,
|
||||
join_host_port,
|
||||
make_zmq_path,
|
||||
make_zmq_socket,
|
||||
split_host_port,
|
||||
split_zmq_path,
|
||||
unique_filepath,
|
||||
)
|
||||
from vllm.utils.hashing import sha256
|
||||
@ -43,18 +34,6 @@ from vllm.utils.mem_utils import MemorySnapshot, memory_profiling
|
||||
from ..utils import create_new_process_for_each_test, flat_product
|
||||
|
||||
|
||||
def test_get_open_port(monkeypatch: pytest.MonkeyPatch):
|
||||
with monkeypatch.context() as m:
|
||||
m.setenv("VLLM_PORT", "5678")
|
||||
# make sure we can get multiple ports, even if the env var is set
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s1:
|
||||
s1.bind(("localhost", get_open_port()))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s2:
|
||||
s2.bind(("localhost", get_open_port()))
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s3:
|
||||
s3.bind(("localhost", get_open_port()))
|
||||
|
||||
|
||||
# Tests for FlexibleArgumentParser
|
||||
@pytest.fixture
|
||||
def parser():
|
||||
@ -573,104 +552,6 @@ def test_sha256(input: tuple):
|
||||
assert digest != sha256(input + (1,))
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path,expected",
|
||||
[
|
||||
("ipc://some_path", ("ipc", "some_path", "")),
|
||||
("tcp://127.0.0.1:5555", ("tcp", "127.0.0.1", "5555")),
|
||||
("tcp://[::1]:5555", ("tcp", "::1", "5555")), # IPv6 address
|
||||
("inproc://some_identifier", ("inproc", "some_identifier", "")),
|
||||
],
|
||||
)
|
||||
def test_split_zmq_path(path, expected):
|
||||
assert split_zmq_path(path) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invalid_path",
|
||||
[
|
||||
"invalid_path", # Missing scheme
|
||||
"tcp://127.0.0.1", # Missing port
|
||||
"tcp://[::1]", # Missing port for IPv6
|
||||
"tcp://:5555", # Missing host
|
||||
],
|
||||
)
|
||||
def test_split_zmq_path_invalid(invalid_path):
|
||||
with pytest.raises(ValueError):
|
||||
split_zmq_path(invalid_path)
|
||||
|
||||
|
||||
def test_make_zmq_socket_ipv6():
|
||||
# Check if IPv6 is supported by trying to create an IPv6 socket
|
||||
try:
|
||||
sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
|
||||
sock.close()
|
||||
except socket.error:
|
||||
pytest.skip("IPv6 is not supported on this system")
|
||||
|
||||
ctx = zmq.Context()
|
||||
ipv6_path = "tcp://[::]:5555" # IPv6 loopback address
|
||||
socket_type = zmq.REP # Example socket type
|
||||
|
||||
# Create the socket
|
||||
zsock: zmq.Socket = make_zmq_socket(ctx, ipv6_path, socket_type)
|
||||
|
||||
# Verify that the IPV6 option is set
|
||||
assert zsock.getsockopt(zmq.IPV6) == 1, (
|
||||
"IPV6 option should be enabled for IPv6 addresses"
|
||||
)
|
||||
|
||||
# Clean up
|
||||
zsock.close()
|
||||
ctx.term()
|
||||
|
||||
|
||||
def test_make_zmq_path():
|
||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_get_tcp_uri():
|
||||
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
||||
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_split_host_port():
|
||||
# valid ipv4
|
||||
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
||||
# invalid ipv4
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("127.0.0.1::5555")
|
||||
with pytest.raises(ValueError):
|
||||
# tailing colon
|
||||
assert split_host_port("127.0.0.1:5555:")
|
||||
with pytest.raises(ValueError):
|
||||
# no colon
|
||||
assert split_host_port("127.0.0.15555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("127.0.0.1:5555a")
|
||||
|
||||
# valid ipv6
|
||||
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
||||
# invalid ipv6
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("[::1]::5555")
|
||||
with pytest.raises(IndexError):
|
||||
# no colon
|
||||
assert split_host_port("[::1]5555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("[::1]:5555a")
|
||||
|
||||
|
||||
def test_join_host_port():
|
||||
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||
|
||||
|
||||
def test_convert_ids_list_to_tokens():
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-1.5B-Instruct")
|
||||
token_ids = tokenizer.encode("Hello, world!")
|
||||
|
@ -18,7 +18,7 @@ from vllm.model_executor.layers.batch_invariant import (
|
||||
vllm_is_batch_invariant,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import get_open_ports_list
|
||||
from vllm.utils.network_utils import get_open_ports_list
|
||||
from vllm.utils.torch_utils import cuda_device_count_stateless
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -27,7 +27,7 @@ from zmq import ( # type: ignore
|
||||
import vllm.envs as envs
|
||||
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (
|
||||
from vllm.utils.network_utils import (
|
||||
get_ip,
|
||||
get_open_port,
|
||||
get_open_zmq_ipc_path,
|
||||
|
@ -40,7 +40,7 @@ from vllm.distributed.utils import divide
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.utils.network_utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
|
||||
|
@ -25,7 +25,7 @@ from vllm.distributed.device_communicators.pynccl_wrapper import (
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.p2p.tensor_memory_pool import ( # noqa: E501
|
||||
TensorMemoryPool,
|
||||
)
|
||||
from vllm.utils import get_ip
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.utils.torch_utils import current_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -15,7 +15,7 @@ from safetensors.torch import save as safetensors_save
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
||||
from vllm.utils.network_utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
|
@ -49,10 +49,8 @@ from vllm.distributed.device_communicators.base_device_communicator import (
|
||||
)
|
||||
from vllm.distributed.utils import StatelessProcessGroup
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import (
|
||||
get_distributed_init_method,
|
||||
)
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.network_utils import get_distributed_init_method
|
||||
from vllm.utils.torch_utils import (
|
||||
direct_register_custom_op,
|
||||
supports_custom_op,
|
||||
|
@ -29,7 +29,7 @@ from torch.distributed.rendezvous import rendezvous
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_tcp_uri
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.utils.torch_utils import is_torch_equal_or_newer
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -81,8 +81,9 @@ from vllm.transformers_utils.config import (
|
||||
maybe_override_with_speculators,
|
||||
)
|
||||
from vllm.transformers_utils.utils import check_gguf_file
|
||||
from vllm.utils import FlexibleArgumentParser, get_ip, is_in_ray_actor
|
||||
from vllm.utils import FlexibleArgumentParser, is_in_ray_actor
|
||||
from vllm.utils.mem_constants import GiB_bytes
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.sample.logits_processor import LogitsProcessor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
@ -21,9 +21,9 @@ from vllm.usage.usage_lib import UsageContext
|
||||
from vllm.utils import (
|
||||
FlexibleArgumentParser,
|
||||
decorate_logs,
|
||||
get_tcp_uri,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.utils.network_utils import get_tcp_uri
|
||||
from vllm.v1.engine.core import EngineCoreProc
|
||||
from vllm.v1.engine.utils import CoreEngineProcManager, launch_core_engines
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
|
@ -18,7 +18,7 @@ from vllm.entrypoints.constants import (
|
||||
)
|
||||
from vllm.entrypoints.ssl import SSLCertRefresher
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import find_process_using_port
|
||||
from vllm.utils.network_utils import find_process_using_port
|
||||
from vllm.v1.engine.exceptions import EngineDeadError, EngineGenerateError
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -115,9 +115,9 @@ from vllm.utils import (
|
||||
Device,
|
||||
FlexibleArgumentParser,
|
||||
decorate_logs,
|
||||
is_valid_ipv6_address,
|
||||
set_ulimit,
|
||||
)
|
||||
from vllm.utils.network_utils import is_valid_ipv6_address
|
||||
from vllm.v1.engine.exceptions import EngineDeadError
|
||||
from vllm.v1.metrics.prometheus import get_prometheus_registry
|
||||
from vllm.version import __version__ as VLLM_VERSION
|
||||
|
@ -19,12 +19,12 @@ from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (
|
||||
from vllm.utils.asyncio import make_async
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_ip,
|
||||
get_open_port,
|
||||
)
|
||||
from vllm.utils.asyncio import make_async
|
||||
from vllm.v1.outputs import SamplerOutput
|
||||
|
||||
if ray is not None:
|
||||
|
@ -15,7 +15,7 @@ from vllm.executor.msgspec_utils import decode_hook, encode_hook
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.sequence import ExecuteModelRequest, IntermediateTensors
|
||||
from vllm.utils import get_ip
|
||||
from vllm.utils.network_utils import get_ip
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
||||
|
@ -13,7 +13,8 @@ import torch.distributed as dist
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import ExecutorBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_distributed_init_method, get_ip, get_open_port, run_method
|
||||
from vllm.utils import run_method
|
||||
from vllm.utils.network_utils import get_distributed_init_method, get_ip, get_open_port
|
||||
from vllm.v1.engine import ReconfigureDistributedRequest, ReconfigureRankType
|
||||
from vllm.v1.outputs import AsyncModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerWrapperBase
|
||||
|
@ -7,12 +7,10 @@ import enum
|
||||
import getpass
|
||||
import importlib
|
||||
import inspect
|
||||
import ipaddress
|
||||
import json
|
||||
import multiprocessing
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
@ -33,15 +31,12 @@ from argparse import (
|
||||
from collections import defaultdict
|
||||
from collections.abc import (
|
||||
Callable,
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from concurrent.futures.process import ProcessPoolExecutor
|
||||
from functools import cache, partial, wraps
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TextIO, TypeVar
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import cloudpickle
|
||||
import psutil
|
||||
@ -49,34 +44,36 @@ import regex as re
|
||||
import setproctitle
|
||||
import torch
|
||||
import yaml
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import enable_trace_function_call, init_logger
|
||||
from vllm.ray.lazy_utils import is_in_ray_actor
|
||||
|
||||
_DEPRECATED_PROFILING = {"cprofile", "cprofile_context"}
|
||||
_DEPRECATED_MAPPINGS = {
|
||||
"cprofile": "profiling",
|
||||
"cprofile_context": "profiling",
|
||||
"get_open_port": "network_utils",
|
||||
}
|
||||
|
||||
|
||||
def __getattr__(name: str) -> Any: # noqa: D401 - short deprecation docstring
|
||||
"""Module-level getattr to handle deprecated profiling utilities."""
|
||||
if name in _DEPRECATED_PROFILING:
|
||||
"""Module-level getattr to handle deprecated utilities."""
|
||||
if name in _DEPRECATED_MAPPINGS:
|
||||
submodule_name = _DEPRECATED_MAPPINGS[name]
|
||||
warnings.warn(
|
||||
f"vllm.utils.{name} is deprecated and will be removed in a future version. "
|
||||
f"Use vllm.utils.profiling.{name} instead.",
|
||||
f"Use vllm.utils.{submodule_name}.{name} instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
import vllm.utils.profiling as _prof
|
||||
|
||||
return getattr(_prof, name)
|
||||
module = __import__(f"vllm.utils.{submodule_name}", fromlist=[submodule_name])
|
||||
return getattr(module, name)
|
||||
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
||||
|
||||
|
||||
def __dir__() -> list[str]:
|
||||
# expose deprecated names in dir() for better UX/tab-completion
|
||||
return sorted(list(globals().keys()) + list(_DEPRECATED_PROFILING))
|
||||
return sorted(list(globals().keys()) + list(_DEPRECATED_MAPPINGS.keys()))
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -149,197 +146,6 @@ def random_uuid() -> str:
|
||||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
|
||||
for sock in sockets:
|
||||
if sock is not None:
|
||||
sock.close(linger=0)
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = envs.VLLM_HOST_IP
|
||||
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
|
||||
logger.warning(
|
||||
"The environment variable HOST_IP is deprecated and ignored, as"
|
||||
" it is often used by Docker and other software to"
|
||||
" interact with the container's network stack. Please "
|
||||
"use VLLM_HOST_IP instead to set the IP address for vLLM processes"
|
||||
" to communicate with each other."
|
||||
)
|
||||
if host_ip:
|
||||
return host_ip
|
||||
|
||||
# IP is not set, try to get it from the network interface
|
||||
|
||||
# try ipv4
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
try:
|
||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try ipv6
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
||||
# Google's public DNS server, see
|
||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
warnings.warn(
|
||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||
"The value can be set by the environment variable"
|
||||
" VLLM_HOST_IP or HOST_IP.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def test_loopback_bind(address, family):
|
||||
try:
|
||||
s = socket.socket(family, socket.SOCK_DGRAM)
|
||||
s.bind((address, 0)) # Port 0 = auto assign
|
||||
s.close()
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def get_loopback_ip() -> str:
|
||||
loopback_ip = envs.VLLM_LOOPBACK_IP
|
||||
if loopback_ip:
|
||||
return loopback_ip
|
||||
|
||||
# VLLM_LOOPBACK_IP is not set, try to get it based on network interface
|
||||
|
||||
if test_loopback_bind("127.0.0.1", socket.AF_INET):
|
||||
return "127.0.0.1"
|
||||
elif test_loopback_bind("::1", socket.AF_INET6):
|
||||
return "::1"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Neither 127.0.0.1 nor ::1 are bound to a local interface. "
|
||||
"Set the VLLM_LOOPBACK_IP environment variable explicitly."
|
||||
)
|
||||
|
||||
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def split_host_port(host_port: str) -> tuple[str, int]:
|
||||
# ipv6
|
||||
if host_port.startswith("["):
|
||||
host, port = host_port.rsplit("]", 1)
|
||||
host = host[1:]
|
||||
port = port.split(":")[1]
|
||||
return host, int(port)
|
||||
else:
|
||||
host, port = host_port.split(":")
|
||||
return host, int(port)
|
||||
|
||||
|
||||
def join_host_port(host: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"[{host}]:{port}"
|
||||
else:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return get_tcp_uri(ip, port)
|
||||
|
||||
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(ip):
|
||||
return f"tcp://[{ip}]:{port}"
|
||||
else:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_zmq_ipc_path() -> str:
|
||||
base_rpc_path = envs.VLLM_RPC_BASE_PATH
|
||||
return f"ipc://{base_rpc_path}/{uuid4()}"
|
||||
|
||||
|
||||
def get_open_zmq_inproc_path() -> str:
|
||||
return f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
"""
|
||||
Get an open port for the vLLM process to listen on.
|
||||
An edge case to handle, is when we run data parallel,
|
||||
we need to avoid ports that are potentially used by
|
||||
the data parallel master process.
|
||||
Right now we reserve 10 ports for the data parallel master
|
||||
process. Currently it uses 2 ports.
|
||||
"""
|
||||
if "VLLM_DP_MASTER_PORT" in os.environ:
|
||||
dp_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
reserved_port_range = range(dp_master_port, dp_master_port + 10)
|
||||
while True:
|
||||
candidate_port = _get_open_port()
|
||||
if candidate_port not in reserved_port_range:
|
||||
return candidate_port
|
||||
return _get_open_port()
|
||||
|
||||
|
||||
def get_open_ports_list(count: int = 5) -> list[int]:
|
||||
"""Get a list of open ports."""
|
||||
ports = set[int]()
|
||||
while len(ports) < count:
|
||||
ports.add(get_open_port())
|
||||
return list(ports)
|
||||
|
||||
|
||||
def _get_open_port() -> int:
|
||||
port = envs.VLLM_PORT
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1 # Increment port number if already in use
|
||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def find_process_using_port(port: int) -> psutil.Process | None:
|
||||
# TODO: We can not check for running processes with network
|
||||
# port on macOS. Therefore, we can not have a full graceful shutdown
|
||||
# of vLLM. For now, let's not look for processes in this case.
|
||||
# Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
|
||||
if sys.platform.startswith("darwin"):
|
||||
return None
|
||||
|
||||
our_pid = os.getpid()
|
||||
for conn in psutil.net_connections():
|
||||
if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
|
||||
try:
|
||||
return psutil.Process(conn.pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def update_environment_variables(envs: dict[str, str]):
|
||||
for k, v in envs.items():
|
||||
if k in os.environ and os.environ[k] != v:
|
||||
@ -1119,122 +925,6 @@ def get_exception_traceback():
|
||||
return err_str
|
||||
|
||||
|
||||
def split_zmq_path(path: str) -> tuple[str, str, str]:
|
||||
"""Split a zmq path into its parts."""
|
||||
parsed = urlparse(path)
|
||||
if not parsed.scheme:
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
scheme = parsed.scheme
|
||||
host = parsed.hostname or ""
|
||||
port = str(parsed.port or "")
|
||||
|
||||
if scheme == "tcp" and not all((host, port)):
|
||||
# The host and port fields are required for tcp
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
if scheme != "tcp" and port:
|
||||
# port only makes sense with tcp
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
return scheme, host, port
|
||||
|
||||
|
||||
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
|
||||
"""Make a ZMQ path from its parts.
|
||||
|
||||
Args:
|
||||
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
|
||||
host: The host - can be an IPv4 address, IPv6 address, or hostname.
|
||||
port: Optional port number, only used for TCP sockets.
|
||||
|
||||
Returns:
|
||||
A properly formatted ZMQ path string.
|
||||
"""
|
||||
if port is None:
|
||||
return f"{scheme}://{host}"
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"{scheme}://[{host}]:{port}"
|
||||
return f"{scheme}://{host}:{port}"
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||
def make_zmq_socket(
|
||||
ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined]
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: bool | None = None,
|
||||
identity: bytes | None = None,
|
||||
linger: int | None = None,
|
||||
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
socket = ctx.socket(socket_type)
|
||||
|
||||
# Calculate buffer size based on system memory
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
# For systems with substantial memory (>32GB total, >16GB available):
|
||||
# - Set a large 0.5GB buffer to improve throughput
|
||||
# For systems with less memory:
|
||||
# - Use system default (-1) to avoid excessive memory consumption
|
||||
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
|
||||
|
||||
if bind is None:
|
||||
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||
|
||||
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||
|
||||
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
|
||||
if identity is not None:
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
if linger is not None:
|
||||
socket.setsockopt(zmq.LINGER, linger)
|
||||
|
||||
if socket_type == zmq.XPUB:
|
||||
socket.setsockopt(zmq.XPUB_VERBOSE, True)
|
||||
|
||||
# Determine if the path is a TCP socket with an IPv6 address.
|
||||
# Enable IPv6 on the zmq socket if so.
|
||||
scheme, host, _ = split_zmq_path(path)
|
||||
if scheme == "tcp" and is_valid_ipv6_address(host):
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
|
||||
if bind:
|
||||
socket.bind(path)
|
||||
else:
|
||||
socket.connect(path)
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: bool | None = None,
|
||||
linger: int = 0,
|
||||
identity: bytes | None = None,
|
||||
) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=linger)
|
||||
|
||||
|
||||
def _maybe_force_spawn():
|
||||
"""Check if we need to force the use of the `spawn` multiprocessing start
|
||||
method.
|
||||
|
331
vllm/utils/network_utils.py
Normal file
331
vllm/utils/network_utils.py
Normal file
@ -0,0 +1,331 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import ipaddress
|
||||
import os
|
||||
import socket
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
Iterator,
|
||||
Sequence,
|
||||
)
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
import psutil
|
||||
import zmq
|
||||
import zmq.asyncio
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def close_sockets(sockets: Sequence[zmq.Socket | zmq.asyncio.Socket]):
|
||||
for sock in sockets:
|
||||
if sock is not None:
|
||||
sock.close(linger=0)
|
||||
|
||||
|
||||
def get_ip() -> str:
|
||||
host_ip = envs.VLLM_HOST_IP
|
||||
if "HOST_IP" in os.environ and "VLLM_HOST_IP" not in os.environ:
|
||||
logger.warning(
|
||||
"The environment variable HOST_IP is deprecated and ignored, as"
|
||||
" it is often used by Docker and other software to"
|
||||
" interact with the container's network stack. Please "
|
||||
"use VLLM_HOST_IP instead to set the IP address for vLLM processes"
|
||||
" to communicate with each other."
|
||||
)
|
||||
if host_ip:
|
||||
return host_ip
|
||||
|
||||
# IP is not set, try to get it from the network interface
|
||||
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
|
||||
s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# try ipv6
|
||||
try:
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as s:
|
||||
# Google's public DNS server, see
|
||||
# https://developers.google.com/speed/public-dns/docs/using#addresses
|
||||
s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable
|
||||
return s.getsockname()[0]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
warnings.warn(
|
||||
"Failed to get the IP address, using 0.0.0.0 by default."
|
||||
"The value can be set by the environment variable"
|
||||
" VLLM_HOST_IP or HOST_IP.",
|
||||
stacklevel=2,
|
||||
)
|
||||
return "0.0.0.0"
|
||||
|
||||
|
||||
def test_loopback_bind(address, family):
|
||||
try:
|
||||
s = socket.socket(family, socket.SOCK_DGRAM)
|
||||
s.bind((address, 0)) # Port 0 = auto assign
|
||||
s.close()
|
||||
return True
|
||||
except OSError:
|
||||
return False
|
||||
|
||||
|
||||
def get_loopback_ip() -> str:
|
||||
loopback_ip = envs.VLLM_LOOPBACK_IP
|
||||
if loopback_ip:
|
||||
return loopback_ip
|
||||
|
||||
# VLLM_LOOPBACK_IP is not set, try to get it based on network interface
|
||||
|
||||
if test_loopback_bind("127.0.0.1", socket.AF_INET):
|
||||
return "127.0.0.1"
|
||||
elif test_loopback_bind("::1", socket.AF_INET6):
|
||||
return "::1"
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Neither 127.0.0.1 nor ::1 are bound to a local interface. "
|
||||
"Set the VLLM_LOOPBACK_IP environment variable explicitly."
|
||||
)
|
||||
|
||||
|
||||
def is_valid_ipv6_address(address: str) -> bool:
|
||||
try:
|
||||
ipaddress.IPv6Address(address)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def split_host_port(host_port: str) -> tuple[str, int]:
|
||||
# ipv6
|
||||
if host_port.startswith("["):
|
||||
host, port = host_port.rsplit("]", 1)
|
||||
host = host[1:]
|
||||
port = port.split(":")[1]
|
||||
return host, int(port)
|
||||
else:
|
||||
host, port = host_port.split(":")
|
||||
return host, int(port)
|
||||
|
||||
|
||||
def join_host_port(host: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"[{host}]:{port}"
|
||||
else:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return get_tcp_uri(ip, port)
|
||||
|
||||
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(ip):
|
||||
return f"tcp://[{ip}]:{port}"
|
||||
else:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_zmq_ipc_path() -> str:
|
||||
base_rpc_path = envs.VLLM_RPC_BASE_PATH
|
||||
return f"ipc://{base_rpc_path}/{uuid4()}"
|
||||
|
||||
|
||||
def get_open_zmq_inproc_path() -> str:
|
||||
return f"inproc://{uuid4()}"
|
||||
|
||||
|
||||
def get_open_port() -> int:
|
||||
"""
|
||||
Get an open port for the vLLM process to listen on.
|
||||
An edge case to handle, is when we run data parallel,
|
||||
we need to avoid ports that are potentially used by
|
||||
the data parallel master process.
|
||||
Right now we reserve 10 ports for the data parallel master
|
||||
process. Currently it uses 2 ports.
|
||||
"""
|
||||
if "VLLM_DP_MASTER_PORT" in os.environ:
|
||||
dp_master_port = envs.VLLM_DP_MASTER_PORT
|
||||
reserved_port_range = range(dp_master_port, dp_master_port + 10)
|
||||
while True:
|
||||
candidate_port = _get_open_port()
|
||||
if candidate_port not in reserved_port_range:
|
||||
return candidate_port
|
||||
return _get_open_port()
|
||||
|
||||
|
||||
def get_open_ports_list(count: int = 5) -> list[int]:
|
||||
"""Get a list of open ports."""
|
||||
ports = set[int]()
|
||||
while len(ports) < count:
|
||||
ports.add(get_open_port())
|
||||
return list(ports)
|
||||
|
||||
|
||||
def _get_open_port() -> int:
|
||||
port = envs.VLLM_PORT
|
||||
if port is not None:
|
||||
while True:
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", port))
|
||||
return port
|
||||
except OSError:
|
||||
port += 1 # Increment port number if already in use
|
||||
logger.info("Port %d is already in use, trying port %d", port - 1, port)
|
||||
# try ipv4
|
||||
try:
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
except OSError:
|
||||
# try ipv6
|
||||
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
||||
s.bind(("", 0))
|
||||
return s.getsockname()[1]
|
||||
|
||||
|
||||
def find_process_using_port(port: int) -> psutil.Process | None:
|
||||
# TODO: We can not check for running processes with network
|
||||
# port on macOS. Therefore, we can not have a full graceful shutdown
|
||||
# of vLLM. For now, let's not look for processes in this case.
|
||||
# Ref: https://www.florianreinhard.de/accessdenied-in-psutil/
|
||||
if sys.platform.startswith("darwin"):
|
||||
return None
|
||||
|
||||
our_pid = os.getpid()
|
||||
for conn in psutil.net_connections():
|
||||
if conn.laddr.port == port and (conn.pid is not None and conn.pid != our_pid):
|
||||
try:
|
||||
return psutil.Process(conn.pid)
|
||||
except psutil.NoSuchProcess:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
def split_zmq_path(path: str) -> tuple[str, str, str]:
|
||||
"""Split a zmq path into its parts."""
|
||||
parsed = urlparse(path)
|
||||
if not parsed.scheme:
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
scheme = parsed.scheme
|
||||
host = parsed.hostname or ""
|
||||
port = str(parsed.port or "")
|
||||
|
||||
if scheme == "tcp" and not all((host, port)):
|
||||
# The host and port fields are required for tcp
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
if scheme != "tcp" and port:
|
||||
# port only makes sense with tcp
|
||||
raise ValueError(f"Invalid zmq path: {path}")
|
||||
|
||||
return scheme, host, port
|
||||
|
||||
|
||||
def make_zmq_path(scheme: str, host: str, port: int | None = None) -> str:
|
||||
"""Make a ZMQ path from its parts.
|
||||
|
||||
Args:
|
||||
scheme: The ZMQ transport scheme (e.g. tcp, ipc, inproc).
|
||||
host: The host - can be an IPv4 address, IPv6 address, or hostname.
|
||||
port: Optional port number, only used for TCP sockets.
|
||||
|
||||
Returns:
|
||||
A properly formatted ZMQ path string.
|
||||
"""
|
||||
if port is None:
|
||||
return f"{scheme}://{host}"
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"{scheme}://[{host}]:{port}"
|
||||
return f"{scheme}://{host}:{port}"
|
||||
|
||||
|
||||
# Adapted from: https://github.com/sgl-project/sglang/blob/v0.4.1/python/sglang/srt/utils.py#L783 # noqa: E501
|
||||
def make_zmq_socket(
|
||||
ctx: zmq.asyncio.Context | zmq.Context, # type: ignore[name-defined]
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: bool | None = None,
|
||||
identity: bytes | None = None,
|
||||
linger: int | None = None,
|
||||
) -> zmq.Socket | zmq.asyncio.Socket: # type: ignore[name-defined]
|
||||
"""Make a ZMQ socket with the proper bind/connect semantics."""
|
||||
|
||||
mem = psutil.virtual_memory()
|
||||
socket = ctx.socket(socket_type)
|
||||
|
||||
# Calculate buffer size based on system memory
|
||||
total_mem = mem.total / 1024**3
|
||||
available_mem = mem.available / 1024**3
|
||||
# For systems with substantial memory (>32GB total, >16GB available):
|
||||
# - Set a large 0.5GB buffer to improve throughput
|
||||
# For systems with less memory:
|
||||
# - Use system default (-1) to avoid excessive memory consumption
|
||||
buf_size = int(0.5 * 1024**3) if total_mem > 32 and available_mem > 16 else -1
|
||||
|
||||
if bind is None:
|
||||
bind = socket_type not in (zmq.PUSH, zmq.SUB, zmq.XSUB)
|
||||
|
||||
if socket_type in (zmq.PULL, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.RCVHWM, 0)
|
||||
socket.setsockopt(zmq.RCVBUF, buf_size)
|
||||
|
||||
if socket_type in (zmq.PUSH, zmq.DEALER, zmq.ROUTER):
|
||||
socket.setsockopt(zmq.SNDHWM, 0)
|
||||
socket.setsockopt(zmq.SNDBUF, buf_size)
|
||||
|
||||
if identity is not None:
|
||||
socket.setsockopt(zmq.IDENTITY, identity)
|
||||
|
||||
if linger is not None:
|
||||
socket.setsockopt(zmq.LINGER, linger)
|
||||
|
||||
if socket_type == zmq.XPUB:
|
||||
socket.setsockopt(zmq.XPUB_VERBOSE, True)
|
||||
|
||||
# Determine if the path is a TCP socket with an IPv6 address.
|
||||
# Enable IPv6 on the zmq socket if so.
|
||||
scheme, host, _ = split_zmq_path(path)
|
||||
if scheme == "tcp" and is_valid_ipv6_address(host):
|
||||
socket.setsockopt(zmq.IPV6, 1)
|
||||
|
||||
if bind:
|
||||
socket.bind(path)
|
||||
else:
|
||||
socket.connect(path)
|
||||
|
||||
return socket
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_socket_ctx(
|
||||
path: str,
|
||||
socket_type: Any,
|
||||
bind: bool | None = None,
|
||||
linger: int = 0,
|
||||
identity: bytes | None = None,
|
||||
) -> Iterator[zmq.Socket]:
|
||||
"""Context manager for a ZMQ socket"""
|
||||
|
||||
ctx = zmq.Context() # type: ignore[attr-defined]
|
||||
try:
|
||||
yield make_zmq_socket(ctx, path, socket_type, bind=bind, identity=identity)
|
||||
except KeyboardInterrupt:
|
||||
logger.debug("Got Keyboard Interrupt.")
|
||||
|
||||
finally:
|
||||
ctx.destroy(linger=linger)
|
@ -10,7 +10,8 @@ import zmq
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_mp_context, make_zmq_socket, set_process_title
|
||||
from vllm.utils import get_mp_context, set_process_title
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.engine import EngineCoreOutputs, EngineCoreRequestType
|
||||
from vllm.v1.serial_utils import MsgpackDecoder
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
@ -30,12 +30,12 @@ from vllm.tasks import POOLING_TASKS, SupportedTask
|
||||
from vllm.transformers_utils.config import maybe_register_config_serialize_by_value
|
||||
from vllm.utils import (
|
||||
decorate_logs,
|
||||
make_zmq_socket,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.utils.gc_utils import maybe_attach_gc_debug_callback
|
||||
from vllm.utils.hashing import get_hash_fn_by_name
|
||||
from vllm.utils.import_utils import resolve_obj_by_qualname
|
||||
from vllm.utils.network_utils import make_zmq_socket
|
||||
from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
generate_scheduler_kv_cache_config,
|
||||
|
@ -23,13 +23,13 @@ from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.request import LoRARequest
|
||||
from vllm.tasks import SupportedTask
|
||||
from vllm.utils import (
|
||||
from vllm.utils.asyncio import in_loop
|
||||
from vllm.utils.network_utils import (
|
||||
close_sockets,
|
||||
get_open_port,
|
||||
get_open_zmq_inproc_path,
|
||||
make_zmq_socket,
|
||||
)
|
||||
from vllm.utils.asyncio import in_loop
|
||||
from vllm.v1.engine import (
|
||||
EngineCoreOutputs,
|
||||
EngineCoreRequest,
|
||||
|
@ -20,7 +20,8 @@ from vllm.config import CacheConfig, ParallelConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.utils import get_mp_context, get_open_zmq_ipc_path, zmq_socket_ctx
|
||||
from vllm.utils import get_mp_context
|
||||
from vllm.utils.network_utils import get_open_zmq_ipc_path, zmq_socket_ctx
|
||||
from vllm.v1.engine.coordinator import DPCoordinator
|
||||
from vllm.v1.executor.abstract import Executor
|
||||
from vllm.v1.utils import get_engine_client_zmq_addr, shutdown
|
||||
|
@ -38,11 +38,13 @@ from vllm.logger import init_logger
|
||||
from vllm.utils import (
|
||||
_maybe_force_spawn,
|
||||
decorate_logs,
|
||||
get_mp_context,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.utils.network_utils import (
|
||||
get_distributed_init_method,
|
||||
get_loopback_ip,
|
||||
get_mp_context,
|
||||
get_open_port,
|
||||
set_process_title,
|
||||
)
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.executor.abstract import Executor, FailureCallback
|
||||
|
@ -25,12 +25,8 @@ from torch.autograd.profiler import record_function
|
||||
import vllm.envs as envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.usage.usage_lib import UsageContext, is_usage_stats_enabled, usage_message
|
||||
from vllm.utils import (
|
||||
get_open_port,
|
||||
get_open_zmq_ipc_path,
|
||||
get_tcp_uri,
|
||||
kill_process_tree,
|
||||
)
|
||||
from vllm.utils import kill_process_tree
|
||||
from vllm.utils.network_utils import get_open_port, get_open_zmq_ipc_path, get_tcp_uri
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
Reference in New Issue
Block a user