mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Derive auto max model len state from original value
This commit is contained in:
@ -37,7 +37,8 @@ Dynamic quantization is also supported via the `quantization` option -- see [her
|
||||
## Context length and batch size
|
||||
|
||||
You can further reduce memory usage by limiting the context length of the model (`max_model_len` option)
|
||||
and the maximum batch size (`max_num_seqs` option).
|
||||
and the maximum batch size (`max_num_seqs` option). Setting `max_model_len=-1` lets vLLM automatically
|
||||
pick the largest context length that fits in GPU memory, up to the model's default maximum.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
@ -320,6 +320,9 @@ def test_human_readable_model_len():
|
||||
args = parser.parse_args(["--max-model-len", "1024"])
|
||||
assert args.max_model_len == 1024
|
||||
|
||||
args = parser.parse_args(["--max-model-len", "-1"])
|
||||
assert args.max_model_len == -1
|
||||
|
||||
# Lower
|
||||
args = parser.parse_args(["--max-model-len", "1m"])
|
||||
assert args.max_model_len == 1_000_000
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
import importlib
|
||||
from collections.abc import Callable
|
||||
|
||||
@ -20,6 +21,7 @@ from vllm.v1.core.kv_cache_utils import (
|
||||
BlockHash,
|
||||
FreeKVCacheBlockQueue,
|
||||
KVCacheBlock,
|
||||
check_enough_kv_cache_memory,
|
||||
estimate_max_model_len,
|
||||
generate_block_hash_extra_keys,
|
||||
generate_scheduler_kv_cache_config,
|
||||
@ -1082,6 +1084,51 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len)
|
||||
assert estimated_max_len == want_estimated_max_len
|
||||
|
||||
|
||||
def test_auto_max_model_len_adjusts_to_available_memory():
|
||||
model_id = "Qwen/Qwen1.5-7B"
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
runner="generate",
|
||||
dtype="float16",
|
||||
max_model_len=-1,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
model_config=model_config,
|
||||
scheduler_config=scheduler_config,
|
||||
)
|
||||
|
||||
kv_cache_spec = {}
|
||||
for i in range(32):
|
||||
layer_name = f"layer_{i}"
|
||||
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||
block_size=16,
|
||||
num_kv_heads=32,
|
||||
head_size=128,
|
||||
dtype=torch.float16,
|
||||
)
|
||||
|
||||
available_memory = 5 * GiB_bytes
|
||||
|
||||
expected_config = copy.deepcopy(vllm_config)
|
||||
default_max_len = expected_config.model_config.max_model_len
|
||||
expected_max_len = estimate_max_model_len(
|
||||
expected_config, kv_cache_spec, available_memory
|
||||
)
|
||||
assert expected_max_len > 0
|
||||
assert expected_max_len < default_max_len
|
||||
|
||||
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||
|
||||
assert vllm_config.model_config.max_model_len == expected_max_len
|
||||
assert vllm_config.scheduler_config.max_model_len == expected_max_len
|
||||
assert (
|
||||
vllm_config.model_config.get_auto_max_model_len_default()
|
||||
== default_max_len
|
||||
)
|
||||
|
||||
|
||||
def test_get_max_concurrency_for_kv_cache_config():
|
||||
# Create a VllmConfig
|
||||
model_id = "Qwen/Qwen1.5-7B"
|
||||
|
@ -711,7 +711,16 @@ class ModelConfig:
|
||||
self.disable_sliding_window = True
|
||||
|
||||
self.original_max_model_len = self.max_model_len
|
||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
||||
auto_max_model_len_requested = self.original_max_model_len == -1
|
||||
max_model_len_for_verification = None
|
||||
if not auto_max_model_len_requested:
|
||||
max_model_len_for_verification = self.max_model_len
|
||||
|
||||
self.max_model_len = self.get_and_verify_max_len(max_model_len_for_verification)
|
||||
if auto_max_model_len_requested:
|
||||
self._auto_max_model_len_default = self.max_model_len
|
||||
else:
|
||||
self._auto_max_model_len_default = None
|
||||
# Init multimodal config if needed
|
||||
if self._model_info.supports_multimodal:
|
||||
if (
|
||||
@ -1745,6 +1754,12 @@ class ModelConfig:
|
||||
logger.info("Using max model len %s", max_model_len)
|
||||
return max_model_len
|
||||
|
||||
def uses_auto_max_model_len(self) -> bool:
|
||||
return getattr(self, "original_max_model_len", None) == -1
|
||||
|
||||
def get_auto_max_model_len_default(self) -> int | None:
|
||||
return getattr(self, "_auto_max_model_len_default", None)
|
||||
|
||||
|
||||
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
|
||||
"""
|
||||
|
@ -663,10 +663,39 @@ def check_enough_kv_cache_memory(
|
||||
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
|
||||
|
||||
if needed_memory > available_memory:
|
||||
default_max_len = (
|
||||
vllm_config.model_config.get_auto_max_model_len_default()
|
||||
or max_model_len
|
||||
)
|
||||
# Estimate the maximum model length that can fit in the available memory
|
||||
estimated_max_len = estimate_max_model_len(
|
||||
vllm_config, kv_cache_spec, available_memory
|
||||
)
|
||||
|
||||
if vllm_config.model_config.uses_auto_max_model_len():
|
||||
if estimated_max_len <= 0:
|
||||
raise ValueError(
|
||||
"Unable to automatically determine a max model length "
|
||||
"that fits in the available KV cache memory. Try "
|
||||
"increasing `gpu_memory_utilization` or setting an "
|
||||
"explicit --max-model-len value."
|
||||
)
|
||||
|
||||
if estimated_max_len == default_max_len:
|
||||
logger.info(
|
||||
"Automatic max model length fits available memory (%d tokens).",
|
||||
estimated_max_len,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Automatic max model length capped at %d tokens to fit "
|
||||
"available memory (model default %d).",
|
||||
estimated_max_len,
|
||||
default_max_len,
|
||||
)
|
||||
vllm_config.recalculate_max_model_len(estimated_max_len)
|
||||
return
|
||||
|
||||
estimated_msg = ""
|
||||
if estimated_max_len > 0:
|
||||
estimated_msg = (
|
||||
|
Reference in New Issue
Block a user