mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256 Approved by: https://github.com/albanD
392 lines
15 KiB
Python
392 lines
15 KiB
Python
# mypy: allow-untyped-defs
|
|
import dataclasses
|
|
import json
|
|
import logging
|
|
import queue
|
|
import threading
|
|
from typing import Any, Optional
|
|
|
|
import torch
|
|
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
|
|
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
|
|
consolidate_safetensors_files,
|
|
)
|
|
from torch.distributed.checkpoint._hf_utils import (
|
|
_gen_file_name,
|
|
_HFStorageInfo,
|
|
_metadata_fn,
|
|
CUSTOM_METADATA_KEY,
|
|
SAVED_OFFSETS_KEY,
|
|
SHARDED_DIR_NAME,
|
|
SUFFIX,
|
|
)
|
|
from torch.distributed.checkpoint.filesystem import SerializationFormat
|
|
from torch.distributed.checkpoint.metadata import (
|
|
ChunkStorageMetadata,
|
|
Metadata,
|
|
MetadataIndex,
|
|
StorageMeta,
|
|
TensorProperties,
|
|
TensorStorageMetadata,
|
|
)
|
|
from torch.distributed.checkpoint.planner import (
|
|
LoadPlan,
|
|
LoadPlanner,
|
|
ReadItem,
|
|
SavePlan,
|
|
SavePlanner,
|
|
WriteItem,
|
|
)
|
|
from torch.distributed.checkpoint.storage import WriteResult
|
|
from torch.futures import Future
|
|
|
|
|
|
logger: logging.Logger = logging.getLogger(__name__)
|
|
|
|
__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"]
|
|
|
|
|
|
class HuggingFaceStorageWriter(FileSystemWriter):
|
|
"""
|
|
A writer that writes to storage in the huggingface safetensors format.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
fqn_to_index_mapping: Optional[dict[str, int]] = None,
|
|
thread_count: int = 1,
|
|
save_distributed: bool = False,
|
|
enable_consolidation: bool = False,
|
|
thread_count_consolidation: int = 1,
|
|
) -> None:
|
|
"""
|
|
Initialize the huggingface writer pointing to path.
|
|
|
|
Args:
|
|
path: directory where the checkpoint will be read from.
|
|
fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to.
|
|
Indices are from 1 to N, where N is the number of files. If not provided,
|
|
the tensors will be written to a single file. If none, then all the tensors on the
|
|
same rank will be written to the same file.
|
|
thread_count: Number of threads to use to write distributed checkpoint. Default to 1.
|
|
save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard.
|
|
Default is False which assumes rank-0 checkpointing of the full state_dict.
|
|
enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be
|
|
saved to path/sharded and the full tensors will be saved to path. Default to False.
|
|
thread_count_consolidation: Number of threads to use for parallel processing of saving data
|
|
to consolidated output files. Default to 1.
|
|
"""
|
|
|
|
super().__init__(
|
|
path=path,
|
|
serialization_format=SerializationFormat.SAFETENSORS,
|
|
thread_count=thread_count,
|
|
)
|
|
self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping
|
|
self.save_distributed: bool = save_distributed
|
|
self.enable_consolidation: bool = enable_consolidation
|
|
self.consolidated_output_path: Optional[str] = None
|
|
if self.enable_consolidation:
|
|
self.consolidated_output_path = str(self.path)
|
|
self.path = self.fs.concat_path(self.path, SHARDED_DIR_NAME)
|
|
self.thread_count_consolidation = thread_count_consolidation
|
|
|
|
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
|
|
new_plans = []
|
|
for i, plan in enumerate(plans, start=1):
|
|
storage_data: dict[str, Any] = {}
|
|
if self.fqn_to_index_mapping is not None:
|
|
storage_data["fqn_to_index_mapping"] = self.fqn_to_index_mapping
|
|
if self.save_distributed:
|
|
storage_data["shard_index"] = i
|
|
|
|
new_plans.append(dataclasses.replace(plan, storage_data=storage_data))
|
|
|
|
return new_plans
|
|
|
|
def write_data(
|
|
self,
|
|
plan: SavePlan,
|
|
planner: SavePlanner,
|
|
) -> Future[list[WriteResult]]:
|
|
if len(plan.items) == 0:
|
|
fut: Future = Future()
|
|
fut.set_result([])
|
|
return fut
|
|
|
|
# storage_plan is a map from key to file index
|
|
storage_data: dict[str, Any] = plan.storage_data
|
|
storage_plan: Optional[dict[str, int]] = None
|
|
shard_index: Optional[int] = None
|
|
if "fqn_to_index_mapping" in storage_data:
|
|
storage_plan = storage_data["fqn_to_index_mapping"]
|
|
if "shard_index" in storage_data:
|
|
shard_index = storage_data["shard_index"]
|
|
|
|
buckets = self._split_by_storage_plan(storage_plan, plan.items)
|
|
highest_index = max(storage_plan.values()) if storage_plan is not None else 1
|
|
|
|
file_queue: queue.Queue = queue.Queue()
|
|
for file_index, write_items in buckets.items():
|
|
file_name = _gen_file_name(file_index, highest_index, shard_index)
|
|
file_queue.put(
|
|
(self.fs.concat_path(self.path, file_name), file_name, write_items)
|
|
)
|
|
|
|
return super()._write_data(planner, file_queue)
|
|
|
|
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
|
if self.save_distributed and not self.enable_consolidation:
|
|
# if we are saving distributed, without consolidating,
|
|
# then we have no metadata to write because a metadata
|
|
# file with fqn to file mapping doesn't make sense
|
|
# in this case, because fqns will be in multiple files
|
|
logger.info("Not consolidating sharded checkpoint in finish step.")
|
|
return
|
|
if self.save_distributed:
|
|
fqn_to_index_mapping: dict[str, int] = (
|
|
self.fqn_to_index_mapping
|
|
if self.fqn_to_index_mapping is not None
|
|
else dict.fromkeys(metadata.state_dict_metadata.keys(), 1)
|
|
)
|
|
|
|
return consolidate_safetensors_files(
|
|
input_dir=str(self.path),
|
|
output_dir=self.consolidated_output_path, # type: ignore[arg-type]
|
|
num_threads=self.thread_count_consolidation,
|
|
fqn_to_index_mapping=fqn_to_index_mapping,
|
|
)
|
|
|
|
# writing a model.index.safetensors.json file with fqn to file mapping
|
|
# for the rank-0 checkpointing case
|
|
metadata_to_write = {}
|
|
storage_md = {}
|
|
total_size = 0
|
|
for wr_list in results:
|
|
storage_md.update(
|
|
{wr.index.fqn: wr.storage_data.relative_path for wr in wr_list}
|
|
)
|
|
total_size += sum([wr.storage_data.length for wr in wr_list])
|
|
metadata_to_write["metadata"] = {"total_size": total_size}
|
|
metadata_to_write["weight_map"] = storage_md
|
|
|
|
metadata_path = self.fs.concat_path(self.path, f"{_metadata_fn}")
|
|
with self.fs.create_stream(metadata_path, "w") as metadata_file:
|
|
json.dump(metadata_to_write, metadata_file, indent=2)
|
|
|
|
def _split_by_storage_plan(
|
|
self, storage_plan: Optional[dict[str, int]], items: list[WriteItem]
|
|
) -> dict[int, list[WriteItem]]:
|
|
# storage_plan is a map from key to index
|
|
if storage_plan is None:
|
|
return {1: items}
|
|
|
|
buckets = {}
|
|
for item in items:
|
|
key = item.index.fqn
|
|
|
|
idx = storage_plan[key]
|
|
if idx not in buckets:
|
|
buckets[idx] = [item]
|
|
else:
|
|
buckets[idx].append(item)
|
|
|
|
return buckets
|
|
|
|
@property
|
|
def metadata_path(self) -> str:
|
|
return _metadata_fn
|
|
|
|
|
|
class HuggingFaceStorageReader(FileSystemReader):
|
|
"""
|
|
A reader that reads a checkpoint in the huggingface safetensors format.
|
|
"""
|
|
|
|
def __init__(self, path: str, thread_count: int = 1) -> None:
|
|
"""
|
|
Initialize the huggingface reader pointing to path.
|
|
|
|
Args:
|
|
path: directory where the checkpoint will be read from.
|
|
thread_count: Number of threads to use to read distributed checkpoint. Default to 1.
|
|
"""
|
|
|
|
super().__init__(path=path)
|
|
self.thread_count = thread_count
|
|
|
|
def _process_read_request(self, f, req: ReadItem, planner: LoadPlanner) -> None:
|
|
"""Helper function to process a single read request."""
|
|
# Create slices for each dimension based on offsets and lengths
|
|
slices = tuple(
|
|
slice(offset, offset + length)
|
|
for offset, length in zip(req.storage_offsets, req.lengths)
|
|
)
|
|
tensor = f.get_slice(req.storage_index.fqn)[slices]
|
|
target_tensor = planner.resolve_tensor(req).detach()
|
|
|
|
if target_tensor.size() != tensor.size():
|
|
raise AssertionError(
|
|
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
|
)
|
|
|
|
target_tensor.copy_(tensor)
|
|
planner.commit_tensor(req, target_tensor)
|
|
|
|
def _read_files_from_queue(
|
|
self,
|
|
file_queue: queue.Queue,
|
|
result_queue: queue.Queue,
|
|
planner: LoadPlanner,
|
|
) -> None:
|
|
from safetensors import safe_open # type: ignore[import]
|
|
|
|
try:
|
|
while True:
|
|
file_name, reqs = file_queue.get_nowait()
|
|
with safe_open(filename=file_name, framework="pt") as f:
|
|
for req in reqs:
|
|
self._process_read_request(f, req, planner)
|
|
result_queue.put(True) # Signal that this file has been processed
|
|
except queue.Empty:
|
|
pass
|
|
|
|
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
|
from safetensors import safe_open # type: ignore[import]
|
|
|
|
per_file: dict[str, list[ReadItem]] = {}
|
|
|
|
for read_item in plan.items:
|
|
item_md: _HFStorageInfo = self.storage_data[read_item.storage_index]
|
|
file_name = item_md.relative_path
|
|
per_file.setdefault(file_name, []).append(read_item)
|
|
|
|
if self.thread_count <= 1 or len(per_file) <= 1:
|
|
for file_name, reqs in per_file.items():
|
|
with safe_open(filename=file_name, framework="pt") as f:
|
|
for req in reqs:
|
|
self._process_read_request(f, req, planner)
|
|
else:
|
|
# Use parallel implementation with thread pool
|
|
file_queue: queue.Queue = queue.Queue()
|
|
result_queue: queue.Queue = queue.Queue()
|
|
|
|
# Fill the queue with files to process
|
|
for file_name, reqs in per_file.items():
|
|
file_queue.put((file_name, reqs))
|
|
|
|
# Create and start worker threads
|
|
threads = []
|
|
num_threads = min(self.thread_count, len(per_file))
|
|
for _ in range(num_threads):
|
|
t = threading.Thread(
|
|
target=self._read_files_from_queue,
|
|
args=(file_queue, result_queue, planner),
|
|
)
|
|
t.start()
|
|
threads.append(t)
|
|
|
|
# Wait for all threads to complete
|
|
for t in threads:
|
|
t.join()
|
|
|
|
# Check if all files were processed
|
|
processed_count = 0
|
|
try:
|
|
while True:
|
|
result_queue.get_nowait()
|
|
processed_count += 1
|
|
except queue.Empty:
|
|
pass
|
|
|
|
if processed_count != len(per_file):
|
|
raise AssertionError(
|
|
f"Not all files were processed: {processed_count} out of {len(per_file)}"
|
|
)
|
|
|
|
fut: Future = Future()
|
|
fut.set_result(None)
|
|
return fut
|
|
|
|
# pyrefly: ignore # bad-override
|
|
def read_metadata(self) -> Metadata:
|
|
from safetensors import safe_open # type: ignore[import]
|
|
from safetensors.torch import _getdtype # type: ignore[import]
|
|
|
|
state_dict_metadata: dict[str, TensorStorageMetadata] = {}
|
|
storage_data: dict[MetadataIndex, _HFStorageInfo] = {}
|
|
|
|
safetensors_files = []
|
|
for file in self.fs.ls(self.path):
|
|
if file.endswith(SUFFIX):
|
|
safetensors_files.append(file)
|
|
|
|
for safetensor_file in safetensors_files:
|
|
with safe_open(safetensor_file, framework="pt") as f:
|
|
keys = f.keys()
|
|
extra_metadata = f.metadata()
|
|
|
|
dcp_sharding_info = None
|
|
if extra_metadata and extra_metadata.get(CUSTOM_METADATA_KEY):
|
|
dcp_sharding_info = json.loads(
|
|
extra_metadata.get(CUSTOM_METADATA_KEY)
|
|
)
|
|
|
|
for key in keys:
|
|
shape = f.get_slice(key).get_shape()
|
|
dtype = f.get_slice(key).get_dtype()
|
|
# construct state_dict_metadata
|
|
if dcp_sharding_info is not None:
|
|
offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY]
|
|
else:
|
|
offset = [0] * len(shape)
|
|
|
|
if key not in state_dict_metadata:
|
|
state_dict_metadata[key] = TensorStorageMetadata(
|
|
properties=TensorProperties(dtype=_getdtype(dtype)),
|
|
size=torch.Size(
|
|
[saved + offset for saved, offset in zip(shape, offset)]
|
|
),
|
|
chunks=[
|
|
ChunkStorageMetadata(
|
|
offsets=torch.Size(offset),
|
|
sizes=torch.Size(shape),
|
|
)
|
|
],
|
|
)
|
|
else:
|
|
state_dict_metadata[key].chunks.append(
|
|
ChunkStorageMetadata(
|
|
torch.Size(offset), sizes=torch.Size(shape)
|
|
)
|
|
)
|
|
size = list(state_dict_metadata[key].size)
|
|
for i in range(len(size)):
|
|
size[i] = max(size[i], shape[i] + offset[i])
|
|
state_dict_metadata[key].size = torch.Size(size)
|
|
|
|
# construct storage data
|
|
if dcp_sharding_info is not None:
|
|
metadata_index = MetadataIndex(
|
|
fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY]
|
|
)
|
|
else:
|
|
metadata_index = MetadataIndex(fqn=key, offset=[0] * len(shape))
|
|
storage_data[metadata_index] = _HFStorageInfo(
|
|
relative_path=safetensor_file,
|
|
shape=torch.Size(shape),
|
|
dtype=_getdtype(dtype),
|
|
)
|
|
|
|
metadata = Metadata(
|
|
state_dict_metadata=state_dict_metadata, # type: ignore[arg-type]
|
|
storage_data=storage_data,
|
|
)
|
|
|
|
if getattr(metadata, "storage_meta", None) is None:
|
|
metadata.storage_meta = StorageMeta()
|
|
metadata.storage_meta.load_id = self.load_id # type: ignore[union-attr]
|
|
|
|
return metadata
|