[Misc] adjust for ipv6 for mookcacke url parse (#20107)

Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
Ning Xie
2025-07-04 04:27:17 +08:00
committed by GitHub
parent 71d6de3a26
commit 1dba2c4ebe
3 changed files with 99 additions and 27 deletions

View File

@ -20,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
deprecate_kwargs, get_open_port, get_tcp_uri,
is_lossless_cast, join_host_port, make_zmq_path,
make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6():
def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
def test_get_tcp_uri():
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
def test_split_host_port():
# valid ipv4
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
# invalid ipv4
with pytest.raises(ValueError):
# multi colon
assert split_host_port("127.0.0.1::5555")
with pytest.raises(ValueError):
# tailing colon
assert split_host_port("127.0.0.1:5555:")
with pytest.raises(ValueError):
# no colon
assert split_host_port("127.0.0.15555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("127.0.0.1:5555a")
# valid ipv6
assert split_host_port("[::1]:5555") == ("::1", 5555)
# invalid ipv6
with pytest.raises(ValueError):
# multi colon
assert split_host_port("[::1]::5555")
with pytest.raises(IndexError):
# no colon
assert split_host_port("[::1]5555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("[::1]:5555a")
def test_join_host_port():
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
assert join_host_port("::1", 5555) == "[::1]:5555"

View File

@ -16,6 +16,7 @@ from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port
logger = init_logger(__name__)
NONE_INT = -150886311
@ -79,18 +80,19 @@ class MooncakeTransferEngine:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
decode_host, base_decode_port = self.config.decode_url.split(':')
prefill_host, base_prefill_port = split_host_port(
self.config.prefill_url)
decode_host, base_decode_port = split_host_port(self.config.decode_url)
# Avoid ports conflict when running prefill and decode on the same node
if prefill_host == decode_host and \
base_prefill_port == base_decode_port:
base_decode_port = str(int(base_decode_port) + 100)
base_decode_port = base_decode_port + 100
prefill_port = int(base_prefill_port) + self.local_rank
decode_port = int(base_decode_port) + self.local_rank
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
self.decode_url = ':'.join([decode_host, str(decode_port)])
prefill_port = base_prefill_port + self.local_rank
decode_port = base_decode_port + self.local_rank
self.prefill_url = join_host_port(prefill_host, prefill_port)
self.decode_url = join_host_port(decode_host, decode_port)
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
self.config.metadata_server, self.config.protocol,
@ -110,22 +112,30 @@ class MooncakeTransferEngine:
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
decode_host, base_decode_port)
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
d_host: str, d_port: str) -> None:
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
d_host: str, d_port: int) -> None:
"""Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
p_rank_offset = p_port + 8 + self.local_rank * 2
d_rank_offset = d_port + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
self.sender_socket.bind(
make_zmq_path("tcp", p_host, p_rank_offset + 1))
self.receiver_socket.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 1))
self.sender_ack.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.receiver_ack.bind(
make_zmq_path("tcp", p_host, p_rank_offset + 2))
else:
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
self.receiver_socket.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 1))
self.sender_socket.bind(
make_zmq_path("tcp", d_host, d_rank_offset + 1))
self.receiver_ack.bind(
make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.sender_ack.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 2))
def initialize(self, local_hostname: str, metadata_server: str,
protocol: str, device_name: str,

View File

@ -46,7 +46,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, TypeVar, Union, cast, overload)
Optional, Tuple, TypeVar, Union, cast, overload)
from urllib.parse import urlparse
from uuid import uuid4
@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool:
return False
def split_host_port(host_port: str) -> Tuple[str, int]:
# ipv6
if host_port.startswith('['):
host, port = host_port.rsplit(']', 1)
host = host[1:]
port = port.split(':')[1]
return host, int(port)
else:
host, port = host_port.split(':')
return host, int(port)
def join_host_port(host: str, port: int) -> str:
if is_valid_ipv6_address(host):
return f"[{host}]:{port}"
else:
return f"{host}:{port}"
def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)
def get_tcp_uri(ip: str, port: int) -> str:
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
if is_valid_ipv6_address(ip):
return f"tcp://[{ip}]:{port}"
else:
return f"tcp://{ip}:{port}"
def get_open_zmq_ipc_path() -> str: