Compare commits

..

3 Commits

Author SHA1 Message Date
7d092fc32c revert skip-merge-desc
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-03 20:30:45 +00:00
1a6c27f271 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-03 20:29:33 +00:00
3c6fd286b4 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-03 18:29:58 +00:00
11 changed files with 577 additions and 148 deletions

View File

@ -0,0 +1,362 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import random
import time
from dataclasses import dataclass
from typing import Optional
import aiohttp # Import aiohttp
import numpy as np
from tqdm import tqdm
from backend_request_func import RequestFuncInput, RequestFuncOutput
from benchmark_dataset import RandomDataset, SampleRequest
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
logger = logging.getLogger(__name__)
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
percentiles_ttft_ms: list[tuple[float, float]]
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
percentiles_itl_ms: list[tuple[float, float]]
mean_e2el_ms: float
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
async def reset_cache(reset_url: str):
"""Sends a POST request to reset the prefix cache."""
logger.debug("Resetting prefix cache at %s", reset_url)
try:
async with (
aiohttp.ClientSession() as session,
session.post(reset_url) as response,
):
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
logger.debug("Prefix cache reset successful: %s", response.status)
except aiohttp.ClientConnectorError as e:
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
except aiohttp.ClientResponseError as e:
logger.error(
"Cache reset request failed with status %s: %s", e.status, e.message
)
except Exception as e:
logger.error("An unexpected error occurred during cache reset: %s", e)
async def sequential_benchmark(
backend: str,
api_url: str,
model_id: str,
tokenizer,
input_requests: list[SampleRequest],
request_func,
selected_percentiles: list[float],
cache_reset_url: Optional[str] = None,
):
"""
Benchmark that processes requests sequentially, waiting for each to complete
before starting the next one. Resets prefix cache between requests.
"""
outputs = []
pbar = tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter()
# Process requests sequentially
for request in input_requests:
prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.expected_output_len,
)
logger.info("Sending request with len %s", request.prompt_len)
logger.debug('Request str: "%s"', request.prompt[:50])
request_start_time = time.perf_counter()
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
)
output = await request_func(request_func_input=request_func_input)
request_end_time = time.perf_counter()
# Add timing information
if output.success and not hasattr(output, "latency"):
output.latency = request_end_time - request_start_time
logger.info("Finished request with latency %.4f s", output.latency)
outputs.append(output)
pbar.update(1)
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
# Calculate metrics
metrics = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
)
print_results(metrics, benchmark_duration)
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"input_lens": [request.prompt_len for request in input_requests],
"output_lens": [
output.output_tokens if output.success else 0 for output in outputs
],
"ttfts": [output.ttft for output in outputs if output.success],
"itls": [output.itl for output in outputs if output.success],
"generated_texts": [
output.generated_text for output in outputs if output.success
],
"errors": [output.error for output in outputs if not output.success],
}
# Add summary statistics
for stat_name in ["ttft", "itl", "e2el"]:
for metric_name in ["mean", "median", "std"]:
result[f"{metric_name}_{stat_name}_ms"] = getattr(
metrics, f"{metric_name}_{stat_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
result[f"p{p_word}_{stat_name}_ms"] = value
return result
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer,
selected_percentiles: list[float],
) -> BenchmarkMetrics:
"""Calculate benchmark metrics from results."""
total_input = 0
completed = 0
total_output = 0
ttfts = []
itls = []
e2els = []
for i, output in enumerate(outputs):
if output.success:
output_len = output.output_tokens
if not output_len:
# Use tokenizer to count output tokens if not provided
output_len = len(
tokenizer(output.generated_text, add_special_tokens=False).input_ids
)
total_output += output_len
total_input += input_requests[i].prompt_len
if hasattr(output, "ttft") and output.ttft is not None:
ttfts.append(output.ttft)
if hasattr(output, "itl") and output.itl:
# Ensure itl is a list of floats
if isinstance(output.itl, list):
itls.extend(output.itl)
else:
logger.warning(
"Expected list for ITL but got %s. Appending as is.",
type(output.itl),
)
itls.append(output.itl)
if hasattr(output, "latency") and output.latency is not None:
e2els.append(output.latency)
completed += 1
return BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=total_output,
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
median_ttft_ms=np.median(ttfts or [0]) * 1000,
std_ttft_ms=np.std(ttfts or [0]) * 1000,
percentiles_ttft_ms=[
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or [0]) * 1000,
median_itl_ms=np.median(itls or [0]) * 1000,
std_itl_ms=np.std(itls or [0]) * 1000,
percentiles_itl_ms=[
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
median_e2el_ms=np.median(e2els or [0]) * 1000,
std_e2el_ms=np.std(e2els or [0]) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
],
)
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
"""Print benchmark results in a formatted way."""
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
def print_metric_stats(metric_name, header):
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_name.lower()}_ms"),
)
)
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
print_metric_stats("TTFT", "Time to First Token")
print_metric_stats("ITL", "Inter-token Latency")
print_metric_stats("E2EL", "End-to-end Latency")
print("=" * 60)
async def main_async(args):
# Import needed functions based on your setup
from backend_request_func import ASYNC_REQUEST_FUNCS
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
# Set up API URL
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
# Set up Cache Reset URL
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
# Get tokenizer
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
# Get request function
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
input_requests = RandomDataset().sample(
tokenizer=tokenizer,
num_requests=args.num_requests,
prefix_len=0,
input_len=args.input_len,
output_len=args.output_len,
range_ratio=0.0,
)
# Run benchmark
result = await sequential_benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_func=request_func,
selected_percentiles=[50, 90, 95, 99],
cache_reset_url=cache_reset_url,
)
return result
def main(args):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
asyncio.run(main_async(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
parser.add_argument(
"--backend", type=str, default="vllm", help="Backend to use for requests"
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server base URL (overrides --host and --port)",
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
)
parser.add_argument("--model", type=str, required=True, help="Name of the model")
parser.add_argument(
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
)
parser.add_argument(
"--num-requests", type=int, default=100, help="Number of requests to process"
)
parser.add_argument(
"--input-len", type=int, default=128, help="Input len for generated prompts"
)
parser.add_argument(
"--output-len", type=int, default=None, help="Override output len for requests"
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from HuggingFace",
)
args = parser.parse_args()
main(args)

View File

@ -20,11 +20,10 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, get_tcp_uri,
is_lossless_cast, join_host_port, make_zmq_path,
make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_host_port,
split_zmq_path, supports_kw, swap_dict_values)
deprecate_kwargs, get_open_port, is_lossless_cast,
make_zmq_path, make_zmq_socket, memory_profiling,
merge_async_iterators, sha256, split_zmq_path,
supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning
@ -877,44 +876,3 @@ def test_make_zmq_socket_ipv6():
def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
def test_get_tcp_uri():
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
def test_split_host_port():
# valid ipv4
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
# invalid ipv4
with pytest.raises(ValueError):
# multi colon
assert split_host_port("127.0.0.1::5555")
with pytest.raises(ValueError):
# tailing colon
assert split_host_port("127.0.0.1:5555:")
with pytest.raises(ValueError):
# no colon
assert split_host_port("127.0.0.15555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("127.0.0.1:5555a")
# valid ipv6
assert split_host_port("[::1]:5555") == ("::1", 5555)
# invalid ipv6
with pytest.raises(ValueError):
# multi colon
assert split_host_port("[::1]::5555")
with pytest.raises(IndexError):
# no colon
assert split_host_port("[::1]5555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("[::1]:5555a")
def test_join_host_port():
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
assert join_host_port("::1", 5555) == "[::1]:5555"

88
tools/pd_disagg/Justfile Normal file
View File

@ -0,0 +1,88 @@
# Needed for the proxy server
vllm-directory := "/home/rshaw/vllm/"
PREFILL_GPU := "0,1,2,3"
DECODE_GPU := "4,5,6,7"
PREFILL_TP := env("PREFILL_TP", "1")
DECODE_TP := env("DECODE_TP", "1")
BLOCK_SIZE := env("BLOCK_SIZE", "128")
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
PROXY_PORT := "8192"
PREFILL_PORT := "8100"
DECODE_PORT := "8200"
PREFILL_NIXL_SIDE_CHANNEL_PORT := "5557"
DECODE_NIXL_SIDE_CHANNEL_PORT := "5568"
prefill:
VLLM_NIXL_SIDE_CHANNEL_PORT={{PREFILL_NIXL_SIDE_CHANNEL_PORT}} \
CUDA_VISIBLE_DEVICES={{PREFILL_GPU}} \
vllm serve {{MODEL}} \
--port {{PREFILL_PORT}} \
--tensor-parallel-size {{PREFILL_TP}} \
--enforce-eager \
--disable-log-requests \
--block-size {{BLOCK_SIZE}} \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
decode:
VLLM_NIXL_SIDE_CHANNEL_PORT={{DECODE_NIXL_SIDE_CHANNEL_PORT}} \
CUDA_VISIBLE_DEVICES={{DECODE_GPU}} \
vllm serve {{MODEL}} \
--port {{DECODE_PORT}} \
--tensor-parallel-size {{DECODE_TP}} \
--enforce-eager \
--disable-log-requests \
--block-size {{BLOCK_SIZE}} \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
proxy:
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
--port {{PROXY_PORT}} \
--prefiller-port {{PREFILL_PORT}} \
--decoder-port {{DECODE_PORT}}
send_request:
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
-H "Content-Type: application/json" \
-d '{ \
"model": "{{MODEL}}", \
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
"max_tokens": 150, \
"temperature": 0.7 \
}'
benchmark NUM_PROMPTS:
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--dataset-name random \
--random-input-len 30000 \
--random-output-len 10 \
--num-prompts {{NUM_PROMPTS}} \
--seed $(date +%s) \
benchmark_one INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
benchmark_one_no_pd INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
--port {{DECODE_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
eval:
lm_eval --model local-completions --tasks gsm8k \
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
--limit 1000

View File

@ -16,7 +16,6 @@ from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port
logger = init_logger(__name__)
NONE_INT = -150886311
@ -80,19 +79,18 @@ class MooncakeTransferEngine:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
prefill_host, base_prefill_port = split_host_port(
self.config.prefill_url)
decode_host, base_decode_port = split_host_port(self.config.decode_url)
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
decode_host, base_decode_port = self.config.decode_url.split(':')
# Avoid ports conflict when running prefill and decode on the same node
if prefill_host == decode_host and \
base_prefill_port == base_decode_port:
base_decode_port = base_decode_port + 100
base_decode_port = str(int(base_decode_port) + 100)
prefill_port = base_prefill_port + self.local_rank
decode_port = base_decode_port + self.local_rank
self.prefill_url = join_host_port(prefill_host, prefill_port)
self.decode_url = join_host_port(decode_host, decode_port)
prefill_port = int(base_prefill_port) + self.local_rank
decode_port = int(base_decode_port) + self.local_rank
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
self.decode_url = ':'.join([decode_host, str(decode_port)])
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
self.config.metadata_server, self.config.protocol,
@ -112,30 +110,22 @@ class MooncakeTransferEngine:
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
decode_host, base_decode_port)
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
d_host: str, d_port: int) -> None:
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
d_host: str, d_port: str) -> None:
"""Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset = p_port + 8 + self.local_rank * 2
d_rank_offset = d_port + 8 + self.local_rank * 2
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0:
self.sender_socket.bind(
make_zmq_path("tcp", p_host, p_rank_offset + 1))
self.receiver_socket.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 1))
self.sender_ack.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.receiver_ack.bind(
make_zmq_path("tcp", p_host, p_rank_offset + 2))
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
else:
self.receiver_socket.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 1))
self.sender_socket.bind(
make_zmq_path("tcp", d_host, d_rank_offset + 1))
self.receiver_ack.bind(
make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.sender_ack.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 2))
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
def initialize(self, local_hostname: str, metadata_server: str,
protocol: str, device_name: str,

View File

@ -55,6 +55,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY
@ -176,7 +179,6 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int,
bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
@ -184,12 +186,13 @@ class Glm4vVisionMLP(nn.Module):
output_sizes=[hidden_features] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj")
self.down_proj = RowParallelLinear(hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.down_proj")
)
self.down_proj = RowParallelLinear(
hidden_features,
in_features,
bias=bias,
quant_config=quant_config,
)
self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor):
@ -404,7 +407,6 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.mlp",
)
def forward(
@ -1276,7 +1278,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Glm4vVisionTransformer(
config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config,
quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"),
)
@ -1289,6 +1291,13 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)):

View File

@ -33,8 +33,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig,
JAISConfig, KimiVLConfig,
MedusaConfig, MiniMaxText01Config,
H2OVLChatConfig,
InternVLChatConfig, JAISConfig,
KimiVLConfig, MedusaConfig,
MiniMaxText01Config,
MiniMaxVL01Config, MllamaConfig,
MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config,
@ -88,6 +90,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"medusa": MedusaConfig,
"eagle": EAGLEConfig,
"exaone": ExaoneConfig,
"h2ovl_chat": H2OVLChatConfig,
"internvl_chat": InternVLChatConfig,
"minimax_text_01": MiniMaxText01Config,
"minimax_vl_01": MiniMaxVL01Config,
"nemotron": NemotronConfig,
@ -100,10 +104,6 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
**_CONFIG_REGISTRY_OVERRIDE_HF
}
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config",
}
class ConfigFormat(str, enum.Enum):
AUTO = "auto"
@ -286,18 +286,6 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
return getattr(config, "is_encoder_decoder", False)
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
"""Remap config attributes to match the expected names."""
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
if hasattr(config, old_attr):
if not hasattr(config, new_attr):
config.update({new_attr: getattr(config, old_attr)})
delattr(config, old_attr)
logger.debug("Remapped config attribute '%s' to '%s'", old_attr,
new_attr)
return config
def get_config(
model: Union[str, Path],
trust_remote_code: bool,
@ -373,9 +361,6 @@ def get_config(
revision=revision,
code_revision=code_revision,
token=_get_hf_token(),
# some old custom model's config needs
# `has_no_defaults_at_init=True` to work.
has_no_defaults_at_init=trust_remote_code,
**kwargs,
)
except ValueError as e:
@ -391,7 +376,6 @@ def get_config(
raise RuntimeError(err_msg) from e
else:
raise e
config = _maybe_remap_hf_config_attrs(config)
elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, **kwargs)

View File

@ -11,6 +11,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig
@ -36,6 +38,8 @@ __all__ = [
"DeepseekVLV2Config",
"MPTConfig",
"RWConfig",
"H2OVLChatConfig",
"InternVLChatConfig",
"JAISConfig",
"MedusaConfig",
"EAGLEConfig",

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from .internvl import InternVLChatConfig
class H2OVLChatConfig(InternVLChatConfig):
model_type = "h2ovl_chat"

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers.configuration_utils import PretrainedConfig
class InternVLChatConfig(PretrainedConfig):
model_type = 'internvl_chat'
is_composition = True
def __init__(self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = PretrainedConfig(**llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch

View File

@ -8,24 +8,8 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from transformers import Qwen2Config
from transformers.configuration_utils import PretrainedConfig
from .internvl import InternVLChatConfig
class NVLM_D_Config(PretrainedConfig):
class NVLM_D_Config(InternVLChatConfig):
model_type = 'NVLM_D'
is_composition = True
def __init__(self, vision_config=None, llm_config=None, **kwargs):
super().__init__(**kwargs)
# Handle vision_config initialization
if vision_config is None:
vision_config = {}
# Handle llm_config initialization
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = Qwen2Config(**llm_config)

View File

@ -46,7 +46,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, TypeVar, Union, cast, overload)
Optional, TypeVar, Union, cast, overload)
from urllib.parse import urlparse
from uuid import uuid4
@ -628,34 +628,14 @@ def is_valid_ipv6_address(address: str) -> bool:
return False
def split_host_port(host_port: str) -> Tuple[str, int]:
# ipv6
if host_port.startswith('['):
host, port = host_port.rsplit(']', 1)
host = host[1:]
port = port.split(':')[1]
return host, int(port)
else:
host, port = host_port.split(':')
return host, int(port)
def join_host_port(host: str, port: int) -> str:
if is_valid_ipv6_address(host):
return f"[{host}]:{port}"
else:
return f"{host}:{port}"
def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port)
def get_tcp_uri(ip: str, port: int) -> str:
if is_valid_ipv6_address(ip):
return f"tcp://[{ip}]:{port}"
else:
return f"tcp://{ip}:{port}"
# Brackets are not permitted in ipv4 addresses,
# see https://github.com/python/cpython/issues/103848
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
def get_open_zmq_ipc_path() -> str: