Compare commits

..

3 Commits

Author SHA1 Message Date
7d092fc32c revert skip-merge-desc
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-03 20:30:45 +00:00
1a6c27f271 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-03 20:29:33 +00:00
3c6fd286b4 updated
Signed-off-by: Robert Shaw <robshaw@redhat.com>
2025-07-03 18:29:58 +00:00
11 changed files with 577 additions and 148 deletions

View File

@ -0,0 +1,362 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import logging
import random
import time
from dataclasses import dataclass
from typing import Optional
import aiohttp # Import aiohttp
import numpy as np
from tqdm import tqdm
from backend_request_func import RequestFuncInput, RequestFuncOutput
from benchmark_dataset import RandomDataset, SampleRequest
try:
from vllm.transformers_utils.tokenizer import get_tokenizer
except ImportError:
from backend_request_func import get_tokenizer
logger = logging.getLogger(__name__)
@dataclass
class BenchmarkMetrics:
completed: int
total_input: int
total_output: int
mean_ttft_ms: float
median_ttft_ms: float
std_ttft_ms: float
percentiles_ttft_ms: list[tuple[float, float]]
mean_itl_ms: float
median_itl_ms: float
std_itl_ms: float
percentiles_itl_ms: list[tuple[float, float]]
mean_e2el_ms: float
median_e2el_ms: float
std_e2el_ms: float
percentiles_e2el_ms: list[tuple[float, float]]
async def reset_cache(reset_url: str):
"""Sends a POST request to reset the prefix cache."""
logger.debug("Resetting prefix cache at %s", reset_url)
try:
async with (
aiohttp.ClientSession() as session,
session.post(reset_url) as response,
):
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
logger.debug("Prefix cache reset successful: %s", response.status)
except aiohttp.ClientConnectorError as e:
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
except aiohttp.ClientResponseError as e:
logger.error(
"Cache reset request failed with status %s: %s", e.status, e.message
)
except Exception as e:
logger.error("An unexpected error occurred during cache reset: %s", e)
async def sequential_benchmark(
backend: str,
api_url: str,
model_id: str,
tokenizer,
input_requests: list[SampleRequest],
request_func,
selected_percentiles: list[float],
cache_reset_url: Optional[str] = None,
):
"""
Benchmark that processes requests sequentially, waiting for each to complete
before starting the next one. Resets prefix cache between requests.
"""
outputs = []
pbar = tqdm(total=len(input_requests))
benchmark_start_time = time.perf_counter()
# Process requests sequentially
for request in input_requests:
prompt, prompt_len, output_len = (
request.prompt,
request.prompt_len,
request.expected_output_len,
)
logger.info("Sending request with len %s", request.prompt_len)
logger.debug('Request str: "%s"', request.prompt[:50])
request_start_time = time.perf_counter()
request_func_input = RequestFuncInput(
model=model_id,
prompt=prompt,
api_url=api_url,
prompt_len=prompt_len,
output_len=output_len,
)
output = await request_func(request_func_input=request_func_input)
request_end_time = time.perf_counter()
# Add timing information
if output.success and not hasattr(output, "latency"):
output.latency = request_end_time - request_start_time
logger.info("Finished request with latency %.4f s", output.latency)
outputs.append(output)
pbar.update(1)
pbar.close()
benchmark_duration = time.perf_counter() - benchmark_start_time
# Calculate metrics
metrics = calculate_metrics(
input_requests=input_requests,
outputs=outputs,
dur_s=benchmark_duration,
tokenizer=tokenizer,
selected_percentiles=selected_percentiles,
)
print_results(metrics, benchmark_duration)
result = {
"duration": benchmark_duration,
"completed": metrics.completed,
"total_input_tokens": metrics.total_input,
"total_output_tokens": metrics.total_output,
"input_lens": [request.prompt_len for request in input_requests],
"output_lens": [
output.output_tokens if output.success else 0 for output in outputs
],
"ttfts": [output.ttft for output in outputs if output.success],
"itls": [output.itl for output in outputs if output.success],
"generated_texts": [
output.generated_text for output in outputs if output.success
],
"errors": [output.error for output in outputs if not output.success],
}
# Add summary statistics
for stat_name in ["ttft", "itl", "e2el"]:
for metric_name in ["mean", "median", "std"]:
result[f"{metric_name}_{stat_name}_ms"] = getattr(
metrics, f"{metric_name}_{stat_name}_ms"
)
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
result[f"p{p_word}_{stat_name}_ms"] = value
return result
def calculate_metrics(
input_requests: list[SampleRequest],
outputs: list[RequestFuncOutput],
dur_s: float,
tokenizer,
selected_percentiles: list[float],
) -> BenchmarkMetrics:
"""Calculate benchmark metrics from results."""
total_input = 0
completed = 0
total_output = 0
ttfts = []
itls = []
e2els = []
for i, output in enumerate(outputs):
if output.success:
output_len = output.output_tokens
if not output_len:
# Use tokenizer to count output tokens if not provided
output_len = len(
tokenizer(output.generated_text, add_special_tokens=False).input_ids
)
total_output += output_len
total_input += input_requests[i].prompt_len
if hasattr(output, "ttft") and output.ttft is not None:
ttfts.append(output.ttft)
if hasattr(output, "itl") and output.itl:
# Ensure itl is a list of floats
if isinstance(output.itl, list):
itls.extend(output.itl)
else:
logger.warning(
"Expected list for ITL but got %s. Appending as is.",
type(output.itl),
)
itls.append(output.itl)
if hasattr(output, "latency") and output.latency is not None:
e2els.append(output.latency)
completed += 1
return BenchmarkMetrics(
completed=completed,
total_input=total_input,
total_output=total_output,
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
median_ttft_ms=np.median(ttfts or [0]) * 1000,
std_ttft_ms=np.std(ttfts or [0]) * 1000,
percentiles_ttft_ms=[
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
],
mean_itl_ms=np.mean(itls or [0]) * 1000,
median_itl_ms=np.median(itls or [0]) * 1000,
std_itl_ms=np.std(itls or [0]) * 1000,
percentiles_itl_ms=[
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
],
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
median_e2el_ms=np.median(e2els or [0]) * 1000,
std_e2el_ms=np.std(e2els or [0]) * 1000,
percentiles_e2el_ms=[
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
],
)
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
"""Print benchmark results in a formatted way."""
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
def print_metric_stats(metric_name, header):
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
print(
"{:<40} {:<10.2f}".format(
f"Mean {metric_name} (ms):",
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
)
)
print(
"{:<40} {:<10.2f}".format(
f"Median {metric_name} (ms):",
getattr(metrics, f"median_{metric_name.lower()}_ms"),
)
)
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
p_word = str(int(p)) if int(p) == p else str(p)
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
print_metric_stats("TTFT", "Time to First Token")
print_metric_stats("ITL", "Inter-token Latency")
print_metric_stats("E2EL", "End-to-end Latency")
print("=" * 60)
async def main_async(args):
# Import needed functions based on your setup
from backend_request_func import ASYNC_REQUEST_FUNCS
backend = args.backend
model_id = args.model
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
# Set up API URL
if args.base_url is not None:
api_url = f"{args.base_url}{args.endpoint}"
else:
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
# Set up Cache Reset URL
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
# Get tokenizer
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
# Get request function
if backend in ASYNC_REQUEST_FUNCS:
request_func = ASYNC_REQUEST_FUNCS[backend]
else:
raise ValueError(f"Unknown backend: {backend}")
input_requests = RandomDataset().sample(
tokenizer=tokenizer,
num_requests=args.num_requests,
prefix_len=0,
input_len=args.input_len,
output_len=args.output_len,
range_ratio=0.0,
)
# Run benchmark
result = await sequential_benchmark(
backend=backend,
api_url=api_url,
model_id=model_id,
tokenizer=tokenizer,
input_requests=input_requests,
request_func=request_func,
selected_percentiles=[50, 90, 95, 99],
cache_reset_url=cache_reset_url,
)
return result
def main(args):
print(args)
random.seed(args.seed)
np.random.seed(args.seed)
asyncio.run(main_async(args))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
parser.add_argument(
"--backend", type=str, default="vllm", help="Backend to use for requests"
)
parser.add_argument(
"--base-url",
type=str,
default=None,
help="Server base URL (overrides --host and --port)",
)
parser.add_argument("--host", type=str, default="127.0.0.1")
parser.add_argument("--port", type=int, default=8000)
parser.add_argument(
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
)
parser.add_argument("--model", type=str, required=True, help="Name of the model")
parser.add_argument(
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
)
parser.add_argument(
"--num-requests", type=int, default=100, help="Number of requests to process"
)
parser.add_argument(
"--input-len", type=int, default=128, help="Input len for generated prompts"
)
parser.add_argument(
"--output-len", type=int, default=None, help="Override output len for requests"
)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--trust-remote-code",
action="store_true",
help="Trust remote code from HuggingFace",
)
args = parser.parse_args()
main(args)

