mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Further reduce the HTTP calls to huggingface.co (#13107)
This commit is contained in:
committed by
GitHub
parent
d59def4730
commit
7c4033acd4
@ -4,12 +4,14 @@ import enum
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from functools import cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Literal, Optional, Type, Union
|
||||
from typing import Any, Callable, Dict, Literal, Optional, Type, Union
|
||||
|
||||
import huggingface_hub
|
||||
from huggingface_hub import (file_exists, hf_hub_download, list_repo_files,
|
||||
try_to_load_from_cache)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub import list_repo_files as hf_list_repo_files
|
||||
from huggingface_hub import try_to_load_from_cache
|
||||
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
|
||||
HFValidationError, LocalEntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
@ -86,6 +88,65 @@ class ConfigFormat(str, enum.Enum):
|
||||
MISTRAL = "mistral"
|
||||
|
||||
|
||||
def with_retry(func: Callable[[], Any],
|
||||
log_msg: str,
|
||||
max_retries: int = 2,
|
||||
retry_delay: int = 2):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
except Exception as e:
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("%s: %s", log_msg, e)
|
||||
raise
|
||||
logger.error("%s: %s, retrying %d of %d", log_msg, e, attempt + 1,
|
||||
max_retries)
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
|
||||
|
||||
# @cache doesn't cache exceptions
|
||||
@cache
|
||||
def list_repo_files(
|
||||
repo_id: str,
|
||||
*,
|
||||
revision: Optional[str] = None,
|
||||
repo_type: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
) -> list[str]:
|
||||
|
||||
def lookup_files():
|
||||
try:
|
||||
return hf_list_repo_files(repo_id,
|
||||
revision=revision,
|
||||
repo_type=repo_type,
|
||||
token=token)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
# Don't raise in offline mode,
|
||||
# all we know is that we don't have this
|
||||
# file cached.
|
||||
return []
|
||||
|
||||
return with_retry(lookup_files, "Error retrieving file list")
|
||||
|
||||
|
||||
def file_exists(
|
||||
repo_id: str,
|
||||
file_name: str,
|
||||
*,
|
||||
repo_type: Optional[str] = None,
|
||||
revision: Optional[str] = None,
|
||||
token: Union[str, bool, None] = None,
|
||||
) -> bool:
|
||||
|
||||
file_list = list_repo_files(repo_id,
|
||||
repo_type=repo_type,
|
||||
revision=revision,
|
||||
token=token)
|
||||
return file_name in file_list
|
||||
|
||||
|
||||
# In offline mode the result can be a false negative
|
||||
def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||
revision: Optional[str]) -> bool:
|
||||
if Path(model).exists():
|
||||
@ -103,31 +164,10 @@ def file_or_path_exists(model: Union[str, Path], config_name: str,
|
||||
# hf_hub. This will fail in offline mode.
|
||||
|
||||
# Call HF to check if the file exists
|
||||
# 2 retries and exponential backoff
|
||||
max_retries = 2
|
||||
retry_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return file_exists(model,
|
||||
config_name,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
# Don't raise in offline mode,
|
||||
# all we know is that we don't have this
|
||||
# file cached.
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error checking file existence: %s, retrying %d of %d", e,
|
||||
attempt + 1, max_retries)
|
||||
if attempt == max_retries - 1:
|
||||
logger.error("Error checking file existence: %s", e)
|
||||
raise
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
continue
|
||||
return False
|
||||
return file_exists(str(model),
|
||||
config_name,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
|
||||
|
||||
def patch_rope_scaling(config: PretrainedConfig) -> None:
|
||||
@ -208,32 +248,7 @@ def get_config(
|
||||
revision=revision):
|
||||
config_format = ConfigFormat.MISTRAL
|
||||
else:
|
||||
# If we're in offline mode and found no valid config format, then
|
||||
# raise an offline mode error to indicate to the user that they
|
||||
# don't have files cached and may need to go online.
|
||||
# This is conveniently triggered by calling file_exists().
|
||||
|
||||
# Call HF to check if the file exists
|
||||
# 2 retries and exponential backoff
|
||||
max_retries = 2
|
||||
retry_delay = 2
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
file_exists(model,
|
||||
HF_CONFIG_NAME,
|
||||
revision=revision,
|
||||
token=HF_TOKEN)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error checking file existence: %s, retrying %d of %d",
|
||||
e, attempt + 1, max_retries)
|
||||
if attempt == max_retries:
|
||||
logger.error("Error checking file existence: %s", e)
|
||||
raise e
|
||||
time.sleep(retry_delay)
|
||||
retry_delay *= 2
|
||||
|
||||
raise ValueError(f"No supported config format found in {model}")
|
||||
raise ValueError(f"No supported config format found in {model}.")
|
||||
|
||||
if config_format == ConfigFormat.HF:
|
||||
config_dict, _ = PretrainedConfig.get_config_dict(
|
||||
@ -339,10 +354,11 @@ def get_hf_file_to_dict(file_name: str,
|
||||
file_name=file_name,
|
||||
revision=revision)
|
||||
|
||||
if file_path is None and file_or_path_exists(
|
||||
model=model, config_name=file_name, revision=revision):
|
||||
if file_path is None:
|
||||
try:
|
||||
hf_hub_file = hf_hub_download(model, file_name, revision=revision)
|
||||
except huggingface_hub.errors.OfflineModeIsEnabled:
|
||||
return None
|
||||
except (RepositoryNotFoundError, RevisionNotFoundError,
|
||||
EntryNotFoundError, LocalEntryNotFoundError) as e:
|
||||
logger.debug("File or repository not found in hf_hub_download", e)
|
||||
@ -363,6 +379,7 @@ def get_hf_file_to_dict(file_name: str,
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
"""
|
||||
This function gets the pooling and normalize
|
||||
@ -390,6 +407,8 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
if modules_dict is None:
|
||||
return None
|
||||
|
||||
logger.info("Found sentence-transformers modules configuration.")
|
||||
|
||||
pooling = next((item for item in modules_dict
|
||||
if item["type"] == "sentence_transformers.models.Pooling"),
|
||||
None)
|
||||
@ -408,6 +427,7 @@ def get_pooling_config(model: str, revision: Optional[str] = 'main'):
|
||||
if pooling_type_name is not None:
|
||||
pooling_type_name = get_pooling_config_name(pooling_type_name)
|
||||
|
||||
logger.info("Found pooling configuration.")
|
||||
return {"pooling_type": pooling_type_name, "normalize": normalize}
|
||||
|
||||
return None
|
||||
@ -435,6 +455,7 @@ def get_pooling_config_name(pooling_name: str) -> Union[str, None]:
|
||||
return None
|
||||
|
||||
|
||||
@cache
|
||||
def get_sentence_transformer_tokenizer_config(model: str,
|
||||
revision: Optional[str] = 'main'
|
||||
):
|
||||
@ -491,6 +512,8 @@ def get_sentence_transformer_tokenizer_config(model: str,
|
||||
if not encoder_dict:
|
||||
return None
|
||||
|
||||
logger.info("Found sentence-transformers tokenize configuration.")
|
||||
|
||||
if all(k in encoder_dict for k in ("max_seq_length", "do_lower_case")):
|
||||
return encoder_dict
|
||||
return None
|
||||
|
Reference in New Issue
Block a user