[Misc] Make download_weights_from_hf more reliable (#23863)

Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
This commit is contained in:
Harry Mellor
2025-08-29 20:37:24 +01:00
committed by GitHub
parent 8c3e199998
commit 5674a40366

View File

@ -278,33 +278,48 @@ def download_weights_from_hf(
Returns:
str: The path to the downloaded model weights.
"""
assert len(allow_patterns) > 0
local_only = huggingface_hub.constants.HF_HUB_OFFLINE
if not local_only:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# Attempt to reduce allow_patterns to a single pattern
# so we only have to call snapshot_download once.
try:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path,
detail=False,
revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
# Use the first pattern found in the HF repo's files.
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
break
except Exception as e:
logger.warning(
"Failed to get file list for '%s'. Trying each pattern in "
"allow_patterns individually until weights have been "
"downloaded. Error: %s", model_name_or_path, e)
logger.info("Using model weights format %s", allow_patterns)
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
start_time = time.perf_counter()
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=local_only,
)
for allow_pattern in allow_patterns:
hf_folder = snapshot_download(
model_name_or_path,
allow_patterns=allow_pattern,
ignore_patterns=ignore_patterns,
cache_dir=cache_dir,
tqdm_class=DisabledTqdm,
revision=revision,
local_files_only=local_only,
)
# If we have downloaded weights for this allow_pattern,
# we don't need to check the rest.
if any(Path(hf_folder).glob(allow_pattern)):
break
time_taken = time.perf_counter() - start_time
if time_taken > 0.5:
logger.info("Time spent downloading weights for %s: %.6f seconds",