mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Derive auto max model len state from original value
This commit is contained in:
@ -37,7 +37,8 @@ Dynamic quantization is also supported via the `quantization` option -- see [her
|
|||||||
## Context length and batch size
|
## Context length and batch size
|
||||||
|
|
||||||
You can further reduce memory usage by limiting the context length of the model (`max_model_len` option)
|
You can further reduce memory usage by limiting the context length of the model (`max_model_len` option)
|
||||||
and the maximum batch size (`max_num_seqs` option).
|
and the maximum batch size (`max_num_seqs` option). Setting `max_model_len=-1` lets vLLM automatically
|
||||||
|
pick the largest context length that fits in GPU memory, up to the model's default maximum.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from vllm import LLM
|
from vllm import LLM
|
||||||
|
@ -320,6 +320,9 @@ def test_human_readable_model_len():
|
|||||||
args = parser.parse_args(["--max-model-len", "1024"])
|
args = parser.parse_args(["--max-model-len", "1024"])
|
||||||
assert args.max_model_len == 1024
|
assert args.max_model_len == 1024
|
||||||
|
|
||||||
|
args = parser.parse_args(["--max-model-len", "-1"])
|
||||||
|
assert args.max_model_len == -1
|
||||||
|
|
||||||
# Lower
|
# Lower
|
||||||
args = parser.parse_args(["--max-model-len", "1m"])
|
args = parser.parse_args(["--max-model-len", "1m"])
|
||||||
assert args.max_model_len == 1_000_000
|
assert args.max_model_len == 1_000_000
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
import importlib
|
import importlib
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
|
||||||
@ -20,6 +21,7 @@ from vllm.v1.core.kv_cache_utils import (
|
|||||||
BlockHash,
|
BlockHash,
|
||||||
FreeKVCacheBlockQueue,
|
FreeKVCacheBlockQueue,
|
||||||
KVCacheBlock,
|
KVCacheBlock,
|
||||||
|
check_enough_kv_cache_memory,
|
||||||
estimate_max_model_len,
|
estimate_max_model_len,
|
||||||
generate_block_hash_extra_keys,
|
generate_block_hash_extra_keys,
|
||||||
generate_scheduler_kv_cache_config,
|
generate_scheduler_kv_cache_config,
|
||||||
@ -1082,6 +1084,51 @@ def test_estimate_max_model_len(model_id, max_model_len, want_estimated_max_len)
|
|||||||
assert estimated_max_len == want_estimated_max_len
|
assert estimated_max_len == want_estimated_max_len
|
||||||
|
|
||||||
|
|
||||||
|
def test_auto_max_model_len_adjusts_to_available_memory():
|
||||||
|
model_id = "Qwen/Qwen1.5-7B"
|
||||||
|
model_config = ModelConfig(
|
||||||
|
model_id,
|
||||||
|
runner="generate",
|
||||||
|
dtype="float16",
|
||||||
|
max_model_len=-1,
|
||||||
|
)
|
||||||
|
scheduler_config = SchedulerConfig(max_num_batched_tokens=32768)
|
||||||
|
|
||||||
|
vllm_config = VllmConfig(
|
||||||
|
model_config=model_config,
|
||||||
|
scheduler_config=scheduler_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_cache_spec = {}
|
||||||
|
for i in range(32):
|
||||||
|
layer_name = f"layer_{i}"
|
||||||
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
|
block_size=16,
|
||||||
|
num_kv_heads=32,
|
||||||
|
head_size=128,
|
||||||
|
dtype=torch.float16,
|
||||||
|
)
|
||||||
|
|
||||||
|
available_memory = 5 * GiB_bytes
|
||||||
|
|
||||||
|
expected_config = copy.deepcopy(vllm_config)
|
||||||
|
default_max_len = expected_config.model_config.max_model_len
|
||||||
|
expected_max_len = estimate_max_model_len(
|
||||||
|
expected_config, kv_cache_spec, available_memory
|
||||||
|
)
|
||||||
|
assert expected_max_len > 0
|
||||||
|
assert expected_max_len < default_max_len
|
||||||
|
|
||||||
|
check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory)
|
||||||
|
|
||||||
|
assert vllm_config.model_config.max_model_len == expected_max_len
|
||||||
|
assert vllm_config.scheduler_config.max_model_len == expected_max_len
|
||||||
|
assert (
|
||||||
|
vllm_config.model_config.get_auto_max_model_len_default()
|
||||||
|
== default_max_len
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_max_concurrency_for_kv_cache_config():
|
def test_get_max_concurrency_for_kv_cache_config():
|
||||||
# Create a VllmConfig
|
# Create a VllmConfig
|
||||||
model_id = "Qwen/Qwen1.5-7B"
|
model_id = "Qwen/Qwen1.5-7B"
|
||||||
|
@ -711,7 +711,16 @@ class ModelConfig:
|
|||||||
self.disable_sliding_window = True
|
self.disable_sliding_window = True
|
||||||
|
|
||||||
self.original_max_model_len = self.max_model_len
|
self.original_max_model_len = self.max_model_len
|
||||||
self.max_model_len = self.get_and_verify_max_len(self.max_model_len)
|
auto_max_model_len_requested = self.original_max_model_len == -1
|
||||||
|
max_model_len_for_verification = None
|
||||||
|
if not auto_max_model_len_requested:
|
||||||
|
max_model_len_for_verification = self.max_model_len
|
||||||
|
|
||||||
|
self.max_model_len = self.get_and_verify_max_len(max_model_len_for_verification)
|
||||||
|
if auto_max_model_len_requested:
|
||||||
|
self._auto_max_model_len_default = self.max_model_len
|
||||||
|
else:
|
||||||
|
self._auto_max_model_len_default = None
|
||||||
# Init multimodal config if needed
|
# Init multimodal config if needed
|
||||||
if self._model_info.supports_multimodal:
|
if self._model_info.supports_multimodal:
|
||||||
if (
|
if (
|
||||||
@ -1745,6 +1754,12 @@ class ModelConfig:
|
|||||||
logger.info("Using max model len %s", max_model_len)
|
logger.info("Using max model len %s", max_model_len)
|
||||||
return max_model_len
|
return max_model_len
|
||||||
|
|
||||||
|
def uses_auto_max_model_len(self) -> bool:
|
||||||
|
return getattr(self, "original_max_model_len", None) == -1
|
||||||
|
|
||||||
|
def get_auto_max_model_len_default(self) -> int | None:
|
||||||
|
return getattr(self, "_auto_max_model_len_default", None)
|
||||||
|
|
||||||
|
|
||||||
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
|
def get_served_model_name(model: str, served_model_name: str | list[str] | None):
|
||||||
"""
|
"""
|
||||||
|
@ -663,10 +663,39 @@ def check_enough_kv_cache_memory(
|
|||||||
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
|
needed_memory = max_memory_usage_bytes(vllm_config, kv_cache_spec.values())
|
||||||
|
|
||||||
if needed_memory > available_memory:
|
if needed_memory > available_memory:
|
||||||
|
default_max_len = (
|
||||||
|
vllm_config.model_config.get_auto_max_model_len_default()
|
||||||
|
or max_model_len
|
||||||
|
)
|
||||||
# Estimate the maximum model length that can fit in the available memory
|
# Estimate the maximum model length that can fit in the available memory
|
||||||
estimated_max_len = estimate_max_model_len(
|
estimated_max_len = estimate_max_model_len(
|
||||||
vllm_config, kv_cache_spec, available_memory
|
vllm_config, kv_cache_spec, available_memory
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if vllm_config.model_config.uses_auto_max_model_len():
|
||||||
|
if estimated_max_len <= 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Unable to automatically determine a max model length "
|
||||||
|
"that fits in the available KV cache memory. Try "
|
||||||
|
"increasing `gpu_memory_utilization` or setting an "
|
||||||
|
"explicit --max-model-len value."
|
||||||
|
)
|
||||||
|
|
||||||
|
if estimated_max_len == default_max_len:
|
||||||
|
logger.info(
|
||||||
|
"Automatic max model length fits available memory (%d tokens).",
|
||||||
|
estimated_max_len,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Automatic max model length capped at %d tokens to fit "
|
||||||
|
"available memory (model default %d).",
|
||||||
|
estimated_max_len,
|
||||||
|
default_max_len,
|
||||||
|
)
|
||||||
|
vllm_config.recalculate_max_model_len(estimated_max_len)
|
||||||
|
return
|
||||||
|
|
||||||
estimated_msg = ""
|
estimated_msg = ""
|
||||||
if estimated_max_len > 0:
|
if estimated_max_len > 0:
|
||||||
estimated_msg = (
|
estimated_msg = (
|
||||||
|
Reference in New Issue
Block a user