HF component update to not use fsspec components (#159405)

Update HF components to not inherit from fsspec components and instead use filesystem writer/reader. The reason is because there doesn't seem to be much of a need for fsspec, since users are using mounted storage. Using local storage will allow for performance improvements because we can take advantage of the safe_open API provided by HF safetensors (30s vs 4s for load of 8b model), which is signifcant performance wins over reading bytes and converting to tensors which is what we are doing now. Also, we can use the official methods provided by HF instead of relying on reading the metadata by bytes and loading it

Differential Revision: [D78993550](https://our.internmc.facebook.com/intern/diff/D78993550/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/159405
Approved by: https://github.com/saumishr
This commit is contained in:
Ankita George
2025-08-07 06:48:21 -07:00
committed by PyTorch MergeBot
parent 57f738b635
commit 69cc606fda

View File

@ -7,10 +7,10 @@ from typing import Any, Optional
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint import FileSystemReader, FileSystemWriter
from torch.distributed.checkpoint._consolidate_hf_safetensors import (
consolidate_safetensors_files,
)
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint._hf_utils import (
_gen_file_name,
_get_dtype,
@ -52,7 +52,7 @@ logger: logging.Logger = logging.getLogger(__name__)
__all__ = ["HuggingFaceStorageWriter", "HuggingFaceStorageReader"]
class HuggingFaceStorageWriter(FsspecWriter):
class HuggingFaceStorageWriter(FileSystemWriter):
"""
A writer that writes to a huggingface repository in the huggingface format.
Uses Fsspec back-end to communicate with back-end storage.
@ -64,26 +64,20 @@ class HuggingFaceStorageWriter(FsspecWriter):
path: str,
fqn_to_index_mapping: Optional[dict[str, int]] = None,
thread_count: int = 1,
token: Optional[str] = None,
save_distributed: bool = False,
enable_consolidation: bool = False,
consolidated_output_path: Optional[str] = None,
thread_count_consolidation: int = 1,
) -> None:
"""
Initialize the huggingface writer pointing to path.
Args:
path: hf directory where the checkpoint will be read from.
Needs to have .safetensors files, but can be from any fsspec supported storage,
including localFS and hf://.
This needs to be a remote path if you want to enable consolidation after saving.
path: directory where the checkpoint will be read from.
fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to.
Indices are from 1 to N, where N is the number of files. If not provided,
the tensors will be written to a single file. If none, then all the tensors on the
same rank will be written to the same file.
thread_count: Number of threads to use to write distributed checkpoint. Default to 1.
token: The token to use to authenticate with huggingface hub.
save_distributed: If True, save the checkpoint using distributed APIs where every rank saves its own shard.
Default is False which assumes rank-0 checkpointing of the full state_dict.
enable_consolidation: If True, consolidate the sharded checkpoint after saving. The sharded tensors will be
@ -92,19 +86,11 @@ class HuggingFaceStorageWriter(FsspecWriter):
to consolidated output files. Default to 1.
"""
if token is not None:
super().__init__(
path=path,
token=token,
serialization_format=SerializationFormat.SAFETENSORS,
thread_count=thread_count,
)
else:
super().__init__(
path=path,
serialization_format=SerializationFormat.SAFETENSORS,
thread_count=thread_count,
)
super().__init__(
path=path,
serialization_format=SerializationFormat.SAFETENSORS,
thread_count=thread_count,
)
self.fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping
self.save_distributed: bool = save_distributed
self.enable_consolidation: bool = enable_consolidation
@ -215,28 +201,22 @@ class HuggingFaceStorageWriter(FsspecWriter):
return _metadata_fn
class HuggingFaceStorageReader(FsspecReader):
class HuggingFaceStorageReader(FileSystemReader):
"""
A reader that reads from a huggingface repository in the huggingface format.
Uses in Fsspec back-end to communicate with storage.
Fsspec registration of the storage solution is required.
"""
def __init__(self, path: str, token: Optional[str] = None) -> None:
def __init__(self, path: str) -> None:
"""
Initialize the huggingface reader pointing to path.
Args:
path: hf directory where the checkpoint will be read from.
Needs to have .safetensors file, but can be from any fsspec supported storage,
including localFS and hf://.
token: The token to use to authenticate with huggingface hub.
path: directory where the checkpoint will be read from.
"""
if token is not None:
super().__init__(path=path, token=token)
else:
super().__init__(path=path)
super().__init__(path=path)
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
per_file: dict[str, list[ReadItem]] = {}