mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 15:43:52 +08:00
Compare commits
3 Commits
Author | SHA1 | Date | |
---|---|---|---|
7d092fc32c | |||
1a6c27f271 | |||
3c6fd286b4 |
362
benchmarks/benchmark_one_concurrent.py
Normal file
362
benchmarks/benchmark_one_concurrent.py
Normal file
@ -0,0 +1,362 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import aiohttp # Import aiohttp
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
from backend_request_func import RequestFuncInput, RequestFuncOutput
|
||||||
|
from benchmark_dataset import RandomDataset, SampleRequest
|
||||||
|
|
||||||
|
try:
|
||||||
|
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||||
|
except ImportError:
|
||||||
|
from backend_request_func import get_tokenizer
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BenchmarkMetrics:
|
||||||
|
completed: int
|
||||||
|
total_input: int
|
||||||
|
total_output: int
|
||||||
|
mean_ttft_ms: float
|
||||||
|
median_ttft_ms: float
|
||||||
|
std_ttft_ms: float
|
||||||
|
percentiles_ttft_ms: list[tuple[float, float]]
|
||||||
|
mean_itl_ms: float
|
||||||
|
median_itl_ms: float
|
||||||
|
std_itl_ms: float
|
||||||
|
percentiles_itl_ms: list[tuple[float, float]]
|
||||||
|
mean_e2el_ms: float
|
||||||
|
median_e2el_ms: float
|
||||||
|
std_e2el_ms: float
|
||||||
|
percentiles_e2el_ms: list[tuple[float, float]]
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_cache(reset_url: str):
|
||||||
|
"""Sends a POST request to reset the prefix cache."""
|
||||||
|
logger.debug("Resetting prefix cache at %s", reset_url)
|
||||||
|
try:
|
||||||
|
async with (
|
||||||
|
aiohttp.ClientSession() as session,
|
||||||
|
session.post(reset_url) as response,
|
||||||
|
):
|
||||||
|
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||||
|
logger.debug("Prefix cache reset successful: %s", response.status)
|
||||||
|
except aiohttp.ClientConnectorError as e:
|
||||||
|
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
|
||||||
|
except aiohttp.ClientResponseError as e:
|
||||||
|
logger.error(
|
||||||
|
"Cache reset request failed with status %s: %s", e.status, e.message
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("An unexpected error occurred during cache reset: %s", e)
|
||||||
|
|
||||||
|
|
||||||
|
async def sequential_benchmark(
|
||||||
|
backend: str,
|
||||||
|
api_url: str,
|
||||||
|
model_id: str,
|
||||||
|
tokenizer,
|
||||||
|
input_requests: list[SampleRequest],
|
||||||
|
request_func,
|
||||||
|
selected_percentiles: list[float],
|
||||||
|
cache_reset_url: Optional[str] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Benchmark that processes requests sequentially, waiting for each to complete
|
||||||
|
before starting the next one. Resets prefix cache between requests.
|
||||||
|
"""
|
||||||
|
outputs = []
|
||||||
|
|
||||||
|
pbar = tqdm(total=len(input_requests))
|
||||||
|
|
||||||
|
benchmark_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
# Process requests sequentially
|
||||||
|
for request in input_requests:
|
||||||
|
prompt, prompt_len, output_len = (
|
||||||
|
request.prompt,
|
||||||
|
request.prompt_len,
|
||||||
|
request.expected_output_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("Sending request with len %s", request.prompt_len)
|
||||||
|
logger.debug('Request str: "%s"', request.prompt[:50])
|
||||||
|
request_start_time = time.perf_counter()
|
||||||
|
|
||||||
|
request_func_input = RequestFuncInput(
|
||||||
|
model=model_id,
|
||||||
|
prompt=prompt,
|
||||||
|
api_url=api_url,
|
||||||
|
prompt_len=prompt_len,
|
||||||
|
output_len=output_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
output = await request_func(request_func_input=request_func_input)
|
||||||
|
|
||||||
|
request_end_time = time.perf_counter()
|
||||||
|
# Add timing information
|
||||||
|
if output.success and not hasattr(output, "latency"):
|
||||||
|
output.latency = request_end_time - request_start_time
|
||||||
|
logger.info("Finished request with latency %.4f s", output.latency)
|
||||||
|
|
||||||
|
outputs.append(output)
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
pbar.close()
|
||||||
|
|
||||||
|
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||||
|
|
||||||
|
# Calculate metrics
|
||||||
|
metrics = calculate_metrics(
|
||||||
|
input_requests=input_requests,
|
||||||
|
outputs=outputs,
|
||||||
|
dur_s=benchmark_duration,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
selected_percentiles=selected_percentiles,
|
||||||
|
)
|
||||||
|
|
||||||
|
print_results(metrics, benchmark_duration)
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"duration": benchmark_duration,
|
||||||
|
"completed": metrics.completed,
|
||||||
|
"total_input_tokens": metrics.total_input,
|
||||||
|
"total_output_tokens": metrics.total_output,
|
||||||
|
"input_lens": [request.prompt_len for request in input_requests],
|
||||||
|
"output_lens": [
|
||||||
|
output.output_tokens if output.success else 0 for output in outputs
|
||||||
|
],
|
||||||
|
"ttfts": [output.ttft for output in outputs if output.success],
|
||||||
|
"itls": [output.itl for output in outputs if output.success],
|
||||||
|
"generated_texts": [
|
||||||
|
output.generated_text for output in outputs if output.success
|
||||||
|
],
|
||||||
|
"errors": [output.error for output in outputs if not output.success],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add summary statistics
|
||||||
|
for stat_name in ["ttft", "itl", "e2el"]:
|
||||||
|
for metric_name in ["mean", "median", "std"]:
|
||||||
|
result[f"{metric_name}_{stat_name}_ms"] = getattr(
|
||||||
|
metrics, f"{metric_name}_{stat_name}_ms"
|
||||||
|
)
|
||||||
|
|
||||||
|
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
|
||||||
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
|
result[f"p{p_word}_{stat_name}_ms"] = value
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_metrics(
|
||||||
|
input_requests: list[SampleRequest],
|
||||||
|
outputs: list[RequestFuncOutput],
|
||||||
|
dur_s: float,
|
||||||
|
tokenizer,
|
||||||
|
selected_percentiles: list[float],
|
||||||
|
) -> BenchmarkMetrics:
|
||||||
|
"""Calculate benchmark metrics from results."""
|
||||||
|
total_input = 0
|
||||||
|
completed = 0
|
||||||
|
total_output = 0
|
||||||
|
ttfts = []
|
||||||
|
itls = []
|
||||||
|
e2els = []
|
||||||
|
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
if output.success:
|
||||||
|
output_len = output.output_tokens
|
||||||
|
|
||||||
|
if not output_len:
|
||||||
|
# Use tokenizer to count output tokens if not provided
|
||||||
|
output_len = len(
|
||||||
|
tokenizer(output.generated_text, add_special_tokens=False).input_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
total_output += output_len
|
||||||
|
total_input += input_requests[i].prompt_len
|
||||||
|
|
||||||
|
if hasattr(output, "ttft") and output.ttft is not None:
|
||||||
|
ttfts.append(output.ttft)
|
||||||
|
|
||||||
|
if hasattr(output, "itl") and output.itl:
|
||||||
|
# Ensure itl is a list of floats
|
||||||
|
if isinstance(output.itl, list):
|
||||||
|
itls.extend(output.itl)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"Expected list for ITL but got %s. Appending as is.",
|
||||||
|
type(output.itl),
|
||||||
|
)
|
||||||
|
itls.append(output.itl)
|
||||||
|
|
||||||
|
if hasattr(output, "latency") and output.latency is not None:
|
||||||
|
e2els.append(output.latency)
|
||||||
|
|
||||||
|
completed += 1
|
||||||
|
|
||||||
|
return BenchmarkMetrics(
|
||||||
|
completed=completed,
|
||||||
|
total_input=total_input,
|
||||||
|
total_output=total_output,
|
||||||
|
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
|
||||||
|
median_ttft_ms=np.median(ttfts or [0]) * 1000,
|
||||||
|
std_ttft_ms=np.std(ttfts or [0]) * 1000,
|
||||||
|
percentiles_ttft_ms=[
|
||||||
|
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
|
||||||
|
],
|
||||||
|
mean_itl_ms=np.mean(itls or [0]) * 1000,
|
||||||
|
median_itl_ms=np.median(itls or [0]) * 1000,
|
||||||
|
std_itl_ms=np.std(itls or [0]) * 1000,
|
||||||
|
percentiles_itl_ms=[
|
||||||
|
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
|
||||||
|
],
|
||||||
|
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
|
||||||
|
median_e2el_ms=np.median(e2els or [0]) * 1000,
|
||||||
|
std_e2el_ms=np.std(e2els or [0]) * 1000,
|
||||||
|
percentiles_e2el_ms=[
|
||||||
|
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
|
||||||
|
"""Print benchmark results in a formatted way."""
|
||||||
|
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
|
||||||
|
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||||
|
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||||
|
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||||
|
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||||
|
|
||||||
|
def print_metric_stats(metric_name, header):
|
||||||
|
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
f"Mean {metric_name} (ms):",
|
||||||
|
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
print(
|
||||||
|
"{:<40} {:<10.2f}".format(
|
||||||
|
f"Median {metric_name} (ms):",
|
||||||
|
getattr(metrics, f"median_{metric_name.lower()}_ms"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
|
||||||
|
p_word = str(int(p)) if int(p) == p else str(p)
|
||||||
|
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||||
|
|
||||||
|
print_metric_stats("TTFT", "Time to First Token")
|
||||||
|
print_metric_stats("ITL", "Inter-token Latency")
|
||||||
|
print_metric_stats("E2EL", "End-to-end Latency")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def main_async(args):
|
||||||
|
# Import needed functions based on your setup
|
||||||
|
from backend_request_func import ASYNC_REQUEST_FUNCS
|
||||||
|
|
||||||
|
backend = args.backend
|
||||||
|
model_id = args.model
|
||||||
|
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||||
|
|
||||||
|
# Set up API URL
|
||||||
|
if args.base_url is not None:
|
||||||
|
api_url = f"{args.base_url}{args.endpoint}"
|
||||||
|
else:
|
||||||
|
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||||
|
|
||||||
|
# Set up Cache Reset URL
|
||||||
|
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
|
||||||
|
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
|
||||||
|
|
||||||
|
# Get tokenizer
|
||||||
|
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
|
||||||
|
|
||||||
|
# Get request function
|
||||||
|
if backend in ASYNC_REQUEST_FUNCS:
|
||||||
|
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown backend: {backend}")
|
||||||
|
|
||||||
|
input_requests = RandomDataset().sample(
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
num_requests=args.num_requests,
|
||||||
|
prefix_len=0,
|
||||||
|
input_len=args.input_len,
|
||||||
|
output_len=args.output_len,
|
||||||
|
range_ratio=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run benchmark
|
||||||
|
result = await sequential_benchmark(
|
||||||
|
backend=backend,
|
||||||
|
api_url=api_url,
|
||||||
|
model_id=model_id,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
input_requests=input_requests,
|
||||||
|
request_func=request_func,
|
||||||
|
selected_percentiles=[50, 90, 95, 99],
|
||||||
|
cache_reset_url=cache_reset_url,
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print(args)
|
||||||
|
random.seed(args.seed)
|
||||||
|
np.random.seed(args.seed)
|
||||||
|
|
||||||
|
asyncio.run(main_async(args))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
|
||||||
|
parser.add_argument(
|
||||||
|
"--backend", type=str, default="vllm", help="Backend to use for requests"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--base-url",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Server base URL (overrides --host and --port)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||||
|
parser.add_argument("--port", type=int, default=8000)
|
||||||
|
parser.add_argument(
|
||||||
|
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
|
||||||
|
)
|
||||||
|
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
||||||
|
parser.add_argument(
|
||||||
|
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--num-requests", type=int, default=100, help="Number of requests to process"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--input-len", type=int, default=128, help="Input len for generated prompts"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output-len", type=int, default=None, help="Override output len for requests"
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=42)
|
||||||
|
parser.add_argument(
|
||||||
|
"--trust-remote-code",
|
||||||
|
action="store_true",
|
||||||
|
help="Trust remote code from HuggingFace",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
@ -20,11 +20,10 @@ from vllm.config import ParallelConfig, VllmConfig, set_current_vllm_config
|
|||||||
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
from vllm.utils import (CacheInfo, FlexibleArgumentParser, LRUCache,
|
||||||
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
MemorySnapshot, PlaceholderModule, StoreBoolean,
|
||||||
bind_kv_cache, common_broadcastable_dtype,
|
bind_kv_cache, common_broadcastable_dtype,
|
||||||
deprecate_kwargs, get_open_port, get_tcp_uri,
|
deprecate_kwargs, get_open_port, is_lossless_cast,
|
||||||
is_lossless_cast, join_host_port, make_zmq_path,
|
make_zmq_path, make_zmq_socket, memory_profiling,
|
||||||
make_zmq_socket, memory_profiling,
|
merge_async_iterators, sha256, split_zmq_path,
|
||||||
merge_async_iterators, sha256, split_host_port,
|
supports_kw, swap_dict_values)
|
||||||
split_zmq_path, supports_kw, swap_dict_values)
|
|
||||||
|
|
||||||
from .utils import create_new_process_for_each_test, error_on_warning
|
from .utils import create_new_process_for_each_test, error_on_warning
|
||||||
|
|
||||||
@ -877,44 +876,3 @@ def test_make_zmq_socket_ipv6():
|
|||||||
def test_make_zmq_path():
|
def test_make_zmq_path():
|
||||||
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
assert make_zmq_path("tcp", "127.0.0.1", "5555") == "tcp://127.0.0.1:5555"
|
||||||
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
assert make_zmq_path("tcp", "::1", "5555") == "tcp://[::1]:5555"
|
||||||
|
|
||||||
|
|
||||||
def test_get_tcp_uri():
|
|
||||||
assert get_tcp_uri("127.0.0.1", 5555) == "tcp://127.0.0.1:5555"
|
|
||||||
assert get_tcp_uri("::1", 5555) == "tcp://[::1]:5555"
|
|
||||||
|
|
||||||
|
|
||||||
def test_split_host_port():
|
|
||||||
# valid ipv4
|
|
||||||
assert split_host_port("127.0.0.1:5555") == ("127.0.0.1", 5555)
|
|
||||||
# invalid ipv4
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# multi colon
|
|
||||||
assert split_host_port("127.0.0.1::5555")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# tailing colon
|
|
||||||
assert split_host_port("127.0.0.1:5555:")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# no colon
|
|
||||||
assert split_host_port("127.0.0.15555")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# none int port
|
|
||||||
assert split_host_port("127.0.0.1:5555a")
|
|
||||||
|
|
||||||
# valid ipv6
|
|
||||||
assert split_host_port("[::1]:5555") == ("::1", 5555)
|
|
||||||
# invalid ipv6
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# multi colon
|
|
||||||
assert split_host_port("[::1]::5555")
|
|
||||||
with pytest.raises(IndexError):
|
|
||||||
# no colon
|
|
||||||
assert split_host_port("[::1]5555")
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
# none int port
|
|
||||||
assert split_host_port("[::1]:5555a")
|
|
||||||
|
|
||||||
|
|
||||||
def test_join_host_port():
|
|
||||||
assert join_host_port("127.0.0.1", 5555) == "127.0.0.1:5555"
|
|
||||||
assert join_host_port("::1", 5555) == "[::1]:5555"
|
|
||||||
|
88
tools/pd_disagg/Justfile
Normal file
88
tools/pd_disagg/Justfile
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# Needed for the proxy server
|
||||||
|
vllm-directory := "/home/rshaw/vllm/"
|
||||||
|
|
||||||
|
PREFILL_GPU := "0,1,2,3"
|
||||||
|
DECODE_GPU := "4,5,6,7"
|
||||||
|
|
||||||
|
PREFILL_TP := env("PREFILL_TP", "1")
|
||||||
|
DECODE_TP := env("DECODE_TP", "1")
|
||||||
|
|
||||||
|
BLOCK_SIZE := env("BLOCK_SIZE", "128")
|
||||||
|
|
||||||
|
MODEL := "meta-llama/Llama-3.1-8B-Instruct"
|
||||||
|
PROXY_PORT := "8192"
|
||||||
|
PREFILL_PORT := "8100"
|
||||||
|
DECODE_PORT := "8200"
|
||||||
|
PREFILL_NIXL_SIDE_CHANNEL_PORT := "5557"
|
||||||
|
DECODE_NIXL_SIDE_CHANNEL_PORT := "5568"
|
||||||
|
|
||||||
|
prefill:
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_PORT={{PREFILL_NIXL_SIDE_CHANNEL_PORT}} \
|
||||||
|
CUDA_VISIBLE_DEVICES={{PREFILL_GPU}} \
|
||||||
|
vllm serve {{MODEL}} \
|
||||||
|
--port {{PREFILL_PORT}} \
|
||||||
|
--tensor-parallel-size {{PREFILL_TP}} \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--block-size {{BLOCK_SIZE}} \
|
||||||
|
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||||
|
|
||||||
|
decode:
|
||||||
|
VLLM_NIXL_SIDE_CHANNEL_PORT={{DECODE_NIXL_SIDE_CHANNEL_PORT}} \
|
||||||
|
CUDA_VISIBLE_DEVICES={{DECODE_GPU}} \
|
||||||
|
vllm serve {{MODEL}} \
|
||||||
|
--port {{DECODE_PORT}} \
|
||||||
|
--tensor-parallel-size {{DECODE_TP}} \
|
||||||
|
--enforce-eager \
|
||||||
|
--disable-log-requests \
|
||||||
|
--block-size {{BLOCK_SIZE}} \
|
||||||
|
--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both"}'
|
||||||
|
|
||||||
|
proxy:
|
||||||
|
python "{{vllm-directory}}tests/v1/kv_connector/nixl_integration/toy_proxy_server.py" \
|
||||||
|
--port {{PROXY_PORT}} \
|
||||||
|
--prefiller-port {{PREFILL_PORT}} \
|
||||||
|
--decoder-port {{DECODE_PORT}}
|
||||||
|
|
||||||
|
send_request:
|
||||||
|
curl -X POST http://localhost:{{PROXY_PORT}}/v1/completions \
|
||||||
|
-H "Content-Type: application/json" \
|
||||||
|
-d '{ \
|
||||||
|
"model": "{{MODEL}}", \
|
||||||
|
"prompt": "Red Hat is the best open source company by far across Linux, K8s, and AI, and vLLM has the greatest community in open source AI software infrastructure. I love vLLM because", \
|
||||||
|
"max_tokens": 150, \
|
||||||
|
"temperature": 0.7 \
|
||||||
|
}'
|
||||||
|
|
||||||
|
benchmark NUM_PROMPTS:
|
||||||
|
python {{vllm-directory}}/benchmarks/benchmark_serving.py \
|
||||||
|
--port {{PROXY_PORT}} \
|
||||||
|
--model {{MODEL}} \
|
||||||
|
--dataset-name random \
|
||||||
|
--random-input-len 30000 \
|
||||||
|
--random-output-len 10 \
|
||||||
|
--num-prompts {{NUM_PROMPTS}} \
|
||||||
|
--seed $(date +%s) \
|
||||||
|
|
||||||
|
benchmark_one INPUT_LEN:
|
||||||
|
python {{vllm-directory}}benchmarks/benchmark_one_concurrent.py \
|
||||||
|
--port {{PROXY_PORT}} \
|
||||||
|
--model {{MODEL}} \
|
||||||
|
--input-len {{INPUT_LEN}} \
|
||||||
|
--output-len 1 \
|
||||||
|
--num-requests 10 \
|
||||||
|
--seed $(date +%s)
|
||||||
|
|
||||||
|
benchmark_one_no_pd INPUT_LEN:
|
||||||
|
python {{vllm-directory}}benchmarks/benchmark_one_concurrent_req.py \
|
||||||
|
--port {{DECODE_PORT}} \
|
||||||
|
--model {{MODEL}} \
|
||||||
|
--input-len {{INPUT_LEN}} \
|
||||||
|
--output-len 1 \
|
||||||
|
--num-requests 10 \
|
||||||
|
--seed $(date +%s)
|
||||||
|
|
||||||
|
eval:
|
||||||
|
lm_eval --model local-completions --tasks gsm8k \
|
||||||
|
--model_args model={{MODEL}},base_url=http://127.0.0.1:{{PROXY_PORT}}/v1/completions,num_concurrent=100,max_retries=3,tokenized_requests=False \
|
||||||
|
--limit 1000
|
@ -16,7 +16,6 @@ from safetensors.torch import save as safetensors_save
|
|||||||
from vllm.config import KVTransferConfig
|
from vllm.config import KVTransferConfig
|
||||||
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
from vllm.distributed.kv_transfer.kv_pipe.base import KVPipeBase
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.utils import join_host_port, make_zmq_path, split_host_port
|
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
NONE_INT = -150886311
|
NONE_INT = -150886311
|
||||||
@ -80,19 +79,18 @@ class MooncakeTransferEngine:
|
|||||||
logger.error(
|
logger.error(
|
||||||
"An error occurred while loading the configuration: %s", exc)
|
"An error occurred while loading the configuration: %s", exc)
|
||||||
raise
|
raise
|
||||||
prefill_host, base_prefill_port = split_host_port(
|
prefill_host, base_prefill_port = self.config.prefill_url.split(':')
|
||||||
self.config.prefill_url)
|
decode_host, base_decode_port = self.config.decode_url.split(':')
|
||||||
decode_host, base_decode_port = split_host_port(self.config.decode_url)
|
|
||||||
|
|
||||||
# Avoid ports conflict when running prefill and decode on the same node
|
# Avoid ports conflict when running prefill and decode on the same node
|
||||||
if prefill_host == decode_host and \
|
if prefill_host == decode_host and \
|
||||||
base_prefill_port == base_decode_port:
|
base_prefill_port == base_decode_port:
|
||||||
base_decode_port = base_decode_port + 100
|
base_decode_port = str(int(base_decode_port) + 100)
|
||||||
|
|
||||||
prefill_port = base_prefill_port + self.local_rank
|
prefill_port = int(base_prefill_port) + self.local_rank
|
||||||
decode_port = base_decode_port + self.local_rank
|
decode_port = int(base_decode_port) + self.local_rank
|
||||||
self.prefill_url = join_host_port(prefill_host, prefill_port)
|
self.prefill_url = ':'.join([prefill_host, str(prefill_port)])
|
||||||
self.decode_url = join_host_port(decode_host, decode_port)
|
self.decode_url = ':'.join([decode_host, str(decode_port)])
|
||||||
|
|
||||||
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
self.initialize(self.prefill_url if kv_rank == 0 else self.decode_url,
|
||||||
self.config.metadata_server, self.config.protocol,
|
self.config.metadata_server, self.config.protocol,
|
||||||
@ -112,30 +110,22 @@ class MooncakeTransferEngine:
|
|||||||
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
self._setup_metadata_sockets(kv_rank, prefill_host, base_prefill_port,
|
||||||
decode_host, base_decode_port)
|
decode_host, base_decode_port)
|
||||||
|
|
||||||
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: int,
|
def _setup_metadata_sockets(self, kv_rank: int, p_host: str, p_port: str,
|
||||||
d_host: str, d_port: int) -> None:
|
d_host: str, d_port: str) -> None:
|
||||||
"""Set up ZeroMQ sockets for sending and receiving data."""
|
"""Set up ZeroMQ sockets for sending and receiving data."""
|
||||||
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
# Offsets < 8 are left for initialization in case tp and pp are enabled
|
||||||
p_rank_offset = p_port + 8 + self.local_rank * 2
|
p_rank_offset = int(p_port) + 8 + self.local_rank * 2
|
||||||
d_rank_offset = d_port + 8 + self.local_rank * 2
|
d_rank_offset = int(d_port) + 8 + self.local_rank * 2
|
||||||
if kv_rank == 0:
|
if kv_rank == 0:
|
||||||
self.sender_socket.bind(
|
self.sender_socket.bind(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||||
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
self.receiver_socket.connect(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||||
self.receiver_socket.connect(
|
self.sender_ack.connect(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||||
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
self.receiver_ack.bind(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||||
self.sender_ack.connect(
|
|
||||||
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
|
||||||
self.receiver_ack.bind(
|
|
||||||
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
|
||||||
else:
|
else:
|
||||||
self.receiver_socket.connect(
|
self.receiver_socket.connect(f"tcp://{p_host}:{p_rank_offset + 1}")
|
||||||
make_zmq_path("tcp", p_host, p_rank_offset + 1))
|
self.sender_socket.bind(f"tcp://{d_host}:{d_rank_offset + 1}")
|
||||||
self.sender_socket.bind(
|
self.receiver_ack.bind(f"tcp://{d_host}:{d_rank_offset + 2}")
|
||||||
make_zmq_path("tcp", d_host, d_rank_offset + 1))
|
self.sender_ack.connect(f"tcp://{p_host}:{p_rank_offset + 2}")
|
||||||
self.receiver_ack.bind(
|
|
||||||
make_zmq_path("tcp", d_host, d_rank_offset + 2))
|
|
||||||
self.sender_ack.connect(
|
|
||||||
make_zmq_path("tcp", p_host, p_rank_offset + 2))
|
|
||||||
|
|
||||||
def initialize(self, local_hostname: str, metadata_server: str,
|
def initialize(self, local_hostname: str, metadata_server: str,
|
||||||
protocol: str, device_name: str,
|
protocol: str, device_name: str,
|
||||||
|
@ -55,6 +55,9 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
|||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
RowParallelLinear)
|
RowParallelLinear)
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
|
||||||
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
||||||
|
GPTQMarlinConfig)
|
||||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
from vllm.model_executor.models.module_mapping import MultiModelKeys
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
@ -176,7 +179,6 @@ class Glm4vVisionMLP(nn.Module):
|
|||||||
hidden_features: int,
|
hidden_features: int,
|
||||||
bias: bool = False,
|
bias: bool = False,
|
||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
prefix: str = "",
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_up_proj = MergedColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
@ -184,12 +186,13 @@ class Glm4vVisionMLP(nn.Module):
|
|||||||
output_sizes=[hidden_features] * 2,
|
output_sizes=[hidden_features] * 2,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.gate_up_proj")
|
)
|
||||||
self.down_proj = RowParallelLinear(hidden_features,
|
self.down_proj = RowParallelLinear(
|
||||||
in_features,
|
hidden_features,
|
||||||
bias=bias,
|
in_features,
|
||||||
quant_config=quant_config,
|
bias=bias,
|
||||||
prefix=f"{prefix}.down_proj")
|
quant_config=quant_config,
|
||||||
|
)
|
||||||
self.act_fn = SiluAndMul()
|
self.act_fn = SiluAndMul()
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
def forward(self, x: torch.Tensor):
|
||||||
@ -404,7 +407,6 @@ class Glm4vVisionBlock(nn.Module):
|
|||||||
mlp_hidden_dim,
|
mlp_hidden_dim,
|
||||||
bias=False,
|
bias=False,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=f"{prefix}.mlp",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
@ -1276,7 +1278,7 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.visual = Glm4vVisionTransformer(
|
self.visual = Glm4vVisionTransformer(
|
||||||
config.vision_config,
|
config.vision_config,
|
||||||
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
||||||
quant_config=quant_config,
|
quant_config=self._maybe_ignore_quant_config(quant_config),
|
||||||
prefix=maybe_prefix(prefix, "visual"),
|
prefix=maybe_prefix(prefix, "visual"),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -1289,6 +1291,13 @@ class Glm4vForConditionalGeneration(nn.Module, SupportsMultiModal,
|
|||||||
self.make_empty_intermediate_tensors = (
|
self.make_empty_intermediate_tensors = (
|
||||||
self.language_model.make_empty_intermediate_tensors)
|
self.language_model.make_empty_intermediate_tensors)
|
||||||
|
|
||||||
|
def _maybe_ignore_quant_config(self, quant_config: QuantizationConfig):
|
||||||
|
# GPTQ configs do not have a list of ignored modules, however AutoGPTQ
|
||||||
|
# seems to avoid vision encoder sections for some models.
|
||||||
|
if isinstance(quant_config, (GPTQConfig, GPTQMarlinConfig)):
|
||||||
|
return None
|
||||||
|
return quant_config
|
||||||
|
|
||||||
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
def _validate_and_reshape_mm_tensor(self, mm_input: object,
|
||||||
name: str) -> torch.Tensor:
|
name: str) -> torch.Tensor:
|
||||||
if not isinstance(mm_input, (torch.Tensor, list)):
|
if not isinstance(mm_input, (torch.Tensor, list)):
|
||||||
|
@ -33,8 +33,10 @@ from vllm.logger import init_logger
|
|||||||
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
from vllm.transformers_utils.configs import (ChatGLMConfig, Cohere2Config,
|
||||||
DbrxConfig, DeepseekVLV2Config,
|
DbrxConfig, DeepseekVLV2Config,
|
||||||
EAGLEConfig, ExaoneConfig,
|
EAGLEConfig, ExaoneConfig,
|
||||||
JAISConfig, KimiVLConfig,
|
H2OVLChatConfig,
|
||||||
MedusaConfig, MiniMaxText01Config,
|
InternVLChatConfig, JAISConfig,
|
||||||
|
KimiVLConfig, MedusaConfig,
|
||||||
|
MiniMaxText01Config,
|
||||||
MiniMaxVL01Config, MllamaConfig,
|
MiniMaxVL01Config, MllamaConfig,
|
||||||
MLPSpeculatorConfig, MPTConfig,
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
NemotronConfig, NVLM_D_Config,
|
NemotronConfig, NVLM_D_Config,
|
||||||
@ -88,6 +90,8 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
|||||||
"medusa": MedusaConfig,
|
"medusa": MedusaConfig,
|
||||||
"eagle": EAGLEConfig,
|
"eagle": EAGLEConfig,
|
||||||
"exaone": ExaoneConfig,
|
"exaone": ExaoneConfig,
|
||||||
|
"h2ovl_chat": H2OVLChatConfig,
|
||||||
|
"internvl_chat": InternVLChatConfig,
|
||||||
"minimax_text_01": MiniMaxText01Config,
|
"minimax_text_01": MiniMaxText01Config,
|
||||||
"minimax_vl_01": MiniMaxVL01Config,
|
"minimax_vl_01": MiniMaxVL01Config,
|
||||||
"nemotron": NemotronConfig,
|
"nemotron": NemotronConfig,
|
||||||
@ -100,10 +104,6 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = {
|
|||||||
**_CONFIG_REGISTRY_OVERRIDE_HF
|
**_CONFIG_REGISTRY_OVERRIDE_HF
|
||||||
}
|
}
|
||||||
|
|
||||||
_CONFIG_ATTRS_MAPPING: dict[str, str] = {
|
|
||||||
"llm_config": "text_config",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigFormat(str, enum.Enum):
|
class ConfigFormat(str, enum.Enum):
|
||||||
AUTO = "auto"
|
AUTO = "auto"
|
||||||
@ -286,18 +286,6 @@ def is_encoder_decoder(config: PretrainedConfig) -> bool:
|
|||||||
return getattr(config, "is_encoder_decoder", False)
|
return getattr(config, "is_encoder_decoder", False)
|
||||||
|
|
||||||
|
|
||||||
def _maybe_remap_hf_config_attrs(config: PretrainedConfig) -> PretrainedConfig:
|
|
||||||
"""Remap config attributes to match the expected names."""
|
|
||||||
for old_attr, new_attr in _CONFIG_ATTRS_MAPPING.items():
|
|
||||||
if hasattr(config, old_attr):
|
|
||||||
if not hasattr(config, new_attr):
|
|
||||||
config.update({new_attr: getattr(config, old_attr)})
|
|
||||||
delattr(config, old_attr)
|
|
||||||
logger.debug("Remapped config attribute '%s' to '%s'", old_attr,
|
|
||||||
new_attr)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def get_config(
|
def get_config(
|
||||||
model: Union[str, Path],
|
model: Union[str, Path],
|
||||||
trust_remote_code: bool,
|
trust_remote_code: bool,
|
||||||
@ -373,9 +361,6 @@ def get_config(
|
|||||||
revision=revision,
|
revision=revision,
|
||||||
code_revision=code_revision,
|
code_revision=code_revision,
|
||||||
token=_get_hf_token(),
|
token=_get_hf_token(),
|
||||||
# some old custom model's config needs
|
|
||||||
# `has_no_defaults_at_init=True` to work.
|
|
||||||
has_no_defaults_at_init=trust_remote_code,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
@ -391,7 +376,6 @@ def get_config(
|
|||||||
raise RuntimeError(err_msg) from e
|
raise RuntimeError(err_msg) from e
|
||||||
else:
|
else:
|
||||||
raise e
|
raise e
|
||||||
config = _maybe_remap_hf_config_attrs(config)
|
|
||||||
|
|
||||||
elif config_format == ConfigFormat.MISTRAL:
|
elif config_format == ConfigFormat.MISTRAL:
|
||||||
config = load_params_config(model, revision, **kwargs)
|
config = load_params_config(model, revision, **kwargs)
|
||||||
|
@ -11,6 +11,8 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig
|
|||||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||||
|
from vllm.transformers_utils.configs.h2ovl import H2OVLChatConfig
|
||||||
|
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||||
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
from vllm.transformers_utils.configs.kimi_vl import KimiVLConfig
|
||||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||||
@ -36,6 +38,8 @@ __all__ = [
|
|||||||
"DeepseekVLV2Config",
|
"DeepseekVLV2Config",
|
||||||
"MPTConfig",
|
"MPTConfig",
|
||||||
"RWConfig",
|
"RWConfig",
|
||||||
|
"H2OVLChatConfig",
|
||||||
|
"InternVLChatConfig",
|
||||||
"JAISConfig",
|
"JAISConfig",
|
||||||
"MedusaConfig",
|
"MedusaConfig",
|
||||||
"EAGLEConfig",
|
"EAGLEConfig",
|
||||||
|
16
vllm/transformers_utils/configs/h2ovl.py
Normal file
16
vllm/transformers_utils/configs/h2ovl.py
Normal file
@ -0,0 +1,16 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/h2oai/h2ovl-mississippi-2b/blob/main/configuration_h2ovl_chat.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# H2OVL-Mississippi
|
||||||
|
# Copyright (c) 2024 H2O.AI
|
||||||
|
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||||
|
# --------------------------------------------------------
|
||||||
|
|
||||||
|
from .internvl import InternVLChatConfig
|
||||||
|
|
||||||
|
|
||||||
|
class H2OVLChatConfig(InternVLChatConfig):
|
||||||
|
model_type = "h2ovl_chat"
|
54
vllm/transformers_utils/configs/internvl.py
Normal file
54
vllm/transformers_utils/configs/internvl.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
|
||||||
|
# Adapted from
|
||||||
|
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
|
||||||
|
# --------------------------------------------------------
|
||||||
|
# InternVL
|
||||||
|
# Copyright (c) 2024 OpenGVLab
|
||||||
|
# Licensed under The MIT License [see LICENSE for details]
|
||||||
|
# --------------------------------------------------------
|
||||||
|
from transformers.configuration_utils import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class InternVLChatConfig(PretrainedConfig):
|
||||||
|
model_type = 'internvl_chat'
|
||||||
|
is_composition = True
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
vision_config=None,
|
||||||
|
llm_config=None,
|
||||||
|
use_backbone_lora=0,
|
||||||
|
use_llm_lora=0,
|
||||||
|
select_layer=-1,
|
||||||
|
force_image_size=None,
|
||||||
|
downsample_ratio=0.5,
|
||||||
|
template=None,
|
||||||
|
dynamic_image_size=False,
|
||||||
|
use_thumbnail=False,
|
||||||
|
ps_version='v1',
|
||||||
|
min_dynamic_patch=1,
|
||||||
|
max_dynamic_patch=6,
|
||||||
|
**kwargs):
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if vision_config is None:
|
||||||
|
vision_config = {}
|
||||||
|
|
||||||
|
if llm_config is None:
|
||||||
|
llm_config = {}
|
||||||
|
|
||||||
|
self.vision_config = PretrainedConfig(**vision_config)
|
||||||
|
self.text_config = PretrainedConfig(**llm_config)
|
||||||
|
|
||||||
|
self.use_backbone_lora = use_backbone_lora
|
||||||
|
self.use_llm_lora = use_llm_lora
|
||||||
|
self.select_layer = select_layer
|
||||||
|
self.force_image_size = force_image_size
|
||||||
|
self.downsample_ratio = downsample_ratio
|
||||||
|
self.template = template
|
||||||
|
self.dynamic_image_size = dynamic_image_size
|
||||||
|
self.use_thumbnail = use_thumbnail
|
||||||
|
self.ps_version = ps_version # pixel shuffle version
|
||||||
|
self.min_dynamic_patch = min_dynamic_patch
|
||||||
|
self.max_dynamic_patch = max_dynamic_patch
|
@ -8,24 +8,8 @@
|
|||||||
# Copyright (c) 2024 NVIDIA
|
# Copyright (c) 2024 NVIDIA
|
||||||
# Licensed under Apache 2.0 License [see LICENSE for details]
|
# Licensed under Apache 2.0 License [see LICENSE for details]
|
||||||
# --------------------------------------------------------
|
# --------------------------------------------------------
|
||||||
from transformers import Qwen2Config
|
from .internvl import InternVLChatConfig
|
||||||
from transformers.configuration_utils import PretrainedConfig
|
|
||||||
|
|
||||||
|
|
||||||
class NVLM_D_Config(PretrainedConfig):
|
class NVLM_D_Config(InternVLChatConfig):
|
||||||
model_type = 'NVLM_D'
|
model_type = 'NVLM_D'
|
||||||
is_composition = True
|
|
||||||
|
|
||||||
def __init__(self, vision_config=None, llm_config=None, **kwargs):
|
|
||||||
super().__init__(**kwargs)
|
|
||||||
|
|
||||||
# Handle vision_config initialization
|
|
||||||
if vision_config is None:
|
|
||||||
vision_config = {}
|
|
||||||
|
|
||||||
# Handle llm_config initialization
|
|
||||||
if llm_config is None:
|
|
||||||
llm_config = {}
|
|
||||||
|
|
||||||
self.vision_config = PretrainedConfig(**vision_config)
|
|
||||||
self.text_config = Qwen2Config(**llm_config)
|
|
||||||
|
@ -46,7 +46,7 @@ from dataclasses import dataclass, field
|
|||||||
from functools import cache, lru_cache, partial, wraps
|
from functools import cache, lru_cache, partial, wraps
|
||||||
from types import MappingProxyType
|
from types import MappingProxyType
|
||||||
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
from typing import (TYPE_CHECKING, Any, Callable, Generic, Literal, NamedTuple,
|
||||||
Optional, Tuple, TypeVar, Union, cast, overload)
|
Optional, TypeVar, Union, cast, overload)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
@ -628,34 +628,14 @@ def is_valid_ipv6_address(address: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def split_host_port(host_port: str) -> Tuple[str, int]:
|
|
||||||
# ipv6
|
|
||||||
if host_port.startswith('['):
|
|
||||||
host, port = host_port.rsplit(']', 1)
|
|
||||||
host = host[1:]
|
|
||||||
port = port.split(':')[1]
|
|
||||||
return host, int(port)
|
|
||||||
else:
|
|
||||||
host, port = host_port.split(':')
|
|
||||||
return host, int(port)
|
|
||||||
|
|
||||||
|
|
||||||
def join_host_port(host: str, port: int) -> str:
|
|
||||||
if is_valid_ipv6_address(host):
|
|
||||||
return f"[{host}]:{port}"
|
|
||||||
else:
|
|
||||||
return f"{host}:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_distributed_init_method(ip: str, port: int) -> str:
|
def get_distributed_init_method(ip: str, port: int) -> str:
|
||||||
return get_tcp_uri(ip, port)
|
return get_tcp_uri(ip, port)
|
||||||
|
|
||||||
|
|
||||||
def get_tcp_uri(ip: str, port: int) -> str:
|
def get_tcp_uri(ip: str, port: int) -> str:
|
||||||
if is_valid_ipv6_address(ip):
|
# Brackets are not permitted in ipv4 addresses,
|
||||||
return f"tcp://[{ip}]:{port}"
|
# see https://github.com/python/cpython/issues/103848
|
||||||
else:
|
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
||||||
return f"tcp://{ip}:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def get_open_zmq_ipc_path() -> str:
|
def get_open_zmq_ipc_path() -> str:
|
||||||
|
Reference in New Issue
Block a user