Revert "HF loads dcp - don't do a full deserialize on every file (#155942)"

This reverts commit 117db5601d78cbc746b35eef71fc815e042e903f.

Reverted https://github.com/pytorch/pytorch/pull/155942 on behalf of https://github.com/jeanschmidt due to Newly introduced tests are red internally, more details on D76442012 ([comment](https://github.com/pytorch/pytorch/pull/155942#issuecomment-3023473036))
This commit is contained in:
PyTorch MergeBot
2025-07-01 11:15:08 +00:00
parent 0bce390269
commit 13bf2655c1
3 changed files with 37 additions and 45 deletions

View File

@ -2,16 +2,14 @@
import json
import os
import pathlib
import sys
import tempfile
from unittest.mock import MagicMock
import torch
from torch.distributed.checkpoint import DefaultLoadPlanner
from torch.distributed.checkpoint._hf_utils import (
_HFStorageInfo,
NUM_BYTES_FOR_HEADER_LEN,
)
from torch.distributed.checkpoint._hf_utils import _HFStorageInfo
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
from torch.distributed.checkpoint.hf_storage import (
@ -162,37 +160,30 @@ class TestHfStorage(TestCase):
)
def test_read_data_hf(self) -> None:
mock_safetensors = MagicMock()
sys.modules["safetensors"] = mock_safetensors
# Create test tensors
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Mock the deserialize function to return our test tensors
# The format matches what's expected in the read_data method
mock_safetensors.deserialize.return_value = [
(
"tensor_0",
{"data": tensor_0.numpy().tobytes(), "dtype": "F32", "shape": [4]},
),
]
with tempfile.TemporaryDirectory() as path:
# Create the reader
reader = HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
# Create test file
file_name = "model-00001-of-00001.safetensors"
file_path = os.path.join(path, file_name)
with open(file_path, "wb") as f:
# write metadata the same way it would be in safetensors file
metadata_contents = json.dumps(
{
"tensor_0": {
"dtype": "F32",
"shape": [1, 4],
"data_offsets": [0, 16],
}
}
)
metadata_bytes = metadata_contents.encode("utf-8")
f.write(
len(metadata_bytes).to_bytes(
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
)
)
f.write(metadata_bytes)
f.write(tensor_0.numpy().tobytes())
pathlib.Path(file_path).touch()
# Set up storage data with _StorageInfo objects
storage_data = {
@ -200,7 +191,7 @@ class TestHfStorage(TestCase):
fqn="tensor_0", offset=torch.Size([0]), index=None
): _HFStorageInfo(
file_path,
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN,
0,
tensor_0.numel() * tensor_0.element_size(),
tensor_0.shape,
tensor_0.dtype,
@ -254,6 +245,7 @@ class TestHfStorage(TestCase):
),
)
# Call read_data
future = reader.read_data(load_plan, load_planner)
future.wait()
@ -331,16 +323,9 @@ class TestHfStorage(TestCase):
)
metadata_bytes = metadata_contents.encode("utf-8")
f.write(
len(metadata_bytes).to_bytes(
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
)
)
f.write(len(metadata_bytes).to_bytes(8, byteorder="little"))
f.write(metadata_bytes)
tensor = torch.rand(5, 10)
f.write(tensor.numpy().tobytes())
metadata = reader.read_metadata()
self.assertEqual(
@ -357,7 +342,6 @@ class TestHfStorage(TestCase):
),
},
)
self.assertEqual(
metadata.storage_data,
{
@ -365,7 +349,7 @@ class TestHfStorage(TestCase):
fqn=key, offset=torch.Size([0, 0]), index=None
): _HFStorageInfo(
os.path.join(path, file_name),
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN,
0,
200,
torch.Size([5, 10]),
torch.float32,

View File

@ -41,8 +41,6 @@ DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO"
FORMAT_KEY = "format"
FORMAT_VALUE = "pt"
NUM_BYTES_FOR_HEADER_LEN = 8
@dataclass
class _HFStorageInfo:
@ -82,11 +80,12 @@ def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]:
# and follows their documentation on how their files are serialized
# https://huggingface.co/docs/safetensors/index#format
header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN)
num_bytes_for_header_len = 8
header_len_bytes = file_bytes.read(num_bytes_for_header_len)
header_len = struct.unpack("<Q", header_len_bytes)[0]
header_json = file_bytes.read(header_len)
metadata = json.loads(header_json)
return (metadata, header_len + NUM_BYTES_FOR_HEADER_LEN)
return (metadata, header_len + num_bytes_for_header_len)
def _get_dtype(dtype_str: str) -> torch.dtype:

View File

@ -18,6 +18,7 @@ from torch.distributed.checkpoint._hf_utils import (
_HFStorageInfo,
_metadata_fn,
CUSTOM_METADATA_KEY,
DATA_KEY,
DATA_OFFSETS_KEY,
DEFAULT_EXTRA_METADATA_KEY,
DTYPE_KEY,
@ -233,6 +234,8 @@ class HuggingFaceStorageReader(FsspecReader):
super().__init__(path=path)
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
from safetensors import deserialize # type: ignore[import-not-found]
per_file: dict[str, list[ReadItem]] = {}
for read_item in plan.items:
@ -242,11 +245,17 @@ class HuggingFaceStorageReader(FsspecReader):
for file_name, reqs in per_file.items():
with self.fs.create_stream(file_name, "rb") as stream:
# TODO: make this more efficient by doing offset reads instead of a
# full deserialization of the file
deserialized = deserialize(stream.read())
deserialized_dict: dict[str, dict[str, Any]] = {
tensor_info[0]: tensor_info[1] for tensor_info in deserialized
}
for req in reqs:
item_md = self.storage_data[req.storage_index]
stream.seek(item_md.offset)
tensor_bytes = stream.read(item_md.length)
tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY]
tensor = torch.frombuffer(
tensor_bytes,
@ -280,7 +289,7 @@ class HuggingFaceStorageReader(FsspecReader):
for safetensor_file in safetensors_files:
with self.fs.create_stream(safetensor_file, "rb") as f:
safetensors_metadata, metadata_size = _get_safetensors_file_metadata(f)
safetensors_metadata, _ = _get_safetensors_file_metadata(f)
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
dcp_sharding_info = None
@ -339,7 +348,7 @@ class HuggingFaceStorageReader(FsspecReader):
)
storage_data[metadata_index] = _HFStorageInfo(
relative_path=safetensor_file,
offset=val[DATA_OFFSETS_KEY][0] + metadata_size,
offset=val[DATA_OFFSETS_KEY][0],
length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
shape=torch.Size(val[SHAPE_KEY]),
dtype=_get_dtype(val[DTYPE_KEY]),