Revert "HF loads dcp - don't do a full deserialize on every file (#155942)"

This reverts commit 117db5601d78cbc746b35eef71fc815e042e903f.

Reverted https://github.com/pytorch/pytorch/pull/155942 on behalf of https://github.com/jeanschmidt due to Newly introduced tests are red internally, more details on D76442012 ([comment](https://github.com/pytorch/pytorch/pull/155942#issuecomment-3023473036))
This commit is contained in:
PyTorch MergeBot
2025-07-01 11:15:08 +00:00
parent 0bce390269
commit 13bf2655c1
3 changed files with 37 additions and 45 deletions

View File

@ -2,16 +2,14 @@
import json import json
import os import os
import pathlib
import sys import sys
import tempfile import tempfile
from unittest.mock import MagicMock from unittest.mock import MagicMock
import torch import torch
from torch.distributed.checkpoint import DefaultLoadPlanner from torch.distributed.checkpoint import DefaultLoadPlanner
from torch.distributed.checkpoint._hf_utils import ( from torch.distributed.checkpoint._hf_utils import _HFStorageInfo
_HFStorageInfo,
NUM_BYTES_FOR_HEADER_LEN,
)
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
from torch.distributed.checkpoint.hf_storage import ( from torch.distributed.checkpoint.hf_storage import (
@ -162,37 +160,30 @@ class TestHfStorage(TestCase):
) )
def test_read_data_hf(self) -> None: def test_read_data_hf(self) -> None:
mock_safetensors = MagicMock()
sys.modules["safetensors"] = mock_safetensors
# Create test tensors # Create test tensors
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Mock the deserialize function to return our test tensors
# The format matches what's expected in the read_data method
mock_safetensors.deserialize.return_value = [
(
"tensor_0",
{"data": tensor_0.numpy().tobytes(), "dtype": "F32", "shape": [4]},
),
]
with tempfile.TemporaryDirectory() as path: with tempfile.TemporaryDirectory() as path:
# Create the reader # Create the reader
reader = HuggingFaceStorageReader(path=path) reader = HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
# Create test file # Create test file
file_name = "model-00001-of-00001.safetensors" file_name = "model-00001-of-00001.safetensors"
file_path = os.path.join(path, file_name) file_path = os.path.join(path, file_name)
pathlib.Path(file_path).touch()
with open(file_path, "wb") as f:
# write metadata the same way it would be in safetensors file
metadata_contents = json.dumps(
{
"tensor_0": {
"dtype": "F32",
"shape": [1, 4],
"data_offsets": [0, 16],
}
}
)
metadata_bytes = metadata_contents.encode("utf-8")
f.write(
len(metadata_bytes).to_bytes(
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
)
)
f.write(metadata_bytes)
f.write(tensor_0.numpy().tobytes())
# Set up storage data with _StorageInfo objects # Set up storage data with _StorageInfo objects
storage_data = { storage_data = {
@ -200,7 +191,7 @@ class TestHfStorage(TestCase):
fqn="tensor_0", offset=torch.Size([0]), index=None fqn="tensor_0", offset=torch.Size([0]), index=None
): _HFStorageInfo( ): _HFStorageInfo(
file_path, file_path,
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN, 0,
tensor_0.numel() * tensor_0.element_size(), tensor_0.numel() * tensor_0.element_size(),
tensor_0.shape, tensor_0.shape,
tensor_0.dtype, tensor_0.dtype,
@ -254,6 +245,7 @@ class TestHfStorage(TestCase):
), ),
) )
# Call read_data
future = reader.read_data(load_plan, load_planner) future = reader.read_data(load_plan, load_planner)
future.wait() future.wait()
@ -331,16 +323,9 @@ class TestHfStorage(TestCase):
) )
metadata_bytes = metadata_contents.encode("utf-8") metadata_bytes = metadata_contents.encode("utf-8")
f.write( f.write(len(metadata_bytes).to_bytes(8, byteorder="little"))
len(metadata_bytes).to_bytes(
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
)
)
f.write(metadata_bytes) f.write(metadata_bytes)
tensor = torch.rand(5, 10)
f.write(tensor.numpy().tobytes())
metadata = reader.read_metadata() metadata = reader.read_metadata()
self.assertEqual( self.assertEqual(
@ -357,7 +342,6 @@ class TestHfStorage(TestCase):
), ),
}, },
) )
self.assertEqual( self.assertEqual(
metadata.storage_data, metadata.storage_data,
{ {
@ -365,7 +349,7 @@ class TestHfStorage(TestCase):
fqn=key, offset=torch.Size([0, 0]), index=None fqn=key, offset=torch.Size([0, 0]), index=None
): _HFStorageInfo( ): _HFStorageInfo(
os.path.join(path, file_name), os.path.join(path, file_name),
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN, 0,
200, 200,
torch.Size([5, 10]), torch.Size([5, 10]),
torch.float32, torch.float32,

View File

@ -41,8 +41,6 @@ DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO"
FORMAT_KEY = "format" FORMAT_KEY = "format"
FORMAT_VALUE = "pt" FORMAT_VALUE = "pt"
NUM_BYTES_FOR_HEADER_LEN = 8
@dataclass @dataclass
class _HFStorageInfo: class _HFStorageInfo:
@ -82,11 +80,12 @@ def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]:
# and follows their documentation on how their files are serialized # and follows their documentation on how their files are serialized
# https://huggingface.co/docs/safetensors/index#format # https://huggingface.co/docs/safetensors/index#format
header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN) num_bytes_for_header_len = 8
header_len_bytes = file_bytes.read(num_bytes_for_header_len)
header_len = struct.unpack("<Q", header_len_bytes)[0] header_len = struct.unpack("<Q", header_len_bytes)[0]
header_json = file_bytes.read(header_len) header_json = file_bytes.read(header_len)
metadata = json.loads(header_json) metadata = json.loads(header_json)
return (metadata, header_len + NUM_BYTES_FOR_HEADER_LEN) return (metadata, header_len + num_bytes_for_header_len)
def _get_dtype(dtype_str: str) -> torch.dtype: def _get_dtype(dtype_str: str) -> torch.dtype:

View File

@ -18,6 +18,7 @@ from torch.distributed.checkpoint._hf_utils import (
_HFStorageInfo, _HFStorageInfo,
_metadata_fn, _metadata_fn,
CUSTOM_METADATA_KEY, CUSTOM_METADATA_KEY,
DATA_KEY,
DATA_OFFSETS_KEY, DATA_OFFSETS_KEY,
DEFAULT_EXTRA_METADATA_KEY, DEFAULT_EXTRA_METADATA_KEY,
DTYPE_KEY, DTYPE_KEY,
@ -233,6 +234,8 @@ class HuggingFaceStorageReader(FsspecReader):
super().__init__(path=path) super().__init__(path=path)
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
from safetensors import deserialize # type: ignore[import-not-found]
per_file: dict[str, list[ReadItem]] = {} per_file: dict[str, list[ReadItem]] = {}
for read_item in plan.items: for read_item in plan.items:
@ -242,11 +245,17 @@ class HuggingFaceStorageReader(FsspecReader):
for file_name, reqs in per_file.items(): for file_name, reqs in per_file.items():
with self.fs.create_stream(file_name, "rb") as stream: with self.fs.create_stream(file_name, "rb") as stream:
# TODO: make this more efficient by doing offset reads instead of a
# full deserialization of the file
deserialized = deserialize(stream.read())
deserialized_dict: dict[str, dict[str, Any]] = {
tensor_info[0]: tensor_info[1] for tensor_info in deserialized
}
for req in reqs: for req in reqs:
item_md = self.storage_data[req.storage_index] item_md = self.storage_data[req.storage_index]
stream.seek(item_md.offset) tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY]
tensor_bytes = stream.read(item_md.length)
tensor = torch.frombuffer( tensor = torch.frombuffer(
tensor_bytes, tensor_bytes,
@ -280,7 +289,7 @@ class HuggingFaceStorageReader(FsspecReader):
for safetensor_file in safetensors_files: for safetensor_file in safetensors_files:
with self.fs.create_stream(safetensor_file, "rb") as f: with self.fs.create_stream(safetensor_file, "rb") as f:
safetensors_metadata, metadata_size = _get_safetensors_file_metadata(f) safetensors_metadata, _ = _get_safetensors_file_metadata(f)
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
dcp_sharding_info = None dcp_sharding_info = None
@ -339,7 +348,7 @@ class HuggingFaceStorageReader(FsspecReader):
) )
storage_data[metadata_index] = _HFStorageInfo( storage_data[metadata_index] = _HFStorageInfo(
relative_path=safetensor_file, relative_path=safetensor_file,
offset=val[DATA_OFFSETS_KEY][0] + metadata_size, offset=val[DATA_OFFSETS_KEY][0],
length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
shape=torch.Size(val[SHAPE_KEY]), shape=torch.Size(val[SHAPE_KEY]),
dtype=_get_dtype(val[DTYPE_KEY]), dtype=_get_dtype(val[DTYPE_KEY]),