Compare commits

...

6 Commits

Author SHA1 Message Date
0f49e915a9 rebase 2025-05-30 14:30:12 -07:00
2f1217f944 benchmarking 2025-05-30 14:27:37 -07:00
e0bf01e87b Script for consolidation of sharded safetensor files
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154743

Script to consolidate sharded safetensors files with DCP into full tensors. This relies on file system operations to read and copy bytes directly instead of the traditional approach of loading and re-sharding and then saving again, because users will have models that are larger than allotted memory.

Differential Revision: [D75536985](https://our.internmc.facebook.com/intern/diff/D75536985/)
ghstack-source-id: 287291639
2025-05-30 14:18:51 -07:00
3b5ae0e9fc Support re-sharding for safetensors checkpoints
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154519

This change will add the ability to support re-sharding for hf safetensors checkpoints.
This is done by adding more metadata when saving each file. This metadata captures the size and offset of the saved shard. This can be used to re-shard on load by using this information to create the chunks belonging to TensorStorageMetadata class.

Differential Revision: [D75226344](https://our.internmc.facebook.com/intern/diff/D75226344/)
ghstack-source-id: 286572125
2025-05-30 10:40:32 -07:00
5f5f654a3e Updates to HFStorageReader to use TensorStorageMetadata instead of BytesStorageMetadata
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154518

As we prepare to support re-sharding, the current approach of using BytesStorageMetadata to read safetenstors won't work anymore. Before, we didn't need to read the metadata of the safetensors file from its header because we were just loading the contents of the file directly into tensors with safetensor.load() that would handle the metadata and deserialization. But now, in preparation of handling re-sharding, we need to read the metadata directly from the header of the safetensors file and store it directly in TensorStorageMetadata objects so that we can perform re-sharding. Re-sharding won't currently work, as we need extra metadata to be stored on each save, so that will be added in a subsequent PR.
In addition this PR adds an integration test in addition to the unit tests.
It also removes the HfFileSystem import because that's only needed if users are using HfFileSystem, but we want to support any backend.
ghstack-source-id: 286649070
@exported-using-ghexport

Differential Revision: [D74891998](https://our.internmc.facebook.com/intern/diff/D74891998/)
2025-05-30 10:40:30 -07:00
21931cbbc6 Changes to HFStorageWriter to support saving shards of tensors
As we move towards supporting saving partial tensors natively with HFStorageWriter, there are some simple changes that need to be made to make this happen.
- The current approach for distributed writes is that every rank has full tensors, but we split up the writing of these full tensors across all available ranks. We're removing this logic that was in the HFSavePlanner and instead assuming that every rank has a shard and saving every rank's local state
    -  as a result we can probably remove the HFSavePlanner, but keeping it as a placeholder for now

- the current naming of files doesn't support shards as its in the format "model-00001-of-00004.safetensors", but if every rank is writing the same file names they will overwrite eachother, so this adds a shard-00001 prefix, so that the rank files don't overwrite eachother
- don't save the metadata file models.safetensors.index.json if sharding is enabled. This file expects a 1 to 1 ratio between tensor and filename, but this doesn't make sense in the sharded saving approach, so we can just get rid of this file
- make the "fqn_to_file_index" map optional. This is to describe which files to save which tensors in, but if users don't want to provide this, we can just save all the tensors to one file. If they run into issues, they can choose how to split up their tensors to be more friendly with 5GB HF remote storage file size soft limit.

Differential Revision: [D75099862](https://our.internmc.facebook.com/intern/diff/D75099862/)

ghstack-source-id: 286648122
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154742
2025-05-30 10:40:28 -07:00
11 changed files with 1757 additions and 196 deletions

View File

@ -0,0 +1,135 @@
# Owner(s): ["oncall: distributed_checkpointing"]
import os
import sys
import torch
import torch.distributed.checkpoint as dist_cp
from torch import distributed as dist
from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import (
consolidate_safetensors_files,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import DTensor, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestConsolidateHFSafeTensors(DTensorTestBase):
def _create_d_tensors(self) -> None:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
# Create local tensor with row-wise sharding
rows_per_rank = global_tensor.shape[0] // self.world_size
start_row = self.rank * rows_per_rank
end_row = start_row + rows_per_rank
local_tensor = global_tensor[start_row:end_row].clone()
# Create DTensor with row-wise sharding
dtensor = DTensor.from_local(
local_tensor,
device_mesh=mesh_1d,
placements=[Shard(0)],
shape=global_tensor.shape,
stride=(4, 1),
)
# Create local tensor with column-wise sharding
cols_per_rank = global_tensor.shape[1] // self.world_size
start_col = self.rank * cols_per_rank
end_col = start_col + cols_per_rank
local_tensor_col = global_tensor[:, start_col:end_col].clone()
# Create DTensor with column-wise sharding
dtensor_col = DTensor.from_local(
local_tensor_col,
device_mesh=mesh_1d,
placements=[Shard(1)], # Column-wise sharding
shape=global_tensor.shape,
stride=(4, 1),
)
state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=self.temp_dir, save_sharded=True
),
)
dist.barrier()
os.sync()
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_consolidate_to_one_file(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
checkpoint_dir = self.temp_dir
output_dir = os.path.join(checkpoint_dir, "consolidated")
os.makedirs(output_dir, exist_ok=True)
self._create_d_tensors()
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
if self.rank == 0:
consolidate_safetensors_files(checkpoint_dir, output_dir)
file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors")
loaded_dict = safetensors.torch.load_file(file_path)
self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"})
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor))
dist.barrier()
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_consolidate_to_two_files(self):
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
checkpoint_dir = self.temp_dir
output_dir = os.path.join(checkpoint_dir, "consolidated")
os.makedirs(output_dir, exist_ok=True)
self._create_d_tensors()
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
if self.rank == 0:
fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2}
consolidate_safetensors_files(
checkpoint_dir, output_dir, fqn_to_index_mapping
)
file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors")
file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors")
loaded_dict = safetensors.torch.load_file(file1_path)
self.assertEqual(loaded_dict.keys(), {"dtensor"})
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
loaded_dict_col = safetensors.torch.load_file(file2_path)
self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"})
self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor))
dist.barrier()
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,420 @@
# Owner(s): ["oncall: distributed_checkpointing"]
import sys
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint import _HuggingFaceLoadPlanner
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
run_tests,
TestCase,
)
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
CHECKPOINT_DIR = "checkpoint"
class MyTestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear_1 = torch.nn.Linear(5, 5)
self.linear_2 = torch.nn.Linear(5, 1)
self.emb = torch.nn.EmbeddingBag(5, 10)
class TestSingleRankSaveLoad(TestCase):
@with_temp_dir
def test_save(self) -> None:
try:
from safetensors.torch import load_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=CHECKPOINT_DIR
),
)
state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
@with_temp_dir
def test_load(self) -> None:
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
state_dict_to_load = MyTestModule().state_dict()
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(
path=CHECKPOINT_DIR
),
)
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
@with_temp_dir
def test_load_into_empty_dict(self) -> None:
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
state_dict_loaded = _load_state_dict_from_keys(
storage_reader=dist_cp._HuggingFaceStorageReader(
path=CHECKPOINT_DIR
),
)
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
@with_temp_dir
def test_load_allowing_resize(self) -> None:
try:
from safetensors.torch import save_file
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
state_dict_to_save = MyTestModule().state_dict()
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
state_dict_to_load= {}
for key in state_dict_to_save.keys():
state_dict_to_load[key] = torch.zeros(1)
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(
path=CHECKPOINT_DIR
),
planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True),
)
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
for key in state_dict_to_save.keys():
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
ONE_D_PLACEMENTS = [
[Shard(0)],
[Replicate()],
]
ONE_D_TO_ONE_D_PLACEMENTS = [
([Replicate()], [Shard(0)]),
([Shard(0)], [Replicate()]),
]
TWO_D_PLACEMENTS = [
[Replicate(), Replicate()],
[Replicate(), Shard(0)],
[Shard(0), Replicate()],
[Shard(0), Shard(0)],
]
TWO_D_TO_TWO_D_PLACEMENTS = []
for p1 in TWO_D_PLACEMENTS:
for p2 in TWO_D_PLACEMENTS:
if p1 != p2:
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
@instantiate_parametrized_tests
class TestDTensorReshardPlacementChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
"""
@with_comms
@skip_if_lt_x_gpu(2)
@with_temp_dir
def test_1d_to_1d_reshard_placement_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
original_placement, new_placement = one_d_to_one_d_placements
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
device_mesh = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, device_mesh, placements=original_placement
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=CHECKPOINT_DIR,
save_sharded=True,
),
)
zero_dtensor = zeros(
[4, 4], device_mesh=device_mesh, placements=new_placement
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(
CHECKPOINT_DIR,
),
)
# materialize the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
device_mesh,
placements=[Replicate()],
)
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
# redistribute the tensor back to its original placement for comparison.
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
device_mesh,
placements=original_placement,
)
self.assertEqual(
state_dict_to_save["dtensor"].to_local(),
state_dict_to_load["dtensor"].to_local(),
)
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_2d_to_2d_reshard_placement_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS:
original_placement, new_placement = two_d_to_two_d_placements
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor,
mesh_2d,
placements=original_placement,
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
planner=dist_cp.DefaultSavePlanner(),
)
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
)
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
mesh_2d,
placements=[Replicate(), Replicate()],
)
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
mesh_2d,
placements=original_placement,
)
self.assertEqual(
state_dict_to_save["dtensor"].to_local(),
state_dict_to_load["dtensor"].to_local(),
)
class TestDTensorReshardMeshChange(DTensorTestBase):
"""
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
"""
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_1d_to_2d_reshard_mesh_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for placements_1d in ONE_D_PLACEMENTS:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, mesh_1d, placements=placements_1d
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
)
for placements_2d in TWO_D_PLACEMENTS:
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
zero_dtensor = zeros(
[4, 4], device_mesh=mesh_2d, placements=placements_2d
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load[
"dtensor"
].redistribute(
mesh_2d,
placements=[Replicate(), Replicate()],
)
self.assertEqual(
global_tensor, state_dict_to_load["dtensor"].to_local()
)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(4)
def test_2d_to_1d_reshard_mesh_change(self) -> None:
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
CHECKPOINT_DIR = self.temp_dir
for placements_2d in TWO_D_PLACEMENTS:
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
dtensor = distribute_tensor(
global_tensor, mesh_2d, placements=placements_2d
)
state_dict_to_save = {"dtensor": dtensor}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
planner=dist_cp.DefaultSavePlanner(),
)
for placements_1d in ONE_D_PLACEMENTS:
mesh_shape = (self.world_size,)
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
zero_dtensor = zeros(
[4, 4], device_mesh=mesh_1d, placements=placements_1d
)
state_dict_to_load = {"dtensor": zero_dtensor}
dist_cp.load(
state_dict=state_dict_to_load,
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
planner=dist_cp.DefaultLoadPlanner(),
)
# materialzie the whole tensor to compare with the original global_tensor
state_dict_to_load["dtensor"] = state_dict_to_load[
"dtensor"
].redistribute(
mesh_1d,
placements=[Replicate()],
)
self.assertEqual(
global_tensor, state_dict_to_load["dtensor"].to_local()
)
@with_comms
@with_temp_dir
@skip_if_lt_x_gpu(2)
def test_dtensor_checkpoint_resharding_with_empty_shard(self):
"""
Test dtensor checkpoint resharding with dtensor containing empty shards.
"""
try:
import safetensors
except ImportError:
print("safetensors not installed")
sys.exit(0)
tensor = torch.rand(1).cuda()
mesh = init_device_mesh(self.device_type, (self.world_size,))
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
ref_state_dict = {"dtensor": dtensor}
dist_cp.save(
state_dict=ref_state_dict,
storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True),
)
tensor = torch.rand(1).cuda()
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
state_dict = {"dtensor": dtensor}
dist_cp.load(
state_dict=state_dict,
storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir),
)
if __name__ == "__main__":
run_tests()

View File

@ -8,10 +8,7 @@ import tempfile
from unittest.mock import MagicMock
import torch
from torch.distributed.checkpoint._hf_planner import (
_FqnToFileMapping,
_HuggingFaceLoadPlanner,
)
from torch.distributed.checkpoint import DefaultLoadPlanner
from torch.distributed.checkpoint._hf_storage import (
_HuggingFaceStorageReader,
_HuggingFaceStorageWriter,
@ -21,24 +18,25 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import LoadPlan, SavePlan
from torch.distributed.checkpoint.planner_helpers import (
_create_read_items,
_create_write_item_for_tensor,
from torch.distributed.checkpoint.planner import (
LoadItemType,
LoadPlan,
ReadItem,
SavePlan,
)
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
from torch.distributed.checkpoint.storage import WriteResult
from torch.testing._internal.common_utils import run_tests, TestCase
class TestHfStorage(TestCase):
def test_write_data_hf(self) -> None:
mock_module = MagicMock()
sys.modules["safetensors"] = mock_module
sys.modules["huggingface_hub"] = mock_module
mock_module = MagicMock()
mock_module.save.return_value = b""
sys.modules["safetensors.torch"] = mock_module
@ -46,7 +44,7 @@ class TestHfStorage(TestCase):
with tempfile.TemporaryDirectory() as path:
writer = _HuggingFaceStorageWriter(
path=path,
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1},
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2},
)
writer.fs = FileSystem()
@ -59,7 +57,7 @@ class TestHfStorage(TestCase):
save_plan = SavePlan(
[write_item_1, write_item_2],
storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}),
storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}},
)
save_planner = DefaultSavePlanner()
save_planner.set_up_planner(state_dict=state_dict)
@ -76,7 +74,7 @@ class TestHfStorage(TestCase):
),
size_in_bytes=tensor0.numel() * tensor0.element_size(),
storage_data=_StorageInfo(
relative_path="model-00001-of-00001.safetensors",
relative_path="model-00001-of-00002.safetensors",
offset=0,
length=tensor0.numel() * tensor0.element_size(),
),
@ -87,7 +85,68 @@ class TestHfStorage(TestCase):
),
size_in_bytes=tensor1.numel() * tensor1.element_size(),
storage_data=_StorageInfo(
relative_path="model-00001-of-00001.safetensors",
relative_path="model-00002-of-00002.safetensors",
offset=0,
length=tensor1.numel() * tensor1.element_size(),
),
),
]
self.assertEqual(
actual_write_results,
expected_write_results,
)
def test_write_data_with_sharding(self) -> None:
mock_module = MagicMock()
mock_module.save.return_value = b""
sys.modules["safetensors.torch"] = mock_module
with tempfile.TemporaryDirectory() as path:
writer = _HuggingFaceStorageWriter(
path=path,
save_sharded=True,
)
writer.fs = FileSystem()
tensor0 = torch.rand(4)
tensor1 = torch.rand(10)
write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0)
write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1)
state_dict = {"tensor_0": tensor0, "tensor_1": tensor1}
save_plan = SavePlan(
[write_item_1, write_item_2],
storage_data={"shard_index": 1},
)
save_planner = DefaultSavePlanner()
save_planner.set_up_planner(state_dict=state_dict)
write_results = writer.write_data(save_plan, save_planner)
write_results.wait()
actual_write_results = write_results.value()
expected_write_results = [
WriteResult(
index=MetadataIndex(
fqn="tensor_0", offset=torch.Size([0]), index=None
),
size_in_bytes=tensor0.numel() * tensor0.element_size(),
storage_data=_StorageInfo(
relative_path="shard-00001-model-00001-of-00001.safetensors",
offset=0,
length=tensor0.numel() * tensor0.element_size(),
),
),
WriteResult(
index=MetadataIndex(
fqn="tensor_1", offset=torch.Size([0]), index=None
),
size_in_bytes=tensor1.numel() * tensor1.element_size(),
storage_data=_StorageInfo(
relative_path="shard-00001-model-00001-of-00001.safetensors",
offset=0,
length=tensor1.numel() * tensor1.element_size(),
),
@ -100,43 +159,84 @@ class TestHfStorage(TestCase):
)
def test_read_data_hf(self) -> None:
mock_module = MagicMock()
sys.modules["safetensors"] = mock_module
sys.modules["huggingface_hub"] = mock_module
mock_safetensors = MagicMock()
sys.modules["safetensors"] = mock_safetensors
name = "tensor_0"
tensor_0 = torch.rand(4)
mock_module = MagicMock()
mock_module.load.return_value = {name: tensor_0}
sys.modules["safetensors.torch"] = mock_module
# Create test tensors
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
# Mock the deserialize function to return our test tensors
# The format matches what's expected in the read_data method
mock_safetensors.deserialize.return_value = [
("tensor_0", {
"data": tensor_0.numpy().tobytes(),
"dtype": "F32",
"shape": [4]
}),
]
with tempfile.TemporaryDirectory() as path:
# Create the reader
reader = _HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
file_name = "model-00001-of-00001"
pathlib.Path(os.path.join(path, file_name)).touch()
# Create test file
file_name = "model-00001-of-00001.safetensors"
file_path = os.path.join(path, file_name)
pathlib.Path(file_path).touch()
reader.set_up_storage_reader(
Metadata(
state_dict_metadata={name: BytesStorageMetadata()},
storage_data={name: file_name},
),
is_coordinator=True,
)
# Set up storage data with _StorageInfo objects
storage_data = {
"tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()),
}
read_items = _create_read_items(name, BytesStorageMetadata(), file_name)
reader.storage_data = storage_data
# Create target tensors that will be updated by read_data
target_tensor_0 = torch.zeros(4)
state_dict = {
"tensor_0": target_tensor_0,
}
# Create read items for the load plan
read_items = []
for name, tensor in state_dict.items():
storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
read_items.append(
ReadItem(
type=LoadItemType.TENSOR,
storage_index=storage_index,
dest_index=dest_index,
storage_offsets=[0, 0],
dest_offsets=[0, 0],
lengths=tensor.size(),
)
)
# Create load plan and planner
load_plan = LoadPlan(read_items)
load_planner = _HuggingFaceLoadPlanner()
load_planner.set_up_planner(state_dict={name: torch.rand(4)})
load_planner = DefaultLoadPlanner()
load_planner.set_up_planner(
state_dict=state_dict,
metadata=Metadata(
state_dict_metadata={
"tensor_0": TensorStorageMetadata(
properties=TensorProperties(dtype=torch.float32),
size=torch.Size([4]),
chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])},
storage_data=storage_data)
)
read_data = reader.read_data(load_plan, load_planner)
read_data.wait()
# Call read_data
future = reader.read_data(load_plan, load_planner)
future.wait()
loaded_tensor = load_planner.original_state_dict[name]
self.assertEqual(loaded_tensor, tensor_0)
# Verify results - the target tensors should now contain the values from our test tensor
self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0))
def test_metadata_hf(self) -> None:
def test_write_metadata_hf(self) -> None:
mock_module = MagicMock()
sys.modules["huggingface_hub"] = mock_module
with tempfile.TemporaryDirectory() as path:
@ -160,7 +260,6 @@ class TestHfStorage(TestCase):
writer = _HuggingFaceStorageWriter(
path=path,
fqn_to_index_mapping=_FqnToFileMapping({}),
)
writer.fs = FileSystem()
writer.finish(
@ -185,26 +284,16 @@ class TestHfStorage(TestCase):
metadata = json.load(f)
self.assertEqual(metadata, expected_metadata)
reader = _HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
metadata = reader.read_metadata()
self.assertEqual(metadata.storage_data, expected_metadata["weight_map"])
def test_read_metadata_when_metadata_file_does_not_exist(self) -> None:
mock_module = MagicMock()
sys.modules["huggingface_hub"] = mock_module
def test_read_metadata_hf(self):
with tempfile.TemporaryDirectory() as path:
reader = _HuggingFaceStorageReader(path=path)
reader.fs = FileSystem()
# there is one safetensor file, but no metadata file,
# so we create metadata from the safetensor file
keys = ["tensor_0", "tensor_1"]
key = "tensor_0"
file_name = "test.safetensors"
with open(os.path.join(path, file_name), "wb") as f:
# write metadata the same way it would be in safetensors file
metadata_contents = json.dumps(
{"tensor_0": "value_0", "tensor_1": "value_1"}
{'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}}
)
metadata_bytes = metadata_contents.encode("utf-8")
@ -216,13 +305,16 @@ class TestHfStorage(TestCase):
self.assertEqual(
metadata.state_dict_metadata,
{
keys[0]: BytesStorageMetadata(),
keys[1]: BytesStorageMetadata(),
key: TensorStorageMetadata(
properties=TensorProperties(dtype=torch.float32),
size=torch.Size([5, 10]),
chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))],
),
},
)
self.assertEqual(
metadata.storage_data,
{keys[0]: file_name, keys[1]: file_name},
{key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)},
)

View File

@ -12,6 +12,7 @@ from .metadata import (
)
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .scripts._consolidate_hf_safetensors import consolidate_safetensors_files
from .state_dict_loader import load, load_state_dict
from .state_dict_saver import async_save, save, save_state_dict
from .storage import StorageReader, StorageWriter

View File

@ -62,25 +62,3 @@ def dedup_save_plans(
)
for plan, item_indexes in zip(all_plans, plan_to_item_indices)
]
def dedup_save_plans_with_fqn_to_index_mapping(
all_plans: list[SavePlan], fqn_to_index_mapping: dict[str, int]
) -> list[SavePlan]:
num_plans = len(all_plans)
to_remove: list[set] = [set() for _ in range(len(all_plans))]
for plan_idx, plan in enumerate(all_plans):
for item_idx, item in enumerate(plan.items):
if (fqn_to_index_mapping[item.index.fqn] - 1) % num_plans != plan_idx:
to_remove[plan_idx].add(item_idx)
for plan_idx, remove_set in enumerate(to_remove):
new_items = [
write_item
for item_idx, write_item in enumerate(all_plans[plan_idx].items)
if item_idx not in remove_set
]
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans

View File

@ -1,49 +1,34 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from torch.distributed.checkpoint._dedup_save_plans import (
dedup_save_plans_with_fqn_to_index_mapping,
)
from torch.distributed.checkpoint.default_planner import (
create_default_local_load_plan,
DefaultLoadPlanner,
DefaultSavePlanner,
)
from torch.distributed.checkpoint.planner import ReadItem, SavePlan
from torch.distributed.checkpoint.planner import LoadPlan
__all__ = ["_HuggingFaceSavePlanner", "_HuggingFaceLoadPlanner"]
@dataclass
class _FqnToFileMapping:
fqn_to_file_index_mapping: dict[str, int]
class _HuggingFaceSavePlanner(DefaultSavePlanner):
"""
A save planner that dedups the save plans based on the fqn to file index mapping.
A planner to work with HuggingFace's safetensors format.
This is a placeholder, as it is likely that the DefaultSavePlanner is enough.
"""
def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
assert len(all_plans) > 0, "all_plans should not be empty"
assert all_plans[0].storage_data is not None, "storage_data should not be None"
assert isinstance(all_plans[0].storage_data, _FqnToFileMapping), (
"storage_data should be of type _FqnToFileMapping"
)
fqn_to_index_mapping: dict[str, int] = all_plans[
0
].storage_data.fqn_to_file_index_mapping
return dedup_save_plans_with_fqn_to_index_mapping(
all_plans, fqn_to_index_mapping
)
class _HuggingFaceLoadPlanner(DefaultLoadPlanner):
def __init__(self, allow_tensor_resize: bool = False):
super().__init__()
self.allow_tensor_resize = allow_tensor_resize
def resolve_tensor(self, read_item: ReadItem):
return self.lookup_tensor(read_item.dest_index)
def create_local_plan(self) -> LoadPlan:
assert self.metadata is not None
# check_md_size is added to avoid the check if we're allowing tensor resize.
# This will be deprecated in favor of _load_state_dict_from_keys and then we
# can remove this planner all together.
return create_default_local_load_plan(
self.state_dict, self.metadata, not self.allow_partial_load, check_md_size=not self.allow_tensor_resize,
)

View File

@ -2,24 +2,24 @@
import dataclasses
import io
import json
import os
import queue
import struct
from typing import Optional
from typing import Any, Optional
import fsspec # type: ignore[import-untyped]
import torch
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
from torch.distributed.checkpoint._hf_planner import (
_FqnToFileMapping,
_HuggingFaceLoadPlanner,
)
from torch.distributed.checkpoint.filesystem import SerializationFormat
from torch.distributed.checkpoint._hf_planner import _HuggingFaceLoadPlanner
from torch.distributed.checkpoint.filesystem import _StorageInfo, SerializationFormat
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
MetadataIndex,
STORAGE_TYPES,
StorageMeta,
TensorProperties,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.planner import (
LoadPlan,
@ -37,9 +37,30 @@ __all__ = ["_HuggingFaceStorageWriter", "_HuggingFaceStorageReader"]
_metadata_fn: str = "model.safetensors.index.json"
FILE_NAME = "model-{cpt_idx}-of-{num_shards}"
FILE_NAME = "model-{cpt_idx}-of-{num_files}"
SHARDED_FILE_NAME = "shard-{shard_idx}-model-{cpt_idx}-of-{num_files}"
SUFFIX = ".safetensors"
# metadata keys
CUSTOM_METADATA_KEY = "DCP_SHARDING_INFO"
DEFAULT_EXTRA_METADATA_KEY = "__metadata__"
SAVED_OFFSETS_KEY = "saved_offsets"
SHAPE_KEY = "shape"
DATA_KEY = "data"
DTYPE_KEY = "dtype"
DATA_OFFSETS_KEY = "data_offsets"
DTYPE_MAP = {
"F16": torch.float16,
"F32": torch.float32,
"F64": torch.float64,
"I8": torch.int8,
"U8": torch.uint8,
"I16": torch.int16,
"I32": torch.int32,
"I64": torch.int64,
"BF16": torch.bfloat16,
}
class _HuggingFaceStorageWriter(FsspecWriter):
"""
@ -50,23 +71,25 @@ class _HuggingFaceStorageWriter(FsspecWriter):
def __init__(
self,
path: str,
fqn_to_index_mapping: dict[str, int],
fqn_to_index_mapping: Optional[dict[str, int]] = None,
token: Optional[str] = None,
save_sharded: bool = False,
) -> None:
"""
Initialize the huggingface writer pointing to path.
Args:
path: hf directory where the checkpoint will be written to. Should begin with hf://.
token: The token to use to authenticate with huggingface hub.
path: hf directory where the checkpoint will be read from.
Needs to have .safetensors files, but can be from any fsspec supported storage,
including localFS and hf://.
fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to.
Indices are from 1 to N, where N is the number of files.
Indices are from 1 to N, where N is the number of files. If not provided,
the tensors will be written to a single file.
token: The token to use to authenticate with huggingface hub.
save_sharded: If True, save the checkpoint as a sharded checkpoint where every rank saves its own shard.
Default is False which assumes full tensors are being saved.
"""
from huggingface_hub import HfFileSystem # type: ignore[import-not-found]
if HfFileSystem.protocol not in fsspec.available_protocols():
fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem)
if token is not None:
super().__init__(
@ -79,16 +102,21 @@ class _HuggingFaceStorageWriter(FsspecWriter):
path=path,
serialization_format=SerializationFormat.SAFETENSORS,
)
self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
plan = super().prepare_local_plan(plan)
return dataclasses.replace(
plan, storage_data=_FqnToFileMapping(self._fqn_to_index_mapping)
)
self._fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping
self._save_sharded = save_sharded
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
return plans
new_plans = []
for i, plan in enumerate(plans, start=1):
storage_data = {}
if self._fqn_to_index_mapping is not None:
storage_data["fqn_to_file_mapping"] = self._fqn_to_index_mapping
if self._save_sharded:
storage_data["shard_index"] = i
new_plans.append(dataclasses.replace(plan, storage_data=storage_data))
return new_plans
def write_data(
self,
@ -101,14 +129,20 @@ class _HuggingFaceStorageWriter(FsspecWriter):
return fut
# storage_plan is a map from key to file index
storage_plan: dict[str, int] = plan.storage_data.fqn_to_file_index_mapping
storage_data: dict[str, Any] = plan.storage_data
storage_plan: Optional[dict[str, int]] = None
shard_index: Optional[int] = None
if "fqn_to_file_mapping" in storage_data:
storage_plan = storage_data["fqn_to_file_mapping"]
if "shard_index" in storage_data:
shard_index = storage_data["shard_index"]
buckets = self._split_by_storage_plan(storage_plan, plan.items)
highest_index = max(storage_plan.values())
highest_index = max(storage_plan.values()) if storage_plan is not None else 1
file_queue: queue.Queue = queue.Queue()
for file_index, write_items in buckets.items():
file_name = self._gen_file_name(file_index, highest_index)
file_name = _gen_file_name(file_index, highest_index, shard_index)
file_queue.put(
(self.fs.concat_path(self.path, file_name), file_name, write_items)
)
@ -116,6 +150,9 @@ class _HuggingFaceStorageWriter(FsspecWriter):
return super()._write_data(planner, file_queue)
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
if self._save_sharded:
return
metadata_to_write = {}
storage_md = {}
total_size = 0
@ -132,9 +169,12 @@ class _HuggingFaceStorageWriter(FsspecWriter):
json.dump(metadata_to_write, metadata_file, indent=2)
def _split_by_storage_plan(
self, storage_plan: dict[str, int], items: list[WriteItem]
self, storage_plan: Optional[dict[str, int]], items: list[WriteItem]
) -> dict[int, list[WriteItem]]:
# storage_plan is a map from key to index
if storage_plan is None:
return {1: items}
buckets = {}
for item in items:
key = item.index.fqn
@ -146,14 +186,6 @@ class _HuggingFaceStorageWriter(FsspecWriter):
return buckets
def _gen_file_name(self, index: int, largest_index: int) -> str:
return (
FILE_NAME.format(
cpt_idx=f"{index}".zfill(5), num_shards=f"{largest_index}".zfill(5)
)
+ SUFFIX
)
@property
def metadata_path(self) -> str:
return _metadata_fn
@ -170,48 +202,61 @@ class _HuggingFaceStorageReader(FsspecReader):
Initialize the huggingface reader pointing to path.
Args:
path: hf directory where the checkpoint will be read from. Should begin with hf://.
path: hf directory where the checkpoint will be read from.
Needs to have .safetensors file, but can be from any fsspec supported storage,
including localFS and hf://.
token: The token to use to authenticate with huggingface hub.
"""
from huggingface_hub import HfFileSystem # type: ignore[import-not-found]
if HfFileSystem.protocol not in fsspec.available_protocols():
fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem)
if token is not None:
super().__init__(path=path, token=token)
else:
super().__init__(path=path)
self.storage_data: dict[str, str] = {}
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
from safetensors.torch import load # type: ignore[import-not-found]
from safetensors import deserialize # type: ignore[import-not-found]
per_file: dict[str, list[ReadItem]] = {}
for read_item in plan.items:
file_name = self.storage_data[read_item.storage_index.fqn]
item_md: _StorageInfo = self.storage_data[read_item.storage_index]
file_name = item_md.relative_path
per_file.setdefault(file_name, []).append(read_item)
for file_name, reqs in per_file.items():
new_path = self.fs.concat_path(self.path, file_name)
with self.fs.create_stream(new_path, "rb") as stream:
loaded_tensors = load(stream.read())
for req in reqs:
tensor = loaded_tensors[req.dest_index.fqn]
with self.fs.create_stream(file_name, "rb") as stream:
# TODO: make this more efficient by doing offset reads instead of a
# full deserialization of the file
deserialized : list[str, dict[str, dict[str, any]]] = deserialize(stream.read())
deserialized_dict : dict[str, dict[str, dict[str, any]]] = {tensor_info[0]: tensor_info[1] for tensor_info in deserialized}
for req in reqs:
tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY]
tensor = torch.frombuffer(tensor_bytes, dtype=planner.metadata.state_dict_metadata[req.dest_index.fqn].properties.dtype)
# TODO: update this to req.lengths once I get rid of allow_tensor_resize, shouldn't need to look at the deserialized
# dict for metadat as we've already done that in read_metadata file
tensor = tensor.reshape(deserialized_dict[req.dest_index.fqn][SHAPE_KEY])
target_tensor = planner.resolve_tensor(req)
if (
isinstance(planner, _HuggingFaceLoadPlanner)
and planner.allow_tensor_resize
):
# this is to support the case when users are calling load on
# an empty state dict without specifying the correct size of the tensors
# in the state dict. Resizing is a hacky way to support this use case.
# But will migrate users to _load_state_dict_from_keys method and deprecate this.
target_tensor = planner.resolve_tensor(req)
target_tensor.resize_(tensor.size())
target_tensor = target_tensor.detach()
else:
assert target_tensor.size() == tensor.size(), (
f"Tensor size mismatch for {req.dest_index.fqn}: {target_tensor.size()} != {tensor.size()}"
tensor = narrow_tensor_by_index(
tensor, req.storage_offsets, req.lengths
)
target_tensor = target_tensor.detach()
target_tensor = planner.resolve_tensor(req).detach()
assert target_tensor.size() == tensor.size(), (
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
)
target_tensor.copy_(tensor)
planner.commit_tensor(req, target_tensor)
@ -220,37 +265,57 @@ class _HuggingFaceStorageReader(FsspecReader):
return fut
def read_metadata(self) -> Metadata:
metadata_path = self.fs.concat_path(self.path, _metadata_fn)
state_dict_metadata: dict[str, STORAGE_TYPES] = {}
storage_data: dict[str, str] = {}
storage_data: dict[MetadataIndex, _StorageInfo] = {}
if not self.fs.exists(metadata_path):
# if metadata file doesn't exist, create it from the safetensors file
safetensors_files = []
for file in self.fs.ls(self.path):
if file.endswith(SUFFIX):
safetensors_files.append(file)
safetensors_files = []
for file in self.fs.ls(self.path):
if file.endswith(SUFFIX):
safetensors_files.append(file)
if len(safetensors_files) != 1:
raise ValueError(
f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files"
)
storage_data = {}
with self.fs.create_stream(safetensors_files[0], "rb") as f:
keys = _get_safetensors_file_keys(f)
for safetensor_file in safetensors_files:
with self.fs.create_stream(safetensor_file, "rb") as f:
safetensors_metadata = _get_safetensors_file_metadata(f)
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
for key in keys:
state_dict_metadata[key] = BytesStorageMetadata()
storage_data[key] = os.path.basename(safetensors_files[0])
else:
with self.fs.create_stream(metadata_path, "r") as metadata_file:
metadata = json.load(metadata_file)
dcp_sharding_info = None
if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY):
dcp_sharding_info = json.loads(custom_metadata.get(CUSTOM_METADATA_KEY))
for key in metadata["weight_map"].keys():
state_dict_metadata[key] = BytesStorageMetadata()
storage_data = metadata["weight_map"]
for key, val in safetensors_metadata.items():
if key == DEFAULT_EXTRA_METADATA_KEY:
continue
# construct state_dict_metadata
if dcp_sharding_info is not None:
offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY]
else:
offset = [0] * len(val[SHAPE_KEY])
if key not in state_dict_metadata:
state_dict_metadata[key] = TensorStorageMetadata(
properties=TensorProperties(dtype=_get_dtype(val[DTYPE_KEY])),
size=torch.Size([saved + offset for saved, offset in zip(val[SHAPE_KEY], offset)]),
chunks=[ChunkStorageMetadata(offsets=torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]))],
)
else:
state_dict_metadata[key].chunks.append(ChunkStorageMetadata(torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY])))
size = list(state_dict_metadata[key].size)
for i in range(len(size)):
size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i])
state_dict_metadata[key].size = torch.Size(size)
# construct storage data
if dcp_sharding_info is not None:
metadata_index = MetadataIndex(fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY])
else:
metadata_index = MetadataIndex(fqn=key, offset=[0] * len(val[SHAPE_KEY]))
storage_data[metadata_index] = _StorageInfo(
safetensor_file,
val[DATA_OFFSETS_KEY][0],
val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
)
metadata = Metadata(
state_dict_metadata=state_dict_metadata,
storage_data=storage_data,
@ -262,8 +327,21 @@ class _HuggingFaceStorageReader(FsspecReader):
return metadata
def _gen_file_name(index: int, largest_index: int, shard_index: Optional[int] = None) -> str:
if shard_index is not None:
return SHARDED_FILE_NAME.format(
shard_idx=f"{shard_index}".zfill(5), cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5)
) + SUFFIX
else:
return (
FILE_NAME.format(
cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5)
)
+ SUFFIX
)
def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]:
def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> Any:
# this uses the same logic that's done in HF code base
# https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308
# and follows their documentation on how their files are serialized
@ -273,4 +351,19 @@ def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]:
header_len = struct.unpack("<Q", header_len_bytes)[0]
header_json = file_bytes.read(header_len)
metadata = json.loads(header_json)
return list(metadata.keys())
return metadata
def _get_dtype(dtype: str) -> torch.dtype:
try:
dtype = DTYPE_MAP[dtype]
except KeyError:
dtype = torch.get_default_dtype
return dtype
def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]:
if DEFAULT_EXTRA_METADATA_KEY in metadata:
custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY]
if CUSTOM_METADATA_KEY in custom_metadata:
return json.loads(custom_metadata[CUSTOM_METADATA_KEY])
return None

View File

@ -442,7 +442,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
if k in metadata.planner_data:
if metadata.planner_data is not None and k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v
@ -451,7 +451,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
def create_default_local_load_plan(
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True, check_md_size: bool = True
) -> LoadPlan:
requests = []
"""
@ -477,6 +477,7 @@ def create_default_local_load_plan(
isinstance(md, TensorStorageMetadata)
and getattr(obj, "size", None) is not None
and md.size != obj.size()
and check_md_size
):
raise ValueError(
f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",

View File

@ -2,6 +2,7 @@
import collections
import dataclasses
import io
import json
import operator
import os
import pickle
@ -416,6 +417,7 @@ def _write_files_from_queue(
)
tensor_dict = {}
metadata_dict = {}
for tensor, write_item in loader.values():
assert tensor.is_cpu
write_results.append(
@ -429,11 +431,12 @@ def _write_files_from_queue(
)
)
tensor_dict[write_item.index.fqn] = tensor
metadata_dict[write_item.index.fqn] = {"saved_offsets": write_item.tensor_data.chunk.offsets}
if serialization_format == SerializationFormat.SAFETENSORS:
from safetensors.torch import save # type: ignore[import-not-found]
stream.write(save(tensor_dict))
stream.write(save(tensor_dict, metadata={"DCP_SHARDING_INFO": json.dumps(metadata_dict), "DCP_VERSION": "1.0"}))
if use_fsync:
try:

View File

@ -0,0 +1,658 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# pyre-strict
"""
This script consolidates distributed checkpoint (DCP) HuggingFace safetensors files.
It takes sharded safetensors files created by DCP and combines them into one or more
consolidated files. This is useful for converting distributed model checkpoints into
a format that can be loaded by standard HuggingFace tools. The combination is done
through a simple metadata parsing and tensor data copying process.
"""
import argparse
import json
import math
import struct
from dataclasses import dataclass, field
from typing import Any, List, Optional
import fsspec
import torch
from fsspec.core import url_to_fs
from safetensors import deserialize
from torch.distributed.checkpoint._hf_storage import (
_gen_file_name,
_get_dcp_custom_metadata,
_get_dtype,
_get_safetensors_file_metadata,
DATA_KEY,
DATA_OFFSETS_KEY,
DEFAULT_EXTRA_METADATA_KEY,
DTYPE_KEY,
SAVED_OFFSETS_KEY,
SHAPE_KEY,
SUFFIX,
)
@dataclass
class _FqnData:
"""
Dataclass to store information about a tensor (identified by its fully qualified name).
Attributes:
offset_in_file: Byte offset where this tensor's data begins in the output file
shape_in_file: Shape of the tensor in the output file
dtype_size: Size of the tensor's data type in bytes
dtype_str: String representation of the tensor's data type
"""
offset_in_file: int = 0
shape_in_file: list[int] = field(default_factory=list)
dtype_size: int = 0
dtype_str: str = ""
@dataclass
class _OutputFileData:
"""
Dataclass to store information about an output safetensors file.
Attributes:
metadata_size: Size of the metadata section in bytes
fqn_data: Dictionary mapping tensor names to their metadata
"""
metadata_size: int = 0
fqn_data: dict[str, _FqnData] = field(default_factory=dict)
def _parse_input_metadata(
safetensors_metadatas: List[Any], output_files_data: dict[str, _OutputFileData]
) -> None:
"""
Parse metadata from input safetensors files to determine the full tensor shapes and types.
This function analyzes the metadata from all input files to determine the complete shape
of each tensor after consolidation. It updates the output_files_data with this information.
Args:
safetensors_metadatas: List of metadata from input safetensors files
output_files_data: Dictionary mapping output file paths to their metadata
Raises:
ValueError: If no DCP custom metadata is found in a safetensors file
"""
# Dictionary to track the full size of each tensor across all shards
fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {}
for safetensors_metadata in safetensors_metadatas:
dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata)
if not dcp_sharding_info:
raise ValueError(
"No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated."
)
for key, val in safetensors_metadata.items():
if key == DEFAULT_EXTRA_METADATA_KEY:
continue
# Get the shape of this tensor shard and its offset in the full tensor
sizes = val[SHAPE_KEY]
offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY]
if key not in fqn_to_size_mapping:
# First time seeing this tensor - calculate its full size by adding offsets to dimensions
cur_size = [size + offset for size, offset in zip(sizes, offsets)]
fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY])
else:
# We've seen this tensor before - update its size if this shard extends beyond current known dimensions
cur_size = fqn_to_size_mapping[key][0]
for i in range(len(sizes)):
cur_size[i] = max(cur_size[i], sizes[i] + offsets[i])
# Now that we know the full size of each tensor, populate the output file data
for fqn, tensor_info in fqn_to_size_mapping.items():
tensor_size = tensor_info[0]
dtype_str = tensor_info[1]
for _, output_data in output_files_data.items():
# Add this tensor to the output file if it's already assigned there or if we're using a single output file
if fqn in output_data.fqn_data or len(output_files_data) == 1:
output_data.fqn_data[fqn] = _FqnData(
shape_in_file=tensor_size,
dtype_size=torch.finfo(_get_dtype(dtype_str)).bits
// 8, # Convert bits to bytes
dtype_str=dtype_str,
)
def _write_metadata(
fs: fsspec.AbstractFileSystem,
output_files_data: dict[str, _OutputFileData],
) -> None:
"""
Write metadata to the beginning of each output safetensors file.
This function writes the metadata section to each output file, including information
about tensor shapes, data types, and offsets. It also updates the offset_in_file
field for each tensor in the output_files_data.
Args:
fs: Filesystem interface for file operations
output_files_data: Dictionary mapping output file paths to their metadata
"""
# Process each output file
for file_path, output_data in output_files_data.items():
with fs.open(file_path, "wb") as f:
metadata = {}
curr_offset = 0
# Calculate offsets for each tensor in the file
for fqn, fqn_data in output_data.fqn_data.items():
# Calculate the end offset by multiplying all dimensions and the data type size
end_offset = (
curr_offset
+ math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size
)
# Store metadata for this tensor
metadata[fqn] = {
SHAPE_KEY: fqn_data.shape_in_file,
DTYPE_KEY: fqn_data.dtype_str,
DATA_OFFSETS_KEY: [
curr_offset,
end_offset,
], # Start and end byte offsets
}
# Store the offset for later use when writing the actual tensor data
fqn_data.offset_in_file = curr_offset
# Update current offset for the next tensor
curr_offset = end_offset
# Convert metadata to JSON and encode as bytes
json_metadata = json.dumps(metadata)
json_bytes = json_metadata.encode("utf-8")
# Write the metadata size as an 8-byte unsigned integer (little-endian)
size_in_bytes = len(json_bytes)
header_len = struct.pack("<Q", size_in_bytes)
# Write the header length and metadata to the file
f.write(header_len)
f.write(json_bytes)
# Store the total metadata size (header + JSON) for later use
output_data.metadata_size = f.tell()
def _write_data(
input_fs: fsspec.AbstractFileSystem,
output_fs: fsspec.AbstractFileSystem,
input_safetensors_files: List[str],
input_metadatas: dict[str, Any],
output_files_data: dict[str, _OutputFileData],
) -> None:
"""
Write tensor data from input files to the output files.
This function reads tensor data from each input file and writes it to the appropriate
position in the output files based on the tensor's offsets.
Args:
fs: Filesystem interface for file operations
input_safetensors_files: List of input safetensors file paths
input_metadatas: Dictionary mapping input file paths to their metadata
output_files_data: Dictionary mapping output file paths to their metadata
"""
# Process each input safetensors file
for safetensors_file in input_safetensors_files:
with input_fs.open(safetensors_file, "rb") as f:
# Deserialize the safetensors file to get tensor data
deserialized = deserialize(f.read())
# Process each tensor in the file
for fqn, val in deserialized:
# Get the tensor data as bytes
data_to_write = val[DATA_KEY]
# Get the offsets of this tensor shard within the full tensor
offsets_of_tensor_being_read = _get_dcp_custom_metadata(
input_metadatas[safetensors_file]
)[fqn][SAVED_OFFSETS_KEY]
# Find which output file(s) this tensor belongs to
for output_file, output_data in output_files_data.items():
# Skip if this tensor doesn't belong in this output file
if fqn not in output_data.fqn_data:
continue
# Get metadata for this tensor in the output file
fqn_data = output_data.fqn_data[fqn]
# Write this tensor shard to the appropriate position in the output file
_write_sub_tensor_to_file(
output_fs,
data_to_write,
fqn_data.dtype_size, # Size of each element in bytes
fqn_data.shape_in_file, # Full tensor shape
offsets_of_tensor_being_read, # Where this shard belongs in the full tensor
val[SHAPE_KEY], # Shape of this shard
output_file,
# Calculate the exact byte position where this tensor data should start
output_data.metadata_size + fqn_data.offset_in_file,
)
def _write_row_wise_tensor(
fs: fsspec.AbstractFileSystem,
sub_tensor_bytes: bytearray,
element_size: int,
full_tensor_strides: list[int],
sub_tensor_strides: list[int],
sub_tensor_offsets: list[int],
sub_tensor_shape: list[int],
output_file_path: str,
output_start_byte: int,
):
"""
Writes a row-wise sharded tensor to the output file.
This is an optimized path for tensors that are sharded along the first dimension,
with all other dimensions being complete. This allows writing entire rows at once.
Args:
fs: Filesystem interface for file operations
sub_tensor_bytes: Byte array containing the sub-tensor data
element_size: The size of each element in bytes
full_tensor_strides: Strides of the full tensor
sub_tensor_strides: Strides of the sub-tensor
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor
sub_tensor_shape: The shape of the sub-tensor
output_file_path: The path to the file where the full tensor is stored
output_start_byte: The starting byte of the full tensor in the file
"""
# Open the output file in read+binary mode to allow seeking and writing
with fs.open(output_file_path, "r+b") as out_f:
# Calculate the number of elements in each row
elements_per_row = full_tensor_strides[
0
] # This is the stride of the first dimension
# For each row in the sub-tensor
for row_idx in range(sub_tensor_shape[0]):
# Calculate the row index in the full tensor
full_row_idx = sub_tensor_offsets[0] + row_idx
# Calculate the position in the full tensor
full_pos = full_row_idx * full_tensor_strides[0]
full_byte_offset = output_start_byte + full_pos * element_size
# Calculate the position in the sub-tensor
sub_pos = row_idx * sub_tensor_strides[0]
sub_byte_offset = sub_pos * element_size
# Extract the row data from the sub-tensor
row_size = elements_per_row * element_size
row_data = sub_tensor_bytes[sub_byte_offset : sub_byte_offset + row_size]
# Seek to the correct position in the output file and write the data
out_f.seek(full_byte_offset)
out_f.write(row_data)
def _write_column_wise_tensor(
fs: fsspec.AbstractFileSystem,
sub_tensor_bytes: bytearray,
element_size: int,
tensor_shape: list[int],
sub_tensor_offsets: list[int],
sub_tensor_shape: list[int],
output_file_path: str,
output_start_byte: int,
):
"""
Writes a column-wise sharded 2D tensor to the output file.
This is an optimized path for 2D tensors that are sharded along the second dimension,
with the first dimension being complete. This requires writing column by column.
Args:
fs: Filesystem interface for file operations
sub_tensor_bytes: Byte array containing the sub-tensor data
element_size: The size of each element in bytes
tensor_shape: The shape of the overall tensor
sub_tensor_strides: Strides of the sub-tensor
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor
sub_tensor_shape: The shape of the sub-tensor
output_file_path: The path to the file where the full tensor is stored
output_start_byte: The starting byte of the full tensor in the file
"""
# Open the output file in read+binary mode to allow seeking and writing
with fs.open(output_file_path, "r+b") as out_f:
# For each column in the sub-tensor
for col_idx in range(sub_tensor_shape[1]):
# Calculate the column index in the full tensor
full_col_idx = sub_tensor_offsets[1] + col_idx
# For each row in the column
for row_idx in range(sub_tensor_shape[0]):
# Calculate the position in the full tensor
full_pos = row_idx * tensor_shape[1] + full_col_idx
full_byte_offset = output_start_byte + full_pos * element_size
# Calculate the position in the sub-tensor
sub_pos = row_idx * sub_tensor_shape[1] + col_idx
sub_byte_offset = sub_pos * element_size
# Extract the element data from the sub-tensor
element_data = sub_tensor_bytes[
sub_byte_offset : sub_byte_offset + element_size
]
# Seek to the correct position in the output file and write the data
out_f.seek(full_byte_offset)
out_f.write(element_data)
def _write_element_by_element(
fs: fsspec.AbstractFileSystem,
sub_tensor_bytes: bytearray,
element_size: int,
tensor_shape: list[int],
full_tensor_strides: list[int],
sub_tensor_strides: list[int],
sub_tensor_offsets: list[int],
sub_tensor_shape: list[int],
output_file_path: str,
output_start_byte: int,
):
"""
Writes a sub-tensor to the output file using a general element-by-element approach.
This is a general approach that works for any sharding pattern, but is less efficient
than the specialized approaches for row-wise or column-wise sharding.
Args:
fs: Filesystem interface for file operations
sub_tensor_bytes: Byte array containing the sub-tensor data
element_size: The size of each element in bytes
tensor_shape: The shape of the overall tensor
full_tensor_strides: Strides of the full tensor
sub_tensor_strides: Strides of the sub-tensor
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor
sub_tensor_shape: The shape of the sub-tensor
output_file_path: The path to the file where the full tensor is stored
output_start_byte: The starting byte of the full tensor in the file
"""
# Open the output file in read+binary mode to allow seeking and writing
with fs.open(output_file_path, "r+b") as out_f:
# Create a list to hold the current indices for each dimension
indices = [0] * len(tensor_shape)
# Calculate the total number of elements in the sub-tensor
total_elements = 1
for dim_size in sub_tensor_shape:
total_elements *= dim_size
# Process each element in the sub-tensor
for element_idx in range(total_elements):
# Calculate the indices for this element in the sub-tensor
sub_idx = element_idx
for dim in range(len(sub_tensor_shape) - 1, -1, -1):
indices[dim] = sub_idx % sub_tensor_shape[dim]
sub_idx //= sub_tensor_shape[dim]
# Calculate the position of this element in the sub-tensor's byte array
sub_pos = 0
for dim in range(len(sub_tensor_shape)):
sub_pos += indices[dim] * sub_tensor_strides[dim]
sub_byte_offset = sub_pos * element_size
# Calculate the position of this element in the full tensor
full_pos = 0
for dim in range(len(tensor_shape)):
# The global index is the local index plus the offset for this dimension
global_idx = indices[dim] + sub_tensor_offsets[dim]
full_pos += global_idx * full_tensor_strides[dim]
full_byte_offset = output_start_byte + full_pos * element_size
# Extract the element data from the sub-tensor
element_data = sub_tensor_bytes[
sub_byte_offset : sub_byte_offset + element_size
]
# Seek to the correct position in the output file and write the data
out_f.seek(full_byte_offset)
out_f.write(element_data)
def _write_sub_tensor_to_file(
fs: fsspec.AbstractFileSystem,
sub_tensor_bytes: bytearray,
element_size: int,
tensor_shape: list[int],
sub_tensor_offsets: list[int],
sub_tensor_shape: list[int],
output_file_path: str,
output_start_byte: int,
):
"""
Writes a sub-tensor from a byte array into a file representing the full tensor at specified offsets.
This function handles the complex task of placing a tensor shard (sub-tensor) at the correct
position within the consolidated tensor file. It works by calculating the exact byte offsets
for each slice of data and writing them to the appropriate positions. This implementation
supports tensors of any dimensionality with optimized paths for common sharding patterns:
- Row-wise sharding (optimized path)
- Column-wise sharding for 2D tensors (optimized path)
- Any other arbitrary sharding pattern (general element-by-element approach)
Args:
sub_tensor_bytes: Byte array containing the sub-tensor data
element_size: The size of each element in bytes
tensor_shape: The shape of the overall tensor (list)
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor (list)
sub_tensor_shape: The shape of the sub-tensor (list)
output_file_path: The path to the file where the full tensor is stored
output_start_byte: The starting byte of the full tensor in the file
"""
# Handle the case of empty tensors
if not tensor_shape or not sub_tensor_shape:
return
# Calculate strides for the full tensor (row-major order, C-style)
# Stride is the number of elements to skip to move to the next element in that dimension
full_tensor_strides = [1] * len(tensor_shape)
for i in range(len(tensor_shape) - 2, -1, -1):
full_tensor_strides[i] = full_tensor_strides[i + 1] * tensor_shape[i + 1]
# Calculate strides for the sub-tensor (row-major order, C-style)
sub_tensor_strides = [1] * len(sub_tensor_shape)
for i in range(len(sub_tensor_shape) - 2, -1, -1):
sub_tensor_strides[i] = sub_tensor_strides[i + 1] * sub_tensor_shape[i + 1]
# Check if this is a row-wise sharded tensor
# Row-wise sharding is detected when the last dimension is complete
# and only the first dimension is partial
is_row_wise = False
if len(tensor_shape) >= 2:
# Check if all dimensions except the first are complete
all_other_dims_complete = True
for i in range(1, len(tensor_shape)):
if sub_tensor_shape[i] != tensor_shape[i]:
all_other_dims_complete = False
break
# Row-wise sharding: first dimension is partial, all others are complete
is_row_wise = all_other_dims_complete and sub_tensor_shape[0] < tensor_shape[0]
# Check if this is a column-wise sharded 2D tensor
# Column-wise sharding is detected when the first dimension is complete
# and the second dimension is partial (only for 2D tensors)
is_column_wise = False
if len(tensor_shape) == 2:
is_column_wise = (
sub_tensor_shape[0] == tensor_shape[0]
and sub_tensor_shape[1] < tensor_shape[1]
)
# Call the appropriate function based on the sharding pattern
if is_row_wise:
_write_row_wise_tensor(
fs,
sub_tensor_bytes,
element_size,
full_tensor_strides,
sub_tensor_strides,
sub_tensor_offsets,
sub_tensor_shape,
output_file_path,
output_start_byte,
)
elif is_column_wise:
_write_column_wise_tensor(
fs,
sub_tensor_bytes,
element_size,
tensor_shape,
sub_tensor_offsets,
sub_tensor_shape,
output_file_path,
output_start_byte,
)
else:
_write_element_by_element(
fs,
sub_tensor_bytes,
element_size,
tensor_shape,
full_tensor_strides,
sub_tensor_strides,
sub_tensor_offsets,
sub_tensor_shape,
output_file_path,
output_start_byte,
)
def consolidate_safetensors_files(
input_dir: str,
output_dir: str,
fqn_to_index_mapping: Optional[dict[str, int]] = None,
) -> None:
"""
Main function to consolidate sharded safetensors files into one or more output files.
This function orchestrates the entire consolidation process:
1. Sets up the output file structure based on the fqn_to_index_mapping
2. Finds all safetensors files in the input directory
3. Parses metadata from all input files
4. Writes metadata to the output files
5. Writes tensor data from input files to output files
Args:
input_dir: Directory containing sharded safetensors files
output_dir: Directory where consolidated files will be written
fqn_to_index_mapping: Optional mapping of tensor names to output file indices.
If None, all tensors will be consolidated into a single file.
"""
# Create filesystem using fsspec for file operations
input_fs, _ = url_to_fs(input_dir)
output_fs, _ = url_to_fs(output_dir)
# Initialize the output file structure
output_files_data: dict[str, _OutputFileData] = {}
if fqn_to_index_mapping is None:
# If no mapping is provided, create a single output file
file_name = _gen_file_name(1, 1) # Generate name like "model.safetensors"
output_path = f"{output_dir}/{file_name}"
output_files_data[output_path] = _OutputFileData()
else:
# Create multiple output files based on the provided mapping
for fqn, index in fqn_to_index_mapping.items():
# Generate names like "model-00001-of-00005.safetensors"
file_name = _gen_file_name(index, max(fqn_to_index_mapping.values()))
output_path = f"{output_dir}/{file_name}"
# Create output file data structure if it doesn't exist yet
if output_path not in output_files_data:
output_files_data[output_path] = _OutputFileData(
fqn_data={fqn: _FqnData()}
)
else:
output_files_data[output_path].fqn_data[fqn] = _FqnData()
# Find all safetensors files in the input directory
safetensors_files = []
for file in input_fs.ls(input_dir, detail=False):
if file.endswith(SUFFIX):
safetensors_files.append(file)
# Read metadata from all input files
input_safetensors_metadatas = {}
for safetensor_file in safetensors_files:
with input_fs.open(safetensor_file, "rb") as f:
input_safetensors_metadatas[safetensor_file] = (
_get_safetensors_file_metadata(f)
)
# Step 1: Parse metadata to determine tensor shapes and types
_parse_input_metadata(input_safetensors_metadatas.values(), output_files_data)
# Step 2: Write metadata headers to output files
_write_metadata(output_fs, output_files_data)
# Step 3: Write actual tensor data from input files to output files
_write_data(
input_fs,
output_fs,
safetensors_files,
input_safetensors_metadatas,
output_files_data,
)
def main() -> None:
"""
Command-line entry point for the consolidation script.
Parses command-line arguments and calls consolidate_safetensors_files with the provided parameters.
"""
# Set up command-line argument parser
parser = argparse.ArgumentParser(
description="Consolidate DCP sharded HuggingFace safetensors files"
)
# Define required and optional arguments
parser.add_argument(
"input_path",
type=str,
required=True,
help="Path to directory containing sharded safetensors files",
)
parser.add_argument(
"output_path",
type=str,
required=True,
help="Path to write consolidated safetensors files. Must be different from input path",
)
parser.add_argument(
"fqn_to_index_mapping",
type=dict[str, int],
required=False,
help="Mapping of which tensor names should belong to which consolidated file. If not provided, all tensors will be consolidated into one file. Expects numbers from 1 to N, where N is the number of files.",
)
# Parse arguments and call the main function
args = parser.parse_args()
consolidate_safetensors_files(
args.input_path, args.output_path, args.fqn_to_index_mapping
)
if __name__ == "__main__":
main()

