HF loads dcp - don't do a full deserialize on every file (#157715)

Summary: These changes in D76442012 got reverted after the PR landed due to aps_models/ads/launchers/pearl/tests/ne/e2e_deterministic_tests:pearl_e2e_ne_tests failing with `Config not loaded due to no timely response from configerator. Likely configerator_proxy or falcon_proxy are not healthy`, but that test failing is definitely transient and unrelated to my changes, so re-creating the diff

Test Plan:
ensure tests pass

Rollback Plan:

Differential Revision: D77871099

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157715
Approved by: https://github.com/meetv18
This commit is contained in:
Ankita George
2025-07-08 18:13:23 +00:00
committed by PyTorch MergeBot
parent 4f5be56612
commit dea4864ce0
3 changed files with 45 additions and 37 deletions

View File

@ -2,14 +2,16 @@
import json
import os
import pathlib
import sys
import tempfile
from unittest.mock import MagicMock
import torch
from torch.distributed.checkpoint import DefaultLoadPlanner
from torch.distributed.checkpoint._hf_utils import _HFStorageInfo
from torch.distributed.checkpoint._hf_utils import (
_HFStorageInfo,
NUM_BYTES_FOR_HEADER_LEN,
)
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
from torch.distributed.checkpoint.hf_storage import (
@ -160,30 +162,37 @@ class TestHfStorage(TestCase):
)
def test_read_data_hf(self) -> None:
mock_safetensors = MagicMock()
sys.modules["safetensors"] = mock_safetensors
# Create test tensors
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Mock the deserialize function to return our test tensors
# The format matches what's expected in the read_data method
mock_safetensors.deserialize.return_value = [
(
"tensor_0",
{"data": tensor_0.numpy().tobytes(), "dtype": "F32", "shape": [4]},
),
]
with tempfile.TemporaryDirectory() as path:
# Create the reader
reader = HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
# Create test file
file_name = "model-00001-of-00001.safetensors"
file_path = os.path.join(path, file_name)
pathlib.Path(file_path).touch()
with open(file_path, "wb") as f:
# write metadata the same way it would be in safetensors file
metadata_contents = json.dumps(
{
"tensor_0": {
"dtype": "F32",
"shape": [1, 4],
"data_offsets": [0, 16],
}
}
)
metadata_bytes = metadata_contents.encode("utf-8")
f.write(
len(metadata_bytes).to_bytes(
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
)
)
f.write(metadata_bytes)
f.write(tensor_0.numpy().tobytes())
# Set up storage data with _StorageInfo objects
storage_data = {
@ -191,7 +200,7 @@ class TestHfStorage(TestCase):
fqn="tensor_0", offset=torch.Size([0]), index=None
): _HFStorageInfo(
file_path,
0,
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN,
tensor_0.numel() * tensor_0.element_size(),
tensor_0.shape,
tensor_0.dtype,
@ -245,7 +254,6 @@ class TestHfStorage(TestCase):
),
)
# Call read_data
future = reader.read_data(load_plan, load_planner)
future.wait()
@ -323,9 +331,16 @@ class TestHfStorage(TestCase):
)
metadata_bytes = metadata_contents.encode("utf-8")
f.write(len(metadata_bytes).to_bytes(8, byteorder="little"))
f.write(
len(metadata_bytes).to_bytes(
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
)
)
f.write(metadata_bytes)
tensor = torch.rand(5, 10)
f.write(tensor.numpy().tobytes())
metadata = reader.read_metadata()
self.assertEqual(
@ -342,6 +357,7 @@ class TestHfStorage(TestCase):
),
},
)
self.assertEqual(
metadata.storage_data,
{
@ -349,7 +365,7 @@ class TestHfStorage(TestCase):
fqn=key, offset=torch.Size([0, 0]), index=None
): _HFStorageInfo(
os.path.join(path, file_name),
0,
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN,
200,
torch.Size([5, 10]),
torch.float32,

View File

@ -41,6 +41,8 @@ DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO"
FORMAT_KEY = "format"
FORMAT_VALUE = "pt"
NUM_BYTES_FOR_HEADER_LEN = 8
@dataclass
class _HFStorageInfo:
@ -80,12 +82,11 @@ def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]:
# and follows their documentation on how their files are serialized
# https://huggingface.co/docs/safetensors/index#format
num_bytes_for_header_len = 8
header_len_bytes = file_bytes.read(num_bytes_for_header_len)
header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN)
header_len = struct.unpack("<Q", header_len_bytes)[0]
header_json = file_bytes.read(header_len)
metadata = json.loads(header_json)
return (metadata, header_len + num_bytes_for_header_len)
return (metadata, header_len + NUM_BYTES_FOR_HEADER_LEN)
def _get_dtype(dtype_str: str) -> torch.dtype:

View File

@ -18,7 +18,6 @@ from torch.distributed.checkpoint._hf_utils import (
_HFStorageInfo,
_metadata_fn,
CUSTOM_METADATA_KEY,
DATA_KEY,
DATA_OFFSETS_KEY,
DEFAULT_EXTRA_METADATA_KEY,
DTYPE_KEY,
@ -234,8 +233,6 @@ class HuggingFaceStorageReader(FsspecReader):
super().__init__(path=path)
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
from safetensors import deserialize # type: ignore[import-not-found]
per_file: dict[str, list[ReadItem]] = {}
for read_item in plan.items:
@ -245,17 +242,11 @@ class HuggingFaceStorageReader(FsspecReader):
for file_name, reqs in per_file.items():
with self.fs.create_stream(file_name, "rb") as stream:
# TODO: make this more efficient by doing offset reads instead of a
# full deserialization of the file
deserialized = deserialize(stream.read())
deserialized_dict: dict[str, dict[str, Any]] = {
tensor_info[0]: tensor_info[1] for tensor_info in deserialized
}
for req in reqs:
item_md = self.storage_data[req.storage_index]
tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY]
stream.seek(item_md.offset)
tensor_bytes = stream.read(item_md.length)
tensor = torch.frombuffer(
tensor_bytes,
@ -289,7 +280,7 @@ class HuggingFaceStorageReader(FsspecReader):
for safetensor_file in safetensors_files:
with self.fs.create_stream(safetensor_file, "rb") as f:
safetensors_metadata, _ = _get_safetensors_file_metadata(f)
safetensors_metadata, metadata_size = _get_safetensors_file_metadata(f)
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
dcp_sharding_info = None
@ -348,7 +339,7 @@ class HuggingFaceStorageReader(FsspecReader):
)
storage_data[metadata_index] = _HFStorageInfo(
relative_path=safetensor_file,
offset=val[DATA_OFFSETS_KEY][0],
offset=val[DATA_OFFSETS_KEY][0] + metadata_size,
length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
shape=torch.Size(val[SHAPE_KEY]),
dtype=_get_dtype(val[DTYPE_KEY]),