[Model loader]: support multi-thread model weight loading (#23928)

Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com>
Signed-off-by: Simon Mo <simon.mo@hey.com>
Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
Yang Kaiyong
2025-09-09 02:49:39 +08:00
committed by GitHub
parent 7be141b2c5
commit 43d9ad03ba
2 changed files with 105 additions and 12 deletions

View File

@ -18,8 +18,9 @@ from vllm.model_executor.model_loader.weight_utils import (
download_safetensors_index_file_from_hf, download_weights_from_hf,
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
multi_thread_pt_weights_iterator,
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
pt_weights_iterator, safetensors_weights_iterator)
from vllm.platforms import current_platform
logger = init_logger(__name__)
@ -28,6 +29,9 @@ logger = init_logger(__name__)
class DefaultModelLoader(BaseModelLoader):
"""Model loader that can load different file types from disk."""
# default number of thread when enable multithread weight loading
DEFAULT_NUM_THREADS = 8
@dataclasses.dataclass
class Source:
"""A source for weights."""
@ -52,9 +56,15 @@ class DefaultModelLoader(BaseModelLoader):
def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
if load_config.model_loader_extra_config:
raise ValueError(f"Model loader extra config is not supported for "
f"load format {load_config.load_format}")
extra_config = load_config.model_loader_extra_config
allowed_keys = {"enable_multithread_load", "num_threads"}
unexpected_keys = set(extra_config.keys()) - allowed_keys
if unexpected_keys:
raise ValueError(f"Unexpected extra config keys for load format "
f"{load_config.load_format}: "
f"{unexpected_keys}")
def _prepare_weights(
self,
@ -145,6 +155,7 @@ class DefaultModelLoader(BaseModelLoader):
self, source: "Source"
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
extra_config = self.load_config.model_loader_extra_config
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
source.model_or_path, source.revision, source.fall_back_to_pt,
source.allow_patterns_overrides)
@ -165,16 +176,34 @@ class DefaultModelLoader(BaseModelLoader):
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = safetensors_weights_iterator(
if extra_config.get("enable_multithread_load"):
weights_iterator = (
multi_thread_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
max_workers=extra_config.get(
"num_threads", self.DEFAULT_NUM_THREADS),
))
else:
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
if extra_config.get("enable_multithread_load"):
weights_iterator = multi_thread_pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
max_workers=extra_config.get("num_threads",
self.DEFAULT_NUM_THREADS),
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
else:
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
self.load_config.pt_load_map_location,
)
if current_platform.is_tpu():
from vllm.platforms.tpu import USE_TPU_COMMONS

View File

@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Utilities for downloading and initializing model weights."""
import concurrent.futures
import fnmatch
import glob
import hashlib
@ -531,6 +532,36 @@ def safetensors_weights_iterator(
yield name, param
def multi_thread_safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
max_workers: int = 4,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model safetensor files."""
def _load_file(st_file: str):
result = load_file(st_file, device="cpu")
return result
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, st_file)
for st_file in hf_weights_files
]
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
)
for future in futures_iter:
state_dict = future.result()
yield from state_dict.items()
def runai_safetensors_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
@ -611,6 +642,39 @@ def pt_weights_iterator(
del state
def multi_thread_pt_weights_iterator(
hf_weights_files: list[str],
use_tqdm_on_load: bool,
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
max_workers: int = 4,
) -> Generator[tuple[str, torch.Tensor], None, None]:
"""Multi-Thread iterate over the weights in the model bin/pt files."""
def _load_file(bin_file: str):
return torch.load(bin_file,
map_location=pt_load_map_location,
weights_only=True)
with concurrent.futures.ThreadPoolExecutor(
max_workers=max_workers) as executor:
futures = [
executor.submit(_load_file, bin_file)
for bin_file in hf_weights_files
]
futures_iter = tqdm(
concurrent.futures.as_completed(futures),
total=len(hf_weights_files),
desc="Multi-thread loading pt checkpoint shards",
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
)
for future in futures_iter:
state = future.result()
yield from state.items()
del state
def get_gguf_extra_tensor_names(
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
reader = gguf.GGUFReader(gguf_file)