diff --git a/tests/test_utils.py b/tests/test_utils.py index 36db8202ba..a165d2d721 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -20,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, MemorySnapshot, PlaceholderModule, StoreBoolean, bind_kv_cache, common_broadcastable_dtype, - deprecate_kwargs, get_open_port, is_lossless_cast, - make_zmq_path, make_zmq_socket, memory_profiling, - merge_async_iterators, sha256, split_zmq_path, - supports_kw, swap_dict_values) + deprecate_kwargs, get_open_port, get_tcp_uri, + is_lossless_cast, join_host_port, make_zmq_path, + make_zmq_socket, memory_profiling, + merge_async_iterators, sha256, split_host_port, + split_zmq_path, supports_kw, swap_dict_values) from .utils import create_new_process_for_each_test, error_on_warning @@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6(): def test_make_zmq_path(): assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" + + +def test_get_tcp_uri(): + assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555" + assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555" + + +def test_split_host_port(): + # valid ipv4 + assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555) + # invalid ipv4 + with pytest.raises(ValueError): + # multi colon + assert split_host_port("127.0.0.1::5555") + with pytest.raises(ValueError): + # tailing colon + assert split_host_port("127.0.0.1:5555:") + with pytest.raises(ValueError): + # no colon + assert split_host_port("127.0.0.15555") + with pytest.raises(ValueError): + # none int port + assert split_host_port("127.0.0.1:5555a") + + # valid ipv6 + assert split_host_port("[::1]:5555") == ("::1", 5555) + # invalid ipv6 + with pytest.raises(ValueError): + # multi colon + assert split_host_port("[::1]::5555") + with pytest.raises(IndexError): + # no colon + assert split_host_port("[::1]5555") + with pytest.raises(ValueError): + # none int port + assert split_host_port("[::1]:5555a") + + +def test_join_host_port(): + assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555" + assert join_host_port("::1", 5555) == "[::1]:5555" diff --git a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py index 9f3494b810..0b560d1b3b 100644 --- a/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +++ b/vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py @@ -16,6 +16,7 @@ from safetensors.torch import save as safetensors_save from vllm.config import KVTransferConfig from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.logger import init_logger +from vllm.utils import join_host_port, make_zmq_path, split_host_port logger = init_logger(__name__) NONE_INT = -150886311 @@ -79,18 +80,19 @@ class MooncakeTransferEngine: logger.error( "An error occurred while loading the configuration: %s", exc) raise - prefill_host, base_prefill_port = self.config.prefill_url.split(':') - decode_host, base_decode_port = self.config.decode_url.split(':') + prefill_host, base_prefill_port = split_host_port( + self.config.prefill_url) + decode_host, base_decode_port = split_host_port(self.config.decode_url) # Avoid ports conflict when running prefill and decode on the same node if prefill_host == decode_host and \ base_prefill_port == base_decode_port: - base_decode_port = str(int(base_decode_port) + 100) + base_decode_port = base_decode_port + 100 - prefill_port = int(base_prefill_port) + self.local_rank - decode_port = int(base_decode_port) + self.local_rank - self.prefill_url = ':'.join([prefill_host, str(prefill_port)]) - self.decode_url = ':'.join([decode_host, str(decode_port)]) + prefill_port = base_prefill_port + self.local_rank + decode_port = base_decode_port + self.local_rank + self.prefill_url = join_host_port(prefill_host, prefill_port) + self.decode_url = join_host_port(decode_host, decode_port) self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, self.config.metadata_server, self.config.protocol, @@ -110,22 +112,30 @@ class MooncakeTransferEngine: self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, decode_host, base_decode_port) - def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str, - d_host: str, d_port: str) -> None: + def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int, + d_host: str, d_port: int) -> None: """Set up ZeroMQ sockets for sending and receiving data.""" # Offsets < 8 are left for initialization in case tp and pp are enabled - p_rank_offset = int(p_port) + 8 + self.local_rank * 2 - d_rank_offset = int(d_port) + 8 + self.local_rank * 2 + p_rank_offset = p_port + 8 + self.local_rank * 2 + d_rank_offset = d_port + 8 + self.local_rank * 2 if kv_rank == 0: - self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}") - self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}") - self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}") - self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}") + self.sender_socket.bind( + make_zmq_path("tcp", p_host, p_rank_offset + 1)) + self.receiver_socket.connect( + make_zmq_path("tcp", d_host, d_rank_offset + 1)) + self.sender_ack.connect( + make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.receiver_ack.bind( + make_zmq_path("tcp", p_host, p_rank_offset + 2)) else: - self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}") - self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}") - self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}") - self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}") + self.receiver_socket.connect( + make_zmq_path("tcp", p_host, p_rank_offset + 1)) + self.sender_socket.bind( + make_zmq_path("tcp", d_host, d_rank_offset + 1)) + self.receiver_ack.bind( + make_zmq_path("tcp", d_host, d_rank_offset + 2)) + self.sender_ack.connect( + make_zmq_path("tcp", p_host, p_rank_offset + 2)) def initialize(self, local_hostname: str, metadata_server: str, protocol: str, device_name: str, diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index d97d873ccf..ccfbf56927 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -46,7 +46,7 @@ from dataclasses import dataclass, field from functools import cache, lru_cache, partial, wraps from types import MappingProxyType from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, - Optional, TypeVar, Union, cast, overload) + Optional, Tuple, TypeVar, Union, cast, overload) from urllib.parse import urlparse from uuid import uuid4 @@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool: return False +def split_host_port(host_port: str) -> Tuple[str, int]: + # ipv6 + if host_port.startswith('['): + host, port = host_port.rsplit(']', 1) + host = host[1:] + port = port.split(':')[1] + return host, int(port) + else: + host, port = host_port.split(':') + return host, int(port) + + +def join_host_port(host: str, port: int) -> str: + if is_valid_ipv6_address(host): + return f"[{host}]:{port}" + else: + return f"{host}:{port}" + + def get_distributed_init_method(ip: str, port: int) -> str: return get_tcp_uri(ip, port) def get_tcp_uri(ip: str, port: int) -> str: - # Brackets are not permitted in ipv4 addresses, - # see https://github.com/python/cpython/issues/103848 - return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" + if is_valid_ipv6_address(ip): + return f"tcp://[{ip}]:{port}" + else: + return f"tcp://{ip}:{port}" def get_open_zmq_ipc_path() -> str: