mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "HF loads dcp - don't do a full deserialize on every file (#155942)"
This reverts commit 117db5601d78cbc746b35eef71fc815e042e903f. Reverted https://github.com/pytorch/pytorch/pull/155942 on behalf of https://github.com/jeanschmidt due to Newly introduced tests are red internally, more details on D76442012 ([comment](https://github.com/pytorch/pytorch/pull/155942#issuecomment-3023473036))
This commit is contained in:
@ -2,16 +2,14 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import pathlib
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.distributed.checkpoint import DefaultLoadPlanner
|
from torch.distributed.checkpoint import DefaultLoadPlanner
|
||||||
from torch.distributed.checkpoint._hf_utils import (
|
from torch.distributed.checkpoint._hf_utils import _HFStorageInfo
|
||||||
_HFStorageInfo,
|
|
||||||
NUM_BYTES_FOR_HEADER_LEN,
|
|
||||||
)
|
|
||||||
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
||||||
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
|
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
|
||||||
from torch.distributed.checkpoint.hf_storage import (
|
from torch.distributed.checkpoint.hf_storage import (
|
||||||
@ -162,37 +160,30 @@ class TestHfStorage(TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_read_data_hf(self) -> None:
|
def test_read_data_hf(self) -> None:
|
||||||
|
mock_safetensors = MagicMock()
|
||||||
|
sys.modules["safetensors"] = mock_safetensors
|
||||||
|
|
||||||
# Create test tensors
|
# Create test tensors
|
||||||
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||||
|
|
||||||
|
# Mock the deserialize function to return our test tensors
|
||||||
|
# The format matches what's expected in the read_data method
|
||||||
|
mock_safetensors.deserialize.return_value = [
|
||||||
|
(
|
||||||
|
"tensor_0",
|
||||||
|
{"data": tensor_0.numpy().tobytes(), "dtype": "F32", "shape": [4]},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as path:
|
with tempfile.TemporaryDirectory() as path:
|
||||||
# Create the reader
|
# Create the reader
|
||||||
reader = HuggingFaceStorageReader(path=path)
|
reader = HuggingFaceStorageReader(path=path)
|
||||||
|
reader.fs = FileSystem()
|
||||||
|
|
||||||
# Create test file
|
# Create test file
|
||||||
file_name = "model-00001-of-00001.safetensors"
|
file_name = "model-00001-of-00001.safetensors"
|
||||||
file_path = os.path.join(path, file_name)
|
file_path = os.path.join(path, file_name)
|
||||||
|
pathlib.Path(file_path).touch()
|
||||||
with open(file_path, "wb") as f:
|
|
||||||
# write metadata the same way it would be in safetensors file
|
|
||||||
metadata_contents = json.dumps(
|
|
||||||
{
|
|
||||||
"tensor_0": {
|
|
||||||
"dtype": "F32",
|
|
||||||
"shape": [1, 4],
|
|
||||||
"data_offsets": [0, 16],
|
|
||||||
}
|
|
||||||
}
|
|
||||||
)
|
|
||||||
metadata_bytes = metadata_contents.encode("utf-8")
|
|
||||||
|
|
||||||
f.write(
|
|
||||||
len(metadata_bytes).to_bytes(
|
|
||||||
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
f.write(metadata_bytes)
|
|
||||||
|
|
||||||
f.write(tensor_0.numpy().tobytes())
|
|
||||||
|
|
||||||
# Set up storage data with _StorageInfo objects
|
# Set up storage data with _StorageInfo objects
|
||||||
storage_data = {
|
storage_data = {
|
||||||
@ -200,7 +191,7 @@ class TestHfStorage(TestCase):
|
|||||||
fqn="tensor_0", offset=torch.Size([0]), index=None
|
fqn="tensor_0", offset=torch.Size([0]), index=None
|
||||||
): _HFStorageInfo(
|
): _HFStorageInfo(
|
||||||
file_path,
|
file_path,
|
||||||
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN,
|
0,
|
||||||
tensor_0.numel() * tensor_0.element_size(),
|
tensor_0.numel() * tensor_0.element_size(),
|
||||||
tensor_0.shape,
|
tensor_0.shape,
|
||||||
tensor_0.dtype,
|
tensor_0.dtype,
|
||||||
@ -254,6 +245,7 @@ class TestHfStorage(TestCase):
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Call read_data
|
||||||
future = reader.read_data(load_plan, load_planner)
|
future = reader.read_data(load_plan, load_planner)
|
||||||
future.wait()
|
future.wait()
|
||||||
|
|
||||||
@ -331,16 +323,9 @@ class TestHfStorage(TestCase):
|
|||||||
)
|
)
|
||||||
metadata_bytes = metadata_contents.encode("utf-8")
|
metadata_bytes = metadata_contents.encode("utf-8")
|
||||||
|
|
||||||
f.write(
|
f.write(len(metadata_bytes).to_bytes(8, byteorder="little"))
|
||||||
len(metadata_bytes).to_bytes(
|
|
||||||
NUM_BYTES_FOR_HEADER_LEN, byteorder="little"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
f.write(metadata_bytes)
|
f.write(metadata_bytes)
|
||||||
|
|
||||||
tensor = torch.rand(5, 10)
|
|
||||||
f.write(tensor.numpy().tobytes())
|
|
||||||
|
|
||||||
metadata = reader.read_metadata()
|
metadata = reader.read_metadata()
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
@ -357,7 +342,6 @@ class TestHfStorage(TestCase):
|
|||||||
),
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
metadata.storage_data,
|
metadata.storage_data,
|
||||||
{
|
{
|
||||||
@ -365,7 +349,7 @@ class TestHfStorage(TestCase):
|
|||||||
fqn=key, offset=torch.Size([0, 0]), index=None
|
fqn=key, offset=torch.Size([0, 0]), index=None
|
||||||
): _HFStorageInfo(
|
): _HFStorageInfo(
|
||||||
os.path.join(path, file_name),
|
os.path.join(path, file_name),
|
||||||
len(metadata_bytes) + NUM_BYTES_FOR_HEADER_LEN,
|
0,
|
||||||
200,
|
200,
|
||||||
torch.Size([5, 10]),
|
torch.Size([5, 10]),
|
||||||
torch.float32,
|
torch.float32,
|
||||||
|
@ -41,8 +41,6 @@ DCP_SHARDING_INFO_KEY = "DCP_SHARDING_INFO"
|
|||||||
FORMAT_KEY = "format"
|
FORMAT_KEY = "format"
|
||||||
FORMAT_VALUE = "pt"
|
FORMAT_VALUE = "pt"
|
||||||
|
|
||||||
NUM_BYTES_FOR_HEADER_LEN = 8
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _HFStorageInfo:
|
class _HFStorageInfo:
|
||||||
@ -82,11 +80,12 @@ def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> tuple[Any, int]:
|
|||||||
# and follows their documentation on how their files are serialized
|
# and follows their documentation on how their files are serialized
|
||||||
# https://huggingface.co/docs/safetensors/index#format
|
# https://huggingface.co/docs/safetensors/index#format
|
||||||
|
|
||||||
header_len_bytes = file_bytes.read(NUM_BYTES_FOR_HEADER_LEN)
|
num_bytes_for_header_len = 8
|
||||||
|
header_len_bytes = file_bytes.read(num_bytes_for_header_len)
|
||||||
header_len = struct.unpack("<Q", header_len_bytes)[0]
|
header_len = struct.unpack("<Q", header_len_bytes)[0]
|
||||||
header_json = file_bytes.read(header_len)
|
header_json = file_bytes.read(header_len)
|
||||||
metadata = json.loads(header_json)
|
metadata = json.loads(header_json)
|
||||||
return (metadata, header_len + NUM_BYTES_FOR_HEADER_LEN)
|
return (metadata, header_len + num_bytes_for_header_len)
|
||||||
|
|
||||||
|
|
||||||
def _get_dtype(dtype_str: str) -> torch.dtype:
|
def _get_dtype(dtype_str: str) -> torch.dtype:
|
||||||
|
@ -18,6 +18,7 @@ from torch.distributed.checkpoint._hf_utils import (
|
|||||||
_HFStorageInfo,
|
_HFStorageInfo,
|
||||||
_metadata_fn,
|
_metadata_fn,
|
||||||
CUSTOM_METADATA_KEY,
|
CUSTOM_METADATA_KEY,
|
||||||
|
DATA_KEY,
|
||||||
DATA_OFFSETS_KEY,
|
DATA_OFFSETS_KEY,
|
||||||
DEFAULT_EXTRA_METADATA_KEY,
|
DEFAULT_EXTRA_METADATA_KEY,
|
||||||
DTYPE_KEY,
|
DTYPE_KEY,
|
||||||
@ -233,6 +234,8 @@ class HuggingFaceStorageReader(FsspecReader):
|
|||||||
super().__init__(path=path)
|
super().__init__(path=path)
|
||||||
|
|
||||||
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
||||||
|
from safetensors import deserialize # type: ignore[import-not-found]
|
||||||
|
|
||||||
per_file: dict[str, list[ReadItem]] = {}
|
per_file: dict[str, list[ReadItem]] = {}
|
||||||
|
|
||||||
for read_item in plan.items:
|
for read_item in plan.items:
|
||||||
@ -242,11 +245,17 @@ class HuggingFaceStorageReader(FsspecReader):
|
|||||||
|
|
||||||
for file_name, reqs in per_file.items():
|
for file_name, reqs in per_file.items():
|
||||||
with self.fs.create_stream(file_name, "rb") as stream:
|
with self.fs.create_stream(file_name, "rb") as stream:
|
||||||
|
# TODO: make this more efficient by doing offset reads instead of a
|
||||||
|
# full deserialization of the file
|
||||||
|
deserialized = deserialize(stream.read())
|
||||||
|
deserialized_dict: dict[str, dict[str, Any]] = {
|
||||||
|
tensor_info[0]: tensor_info[1] for tensor_info in deserialized
|
||||||
|
}
|
||||||
|
|
||||||
for req in reqs:
|
for req in reqs:
|
||||||
item_md = self.storage_data[req.storage_index]
|
item_md = self.storage_data[req.storage_index]
|
||||||
|
|
||||||
stream.seek(item_md.offset)
|
tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY]
|
||||||
tensor_bytes = stream.read(item_md.length)
|
|
||||||
|
|
||||||
tensor = torch.frombuffer(
|
tensor = torch.frombuffer(
|
||||||
tensor_bytes,
|
tensor_bytes,
|
||||||
@ -280,7 +289,7 @@ class HuggingFaceStorageReader(FsspecReader):
|
|||||||
|
|
||||||
for safetensor_file in safetensors_files:
|
for safetensor_file in safetensors_files:
|
||||||
with self.fs.create_stream(safetensor_file, "rb") as f:
|
with self.fs.create_stream(safetensor_file, "rb") as f:
|
||||||
safetensors_metadata, metadata_size = _get_safetensors_file_metadata(f)
|
safetensors_metadata, _ = _get_safetensors_file_metadata(f)
|
||||||
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
|
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
|
||||||
|
|
||||||
dcp_sharding_info = None
|
dcp_sharding_info = None
|
||||||
@ -339,7 +348,7 @@ class HuggingFaceStorageReader(FsspecReader):
|
|||||||
)
|
)
|
||||||
storage_data[metadata_index] = _HFStorageInfo(
|
storage_data[metadata_index] = _HFStorageInfo(
|
||||||
relative_path=safetensor_file,
|
relative_path=safetensor_file,
|
||||||
offset=val[DATA_OFFSETS_KEY][0] + metadata_size,
|
offset=val[DATA_OFFSETS_KEY][0],
|
||||||
length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
|
length=val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
|
||||||
shape=torch.Size(val[SHAPE_KEY]),
|
shape=torch.Size(val[SHAPE_KEY]),
|
||||||
dtype=_get_dtype(val[DTYPE_KEY]),
|
dtype=_get_dtype(val[DTYPE_KEY]),
|
||||||
|
Reference in New Issue
Block a user