mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Model loader]: support multi-thread model weight loading (#23928)
Signed-off-by: Yang Kaiyong <yangkaiyong.yky@antgroup.com> Signed-off-by: Simon Mo <simon.mo@hey.com> Co-authored-by: Simon Mo <simon.mo@hey.com>
This commit is contained in:
@ -18,8 +18,9 @@ from vllm.model_executor.model_loader.weight_utils import (
|
||||
download_safetensors_index_file_from_hf, download_weights_from_hf,
|
||||
fastsafetensors_weights_iterator, filter_duplicate_safetensors_files,
|
||||
filter_files_not_needed_for_inference, maybe_download_from_modelscope,
|
||||
np_cache_weights_iterator, pt_weights_iterator,
|
||||
safetensors_weights_iterator)
|
||||
multi_thread_pt_weights_iterator,
|
||||
multi_thread_safetensors_weights_iterator, np_cache_weights_iterator,
|
||||
pt_weights_iterator, safetensors_weights_iterator)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -28,6 +29,9 @@ logger = init_logger(__name__)
|
||||
class DefaultModelLoader(BaseModelLoader):
|
||||
"""Model loader that can load different file types from disk."""
|
||||
|
||||
# default number of thread when enable multithread weight loading
|
||||
DEFAULT_NUM_THREADS = 8
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Source:
|
||||
"""A source for weights."""
|
||||
@ -52,9 +56,15 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
|
||||
def __init__(self, load_config: LoadConfig):
|
||||
super().__init__(load_config)
|
||||
if load_config.model_loader_extra_config:
|
||||
raise ValueError(f"Model loader extra config is not supported for "
|
||||
f"load format {load_config.load_format}")
|
||||
|
||||
extra_config = load_config.model_loader_extra_config
|
||||
allowed_keys = {"enable_multithread_load", "num_threads"}
|
||||
unexpected_keys = set(extra_config.keys()) - allowed_keys
|
||||
|
||||
if unexpected_keys:
|
||||
raise ValueError(f"Unexpected extra config keys for load format "
|
||||
f"{load_config.load_format}: "
|
||||
f"{unexpected_keys}")
|
||||
|
||||
def _prepare_weights(
|
||||
self,
|
||||
@ -145,6 +155,7 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self, source: "Source"
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Get an iterator for the model weights based on the load format."""
|
||||
extra_config = self.load_config.model_loader_extra_config
|
||||
hf_folder, hf_weights_files, use_safetensors = self._prepare_weights(
|
||||
source.model_or_path, source.revision, source.fall_back_to_pt,
|
||||
source.allow_patterns_overrides)
|
||||
@ -165,16 +176,34 @@ class DefaultModelLoader(BaseModelLoader):
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = (
|
||||
multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
max_workers=extra_config.get(
|
||||
"num_threads", self.DEFAULT_NUM_THREADS),
|
||||
))
|
||||
else:
|
||||
weights_iterator = safetensors_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
)
|
||||
else:
|
||||
if extra_config.get("enable_multithread_load"):
|
||||
weights_iterator = multi_thread_pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
max_workers=extra_config.get("num_threads",
|
||||
self.DEFAULT_NUM_THREADS),
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
else:
|
||||
weights_iterator = pt_weights_iterator(
|
||||
hf_weights_files,
|
||||
self.load_config.use_tqdm_on_load,
|
||||
self.load_config.pt_load_map_location,
|
||||
)
|
||||
|
||||
if current_platform.is_tpu():
|
||||
from vllm.platforms.tpu import USE_TPU_COMMONS
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Utilities for downloading and initializing model weights."""
|
||||
import concurrent.futures
|
||||
import fnmatch
|
||||
import glob
|
||||
import hashlib
|
||||
@ -531,6 +532,36 @@ def safetensors_weights_iterator(
|
||||
yield name, param
|
||||
|
||||
|
||||
def multi_thread_safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
max_workers: int = 4,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model safetensor files."""
|
||||
|
||||
def _load_file(st_file: str):
|
||||
result = load_file(st_file, device="cpu")
|
||||
return result
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(_load_file, st_file)
|
||||
for st_file in hf_weights_files
|
||||
]
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
|
||||
for future in futures_iter:
|
||||
state_dict = future.result()
|
||||
yield from state_dict.items()
|
||||
|
||||
|
||||
def runai_safetensors_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
@ -611,6 +642,39 @@ def pt_weights_iterator(
|
||||
del state
|
||||
|
||||
|
||||
def multi_thread_pt_weights_iterator(
|
||||
hf_weights_files: list[str],
|
||||
use_tqdm_on_load: bool,
|
||||
pt_load_map_location: Union[str, dict[str, str]] = "cpu",
|
||||
max_workers: int = 4,
|
||||
) -> Generator[tuple[str, torch.Tensor], None, None]:
|
||||
"""Multi-Thread iterate over the weights in the model bin/pt files."""
|
||||
|
||||
def _load_file(bin_file: str):
|
||||
return torch.load(bin_file,
|
||||
map_location=pt_load_map_location,
|
||||
weights_only=True)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=max_workers) as executor:
|
||||
futures = [
|
||||
executor.submit(_load_file, bin_file)
|
||||
for bin_file in hf_weights_files
|
||||
]
|
||||
futures_iter = tqdm(
|
||||
concurrent.futures.as_completed(futures),
|
||||
total=len(hf_weights_files),
|
||||
desc="Multi-thread loading pt checkpoint shards",
|
||||
disable=not enable_tqdm(use_tqdm_on_load),
|
||||
bar_format=_BAR_FORMAT,
|
||||
)
|
||||
|
||||
for future in futures_iter:
|
||||
state = future.result()
|
||||
yield from state.items()
|
||||
del state
|
||||
|
||||
|
||||
def get_gguf_extra_tensor_names(
|
||||
gguf_file: str, gguf_to_hf_name_map: dict[str, str]) -> list[str]:
|
||||
reader = gguf.GGUFReader(gguf_file)
|
||||
|
Reference in New Issue
Block a user