mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Fix skipped max-model-len validation when deriving max model length from tokenizer config (#19660)
Signed-off-by: Ye (Charlotte) Qi <yeq@meta.com>
This commit is contained in:
committed by
GitHub
parent
367871a469
commit
b692e9cd07
@ -438,3 +438,31 @@ def test_load_config_pt_load_map_location(pt_load_map_location):
|
||||
config = VllmConfig(load_config=load_config)
|
||||
|
||||
assert config.load_config.pt_load_map_location == pt_load_map_location
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("model_id", "max_model_len", "expected_max_len", "should_raise"), [
|
||||
("BAAI/bge-reranker-base", None, 512, False),
|
||||
("BAAI/bge-reranker-base", 256, 256, False),
|
||||
("BAAI/bge-reranker-base", 513, 512, True),
|
||||
])
|
||||
def test_get_and_verify_max_len(model_id, max_model_len, expected_max_len,
|
||||
should_raise):
|
||||
"""Test get_and_verify_max_len with different configurations."""
|
||||
model_config = ModelConfig(
|
||||
model_id,
|
||||
task="auto",
|
||||
tokenizer=model_id,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
)
|
||||
|
||||
if should_raise:
|
||||
with pytest.raises(ValueError):
|
||||
model_config.get_and_verify_max_len(max_model_len)
|
||||
else:
|
||||
actual_max_len = model_config.get_and_verify_max_len(max_model_len)
|
||||
assert actual_max_len == expected_max_len
|
||||
|
@ -1429,25 +1429,19 @@ class ModelConfig:
|
||||
return getattr(self.hf_config, "matryoshka_dimensions", None)
|
||||
|
||||
def get_and_verify_max_len(self, max_model_len: int):
|
||||
tokenizer_config = try_get_tokenizer_config(
|
||||
self.tokenizer,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.tokenizer_revision)
|
||||
max_model_len = _get_and_verify_max_len(
|
||||
hf_config=self.hf_text_config,
|
||||
tokenizer_config=tokenizer_config,
|
||||
max_model_len=max_model_len,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||
spec_target_max_model_len=self.spec_target_max_model_len,
|
||||
encoder_config=self.encoder_config)
|
||||
|
||||
tokenizer_config = try_get_tokenizer_config(
|
||||
self.tokenizer,
|
||||
trust_remote_code=self.trust_remote_code,
|
||||
revision=self.tokenizer_revision)
|
||||
|
||||
if tokenizer_config is None:
|
||||
return max_model_len
|
||||
|
||||
model_max_length = tokenizer_config.get("model_max_length",
|
||||
max_model_len)
|
||||
max_model_len = min(max_model_len, model_max_length)
|
||||
logger.info("Using max model len %s", max_model_len)
|
||||
return max_model_len
|
||||
|
||||
|
||||
@ -3283,6 +3277,7 @@ def _get_and_verify_dtype(
|
||||
|
||||
def _get_and_verify_max_len(
|
||||
hf_config: PretrainedConfig,
|
||||
tokenizer_config: Optional[dict],
|
||||
max_model_len: Optional[int],
|
||||
disable_sliding_window: bool,
|
||||
sliding_window_len: Optional[Union[int, list[Optional[int]]]],
|
||||
@ -3309,7 +3304,7 @@ def _get_and_verify_max_len(
|
||||
"max_seq_length",
|
||||
"seq_len",
|
||||
]
|
||||
# Choose the smallest "max_length" from the possible keys.
|
||||
# Choose the smallest "max_length" from the possible keys
|
||||
max_len_key = None
|
||||
for key in possible_keys:
|
||||
max_len = getattr(hf_config, key, None)
|
||||
@ -3332,6 +3327,13 @@ def _get_and_verify_max_len(
|
||||
derived_max_model_len = min(derived_max_model_len,
|
||||
sliding_window_len_min)
|
||||
|
||||
# Consider model_max_length in tokenizer_config
|
||||
if tokenizer_config:
|
||||
tokenizer_model_max_length = tokenizer_config.get(
|
||||
"model_max_length", derived_max_model_len)
|
||||
derived_max_model_len = min(derived_max_model_len,
|
||||
tokenizer_model_max_length)
|
||||
|
||||
# If none of the keys were found in the config, use a default and
|
||||
# log a warning.
|
||||
if derived_max_model_len == float("inf"):
|
||||
|
Reference in New Issue
Block a user