mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
2f2fcb31b8 | |||
1dba2c4ebe | |||
71d6de3a26 |
@ -20,10 +20,11 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||
bind_kv_cache, common_broadcastable_dtype,
|
||||
deprecate_kwargs, get_open_port, is_lossless_cast,
|
||||
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_zmq_path,
|
||||
supports_kw, swap_dict_values)
|
||||
deprecate_kwargs, get_open_port, get_tcp_uri,
|
||||
is_lossless_cast, join_host_port, make_zmq_path,
|
||||
make_zmq_socket, memory_profiling,
|
||||
merge_async_iterators, sha256, split_host_port,
|
||||
split_zmq_path, supports_kw, swap_dict_values)
|
||||
|
||||
from .utils import create_new_process_for_each_test, error_on_warning
|
||||
|
||||
@ -876,3 +877,44 @@ def test_make_zmq_socket_ipv6():
|
||||
def test_make_zmq_path():
|
||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_get_tcp_uri():
|
||||
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
||||
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
||||
|
||||
|
||||
def test_split_host_port():
|
||||
# valid ipv4
|
||||
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
||||
# invalid ipv4
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("127.0.0.1::5555")
|
||||
with pytest.raises(ValueError):
|
||||
# tailing colon
|
||||
assert split_host_port("127.0.0.1:5555:")
|
||||
with pytest.raises(ValueError):
|
||||
# no colon
|
||||
assert split_host_port("127.0.0.15555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("127.0.0.1:5555a")
|
||||
|
||||
# valid ipv6
|
||||
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
||||
# invalid ipv6
|
||||
with pytest.raises(ValueError):
|
||||
# multi colon
|
||||
assert split_host_port("[::1]::5555")
|
||||
with pytest.raises(IndexError):
|
||||
# no colon
|
||||
assert split_host_port("[::1]5555")
|
||||
with pytest.raises(ValueError):
|
||||
# none int port
|
||||
assert split_host_port("[::1]:5555a")
|
||||
|
||||
|
||||
def test_join_host_port():
|
||||
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
||||
|
@ -16,6 +16,7 @@ from safetensors.torch import save as safetensors_save
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
||||
|
||||
logger = init_logger(__name__)
|
||||
NONE_INT = -150886311
|
||||
@ -79,18 +80,19 @@ class MooncakeTransferEngine:
|
||||
logger.error(
|
||||
"An error occurred while loading the configuration: %s", exc)
|
||||
raise
|
||||
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
||||
decode_host, base_decode_port = self.config.decode_url.split(':')
|
||||
prefill_host, base_prefill_port = split_host_port(
|
||||
self.config.prefill_url)
|
||||
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
||||
|
||||
# Avoid ports conflict when running prefill and decode on the same node
|
||||
if prefill_host == decode_host and \
|
||||
base_prefill_port == base_decode_port:
|
||||
base_decode_port = str(int(base_decode_port) + 100)
|
||||
base_decode_port = base_decode_port + 100
|
||||
|
||||
prefill_port = int(base_prefill_port) + self.local_rank
|
||||
decode_port = int(base_decode_port) + self.local_rank
|
||||
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
||||
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
||||
prefill_port = base_prefill_port + self.local_rank
|
||||
decode_port = base_decode_port + self.local_rank
|
||||
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
||||
self.decode_url = join_host_port(decode_host, decode_port)
|
||||
|
||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||
self.config.metadata_server, self.config.protocol,
|
||||
@ -110,22 +112,30 @@ class MooncakeTransferEngine:
|
||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||
decode_host, base_decode_port)
|
||||
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
||||
d_host: str, d_port: str) -> None:
|
||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
|
||||
d_host: str, d_port: int) -> None:
|
||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||
p_rank_offset = p_port + 8 + self.local_rank * 2
|
||||
d_rank_offset = d_port + 8 + self.local_rank * 2
|
||||
if kv_rank == 0:
|
||||
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
self.sender_socket.bind(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||
self.sender_ack.connect(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.receiver_ack.bind(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
else:
|
||||
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||
self.receiver_socket.connect(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
||||
self.sender_socket.bind(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
||||
self.receiver_ack.bind(
|
||||
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
||||
self.sender_ack.connect(
|
||||
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
||||
|
||||
def initialize(self, local_hostname: str, metadata_server: str,
|
||||
protocol: str, device_name: str,
|
||||
|
@ -55,9 +55,6 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||
GPTQMarlinConfig)
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
@ -179,6 +176,7 @@ class Glm4vVisionMLP(nn.Module):
|
||||
hidden_features: int,
|
||||
bias: bool = False,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
@ -186,13 +184,12 @@ class Glm4vVisionMLP(nn.Module):
|
||||
output_sizes=[hidden_features] * 2,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(hidden_features,
|
||||
in_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
@ -407,6 +404,7 @@ class Glm4vVisionBlock(nn.Module):
|
||||
mlp_hidden_dim,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.mlp",
|
||||
)
|
||||
|
||||
def forward(
|
||||
@ -1278,7 +1276,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.visual = Glm4vVisionTransformer(
|
||||
config.vision_config,
|
||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(prefix, "visual"),
|
||||
)
|
||||
|
||||
@ -1291,13 +1289,6 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.language_model.make_empty_intermediate_tensors)
|
||||
|
||||
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||
# seems to avoid vision encoder sections for some models.
|
||||
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||
return None
|
||||
return quant_config
|
||||
|
||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||
name: str) -> torch.Tensor:
|
||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||
|
@ -33,10 +33,8 @@ from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||
DbrxConfig, DeepseekVLV2Config,
|
||||
EAGLEConfig, ExaoneConfig,
|
||||
H2OVLChatConfig,
|
||||
InternVLChatConfig, JAISConfig,
|
||||
KimiVLConfig, MedusaConfig,
|
||||
MiniMaxText01Config,
|
||||
JAISConfig, KimiVLConfig,
|
||||
MedusaConfig, MiniMaxText01Config,
|
||||
MiniMaxVL01Config, MllamaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, NVLM_D_Config,
|
||||
@ -90,8 +88,6 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
||||
"medusa": MedusaConfig,
|
||||
"eagle": EAGLEConfig,
|
||||
"exaone": ExaoneConfig,
|
||||
"h2ovl_chat": H2OVLChatConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"minimax_text_01": MiniMaxText01Config,
|
||||
"minimax_vl_01": MiniMaxVL01Config,
|
||||
"nemotron": NemotronConfig,
|
||||
@ -104,6 +100,10 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||
}
|
||||
|
||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
||||
"llm_config": "text_config",
|
||||
}
|
||||
|
||||
|
||||
class ConfigFormat(str, enum.Enum):
|
||||
AUTO = "auto"
|
||||
@ -286,6 +286,18 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
||||
return getattr(config, "is_encoder_decoder", False)
|
||||
|
||||
|
||||
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
||||
"""Remap config attributes to match the expected names."""
|
||||
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
|
||||
if hasattr(config, old_attr):
|
||||
if not hasattr(config, new_attr):
|
||||
config.update({new_attr: getattr(config, old_attr)})
|
||||
delattr(config, old_attr)
|
||||
logger.debug("Remapped config attribute '%s' to '%s'", old_attr,
|
||||
new_attr)
|
||||
return config
|
||||
|
||||
|
||||
def get_config(
|
||||
model: Union[str, Path],
|
||||
trust_remote_code: bool,
|
||||
@ -361,6 +373,9 @@ def get_config(
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=_get_hf_token(),
|
||||
# some old custom model's config needs
|
||||
# `has_no_defaults_at_init=True` to work.
|
||||
has_no_defaults_at_init=trust_remote_code,
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -376,6 +391,7 @@ def get_config(
|
||||
raise RuntimeError(err_msg) from e
|
||||
else:
|
||||
raise e
|
||||
config = _maybe_remap_hf_config_attrs(config)
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
config = load_params_config(model, revision, **kwargs)
|
||||
|
@ -11,8 +11,6 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
@ -38,8 +36,6 @@ __all__ = [
|
||||
"DeepseekVLV2Config",
|
||||
"MPTConfig",
|
||||
"RWConfig",
|
||||
"H2OVLChatConfig",
|
||||
"InternVLChatConfig",
|
||||
"JAISConfig",
|
||||
"MedusaConfig",
|
||||
"EAGLEConfig",
|
||||
|
@ -1,16 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# H2OVL-Mississippi
|
||||
# Copyright (c) 2024 H2O.AI
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
|
||||
from .internvl import InternVLChatConfig
|
||||
|
||||
|
||||
class H2OVLChatConfig(InternVLChatConfig):
|
||||
model_type = "h2ovl_chat"
|
@ -1,54 +0,0 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
# Adapted from
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2024 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class InternVLChatConfig(PretrainedConfig):
|
||||
model_type = 'internvl_chat'
|
||||
is_composition = True
|
||||
|
||||
def __init__(self,
|
||||
vision_config=None,
|
||||
llm_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
select_layer=-1,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
ps_version='v1',
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=6,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = {}
|
||||
|
||||
self.vision_config = PretrainedConfig(**vision_config)
|
||||
self.text_config = PretrainedConfig(**llm_config)
|
||||
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.ps_version = ps_version # pixel shuffle version
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
@ -8,8 +8,24 @@
|
||||
# Copyright (c) 2024 NVIDIA
|
||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from .internvl import InternVLChatConfig
|
||||
from transformers import Qwen2Config
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class NVLM_D_Config(InternVLChatConfig):
|
||||
class NVLM_D_Config(PretrainedConfig):
|
||||
model_type = 'NVLM_D'
|
||||
is_composition = True
|
||||
|
||||
def __init__(self, vision_config=None, llm_config=None, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Handle vision_config initialization
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
|
||||
# Handle llm_config initialization
|
||||
if llm_config is None:
|
||||
llm_config = {}
|
||||
|
||||
self.vision_config = PretrainedConfig(**vision_config)
|
||||
self.text_config = Qwen2Config(**llm_config)
|
||||
|
@ -46,7 +46,7 @@ from dataclasses import dataclass, field
|
||||
from functools import cache, lru_cache, partial, wraps
|
||||
from types import MappingProxyType
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||
Optional, TypeVar, Union, cast, overload)
|
||||
Optional, Tuple, TypeVar, Union, cast, overload)
|
||||
from urllib.parse import urlparse
|
||||
from uuid import uuid4
|
||||
|
||||
@ -628,14 +628,34 @@ def is_valid_ipv6_address(address: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def split_host_port(host_port: str) -> Tuple[str, int]:
|
||||
# ipv6
|
||||
if host_port.startswith('['):
|
||||
host, port = host_port.rsplit(']', 1)
|
||||
host = host[1:]
|
||||
port = port.split(':')[1]
|
||||
return host, int(port)
|
||||
else:
|
||||
host, port = host_port.split(':')
|
||||
return host, int(port)
|
||||
|
||||
|
||||
def join_host_port(host: str, port: int) -> str:
|
||||
if is_valid_ipv6_address(host):
|
||||
return f"[{host}]:{port}"
|
||||
else:
|
||||
return f"{host}:{port}"
|
||||
|
||||
|
||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||
return get_tcp_uri(ip, port)
|
||||
|
||||
|
||||
def get_tcp_uri(ip: str, port: int) -> str:
|
||||
# Brackets are not permitted in ipv4 addresses,
|
||||
# see https://github.com/python/cpython/issues/103848
|
||||
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||
if is_valid_ipv6_address(ip):
|
||||
return f"tcp://[{ip}]:{port}"
|
||||
else:
|
||||
return f"tcp://{ip}:{port}"
|
||||
|
||||
|
||||
def get_open_zmq_ipc_path() -> str:
|
||||
|
Reference in New Issue
Block a user