View File

@ -0,0 +1,195 @@
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
# pyre-strict
"""
This script demonstrates how to:
1. Load tensors
2. Split them into d-tensors (distributed tensors)
3. Save them using DCP (Distributed Checkpoint) with HuggingFace integration
4. Run the consolidation function to consolidate the sharded files back to full tensors
"""
import argparse
import os
import time
import torch
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint import consolidate_safetensors_files
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
def init_distributed():
"""Initialize the distributed environment."""
if not dist.is_available() or not dist.is_initialized():
if torch.cuda.is_available():
dist.init_process_group(backend="nccl")
else:
dist.init_process_group(backend="gloo")
# Set device to current rank
if torch.cuda.is_available():
torch.cuda.set_device(dist.get_rank())
return dist.get_world_size(), dist.get_rank()
def load_and_split_tensors(model_path: str, device_mesh):
"""
Load tensors from a model and split them into d-tensors.
Args:
model_path: Path to the model to load
device_mesh: Device mesh to distribute tensors on
Returns:
state_dict with distributed tensors
"""
# Load the model
if model_path.endswith(".safetensors"):
from safetensors.torch import load_file
state_dict = load_file(model_path)
# Create distributed tensors
distributed_state_dict = {}
for key, tensor in state_dict.items():
# Choose sharding strategy based on tensor size
if tensor.dim() >= 2 and tensor.size(0) > 10:
# Shard along the first dimension for large tensors
placements = [Shard(0)]
else:
# Replicate small tensors
placements = [Replicate()]
# Distribute the tensor
dtensor = distribute_tensor(tensor, device_mesh, placements=placements)
distributed_state_dict[key] = dtensor
return distributed_state_dict, state_dict
def save_with_dcp_huggingface(state_dict, output_dir: str):
"""
Save the distributed state dict using DCP with HuggingFace integration.
Args:
state_dict: State dict with distributed tensors
output_dir: Directory to save the checkpoint
"""
os.makedirs(output_dir, exist_ok=True)
# Save the distributed checkpoint
dist_cp.save(
state_dict=state_dict,
storage_writer=dist_cp._HuggingFaceStorageWriter(
path=output_dir,
save_sharded=True, # Save as sharded files
),
)
print(f"Saved distributed checkpoint to {output_dir}")
def save_safetensors(state_dict, output_dir: str):
"""
Save the state dict as safetensors.
Args:
state_dict: State dict with tensors
output_dir: Directory to save the checkpoint
"""
os.makedirs(output_dir, exist_ok=True)
# Save the checkpoint
from safetensors.torch import save_file
save_file(state_dict, os.path.join(output_dir, "checkpoint.safetensors"))
def consolidate_checkpoint(input_dir: str, output_dir: str, fqn_to_index_mapping=None):
"""
Consolidate the sharded checkpoint files into full tensors.
Args:
input_dir: Directory containing sharded checkpoint files
output_dir: Directory to save the consolidated checkpoint
fqn_to_index_mapping: Optional mapping of tensor names to output file indices
"""
os.makedirs(output_dir, exist_ok=True)
# Consolidate the checkpoint
consolidate_safetensors_files(
input_dir=input_dir,
output_dir=output_dir,
fqn_to_index_mapping=fqn_to_index_mapping,
)
print(f"Consolidated checkpoint saved to {output_dir}")
def main():
parser = argparse.ArgumentParser(
description="Split tensors into d-tensors, save with DCP HuggingFace, and consolidate"
)
"""
parser.add_argument(
"--model-path", type=str, default="", help="Path to the model to load"
)
parser.add_argument(
"--sharded-output-dir",
type=str,
default="./sharded_checkpoint",
help="Directory to save the sharded checkpoint",
)
parser.add_argument(
"--consolidated-output-dir",
type=str,
default="./consolidated_checkpoint",
help="Directory to save the consolidated checkpoint",
)
args = parser.parse_args()
"""
# Initialize distributed environment
world_size, rank = init_distributed()
print(f"Running with world_size={world_size}, rank={rank}")
model_path = "/home/ankitageorge/.cache/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/model-00001-of-00030.safetensors"
base_dir = "/data/users/ankitageorge/testing/"
sharded_output_dir = os.path.join(base_dir, "sharded_checkpoint")
dcp_consolidated_output_dir = os.path.join(base_dir, "dcp_consolidated_checkpoint")
consolidated_output_dir = os.path.join(base_dir, "consolidated_checkpoint")
# Initialize device mesh
device_type = "cuda" if torch.cuda.is_available() else "cpu"
device_mesh = init_device_mesh(device_type, (world_size,))
# Load and split tensors
distributed_state_dict, state_dict = load_and_split_tensors(model_path, device_mesh)
start_time = time.time()
if rank == 0:
save_safetensors(state_dict, consolidated_output_dir)
print("time to save as safetensors ", time.time() - start_time)
# Save with DCP HuggingFace
start_time = time.time()
save_with_dcp_huggingface(distributed_state_dict, sharded_output_dir)
# Make sure all processes have finished saving
dist.barrier()
print("time to save with DCP HuggingFace ", time.time() - start_time)
# Only rank 0 needs to consolidate the checkpoint
start_time = time.time()
if rank == 0:
consolidate_checkpoint(sharded_output_dir, dcp_consolidated_output_dir)
print("Time to consolidate checkpoint: ", time.time() - start_time)
if __name__ == "__main__":
main()