mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Automatically configure max_num_batched_tokens
(#1198)
This commit is contained in:
@ -266,11 +266,36 @@ class SchedulerConfig:
|
||||
and generated text).
|
||||
"""
|
||||
|
||||
def __init__(self, max_num_batched_tokens: int, max_num_seqs: int,
|
||||
max_model_len: int) -> None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
def __init__(
|
||||
self,
|
||||
max_num_batched_tokens: Optional[int],
|
||||
max_num_seqs: int,
|
||||
max_model_len: int,
|
||||
) -> None:
|
||||
if max_num_batched_tokens is not None:
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
else:
|
||||
# If max_model_len is too short, use 2048 as the default value for
|
||||
# higher throughput.
|
||||
self.max_num_batched_tokens = max(max_model_len, 2048)
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_model_len = max_model_len
|
||||
self._verify_args()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
if self.max_num_batched_tokens < self.max_model_len:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) is "
|
||||
f"smaller than max_model_len ({self.max_model_len}). "
|
||||
"This effectively limits the maximum sequence length to "
|
||||
"max_num_batched_tokens and makes vLLM reject longer "
|
||||
"sequences. Please increase max_num_batched_tokens or "
|
||||
"decrease max_model_len.")
|
||||
if self.max_num_batched_tokens < self.max_num_seqs:
|
||||
raise ValueError(
|
||||
f"max_num_batched_tokens ({self.max_num_batched_tokens}) must "
|
||||
"be greater than or equal to max_num_seqs "
|
||||
f"({self.max_num_seqs}).")
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
@ -350,14 +375,14 @@ def _get_and_verify_max_len(
|
||||
max_len_key = getattr(hf_config, key, None)
|
||||
if max_len_key is not None:
|
||||
derived_max_model_len = min(derived_max_model_len, max_len_key)
|
||||
if derived_max_model_len == float("inf"):
|
||||
raise ValueError(
|
||||
"The model's config.json must contain one of the following keys "
|
||||
"to determine the original maximum length of the model: "
|
||||
f"{possible_keys}")
|
||||
|
||||
rope_scaling = getattr(hf_config, "rope_scaling", None)
|
||||
if rope_scaling is not None:
|
||||
if derived_max_model_len == float("inf"):
|
||||
raise ValueError(
|
||||
"When using rope_scaling, the model's config.json must "
|
||||
"contain one of the following keys to determine the original "
|
||||
f"maximum length of the model: {possible_keys}")
|
||||
assert "factor" in rope_scaling
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
derived_max_model_len *= scaling_factor
|
||||
@ -371,4 +396,4 @@ def _get_and_verify_max_len(
|
||||
" in model's config.json). This may lead to incorrect model "
|
||||
"outputs or CUDA errors. Make sure the value is correct and "
|
||||
"within the model context size.")
|
||||
return max_model_len
|
||||
return int(max_model_len)
|
||||
|
@ -25,7 +25,7 @@ class EngineArgs:
|
||||
block_size: int = 16
|
||||
swap_space: int = 4 # GiB
|
||||
gpu_memory_utilization: float = 0.90
|
||||
max_num_batched_tokens: int = 2560
|
||||
max_num_batched_tokens: Optional[int] = None
|
||||
max_num_seqs: int = 256
|
||||
disable_log_stats: bool = False
|
||||
revision: Optional[str] = None
|
||||
@ -34,7 +34,6 @@ class EngineArgs:
|
||||
def __post_init__(self):
|
||||
if self.tokenizer is None:
|
||||
self.tokenizer = self.model
|
||||
self.max_num_seqs = min(self.max_num_seqs, self.max_num_batched_tokens)
|
||||
|
||||
@staticmethod
|
||||
def add_cli_args(
|
||||
|
Reference in New Issue
Block a user