View File

@ -20,11 +20,10 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache, from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
MemorySnapshot, PlaceholderModule, StoreBoolean, MemorySnapshot, PlaceholderModule, StoreBoolean,
bind_kv_cache, common_broadcastable_dtype, bind_kv_cache, common_broadcastable_dtype,
deprecate_kwargs, get_open_port, get_tcp_uri, deprecate_kwargs, get_open_port, is_lossless_cast,
is_lossless_cast, join_host_port, make_zmq_path, make_zmq_path, make_zmq_socket, memory_profiling,
make_zmq_socket, memory_profiling, merge_async_iterators, sha256, split_zmq_path,
merge_async_iterators, sha256, split_host_port, supports_kw, swap_dict_values)
split_zmq_path, supports_kw, swap_dict_values)
from .utils import create_new_process_for_each_test, error_on_warning from .utils import create_new_process_for_each_test, error_on_warning
@ -877,44 +876,3 @@ def test_make_zmq_socket_ipv6():
def test_make_zmq_path(): def test_make_zmq_path():
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555" assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555" assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
def test_get_tcp_uri():
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
def test_split_host_port():
# valid ipv4
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
# invalid ipv4
with pytest.raises(ValueError):
# multi colon
assert split_host_port("127.0.0.1::5555")
with pytest.raises(ValueError):
# tailing colon
assert split_host_port("127.0.0.1:5555:")
with pytest.raises(ValueError):
# no colon
assert split_host_port("127.0.0.15555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("127.0.0.1:5555a")
# valid ipv6
assert split_host_port("[::1]:5555") == ("::1", 5555)
# invalid ipv6
with pytest.raises(ValueError):
# multi colon
assert split_host_port("[::1]::5555")
with pytest.raises(IndexError):
# no colon
assert split_host_port("[::1]5555")
with pytest.raises(ValueError):
# none int port
assert split_host_port("[::1]:5555a")
def test_join_host_port():
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
assert join_host_port("::1", 5555) == "[::1]:5555"

88
tools/pd_disagg/Justfile Normal file
View File

@ -0,0 +1,88 @@
# Needed for the proxy server
vllm-directory := "/home/rshaw/vllm/"
PREFILL_GPU := "0,1,2,3"
DECODE_GPU := "4,5,6,7"
PREFILL_TP := env("PREFILL_TP", "1")
DECODE_TP := env("DECODE_TP", "1")
BLOCK_SIZE := env("BLOCK_SIZE", "128")
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
PROXY_PORT := "8192"
PREFILL_PORT := "8100"
DECODE_PORT := "8200"
PREFILL_NIXL_SIDE_CHANNEL_PORT := "5557"
DECODE_NIXL_SIDE_CHANNEL_PORT := "5568"
prefill:
VLLM_NIXL_SIDE_CHANNEL_PORT={{PREFILL_NIXL_SIDE_CHANNEL_PORT}} \
CUDA_VISIBLE_DEVICES={{PREFILL_GPU}} \
vllm serve {{MODEL}} \
--port {{PREFILL_PORT}} \
--tensor-parallel-size {{PREFILL_TP}} \
--enforce-eager \
--disable-log-requests \
--block-size {{BLOCK_SIZE}} \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
decode:
VLLM_NIXL_SIDE_CHANNEL_PORT={{DECODE_NIXL_SIDE_CHANNEL_PORT}} \
CUDA_VISIBLE_DEVICES={{DECODE_GPU}} \
vllm serve {{MODEL}} \
--port {{DECODE_PORT}} \
--tensor-parallel-size {{DECODE_TP}} \
--enforce-eager \
--disable-log-requests \
--block-size {{BLOCK_SIZE}} \
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
proxy:
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
--port {{PROXY_PORT}} \
--prefiller-port {{PREFILL_PORT}} \
--decoder-port {{DECODE_PORT}}
send_request:
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
-H "Content-Type: application/json" \
-d '{ \
"model": "{{MODEL}}", \
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
"max_tokens": 150, \
"temperature": 0.7 \
}'
benchmark NUM_PROMPTS:
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--dataset-name random \
--random-input-len 30000 \
--random-output-len 10 \
--num-prompts {{NUM_PROMPTS}} \
--seed $(date +%s) \
benchmark_one INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent.py \
--port {{PROXY_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
benchmark_one_no_pd INPUT_LEN:
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
--port {{DECODE_PORT}} \
--model {{MODEL}} \
--input-len {{INPUT_LEN}} \
--output-len 1 \
--num-requests 10 \
--seed $(date +%s)
eval:
lm_eval --model local-completions --tasks gsm8k \
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
--limit 1000

View File

@ -16,7 +16,6 @@ from safetensors.torch import save as safetensors_save
from vllm.config import KVTransferConfig from vllm.config import KVTransferConfig
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.utils import join_host_port, make_zmq_path, split_host_port
logger = init_logger(__name__) logger = init_logger(__name__)
NONE_INT = -150886311 NONE_INT = -150886311
@ -80,19 +79,18 @@ class MooncakeTransferEngine:
logger.error( logger.error(
"An error occurred while loading the configuration: %s", exc) "An error occurred while loading the configuration: %s", exc)
raise raise
prefill_host, base_prefill_port = split_host_port( prefill_host, base_prefill_port = self.config.prefill_url.split(':')
self.config.prefill_url) decode_host, base_decode_port = self.config.decode_url.split(':')
decode_host, base_decode_port = split_host_port(self.config.decode_url)
# Avoid ports conflict when running prefill and decode on the same node # Avoid ports conflict when running prefill and decode on the same node
if prefill_host == decode_host and \ if prefill_host == decode_host and \
base_prefill_port == base_decode_port: base_prefill_port == base_decode_port:
base_decode_port = base_decode_port + 100 base_decode_port = str(int(base_decode_port) + 100)
prefill_port = base_prefill_port + self.local_rank prefill_port = int(base_prefill_port) + self.local_rank
decode_port = base_decode_port + self.local_rank decode_port = int(base_decode_port) + self.local_rank
self.prefill_url = join_host_port(prefill_host, prefill_port) self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
self.decode_url = join_host_port(decode_host, decode_port) self.decode_url = ':'.join([decode_host, str(decode_port)])
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url, self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
self.config.metadata_server, self.config.protocol, self.config.metadata_server, self.config.protocol,
@ -112,30 +110,22 @@ class MooncakeTransferEngine:
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port, self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
decode_host, base_decode_port) decode_host, base_decode_port)
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int, def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
d_host: str, d_port: int) -> None: d_host: str, d_port: str) -> None:
"""Set up ZeroMQ sockets for sending and receiving data.""" """Set up ZeroMQ sockets for sending and receiving data."""
# Offsets < 8 are left for initialization in case tp and pp are enabled # Offsets < 8 are left for initialization in case tp and pp are enabled
p_rank_offset = p_port + 8 + self.local_rank * 2 p_rank_offset = int(p_port) + 8 + self.local_rank * 2
d_rank_offset = d_port + 8 + self.local_rank * 2 d_rank_offset = int(d_port) + 8 + self.local_rank * 2
if kv_rank == 0: if kv_rank == 0:
self.sender_socket.bind( self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
make_zmq_path("tcp", p_host, p_rank_offset + 1)) self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
self.receiver_socket.connect( self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
make_zmq_path("tcp", d_host, d_rank_offset + 1)) self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
self.sender_ack.connect(
make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.receiver_ack.bind(
make_zmq_path("tcp", p_host, p_rank_offset + 2))
else: else:
self.receiver_socket.connect( self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
make_zmq_path("tcp", p_host, p_rank_offset + 1)) self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
self.sender_socket.bind( self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
make_zmq_path("tcp", d_host, d_rank_offset + 1)) self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
self.receiver_ack.bind(
make_zmq_path("tcp", d_host, d_rank_offset + 2))
self.sender_ack.connect(
make_zmq_path("tcp", p_host, p_rank_offset + 2))
def initialize(self, local_hostname: str, metadata_server: str, def initialize(self, local_hostname: str, metadata_server: str,
protocol: str, device_name: str, protocol: str, device_name: str,

View File

@ -55,6 +55,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
RowParallelLinear) RowParallelLinear)
from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.gptq_marlin import (
GPTQMarlinConfig)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal import MULTIMODAL_REGISTRY
@ -176,7 +179,6 @@ class Glm4vVisionMLP(nn.Module):
hidden_features: int, hidden_features: int,
bias: bool = False, bias: bool = False,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
): ):
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
@ -184,12 +186,13 @@ class Glm4vVisionMLP(nn.Module):
output_sizes=[hidden_features] * 2, output_sizes=[hidden_features] * 2,
bias=bias, bias=bias,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj") )
self.down_proj = RowParallelLinear(hidden_features, self.down_proj = RowParallelLinear(
in_features, hidden_features,
bias=bias, in_features,
quant_config=quant_config, bias=bias,
prefix=f"{prefix}.down_proj") quant_config=quant_config,
)
self.act_fn = SiluAndMul() self.act_fn = SiluAndMul()
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
@ -404,7 +407,6 @@ class Glm4vVisionBlock(nn.Module):
mlp_hidden_dim, mlp_hidden_dim,
bias=False, bias=False,
quant_config=quant_config, quant_config=quant_config,
prefix=f"{prefix}.mlp",
) )
def forward( def forward(
@ -1276,7 +1278,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self.visual = Glm4vVisionTransformer( self.visual = Glm4vVisionTransformer(
config.vision_config, config.vision_config,
norm_eps=getattr(config, "rms_norm_eps", 1e-5), norm_eps=getattr(config, "rms_norm_eps", 1e-5),
quant_config=quant_config, quant_config=self._maybe_ignore_quant_config(quant_config),
prefix=maybe_prefix(prefix, "visual"), prefix=maybe_prefix(prefix, "visual"),
) )
@ -1289,6 +1291,13 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.language_model.make_empty_intermediate_tensors) self.language_model.make_empty_intermediate_tensors)
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
# seems to avoid vision encoder sections for some models.
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
return None
return quant_config
def _validate_and_reshape_mm_tensor(self, mm_input: object, def _validate_and_reshape_mm_tensor(self, mm_input: object,
name: str) -> torch.Tensor: name: str) -> torch.Tensor:
if not isinstance(mm_input, (torch.Tensor, list)): if not isinstance(mm_input, (torch.Tensor, list)):

