Compare commits

...

1 Commits

Author SHA1 Message Date
9874169d07 Simplify max model length auto selection 2025-08-29 05:03:25 -04:00
4 changed files with 37 additions and 3 deletions

View File

@ -337,6 +337,10 @@ def test_human_readable_model_len():
args = parser.parse_args(["--max-model-len", "10.212345k"]) args = parser.parse_args(["--max-model-len", "10.212345k"])
assert args.max_model_len == 10212 assert args.max_model_len == 10212
# Auto via -1
args = parser.parse_args(["--max-model-len", "-1"])
assert args.max_model_len == -1
# Invalid (do not allow decimals with binary multipliers) # Invalid (do not allow decimals with binary multipliers)
for invalid in ["1a", "pwd", "10.24", "1.23M"]: for invalid in ["1a", "pwd", "10.24", "1.23M"]:
with pytest.raises(ArgumentError): with pytest.raises(ArgumentError):

View File

@ -345,7 +345,13 @@ class ModelConfig:
format. Examples:\n format. Examples:\n
- 1k -> 1000\n - 1k -> 1000\n
- 1K -> 1024\n - 1K -> 1024\n
- 25.6k -> 25,600""" - 25.6k -> 25,600\n
Pass ``-1`` to automatically choose the largest length that fits
in available GPU memory."""
auto_max_model_len: bool = False
"""Automatically determine the maximum model length that fits in GPU
memory. Enabled when ``--max-model-len`` is ``-1``."""
spec_target_max_model_len: Optional[int] = None spec_target_max_model_len: Optional[int] = None
"""Specify the maximum length for spec decoding draft models.""" """Specify the maximum length for spec decoding draft models."""
quantization: SkipValidation[Optional[QuantizationMethods]] = None quantization: SkipValidation[Optional[QuantizationMethods]] = None

View File

@ -227,7 +227,9 @@ def _compute_kwargs(cls: ConfigType) -> dict[str, Any]:
elif contains_type(type_hints, int): elif contains_type(type_hints, int):
kwargs[name]["type"] = int kwargs[name]["type"] = int
# Special case for large integers # Special case for large integers
if name in {"max_model_len", "max_num_batched_tokens"}: if name == "max_model_len":
kwargs[name]["type"] = human_readable_int
elif name == "max_num_batched_tokens":
kwargs[name]["type"] = human_readable_int kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float): elif contains_type(type_hints, float):
kwargs[name]["type"] = float kwargs[name]["type"] = float
@ -945,6 +947,12 @@ class EngineArgs:
self.mm_encoder_tp_mode = "data" self.mm_encoder_tp_mode = "data"
max_model_len = self.max_model_len
auto_max_model_len = False
if max_model_len is not None and max_model_len < 0:
auto_max_model_len = True
max_model_len = None
return ModelConfig( return ModelConfig(
model=self.model, model=self.model,
hf_config_path=self.hf_config_path, hf_config_path=self.hf_config_path,
@ -964,7 +972,8 @@ class EngineArgs:
hf_token=self.hf_token, hf_token=self.hf_token,
hf_overrides=self.hf_overrides, hf_overrides=self.hf_overrides,
tokenizer_revision=self.tokenizer_revision, tokenizer_revision=self.tokenizer_revision,
max_model_len=self.max_model_len, max_model_len=max_model_len,
auto_max_model_len=auto_max_model_len,
quantization=self.quantization, quantization=self.quantization,
enforce_eager=self.enforce_eager, enforce_eager=self.enforce_eager,
max_seq_len_to_capture=self.max_seq_len_to_capture, max_seq_len_to_capture=self.max_seq_len_to_capture,
@ -1847,3 +1856,4 @@ def human_readable_int(value):
# Regular plain number. # Regular plain number.
return int(value) return int(value)

View File

@ -655,6 +655,9 @@ def estimate_max_model_len(vllm_config: VllmConfig,
left = mid + 1 left = mid + 1
else: else:
right = mid - 1 right = mid - 1
# Restore the original max_model_len before returning.
vllm_config.model_config.max_model_len = current_max
return result return result
@ -690,6 +693,17 @@ def check_enough_kv_cache_memory(vllm_config: VllmConfig,
# Estimate the maximum model length that can fit in the available memory # Estimate the maximum model length that can fit in the available memory
estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec,
available_memory) available_memory)
if vllm_config.model_config.auto_max_model_len:
if estimated_max_len <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
logger.info(
"Setting max_model_len to %s based on available memory.",
estimated_max_len)
vllm_config.recalculate_max_model_len(estimated_max_len)
return
estimated_msg = "" estimated_msg = ""
if estimated_max_len > 0: if estimated_max_len > 0:
estimated_msg = ( estimated_msg = (