mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
add tqdm when loading checkpoint shards (#6569)
Co-authored-by: tianyi.zhao <tianyi.zhao@transwarp.io> Co-authored-by: youkaichao <youkaichao@126.com>
This commit is contained in:
@ -331,7 +331,8 @@ def np_cache_weights_iterator(
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
if not os.path.exists(weight_names_file):
|
||||
weight_names: List[str] = []
|
||||
for bin_file in hf_weights_files:
|
||||
for bin_file in tqdm(hf_weights_files,
|
||||
desc="Loading np_cache checkpoint shards"):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
param_path = os.path.join(np_folder, name)
|
||||
@ -355,7 +356,8 @@ def safetensors_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model safetensor files."""
|
||||
for st_file in hf_weights_files:
|
||||
for st_file in tqdm(hf_weights_files,
|
||||
desc="Loading safetensors checkpoint shards"):
|
||||
with safe_open(st_file, framework="pt") as f:
|
||||
for name in f.keys(): # noqa: SIM118
|
||||
param = f.get_tensor(name)
|
||||
@ -366,7 +368,8 @@ def pt_weights_iterator(
|
||||
hf_weights_files: List[str]
|
||||
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
||||
"""Iterate over the weights in the model bin/pt files."""
|
||||
for bin_file in hf_weights_files:
|
||||
for bin_file in tqdm(hf_weights_files,
|
||||
desc="Loading pt checkpoint shards"):
|
||||
state = torch.load(bin_file, map_location="cpu")
|
||||
for name, param in state.items():
|
||||
yield name, param
|
||||
|
Reference in New Issue
Block a user