mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] adjust for ipv6 for mookcacke url parse (#20107)
Signed-off-by: Andy Xie <andy.xning@gmail.com>
This commit is contained in:
@ -20,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, common_broadcastable_dtype,
|
||||
deprecate_kwargs, get_open_port, is_lossless_cast,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_zmq_path,
|
||||
supports_kw, swap_dict_values)
|
||||
deprecate_kwargs, get_open_port, get_tcp_uri,
|
||||
is_lossless_cast, join_host_port, make_zmq_path,
|
||||
make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_host_port,
|
||||
split_zmq_path, supports_kw, swap_dict_values)
|
||||
|
||||
from .utils import create_new_process_for_each_test, error_on_warning
|
||||
|
||||
@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6():
|
||||
def test_make_zmq_path():
|
||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_get_tcp_uri():
|
||||
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
||||
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_split_host_port():
|
||||
# valid ipv4
|
||||
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
||||
# invalid ipv4
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("127.0.0.1::5555")
|
||||
with pytest.raises(ValueError):
|
||||
# tailing colon
|
||||
assert split_host_port("127.0.0.1:5555:")
|
||||
with pytest.raises(ValueError):
|
||||
# no colon
|
||||
assert split_host_port("127.0.0.15555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("127.0.0.1:5555a")
|
||||
|
||||
# valid ipv6
|
||||
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
||||
# invalid ipv6
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("[::1]::5555")
|
||||
with pytest.raises(IndexError):
|
||||
# no colon
|
||||
assert split_host_port("[::1]5555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("[::1]:5555a")
|
||||
|
||||
|
||||
def test_join_host_port():
|
||||
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||
|
@ -16,6 +16,7 @@ from safetensors.torch import save as safetensors_save
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
@ -79,18 +80,19 @@ class MooncakeTransferEngine:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
||||
decode_host, base_decode_port = self.config.decode_url.split(':')
|
||||
prefill_host, base_prefill_port = split_host_port(
|
||||
self.config.prefill_url)
|
||||
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and \
|
||||
base_prefill_port == base_decode_port:
|
||||
base_decode_port = str(int(base_decode_port) + 100)
|
||||
base_decode_port = base_decode_port + 100
|
||||
|
||||
prefill_port = int(base_prefill_port) + self.local_rank
|
||||
decode_port = int(base_decode_port) + self.local_rank
|
||||
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
||||
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
||||
prefill_port = base_prefill_port + self.local_rank
|
||||
decode_port = base_decode_port + self.local_rank
|
||||
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
||||
self.decode_url = join_host_port(decode_host, decode_port)
|
||||
|
||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server, self.config.protocol,
|
||||
@ -110,22 +112,30 @@ class MooncakeTransferEngine:
|
||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||
decode_host, base_decode_port)
|
||||
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
||||
d_host: str, d_port: str) -> None:
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
|
||||
d_host: str, d_port: int) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||
p_rank_offset = p_port + 8 + self.local_rank * 2
|
||||
d_rank_offset = d_port + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
self.sender_socket.bind(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||
self.sender_ack.connect(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.receiver_ack.bind(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
else:
|
||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||
self.sender_socket.bind(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||
self.receiver_ack.bind(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.sender_ack.connect(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
|
||||
def initialize(self, local_hostname: str, metadata_server: str,
|
||||
protocol: str, device_name: str,
|
||||
|
@ -46,7 +46,7 @@ from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
from types import MappingProxyType
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||
Optional, TypeVar, Union, cast, overload)
|
||||
Optional, Tuple, TypeVar, Union, cast, overload)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def split_host_port(host_port: str) -> Tuple[str, int]:
|
||||
# ipv6
|
||||
if host_port.startswith('['):
|
||||
host, port = host_port.rsplit(']', 1)
|
||||
host = host[1:]
|
||||
port = port.split(':')[1]
|
||||
return host, int(port)
|
||||
else:
|
||||
host, port = host_port.split(':')
|
||||
return host, int(port)
|
||||
|
||||
|
||||
def join_host_port(host: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"[{host}]:{port}"
|
||||
else:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return get_tcp_uri(ip, port)
|
||||
|
||||
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
# Brackets are not permitted in ipv4 addresses,
|
||||
# see https://github.com/python/cpython/issues/103848
|
||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||
if is_valid_ipv6_address(ip):
|
||||
return f"tcp://[{ip}]:{port}"
|
||||
else:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_zmq_ipc_path() -> str:
|
||||
|
Reference in New Issue
Block a user