View File

@ -33,8 +33,10 @@ from vllm.logger import init_logger
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config, from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
DbrxConfig, DeepseekVLV2Config, DbrxConfig, DeepseekVLV2Config,
EAGLEConfig, ExaoneConfig, EAGLEConfig, ExaoneConfig,
JAISConfig, KimiVLConfig, H2OVLChatConfig,
MedusaConfig, MiniMaxText01Config, InternVLChatConfig, JAISConfig,
KimiVLConfig, MedusaConfig,
MiniMaxText01Config,
MiniMaxVL01Config, MllamaConfig, MiniMaxVL01Config, MllamaConfig,
MLPSpeculatorConfig, MPTConfig, MLPSpeculatorConfig, MPTConfig,
NemotronConfig, NVLM_D_Config, NemotronConfig, NVLM_D_Config,
@ -88,6 +90,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
"medusa": MedusaConfig, "medusa": MedusaConfig,
"eagle": EAGLEConfig, "eagle": EAGLEConfig,
"exaone": ExaoneConfig, "exaone": ExaoneConfig,
"h2ovl_chat": H2OVLChatConfig,
"internvl_chat": InternVLChatConfig,
"minimax_text_01": MiniMaxText01Config, "minimax_text_01": MiniMaxText01Config,
"minimax_vl_01": MiniMaxVL01Config, "minimax_vl_01": MiniMaxVL01Config,
"nemotron": NemotronConfig, "nemotron": NemotronConfig,
@ -100,10 +104,6 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
**_CONFIG_REGISTRY_OVERRIDE_HF **_CONFIG_REGISTRY_OVERRIDE_HF
} }
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
"llm_config": "text_config",
}
class ConfigFormat(str, enum.Enum): class ConfigFormat(str, enum.Enum):
AUTO = "auto" AUTO = "auto"
@ -286,18 +286,6 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
return getattr(config, "is_encoder_decoder", False) return getattr(config, "is_encoder_decoder", False)
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
"""Remap config attributes to match the expected names."""
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
if hasattr(config, old_attr):
if not hasattr(config, new_attr):
config.update({new_attr: getattr(config, old_attr)})
delattr(config, old_attr)
logger.debug("Remapped config attribute '%s' to '%s'", old_attr,
new_attr)
return config
def get_config( def get_config(
model: Union[str, Path], model: Union[str, Path],
trust_remote_code: bool, trust_remote_code: bool,
@ -373,9 +361,6 @@ def get_config(
revision=revision, revision=revision,
code_revision=code_revision, code_revision=code_revision,
token=_get_hf_token(), token=_get_hf_token(),
# some old custom model's config needs
# `has_no_defaults_at_init=True` to work.
has_no_defaults_at_init=trust_remote_code,
**kwargs, **kwargs,
) )
except ValueError as e: except ValueError as e:
@ -391,7 +376,6 @@ def get_config(
raise RuntimeError(err_msg) from e raise RuntimeError(err_msg) from e
else: else:
raise e raise e
config = _maybe_remap_hf_config_attrs(config)
elif config_format == ConfigFormat.MISTRAL: elif config_format == ConfigFormat.MISTRAL:
config = load_params_config(model, revision, **kwargs) config = load_params_config(model, revision, **kwargs)

View File

@ -11,6 +11,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the # tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library. # `FalconConfig` class from the official HuggingFace transformers library.
from vllm.transformers_utils.configs.falcon import RWConfig from vllm.transformers_utils.configs.falcon import RWConfig
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
from vllm.transformers_utils.configs.jais import JAISConfig from vllm.transformers_utils.configs.jais import JAISConfig
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
from vllm.transformers_utils.configs.medusa import MedusaConfig from vllm.transformers_utils.configs.medusa import MedusaConfig
@ -36,6 +38,8 @@ __all__ = [
"DeepseekVLV2Config", "DeepseekVLV2Config",
"MPTConfig", "MPTConfig",
"RWConfig", "RWConfig",
"H2OVLChatConfig",
"InternVLChatConfig",
"JAISConfig", "JAISConfig",
"MedusaConfig", "MedusaConfig",
"EAGLEConfig", "EAGLEConfig",

View File

@ -0,0 +1,16 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py
# --------------------------------------------------------
# H2OVL-Mississippi
# Copyright (c) 2024 H2O.AI
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from .internvl import InternVLChatConfig
class H2OVLChatConfig(InternVLChatConfig):
model_type = "h2ovl_chat"

View File

@ -0,0 +1,54 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# Adapted from
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
# --------------------------------------------------------
# InternVL
# Copyright (c) 2024 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from transformers.configuration_utils import PretrainedConfig
class InternVLChatConfig(PretrainedConfig):
model_type = 'internvl_chat'
is_composition = True
def __init__(self,
vision_config=None,
llm_config=None,
use_backbone_lora=0,
use_llm_lora=0,
select_layer=-1,
force_image_size=None,
downsample_ratio=0.5,
template=None,
dynamic_image_size=False,
use_thumbnail=False,
ps_version='v1',
min_dynamic_patch=1,
max_dynamic_patch=6,
**kwargs):
super().__init__(**kwargs)
if vision_config is None:
vision_config = {}
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = PretrainedConfig(**llm_config)
self.use_backbone_lora = use_backbone_lora
self.use_llm_lora = use_llm_lora
self.select_layer = select_layer
self.force_image_size = force_image_size
self.downsample_ratio = downsample_ratio
self.template = template
self.dynamic_image_size = dynamic_image_size
self.use_thumbnail = use_thumbnail
self.ps_version = ps_version # pixel shuffle version
self.min_dynamic_patch = min_dynamic_patch
self.max_dynamic_patch = max_dynamic_patch

View File

@ -8,24 +8,8 @@
# Copyright (c) 2024 NVIDIA # Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details] # Licensed under Apache 2.0 License [see LICENSE for details]
# -------------------------------------------------------- # --------------------------------------------------------
from transformers import Qwen2Config from .internvl import InternVLChatConfig
from transformers.configuration_utils import PretrainedConfig
class NVLM_D_Config(PretrainedConfig): class NVLM_D_Config(InternVLChatConfig):
model_type = 'NVLM_D' model_type = 'NVLM_D'
is_composition = True
def __init__(self, vision_config=None, llm_config=None, **kwargs):
super().__init__(**kwargs)
# Handle vision_config initialization
if vision_config is None:
vision_config = {}
# Handle llm_config initialization
if llm_config is None:
llm_config = {}
self.vision_config = PretrainedConfig(**vision_config)
self.text_config = Qwen2Config(**llm_config)

View File

@ -46,7 +46,7 @@ from dataclasses import dataclass, field
from functools import cache, lru_cache, partial, wraps from functools import cache, lru_cache, partial, wraps
from types import MappingProxyType from types import MappingProxyType
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple, from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
Optional, Tuple, TypeVar, Union, cast, overload) Optional, TypeVar, Union, cast, overload)
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@ -628,34 +628,14 @@ def is_valid_ipv6_address(address: str) -> bool:
return False return False
def split_host_port(host_port: str) -> Tuple[str, int]:
# ipv6
if host_port.startswith('['):
host, port = host_port.rsplit(']', 1)
host = host[1:]
port = port.split(':')[1]
return host, int(port)
else:
host, port = host_port.split(':')
return host, int(port)
def join_host_port(host: str, port: int) -> str:
if is_valid_ipv6_address(host):
return f"[{host}]:{port}"
else:
return f"{host}:{port}"
def get_distributed_init_method(ip: str, port: int) -> str: def get_distributed_init_method(ip: str, port: int) -> str:
return get_tcp_uri(ip, port) return get_tcp_uri(ip, port)
def get_tcp_uri(ip: str, port: int) -> str: def get_tcp_uri(ip: str, port: int) -> str:
if is_valid_ipv6_address(ip): # Brackets are not permitted in ipv4 addresses,
return f"tcp://[{ip}]:{port}" # see https://github.com/python/cpython/issues/103848
else: return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
return f"tcp://{ip}:{port}"
def get_open_zmq_ipc_path() -> str: def get_open_zmq_ipc_path() -> str: