add tqdm when loading checkpoint shards (#6569)

Co-authored-by: tianyi.zhao <tianyi.zhao@transwarp.io>
Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
zhaotyer
2024-07-23 11:48:01 +08:00
committed by GitHub
parent 7c2749a4fd
commit e519ae097a

View File

@ -331,7 +331,8 @@ def np_cache_weights_iterator(
with get_lock(model_name_or_path, cache_dir):
if not os.path.exists(weight_names_file):
weight_names: List[str] = []
for bin_file in hf_weights_files:
for bin_file in tqdm(hf_weights_files,
desc="Loading np_cache checkpoint shards"):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
param_path = os.path.join(np_folder, name)
@ -355,7 +356,8 @@ def safetensors_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
for st_file in hf_weights_files:
for st_file in tqdm(hf_weights_files,
desc="Loading safetensors checkpoint shards"):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
@ -366,7 +368,8 @@ def pt_weights_iterator(
hf_weights_files: List[str]
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
for bin_file in hf_weights_files:
for bin_file in tqdm(hf_weights_files,
desc="Loading pt checkpoint shards"):
state = torch.load(bin_file, map_location="cpu")
for name, param in state.items():
yield name, param