mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
import enum
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional, Type, Union
|
||||
|
||||
@ -41,6 +42,7 @@ else:
|
||||
from transformers import AutoConfig
|
||||
|
||||
MISTRAL_CONFIG_NAME = "params.json"
|
||||
HF_TOKEN = os.getenv('HF_TOKEN', None)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -77,8 +79,8 @@ class ConfigFormat(str, enum.Enum):
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
||||
token) -> bool:
|
||||
def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||
revision: Optional[str]) -> bool:
|
||||
if Path(model).exists():
|
||||
return (Path(model) / config_name).is_file()
|
||||
|
||||
@ -93,7 +95,10 @@ def file_or_path_exists(model: Union[str, Path], config_name, revision,
|
||||
# NB: file_exists will only check for the existence of the config file on
|
||||
# hf_hub. This will fail in offline mode.
|
||||
try:
|
||||
return file_exists(model, config_name, revision=revision, token=token)
|
||||
return file_exists(model,
|
||||
config_name,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
# Don't raise in offline mode, all we know is that we don't have this
|
||||
# file cached.
|
||||
@ -161,7 +166,6 @@ def get_config(
|
||||
revision: Optional[str] = None,
|
||||
code_revision: Optional[str] = None,
|
||||
config_format: ConfigFormat = ConfigFormat.AUTO,
|
||||
token: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> PretrainedConfig:
|
||||
# Separate model folder from file path for GGUF models
|
||||
@ -173,19 +177,20 @@ def get_config(
|
||||
|
||||
if config_format == ConfigFormat.AUTO:
|
||||
if is_gguf or file_or_path_exists(
|
||||
model, HF_CONFIG_NAME, revision=revision, token=token):
|
||||
model, HF_CONFIG_NAME, revision=revision):
|
||||
config_format = ConfigFormat.HF
|
||||
elif file_or_path_exists(model,
|
||||
MISTRAL_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=token):
|
||||
elif file_or_path_exists(model, MISTRAL_CONFIG_NAME,
|
||||
revision=revision):
|
||||
config_format = ConfigFormat.MISTRAL
|
||||
else:
|
||||
# If we're in offline mode and found no valid config format, then
|
||||
# raise an offline mode error to indicate to the user that they
|
||||
# don't have files cached and may need to go online.
|
||||
# This is conveniently triggered by calling file_exists().
|
||||
file_exists(model, HF_CONFIG_NAME, revision=revision, token=token)
|
||||
file_exists(model,
|
||||
HF_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
|
||||
raise ValueError(f"No supported config format found in {model}")
|
||||
|
||||
@ -194,7 +199,7 @@ def get_config(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=token,
|
||||
token=HF_TOKEN,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@ -206,7 +211,7 @@ def get_config(
|
||||
model,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=token,
|
||||
token=HF_TOKEN,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
@ -216,7 +221,7 @@ def get_config(
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=revision,
|
||||
code_revision=code_revision,
|
||||
token=token,
|
||||
token=HF_TOKEN,
|
||||
**kwargs,
|
||||
)
|
||||
except ValueError as e:
|
||||
@ -234,7 +239,7 @@ def get_config(
|
||||
raise e
|
||||
|
||||
elif config_format == ConfigFormat.MISTRAL:
|
||||
config = load_params_config(model, revision, token=token, **kwargs)
|
||||
config = load_params_config(model, revision, token=HF_TOKEN, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config format: {config_format}")
|
||||
|
||||
@ -256,8 +261,7 @@ def get_config(
|
||||
|
||||
def get_hf_file_to_dict(file_name: str,
|
||||
model: Union[str, Path],
|
||||
revision: Optional[str] = 'main',
|
||||
token: Optional[str] = None):
|
||||
revision: Optional[str] = 'main'):
|
||||
"""
|
||||
Downloads a file from the Hugging Face Hub and returns
|
||||
its contents as a dictionary.
|
||||
@ -266,7 +270,6 @@ def get_hf_file_to_dict(file_name: str,
|
||||
- file_name (str): The name of the file to download.
|
||||
- model (str): The name of the model on the Hugging Face Hub.
|
||||
- revision (str): The specific version of the model.
|
||||
- token (str): The Hugging Face authentication token.
|
||||
|
||||
Returns:
|
||||
- config_dict (dict): A dictionary containing
|
||||
@ -276,8 +279,7 @@ def get_hf_file_to_dict(file_name: str,
|
||||
|
||||
if file_or_path_exists(model=model,
|
||||
config_name=file_name,
|
||||
revision=revision,
|
||||
token=token):
|
||||
revision=revision):
|
||||
|
||||
if not file_path.is_file():
|
||||
try:
|
||||
@ -296,9 +298,7 @@ def get_hf_file_to_dict(file_name: str,
|
||||
return None
|
||||
|
||||
|
||||
def get_pooling_config(model: str,
|
||||
revision: Optional[str] = 'main',
|
||||
token: Optional[str] = None):
|
||||
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
"""
|
||||
This function gets the pooling and normalize
|
||||
config from the model - only applies to
|
||||
@ -315,8 +315,7 @@ def get_pooling_config(model: str,
|
||||
"""
|
||||
|
||||
modules_file_name = "modules.json"
|
||||
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision,
|
||||
token)
|
||||
modules_dict = get_hf_file_to_dict(modules_file_name, model, revision)
|
||||
|
||||
if modules_dict is None:
|
||||
return None
|
||||
@ -332,8 +331,7 @@ def get_pooling_config(model: str,
|
||||
if pooling:
|
||||
|
||||
pooling_file_name = "{}/config.json".format(pooling["path"])
|
||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision,
|
||||
token)
|
||||
pooling_dict = get_hf_file_to_dict(pooling_file_name, model, revision)
|
||||
pooling_type_name = next(
|
||||
(item for item, val in pooling_dict.items() if val is True), None)
|
||||
|
||||
@ -368,8 +366,8 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
|
||||
|
||||
|
||||
def get_sentence_transformer_tokenizer_config(model: str,
|
||||
revision: Optional[str] = 'main',
|
||||
token: Optional[str] = None):
|
||||
revision: Optional[str] = 'main'
|
||||
):
|
||||
"""
|
||||
Returns the tokenization configuration dictionary for a
|
||||
given Sentence Transformer BERT model.
|
||||
@ -379,7 +377,6 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
||||
BERT model.
|
||||
- revision (str, optional): The revision of the m
|
||||
odel to use. Defaults to 'main'.
|
||||
- token (str): A Hugging Face access token.
|
||||
|
||||
Returns:
|
||||
- dict: A dictionary containing the configuration parameters
|
||||
@ -394,7 +391,7 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
||||
"sentence_xlm-roberta_config.json",
|
||||
"sentence_xlnet_config.json",
|
||||
]:
|
||||
encoder_dict = get_hf_file_to_dict(config_name, model, revision, token)
|
||||
encoder_dict = get_hf_file_to_dict(config_name, model, revision)
|
||||
if encoder_dict:
|
||||
break
|
||||
|
||||
@ -474,16 +471,14 @@ def maybe_register_config_serialize_by_value() -> None:
|
||||
exc_info=e)
|
||||
|
||||
|
||||
def load_params_config(model: Union[str, Path],
|
||||
revision: Optional[str],
|
||||
token: Optional[str] = None,
|
||||
def load_params_config(model: Union[str, Path], revision: Optional[str],
|
||||
**kwargs) -> PretrainedConfig:
|
||||
# This function loads a params.json config which
|
||||
# should be used when loading models in mistral format
|
||||
|
||||
config_file_name = "params.json"
|
||||
|
||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision, token)
|
||||
config_dict = get_hf_file_to_dict(config_file_name, model, revision)
|
||||
assert isinstance(config_dict, dict)
|
||||
|
||||
config_mapping = {
|
||||
|
Reference in New Issue
Block a user