[Misc] Make download_weights_from_hf
more reliable (#23863)
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
@ -278,33 +278,48 @@ def download_weights_from_hf(
|
||||
Returns:
|
||||
str: The path to the downloaded model weights.
|
||||
"""
|
||||
assert len(allow_patterns) > 0
|
||||
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
|
||||
if not local_only:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
# Attempt to reduce allow_patterns to a single pattern
|
||||
# so we only have to call snapshot_download once.
|
||||
try:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path,
|
||||
detail=False,
|
||||
revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
# Use the first pattern found in the HF repo's files.
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to get file list for '%s'. Trying each pattern in "
|
||||
"allow_patterns individually until weights have been "
|
||||
"downloaded. Error: %s", model_name_or_path, e)
|
||||
|
||||
logger.info("Using model weights format %s", allow_patterns)
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
start_time = time.perf_counter()
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_patterns,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
for allow_pattern in allow_patterns:
|
||||
hf_folder = snapshot_download(
|
||||
model_name_or_path,
|
||||
allow_patterns=allow_pattern,
|
||||
ignore_patterns=ignore_patterns,
|
||||
cache_dir=cache_dir,
|
||||
tqdm_class=DisabledTqdm,
|
||||
revision=revision,
|
||||
local_files_only=local_only,
|
||||
)
|
||||
# If we have downloaded weights for this allow_pattern,
|
||||
# we don't need to check the rest.
|
||||
if any(Path(hf_folder).glob(allow_pattern)):
|
||||
break
|
||||
time_taken = time.perf_counter() - start_time
|
||||
if time_taken > 0.5:
|
||||
logger.info("Time spent downloading weights for %s: %.6f seconds",
|
||||
|
Reference in New Issue
Block a user