mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Compare commits
6 Commits
ciflow/tru
...
benchmarki
Author | SHA1 | Date | |
---|---|---|---|
0f49e915a9 | |||
2f1217f944 | |||
e0bf01e87b | |||
3b5ae0e9fc | |||
5f5f654a3e | |||
21931cbbc6 |
135
test/distributed/checkpoint/test_consolidate_hf_safetensors.py
Normal file
135
test/distributed/checkpoint/test_consolidate_hf_safetensors.py
Normal file
@ -0,0 +1,135 @@
|
||||
# Owner(s): ["oncall: distributed_checkpointing"]
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import (
|
||||
consolidate_safetensors_files,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import DTensor, Shard
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
class TestConsolidateHFSafeTensors(DTensorTestBase):
|
||||
def _create_d_tensors(self) -> None:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
# Create local tensor with row-wise sharding
|
||||
rows_per_rank = global_tensor.shape[0] // self.world_size
|
||||
start_row = self.rank * rows_per_rank
|
||||
end_row = start_row + rows_per_rank
|
||||
local_tensor = global_tensor[start_row:end_row].clone()
|
||||
|
||||
# Create DTensor with row-wise sharding
|
||||
dtensor = DTensor.from_local(
|
||||
local_tensor,
|
||||
device_mesh=mesh_1d,
|
||||
placements=[Shard(0)],
|
||||
shape=global_tensor.shape,
|
||||
stride=(4, 1),
|
||||
)
|
||||
|
||||
# Create local tensor with column-wise sharding
|
||||
cols_per_rank = global_tensor.shape[1] // self.world_size
|
||||
start_col = self.rank * cols_per_rank
|
||||
end_col = start_col + cols_per_rank
|
||||
local_tensor_col = global_tensor[:, start_col:end_col].clone()
|
||||
|
||||
# Create DTensor with column-wise sharding
|
||||
dtensor_col = DTensor.from_local(
|
||||
local_tensor_col,
|
||||
device_mesh=mesh_1d,
|
||||
placements=[Shard(1)], # Column-wise sharding
|
||||
shape=global_tensor.shape,
|
||||
stride=(4, 1),
|
||||
)
|
||||
|
||||
state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col}
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=self.temp_dir, save_sharded=True
|
||||
),
|
||||
)
|
||||
dist.barrier()
|
||||
os.sync()
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_consolidate_to_one_file(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
checkpoint_dir = self.temp_dir
|
||||
output_dir = os.path.join(checkpoint_dir, "consolidated")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
self._create_d_tensors()
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
|
||||
if self.rank == 0:
|
||||
consolidate_safetensors_files(checkpoint_dir, output_dir)
|
||||
|
||||
file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors")
|
||||
loaded_dict = safetensors.torch.load_file(file_path)
|
||||
self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"})
|
||||
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
|
||||
self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor))
|
||||
dist.barrier()
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_consolidate_to_two_files(self):
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
checkpoint_dir = self.temp_dir
|
||||
output_dir = os.path.join(checkpoint_dir, "consolidated")
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
self._create_d_tensors()
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
|
||||
if self.rank == 0:
|
||||
fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2}
|
||||
consolidate_safetensors_files(
|
||||
checkpoint_dir, output_dir, fqn_to_index_mapping
|
||||
)
|
||||
|
||||
file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors")
|
||||
file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors")
|
||||
|
||||
loaded_dict = safetensors.torch.load_file(file1_path)
|
||||
self.assertEqual(loaded_dict.keys(), {"dtensor"})
|
||||
self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor))
|
||||
|
||||
loaded_dict_col = safetensors.torch.load_file(file2_path)
|
||||
self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"})
|
||||
self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor))
|
||||
dist.barrier()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
420
test/distributed/checkpoint/test_hf_safetensor_e2e.py
Normal file
420
test/distributed/checkpoint/test_hf_safetensor_e2e.py
Normal file
@ -0,0 +1,420 @@
|
||||
# Owner(s): ["oncall: distributed_checkpointing"]
|
||||
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed.checkpoint import _HuggingFaceLoadPlanner
|
||||
from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner
|
||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros
|
||||
from torch.testing._internal.common_utils import (
|
||||
instantiate_parametrized_tests,
|
||||
run_tests,
|
||||
TestCase,
|
||||
)
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
CHECKPOINT_DIR = "checkpoint"
|
||||
|
||||
|
||||
class MyTestModule(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.linear_1 = torch.nn.Linear(5, 5)
|
||||
self.linear_2 = torch.nn.Linear(5, 1)
|
||||
self.emb = torch.nn.EmbeddingBag(5, 10)
|
||||
|
||||
class TestSingleRankSaveLoad(TestCase):
|
||||
@with_temp_dir
|
||||
def test_save(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import load_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
)
|
||||
|
||||
state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
|
||||
|
||||
@with_temp_dir
|
||||
def test_load(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
state_dict_to_load = MyTestModule().state_dict()
|
||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
|
||||
|
||||
@with_temp_dir
|
||||
def test_load_into_empty_dict(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
|
||||
state_dict_loaded = _load_state_dict_from_keys(
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key]))
|
||||
|
||||
@with_temp_dir
|
||||
def test_load_allowing_resize(self) -> None:
|
||||
try:
|
||||
from safetensors.torch import save_file
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
state_dict_to_save = MyTestModule().state_dict()
|
||||
save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors")
|
||||
|
||||
state_dict_to_load= {}
|
||||
for key in state_dict_to_save.keys():
|
||||
state_dict_to_load[key] = torch.zeros(1)
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
path=CHECKPOINT_DIR
|
||||
),
|
||||
planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True),
|
||||
)
|
||||
|
||||
self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys()))
|
||||
for key in state_dict_to_save.keys():
|
||||
self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key]))
|
||||
|
||||
ONE_D_PLACEMENTS = [
|
||||
[Shard(0)],
|
||||
[Replicate()],
|
||||
]
|
||||
ONE_D_TO_ONE_D_PLACEMENTS = [
|
||||
([Replicate()], [Shard(0)]),
|
||||
([Shard(0)], [Replicate()]),
|
||||
]
|
||||
|
||||
TWO_D_PLACEMENTS = [
|
||||
[Replicate(), Replicate()],
|
||||
[Replicate(), Shard(0)],
|
||||
[Shard(0), Replicate()],
|
||||
[Shard(0), Shard(0)],
|
||||
]
|
||||
TWO_D_TO_TWO_D_PLACEMENTS = []
|
||||
for p1 in TWO_D_PLACEMENTS:
|
||||
for p2 in TWO_D_PLACEMENTS:
|
||||
if p1 != p2:
|
||||
TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2))
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestDTensorReshardPlacementChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change.
|
||||
"""
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@with_temp_dir
|
||||
def test_1d_to_1d_reshard_placement_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
|
||||
for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS:
|
||||
original_placement, new_placement = one_d_to_one_d_placements
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
device_mesh = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, device_mesh, placements=original_placement
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=CHECKPOINT_DIR,
|
||||
save_sharded=True,
|
||||
),
|
||||
)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=device_mesh, placements=new_placement
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(
|
||||
CHECKPOINT_DIR,
|
||||
),
|
||||
)
|
||||
|
||||
# materialize the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
device_mesh,
|
||||
placements=[Replicate()],
|
||||
)
|
||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
||||
|
||||
# redistribute the tensor back to its original placement for comparison.
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
device_mesh,
|
||||
placements=original_placement,
|
||||
)
|
||||
self.assertEqual(
|
||||
state_dict_to_save["dtensor"].to_local(),
|
||||
state_dict_to_load["dtensor"].to_local(),
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_2d_to_2d_reshard_placement_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS:
|
||||
original_placement, new_placement = two_d_to_two_d_placements
|
||||
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor,
|
||||
mesh_2d,
|
||||
placements=original_placement,
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
||||
)
|
||||
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
mesh_2d,
|
||||
placements=[Replicate(), Replicate()],
|
||||
)
|
||||
self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local())
|
||||
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute(
|
||||
mesh_2d,
|
||||
placements=original_placement,
|
||||
)
|
||||
self.assertEqual(
|
||||
state_dict_to_save["dtensor"].to_local(),
|
||||
state_dict_to_load["dtensor"].to_local(),
|
||||
)
|
||||
|
||||
|
||||
class TestDTensorReshardMeshChange(DTensorTestBase):
|
||||
"""
|
||||
Test DCP reshard for DTensor with placements changes and mesh_tensor change.
|
||||
"""
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_1d_to_2d_reshard_mesh_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for placements_1d in ONE_D_PLACEMENTS:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, mesh_1d, placements=placements_1d
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
||||
)
|
||||
|
||||
for placements_2d in TWO_D_PLACEMENTS:
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=mesh_2d, placements=placements_2d
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
||||
"dtensor"
|
||||
].redistribute(
|
||||
mesh_2d,
|
||||
placements=[Replicate(), Replicate()],
|
||||
)
|
||||
self.assertEqual(
|
||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(4)
|
||||
def test_2d_to_1d_reshard_mesh_change(self) -> None:
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
for placements_2d in TWO_D_PLACEMENTS:
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor, mesh_2d, placements=placements_2d
|
||||
)
|
||||
state_dict_to_save = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
|
||||
for placements_1d in ONE_D_PLACEMENTS:
|
||||
mesh_shape = (self.world_size,)
|
||||
mesh_1d = init_device_mesh(self.device_type, mesh_shape)
|
||||
|
||||
zero_dtensor = zeros(
|
||||
[4, 4], device_mesh=mesh_1d, placements=placements_1d
|
||||
)
|
||||
state_dict_to_load = {"dtensor": zero_dtensor}
|
||||
|
||||
dist_cp.load(
|
||||
state_dict=state_dict_to_load,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultLoadPlanner(),
|
||||
)
|
||||
|
||||
# materialzie the whole tensor to compare with the original global_tensor
|
||||
state_dict_to_load["dtensor"] = state_dict_to_load[
|
||||
"dtensor"
|
||||
].redistribute(
|
||||
mesh_1d,
|
||||
placements=[Replicate()],
|
||||
)
|
||||
self.assertEqual(
|
||||
global_tensor, state_dict_to_load["dtensor"].to_local()
|
||||
)
|
||||
|
||||
@with_comms
|
||||
@with_temp_dir
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_dtensor_checkpoint_resharding_with_empty_shard(self):
|
||||
"""
|
||||
Test dtensor checkpoint resharding with dtensor containing empty shards.
|
||||
"""
|
||||
try:
|
||||
import safetensors
|
||||
except ImportError:
|
||||
print("safetensors not installed")
|
||||
sys.exit(0)
|
||||
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh = init_device_mesh(self.device_type, (self.world_size,))
|
||||
dtensor = distribute_tensor(tensor, mesh, [Shard(0)])
|
||||
ref_state_dict = {"dtensor": dtensor}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=ref_state_dict,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True),
|
||||
)
|
||||
|
||||
tensor = torch.rand(1).cuda()
|
||||
mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2))
|
||||
dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)])
|
||||
state_dict = {"dtensor": dtensor}
|
||||
dist_cp.load(
|
||||
state_dict=state_dict,
|
||||
storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -8,10 +8,7 @@ import tempfile
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint._hf_planner import (
|
||||
_FqnToFileMapping,
|
||||
_HuggingFaceLoadPlanner,
|
||||
)
|
||||
from torch.distributed.checkpoint import DefaultLoadPlanner
|
||||
from torch.distributed.checkpoint._hf_storage import (
|
||||
_HuggingFaceStorageReader,
|
||||
_HuggingFaceStorageWriter,
|
||||
@ -21,24 +18,25 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
||||
from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
TensorProperties,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import LoadPlan, SavePlan
|
||||
from torch.distributed.checkpoint.planner_helpers import (
|
||||
_create_read_items,
|
||||
_create_write_item_for_tensor,
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
LoadItemType,
|
||||
LoadPlan,
|
||||
ReadItem,
|
||||
SavePlan,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
|
||||
from torch.distributed.checkpoint.storage import WriteResult
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class TestHfStorage(TestCase):
|
||||
def test_write_data_hf(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["safetensors"] = mock_module
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
|
||||
mock_module = MagicMock()
|
||||
mock_module.save.return_value = b""
|
||||
sys.modules["safetensors.torch"] = mock_module
|
||||
@ -46,7 +44,7 @@ class TestHfStorage(TestCase):
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
writer = _HuggingFaceStorageWriter(
|
||||
path=path,
|
||||
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1},
|
||||
fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2},
|
||||
)
|
||||
writer.fs = FileSystem()
|
||||
|
||||
@ -59,7 +57,7 @@ class TestHfStorage(TestCase):
|
||||
|
||||
save_plan = SavePlan(
|
||||
[write_item_1, write_item_2],
|
||||
storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}),
|
||||
storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}},
|
||||
)
|
||||
save_planner = DefaultSavePlanner()
|
||||
save_planner.set_up_planner(state_dict=state_dict)
|
||||
@ -76,7 +74,7 @@ class TestHfStorage(TestCase):
|
||||
),
|
||||
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="model-00001-of-00001.safetensors",
|
||||
relative_path="model-00001-of-00002.safetensors",
|
||||
offset=0,
|
||||
length=tensor0.numel() * tensor0.element_size(),
|
||||
),
|
||||
@ -87,7 +85,68 @@ class TestHfStorage(TestCase):
|
||||
),
|
||||
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="model-00001-of-00001.safetensors",
|
||||
relative_path="model-00002-of-00002.safetensors",
|
||||
offset=0,
|
||||
length=tensor1.numel() * tensor1.element_size(),
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
self.assertEqual(
|
||||
actual_write_results,
|
||||
expected_write_results,
|
||||
)
|
||||
|
||||
def test_write_data_with_sharding(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
mock_module.save.return_value = b""
|
||||
sys.modules["safetensors.torch"] = mock_module
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
writer = _HuggingFaceStorageWriter(
|
||||
path=path,
|
||||
save_sharded=True,
|
||||
)
|
||||
writer.fs = FileSystem()
|
||||
|
||||
tensor0 = torch.rand(4)
|
||||
tensor1 = torch.rand(10)
|
||||
write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0)
|
||||
write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1)
|
||||
|
||||
state_dict = {"tensor_0": tensor0, "tensor_1": tensor1}
|
||||
|
||||
save_plan = SavePlan(
|
||||
[write_item_1, write_item_2],
|
||||
storage_data={"shard_index": 1},
|
||||
)
|
||||
save_planner = DefaultSavePlanner()
|
||||
save_planner.set_up_planner(state_dict=state_dict)
|
||||
|
||||
write_results = writer.write_data(save_plan, save_planner)
|
||||
|
||||
write_results.wait()
|
||||
actual_write_results = write_results.value()
|
||||
|
||||
expected_write_results = [
|
||||
WriteResult(
|
||||
index=MetadataIndex(
|
||||
fqn="tensor_0", offset=torch.Size([0]), index=None
|
||||
),
|
||||
size_in_bytes=tensor0.numel() * tensor0.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="shard-00001-model-00001-of-00001.safetensors",
|
||||
offset=0,
|
||||
length=tensor0.numel() * tensor0.element_size(),
|
||||
),
|
||||
),
|
||||
WriteResult(
|
||||
index=MetadataIndex(
|
||||
fqn="tensor_1", offset=torch.Size([0]), index=None
|
||||
),
|
||||
size_in_bytes=tensor1.numel() * tensor1.element_size(),
|
||||
storage_data=_StorageInfo(
|
||||
relative_path="shard-00001-model-00001-of-00001.safetensors",
|
||||
offset=0,
|
||||
length=tensor1.numel() * tensor1.element_size(),
|
||||
),
|
||||
@ -100,43 +159,84 @@ class TestHfStorage(TestCase):
|
||||
)
|
||||
|
||||
def test_read_data_hf(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["safetensors"] = mock_module
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
mock_safetensors = MagicMock()
|
||||
sys.modules["safetensors"] = mock_safetensors
|
||||
|
||||
name = "tensor_0"
|
||||
tensor_0 = torch.rand(4)
|
||||
mock_module = MagicMock()
|
||||
mock_module.load.return_value = {name: tensor_0}
|
||||
sys.modules["safetensors.torch"] = mock_module
|
||||
# Create test tensors
|
||||
tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0])
|
||||
|
||||
# Mock the deserialize function to return our test tensors
|
||||
# The format matches what's expected in the read_data method
|
||||
mock_safetensors.deserialize.return_value = [
|
||||
("tensor_0", {
|
||||
"data": tensor_0.numpy().tobytes(),
|
||||
"dtype": "F32",
|
||||
"shape": [4]
|
||||
}),
|
||||
]
|
||||
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
# Create the reader
|
||||
reader = _HuggingFaceStorageReader(path=path)
|
||||
reader.fs = FileSystem()
|
||||
file_name = "model-00001-of-00001"
|
||||
|
||||
pathlib.Path(os.path.join(path, file_name)).touch()
|
||||
# Create test file
|
||||
file_name = "model-00001-of-00001.safetensors"
|
||||
file_path = os.path.join(path, file_name)
|
||||
pathlib.Path(file_path).touch()
|
||||
|
||||
reader.set_up_storage_reader(
|
||||
Metadata(
|
||||
state_dict_metadata={name: BytesStorageMetadata()},
|
||||
storage_data={name: file_name},
|
||||
),
|
||||
is_coordinator=True,
|
||||
)
|
||||
# Set up storage data with _StorageInfo objects
|
||||
storage_data = {
|
||||
"tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()),
|
||||
}
|
||||
|
||||
read_items = _create_read_items(name, BytesStorageMetadata(), file_name)
|
||||
|
||||
reader.storage_data = storage_data
|
||||
|
||||
# Create target tensors that will be updated by read_data
|
||||
target_tensor_0 = torch.zeros(4)
|
||||
state_dict = {
|
||||
"tensor_0": target_tensor_0,
|
||||
}
|
||||
|
||||
# Create read items for the load plan
|
||||
read_items = []
|
||||
for name, tensor in state_dict.items():
|
||||
storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
|
||||
dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None)
|
||||
read_items.append(
|
||||
ReadItem(
|
||||
type=LoadItemType.TENSOR,
|
||||
storage_index=storage_index,
|
||||
dest_index=dest_index,
|
||||
storage_offsets=[0, 0],
|
||||
dest_offsets=[0, 0],
|
||||
lengths=tensor.size(),
|
||||
)
|
||||
)
|
||||
|
||||
# Create load plan and planner
|
||||
load_plan = LoadPlan(read_items)
|
||||
load_planner = _HuggingFaceLoadPlanner()
|
||||
load_planner.set_up_planner(state_dict={name: torch.rand(4)})
|
||||
load_planner = DefaultLoadPlanner()
|
||||
load_planner.set_up_planner(
|
||||
state_dict=state_dict,
|
||||
metadata=Metadata(
|
||||
state_dict_metadata={
|
||||
"tensor_0": TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=torch.float32),
|
||||
size=torch.Size([4]),
|
||||
chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])},
|
||||
storage_data=storage_data)
|
||||
)
|
||||
|
||||
read_data = reader.read_data(load_plan, load_planner)
|
||||
read_data.wait()
|
||||
# Call read_data
|
||||
future = reader.read_data(load_plan, load_planner)
|
||||
future.wait()
|
||||
|
||||
loaded_tensor = load_planner.original_state_dict[name]
|
||||
self.assertEqual(loaded_tensor, tensor_0)
|
||||
# Verify results - the target tensors should now contain the values from our test tensor
|
||||
self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0))
|
||||
|
||||
def test_metadata_hf(self) -> None:
|
||||
def test_write_metadata_hf(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
@ -160,7 +260,6 @@ class TestHfStorage(TestCase):
|
||||
|
||||
writer = _HuggingFaceStorageWriter(
|
||||
path=path,
|
||||
fqn_to_index_mapping=_FqnToFileMapping({}),
|
||||
)
|
||||
writer.fs = FileSystem()
|
||||
writer.finish(
|
||||
@ -185,26 +284,16 @@ class TestHfStorage(TestCase):
|
||||
metadata = json.load(f)
|
||||
self.assertEqual(metadata, expected_metadata)
|
||||
|
||||
reader = _HuggingFaceStorageReader(path=path)
|
||||
reader.fs = FileSystem()
|
||||
metadata = reader.read_metadata()
|
||||
self.assertEqual(metadata.storage_data, expected_metadata["weight_map"])
|
||||
|
||||
def test_read_metadata_when_metadata_file_does_not_exist(self) -> None:
|
||||
mock_module = MagicMock()
|
||||
sys.modules["huggingface_hub"] = mock_module
|
||||
|
||||
def test_read_metadata_hf(self):
|
||||
with tempfile.TemporaryDirectory() as path:
|
||||
reader = _HuggingFaceStorageReader(path=path)
|
||||
reader.fs = FileSystem()
|
||||
# there is one safetensor file, but no metadata file,
|
||||
# so we create metadata from the safetensor file
|
||||
keys = ["tensor_0", "tensor_1"]
|
||||
|
||||
key = "tensor_0"
|
||||
file_name = "test.safetensors"
|
||||
with open(os.path.join(path, file_name), "wb") as f:
|
||||
# write metadata the same way it would be in safetensors file
|
||||
metadata_contents = json.dumps(
|
||||
{"tensor_0": "value_0", "tensor_1": "value_1"}
|
||||
{'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}}
|
||||
)
|
||||
metadata_bytes = metadata_contents.encode("utf-8")
|
||||
|
||||
@ -216,13 +305,16 @@ class TestHfStorage(TestCase):
|
||||
self.assertEqual(
|
||||
metadata.state_dict_metadata,
|
||||
{
|
||||
keys[0]: BytesStorageMetadata(),
|
||||
keys[1]: BytesStorageMetadata(),
|
||||
key: TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=torch.float32),
|
||||
size=torch.Size([5, 10]),
|
||||
chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))],
|
||||
),
|
||||
},
|
||||
)
|
||||
self.assertEqual(
|
||||
metadata.storage_data,
|
||||
{keys[0]: file_name, keys[1]: file_name},
|
||||
{key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)},
|
||||
)
|
||||
|
||||
|
||||
|
@ -12,6 +12,7 @@ from .metadata import (
|
||||
)
|
||||
from .optimizer import load_sharded_optimizer_state_dict
|
||||
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
|
||||
from .scripts._consolidate_hf_safetensors import consolidate_safetensors_files
|
||||
from .state_dict_loader import load, load_state_dict
|
||||
from .state_dict_saver import async_save, save, save_state_dict
|
||||
from .storage import StorageReader, StorageWriter
|
||||
|
@ -62,25 +62,3 @@ def dedup_save_plans(
|
||||
)
|
||||
for plan, item_indexes in zip(all_plans, plan_to_item_indices)
|
||||
]
|
||||
|
||||
|
||||
def dedup_save_plans_with_fqn_to_index_mapping(
|
||||
all_plans: list[SavePlan], fqn_to_index_mapping: dict[str, int]
|
||||
) -> list[SavePlan]:
|
||||
num_plans = len(all_plans)
|
||||
|
||||
to_remove: list[set] = [set() for _ in range(len(all_plans))]
|
||||
for plan_idx, plan in enumerate(all_plans):
|
||||
for item_idx, item in enumerate(plan.items):
|
||||
if (fqn_to_index_mapping[item.index.fqn] - 1) % num_plans != plan_idx:
|
||||
to_remove[plan_idx].add(item_idx)
|
||||
|
||||
for plan_idx, remove_set in enumerate(to_remove):
|
||||
new_items = [
|
||||
write_item
|
||||
for item_idx, write_item in enumerate(all_plans[plan_idx].items)
|
||||
if item_idx not in remove_set
|
||||
]
|
||||
all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
|
||||
|
||||
return all_plans
|
||||
|
@ -1,49 +1,34 @@
|
||||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
|
||||
from torch.distributed.checkpoint._dedup_save_plans import (
|
||||
dedup_save_plans_with_fqn_to_index_mapping,
|
||||
)
|
||||
from torch.distributed.checkpoint.default_planner import (
|
||||
create_default_local_load_plan,
|
||||
DefaultLoadPlanner,
|
||||
DefaultSavePlanner,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import ReadItem, SavePlan
|
||||
from torch.distributed.checkpoint.planner import LoadPlan
|
||||
|
||||
|
||||
__all__ = ["_HuggingFaceSavePlanner", "_HuggingFaceLoadPlanner"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FqnToFileMapping:
|
||||
fqn_to_file_index_mapping: dict[str, int]
|
||||
|
||||
|
||||
class _HuggingFaceSavePlanner(DefaultSavePlanner):
|
||||
"""
|
||||
A save planner that dedups the save plans based on the fqn to file index mapping.
|
||||
A planner to work with HuggingFace's safetensors format.
|
||||
This is a placeholder, as it is likely that the DefaultSavePlanner is enough.
|
||||
"""
|
||||
|
||||
def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]:
|
||||
assert len(all_plans) > 0, "all_plans should not be empty"
|
||||
assert all_plans[0].storage_data is not None, "storage_data should not be None"
|
||||
assert isinstance(all_plans[0].storage_data, _FqnToFileMapping), (
|
||||
"storage_data should be of type _FqnToFileMapping"
|
||||
)
|
||||
|
||||
fqn_to_index_mapping: dict[str, int] = all_plans[
|
||||
0
|
||||
].storage_data.fqn_to_file_index_mapping
|
||||
|
||||
return dedup_save_plans_with_fqn_to_index_mapping(
|
||||
all_plans, fqn_to_index_mapping
|
||||
)
|
||||
|
||||
|
||||
class _HuggingFaceLoadPlanner(DefaultLoadPlanner):
|
||||
def __init__(self, allow_tensor_resize: bool = False):
|
||||
super().__init__()
|
||||
self.allow_tensor_resize = allow_tensor_resize
|
||||
|
||||
def resolve_tensor(self, read_item: ReadItem):
|
||||
return self.lookup_tensor(read_item.dest_index)
|
||||
def create_local_plan(self) -> LoadPlan:
|
||||
assert self.metadata is not None
|
||||
|
||||
# check_md_size is added to avoid the check if we're allowing tensor resize.
|
||||
# This will be deprecated in favor of _load_state_dict_from_keys and then we
|
||||
# can remove this planner all together.
|
||||
return create_default_local_load_plan(
|
||||
self.state_dict, self.metadata, not self.allow_partial_load, check_md_size=not self.allow_tensor_resize,
|
||||
)
|
||||
|
@ -2,24 +2,24 @@
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import queue
|
||||
import struct
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import fsspec # type: ignore[import-untyped]
|
||||
import torch
|
||||
from torch.distributed._shard._utils import narrow_tensor_by_index
|
||||
|
||||
from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
|
||||
from torch.distributed.checkpoint._hf_planner import (
|
||||
_FqnToFileMapping,
|
||||
_HuggingFaceLoadPlanner,
|
||||
)
|
||||
from torch.distributed.checkpoint.filesystem import SerializationFormat
|
||||
from torch.distributed.checkpoint._hf_planner import _HuggingFaceLoadPlanner
|
||||
from torch.distributed.checkpoint.filesystem import _StorageInfo, SerializationFormat
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
ChunkStorageMetadata,
|
||||
Metadata,
|
||||
MetadataIndex,
|
||||
STORAGE_TYPES,
|
||||
StorageMeta,
|
||||
TensorProperties,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.checkpoint.planner import (
|
||||
LoadPlan,
|
||||
@ -37,9 +37,30 @@ __all__ = ["_HuggingFaceStorageWriter", "_HuggingFaceStorageReader"]
|
||||
|
||||
_metadata_fn: str = "model.safetensors.index.json"
|
||||
|
||||
FILE_NAME = "model-{cpt_idx}-of-{num_shards}"
|
||||
FILE_NAME = "model-{cpt_idx}-of-{num_files}"
|
||||
SHARDED_FILE_NAME = "shard-{shard_idx}-model-{cpt_idx}-of-{num_files}"
|
||||
SUFFIX = ".safetensors"
|
||||
|
||||
# metadata keys
|
||||
CUSTOM_METADATA_KEY = "DCP_SHARDING_INFO"
|
||||
DEFAULT_EXTRA_METADATA_KEY = "__metadata__"
|
||||
SAVED_OFFSETS_KEY = "saved_offsets"
|
||||
SHAPE_KEY = "shape"
|
||||
DATA_KEY = "data"
|
||||
DTYPE_KEY = "dtype"
|
||||
DATA_OFFSETS_KEY = "data_offsets"
|
||||
|
||||
DTYPE_MAP = {
|
||||
"F16": torch.float16,
|
||||
"F32": torch.float32,
|
||||
"F64": torch.float64,
|
||||
"I8": torch.int8,
|
||||
"U8": torch.uint8,
|
||||
"I16": torch.int16,
|
||||
"I32": torch.int32,
|
||||
"I64": torch.int64,
|
||||
"BF16": torch.bfloat16,
|
||||
}
|
||||
|
||||
class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
"""
|
||||
@ -50,23 +71,25 @@ class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
def __init__(
|
||||
self,
|
||||
path: str,
|
||||
fqn_to_index_mapping: dict[str, int],
|
||||
fqn_to_index_mapping: Optional[dict[str, int]] = None,
|
||||
token: Optional[str] = None,
|
||||
save_sharded: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the huggingface writer pointing to path.
|
||||
|
||||
Args:
|
||||
path: hf directory where the checkpoint will be written to. Should begin with hf://.
|
||||
token: The token to use to authenticate with huggingface hub.
|
||||
path: hf directory where the checkpoint will be read from.
|
||||
Needs to have .safetensors files, but can be from any fsspec supported storage,
|
||||
including localFS and hf://.
|
||||
fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to.
|
||||
Indices are from 1 to N, where N is the number of files.
|
||||
Indices are from 1 to N, where N is the number of files. If not provided,
|
||||
the tensors will be written to a single file.
|
||||
token: The token to use to authenticate with huggingface hub.
|
||||
save_sharded: If True, save the checkpoint as a sharded checkpoint where every rank saves its own shard.
|
||||
Default is False which assumes full tensors are being saved.
|
||||
|
||||
"""
|
||||
from huggingface_hub import HfFileSystem # type: ignore[import-not-found]
|
||||
|
||||
if HfFileSystem.protocol not in fsspec.available_protocols():
|
||||
fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem)
|
||||
|
||||
if token is not None:
|
||||
super().__init__(
|
||||
@ -79,16 +102,21 @@ class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
path=path,
|
||||
serialization_format=SerializationFormat.SAFETENSORS,
|
||||
)
|
||||
self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping
|
||||
|
||||
def prepare_local_plan(self, plan: SavePlan) -> SavePlan:
|
||||
plan = super().prepare_local_plan(plan)
|
||||
return dataclasses.replace(
|
||||
plan, storage_data=_FqnToFileMapping(self._fqn_to_index_mapping)
|
||||
)
|
||||
self._fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping
|
||||
self._save_sharded = save_sharded
|
||||
|
||||
def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]:
|
||||
return plans
|
||||
new_plans = []
|
||||
for i, plan in enumerate(plans, start=1):
|
||||
storage_data = {}
|
||||
if self._fqn_to_index_mapping is not None:
|
||||
storage_data["fqn_to_file_mapping"] = self._fqn_to_index_mapping
|
||||
if self._save_sharded:
|
||||
storage_data["shard_index"] = i
|
||||
|
||||
new_plans.append(dataclasses.replace(plan, storage_data=storage_data))
|
||||
|
||||
return new_plans
|
||||
|
||||
def write_data(
|
||||
self,
|
||||
@ -101,14 +129,20 @@ class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
return fut
|
||||
|
||||
# storage_plan is a map from key to file index
|
||||
storage_plan: dict[str, int] = plan.storage_data.fqn_to_file_index_mapping
|
||||
storage_data: dict[str, Any] = plan.storage_data
|
||||
storage_plan: Optional[dict[str, int]] = None
|
||||
shard_index: Optional[int] = None
|
||||
if "fqn_to_file_mapping" in storage_data:
|
||||
storage_plan = storage_data["fqn_to_file_mapping"]
|
||||
if "shard_index" in storage_data:
|
||||
shard_index = storage_data["shard_index"]
|
||||
|
||||
buckets = self._split_by_storage_plan(storage_plan, plan.items)
|
||||
highest_index = max(storage_plan.values())
|
||||
highest_index = max(storage_plan.values()) if storage_plan is not None else 1
|
||||
|
||||
file_queue: queue.Queue = queue.Queue()
|
||||
for file_index, write_items in buckets.items():
|
||||
file_name = self._gen_file_name(file_index, highest_index)
|
||||
file_name = _gen_file_name(file_index, highest_index, shard_index)
|
||||
file_queue.put(
|
||||
(self.fs.concat_path(self.path, file_name), file_name, write_items)
|
||||
)
|
||||
@ -116,6 +150,9 @@ class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
return super()._write_data(planner, file_queue)
|
||||
|
||||
def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None:
|
||||
if self._save_sharded:
|
||||
return
|
||||
|
||||
metadata_to_write = {}
|
||||
storage_md = {}
|
||||
total_size = 0
|
||||
@ -132,9 +169,12 @@ class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
json.dump(metadata_to_write, metadata_file, indent=2)
|
||||
|
||||
def _split_by_storage_plan(
|
||||
self, storage_plan: dict[str, int], items: list[WriteItem]
|
||||
self, storage_plan: Optional[dict[str, int]], items: list[WriteItem]
|
||||
) -> dict[int, list[WriteItem]]:
|
||||
# storage_plan is a map from key to index
|
||||
if storage_plan is None:
|
||||
return {1: items}
|
||||
|
||||
buckets = {}
|
||||
for item in items:
|
||||
key = item.index.fqn
|
||||
@ -146,14 +186,6 @@ class _HuggingFaceStorageWriter(FsspecWriter):
|
||||
|
||||
return buckets
|
||||
|
||||
def _gen_file_name(self, index: int, largest_index: int) -> str:
|
||||
return (
|
||||
FILE_NAME.format(
|
||||
cpt_idx=f"{index}".zfill(5), num_shards=f"{largest_index}".zfill(5)
|
||||
)
|
||||
+ SUFFIX
|
||||
)
|
||||
|
||||
@property
|
||||
def metadata_path(self) -> str:
|
||||
return _metadata_fn
|
||||
@ -170,48 +202,61 @@ class _HuggingFaceStorageReader(FsspecReader):
|
||||
Initialize the huggingface reader pointing to path.
|
||||
|
||||
Args:
|
||||
path: hf directory where the checkpoint will be read from. Should begin with hf://.
|
||||
path: hf directory where the checkpoint will be read from.
|
||||
Needs to have .safetensors file, but can be from any fsspec supported storage,
|
||||
including localFS and hf://.
|
||||
token: The token to use to authenticate with huggingface hub.
|
||||
"""
|
||||
from huggingface_hub import HfFileSystem # type: ignore[import-not-found]
|
||||
|
||||
if HfFileSystem.protocol not in fsspec.available_protocols():
|
||||
fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem)
|
||||
|
||||
if token is not None:
|
||||
super().__init__(path=path, token=token)
|
||||
else:
|
||||
super().__init__(path=path)
|
||||
|
||||
self.storage_data: dict[str, str] = {}
|
||||
|
||||
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
||||
from safetensors.torch import load # type: ignore[import-not-found]
|
||||
|
||||
from safetensors import deserialize # type: ignore[import-not-found]
|
||||
per_file: dict[str, list[ReadItem]] = {}
|
||||
|
||||
for read_item in plan.items:
|
||||
file_name = self.storage_data[read_item.storage_index.fqn]
|
||||
item_md: _StorageInfo = self.storage_data[read_item.storage_index]
|
||||
file_name = item_md.relative_path
|
||||
per_file.setdefault(file_name, []).append(read_item)
|
||||
|
||||
for file_name, reqs in per_file.items():
|
||||
new_path = self.fs.concat_path(self.path, file_name)
|
||||
with self.fs.create_stream(new_path, "rb") as stream:
|
||||
loaded_tensors = load(stream.read())
|
||||
for req in reqs:
|
||||
tensor = loaded_tensors[req.dest_index.fqn]
|
||||
with self.fs.create_stream(file_name, "rb") as stream:
|
||||
# TODO: make this more efficient by doing offset reads instead of a
|
||||
# full deserialization of the file
|
||||
deserialized : list[str, dict[str, dict[str, any]]] = deserialize(stream.read())
|
||||
deserialized_dict : dict[str, dict[str, dict[str, any]]] = {tensor_info[0]: tensor_info[1] for tensor_info in deserialized}
|
||||
|
||||
for req in reqs:
|
||||
tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY]
|
||||
tensor = torch.frombuffer(tensor_bytes, dtype=planner.metadata.state_dict_metadata[req.dest_index.fqn].properties.dtype)
|
||||
# TODO: update this to req.lengths once I get rid of allow_tensor_resize, shouldn't need to look at the deserialized
|
||||
# dict for metadat as we've already done that in read_metadata file
|
||||
tensor = tensor.reshape(deserialized_dict[req.dest_index.fqn][SHAPE_KEY])
|
||||
|
||||
target_tensor = planner.resolve_tensor(req)
|
||||
if (
|
||||
isinstance(planner, _HuggingFaceLoadPlanner)
|
||||
and planner.allow_tensor_resize
|
||||
):
|
||||
# this is to support the case when users are calling load on
|
||||
# an empty state dict without specifying the correct size of the tensors
|
||||
# in the state dict. Resizing is a hacky way to support this use case.
|
||||
# But will migrate users to _load_state_dict_from_keys method and deprecate this.
|
||||
target_tensor = planner.resolve_tensor(req)
|
||||
target_tensor.resize_(tensor.size())
|
||||
target_tensor = target_tensor.detach()
|
||||
else:
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"Tensor size mismatch for {req.dest_index.fqn}: {target_tensor.size()} != {tensor.size()}"
|
||||
tensor = narrow_tensor_by_index(
|
||||
tensor, req.storage_offsets, req.lengths
|
||||
)
|
||||
target_tensor = target_tensor.detach()
|
||||
target_tensor = planner.resolve_tensor(req).detach()
|
||||
|
||||
assert target_tensor.size() == tensor.size(), (
|
||||
f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}"
|
||||
)
|
||||
|
||||
target_tensor.copy_(tensor)
|
||||
planner.commit_tensor(req, target_tensor)
|
||||
|
||||
@ -220,37 +265,57 @@ class _HuggingFaceStorageReader(FsspecReader):
|
||||
return fut
|
||||
|
||||
def read_metadata(self) -> Metadata:
|
||||
metadata_path = self.fs.concat_path(self.path, _metadata_fn)
|
||||
|
||||
state_dict_metadata: dict[str, STORAGE_TYPES] = {}
|
||||
storage_data: dict[str, str] = {}
|
||||
storage_data: dict[MetadataIndex, _StorageInfo] = {}
|
||||
|
||||
if not self.fs.exists(metadata_path):
|
||||
# if metadata file doesn't exist, create it from the safetensors file
|
||||
safetensors_files = []
|
||||
for file in self.fs.ls(self.path):
|
||||
if file.endswith(SUFFIX):
|
||||
safetensors_files.append(file)
|
||||
safetensors_files = []
|
||||
for file in self.fs.ls(self.path):
|
||||
if file.endswith(SUFFIX):
|
||||
safetensors_files.append(file)
|
||||
|
||||
if len(safetensors_files) != 1:
|
||||
raise ValueError(
|
||||
f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files"
|
||||
)
|
||||
storage_data = {}
|
||||
with self.fs.create_stream(safetensors_files[0], "rb") as f:
|
||||
keys = _get_safetensors_file_keys(f)
|
||||
for safetensor_file in safetensors_files:
|
||||
with self.fs.create_stream(safetensor_file, "rb") as f:
|
||||
safetensors_metadata = _get_safetensors_file_metadata(f)
|
||||
custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY)
|
||||
|
||||
for key in keys:
|
||||
state_dict_metadata[key] = BytesStorageMetadata()
|
||||
storage_data[key] = os.path.basename(safetensors_files[0])
|
||||
else:
|
||||
with self.fs.create_stream(metadata_path, "r") as metadata_file:
|
||||
metadata = json.load(metadata_file)
|
||||
dcp_sharding_info = None
|
||||
if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY):
|
||||
dcp_sharding_info = json.loads(custom_metadata.get(CUSTOM_METADATA_KEY))
|
||||
|
||||
for key in metadata["weight_map"].keys():
|
||||
state_dict_metadata[key] = BytesStorageMetadata()
|
||||
storage_data = metadata["weight_map"]
|
||||
for key, val in safetensors_metadata.items():
|
||||
if key == DEFAULT_EXTRA_METADATA_KEY:
|
||||
continue
|
||||
|
||||
# construct state_dict_metadata
|
||||
if dcp_sharding_info is not None:
|
||||
offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY]
|
||||
else:
|
||||
offset = [0] * len(val[SHAPE_KEY])
|
||||
|
||||
if key not in state_dict_metadata:
|
||||
state_dict_metadata[key] = TensorStorageMetadata(
|
||||
properties=TensorProperties(dtype=_get_dtype(val[DTYPE_KEY])),
|
||||
size=torch.Size([saved + offset for saved, offset in zip(val[SHAPE_KEY], offset)]),
|
||||
chunks=[ChunkStorageMetadata(offsets=torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]))],
|
||||
)
|
||||
else:
|
||||
state_dict_metadata[key].chunks.append(ChunkStorageMetadata(torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY])))
|
||||
size = list(state_dict_metadata[key].size)
|
||||
for i in range(len(size)):
|
||||
size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i])
|
||||
state_dict_metadata[key].size = torch.Size(size)
|
||||
|
||||
# construct storage data
|
||||
if dcp_sharding_info is not None:
|
||||
metadata_index = MetadataIndex(fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY])
|
||||
else:
|
||||
metadata_index = MetadataIndex(fqn=key, offset=[0] * len(val[SHAPE_KEY]))
|
||||
storage_data[metadata_index] = _StorageInfo(
|
||||
safetensor_file,
|
||||
val[DATA_OFFSETS_KEY][0],
|
||||
val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0],
|
||||
)
|
||||
|
||||
metadata = Metadata(
|
||||
state_dict_metadata=state_dict_metadata,
|
||||
storage_data=storage_data,
|
||||
@ -262,8 +327,21 @@ class _HuggingFaceStorageReader(FsspecReader):
|
||||
|
||||
return metadata
|
||||
|
||||
def _gen_file_name(index: int, largest_index: int, shard_index: Optional[int] = None) -> str:
|
||||
if shard_index is not None:
|
||||
return SHARDED_FILE_NAME.format(
|
||||
shard_idx=f"{shard_index}".zfill(5), cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5)
|
||||
) + SUFFIX
|
||||
else:
|
||||
return (
|
||||
FILE_NAME.format(
|
||||
cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5)
|
||||
)
|
||||
+ SUFFIX
|
||||
)
|
||||
|
||||
def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]:
|
||||
|
||||
def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> Any:
|
||||
# this uses the same logic that's done in HF code base
|
||||
# https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308
|
||||
# and follows their documentation on how their files are serialized
|
||||
@ -273,4 +351,19 @@ def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]:
|
||||
header_len = struct.unpack("<Q", header_len_bytes)[0]
|
||||
header_json = file_bytes.read(header_len)
|
||||
metadata = json.loads(header_json)
|
||||
return list(metadata.keys())
|
||||
return metadata
|
||||
|
||||
def _get_dtype(dtype: str) -> torch.dtype:
|
||||
try:
|
||||
dtype = DTYPE_MAP[dtype]
|
||||
except KeyError:
|
||||
dtype = torch.get_default_dtype
|
||||
|
||||
return dtype
|
||||
|
||||
def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]:
|
||||
if DEFAULT_EXTRA_METADATA_KEY in metadata:
|
||||
custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY]
|
||||
if CUSTOM_METADATA_KEY in custom_metadata:
|
||||
return json.loads(custom_metadata[CUSTOM_METADATA_KEY])
|
||||
return None
|
||||
|
@ -442,7 +442,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
|
||||
if isinstance(v, TensorStorageMetadata):
|
||||
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
|
||||
if k in metadata.planner_data:
|
||||
if metadata.planner_data is not None and k in metadata.planner_data:
|
||||
set_element(state_dict, metadata.planner_data[k], v)
|
||||
else:
|
||||
state_dict[k] = v
|
||||
@ -451,7 +451,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
|
||||
|
||||
def create_default_local_load_plan(
|
||||
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True
|
||||
state_dict: dict[str, Any], metadata: Metadata, strict: bool = True, check_md_size: bool = True
|
||||
) -> LoadPlan:
|
||||
requests = []
|
||||
"""
|
||||
@ -477,6 +477,7 @@ def create_default_local_load_plan(
|
||||
isinstance(md, TensorStorageMetadata)
|
||||
and getattr(obj, "size", None) is not None
|
||||
and md.size != obj.size()
|
||||
and check_md_size
|
||||
):
|
||||
raise ValueError(
|
||||
f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}",
|
||||
|
@ -2,6 +2,7 @@
|
||||
import collections
|
||||
import dataclasses
|
||||
import io
|
||||
import json
|
||||
import operator
|
||||
import os
|
||||
import pickle
|
||||
@ -416,6 +417,7 @@ def _write_files_from_queue(
|
||||
)
|
||||
|
||||
tensor_dict = {}
|
||||
metadata_dict = {}
|
||||
for tensor, write_item in loader.values():
|
||||
assert tensor.is_cpu
|
||||
write_results.append(
|
||||
@ -429,11 +431,12 @@ def _write_files_from_queue(
|
||||
)
|
||||
)
|
||||
tensor_dict[write_item.index.fqn] = tensor
|
||||
metadata_dict[write_item.index.fqn] = {"saved_offsets": write_item.tensor_data.chunk.offsets}
|
||||
|
||||
if serialization_format == SerializationFormat.SAFETENSORS:
|
||||
from safetensors.torch import save # type: ignore[import-not-found]
|
||||
|
||||
stream.write(save(tensor_dict))
|
||||
stream.write(save(tensor_dict, metadata={"DCP_SHARDING_INFO": json.dumps(metadata_dict), "DCP_VERSION": "1.0"}))
|
||||
|
||||
if use_fsync:
|
||||
try:
|
||||
|
@ -0,0 +1,658 @@
|
||||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
|
||||
# pyre-strict
|
||||
|
||||
"""
|
||||
This script consolidates distributed checkpoint (DCP) HuggingFace safetensors files.
|
||||
|
||||
It takes sharded safetensors files created by DCP and combines them into one or more
|
||||
consolidated files. This is useful for converting distributed model checkpoints into
|
||||
a format that can be loaded by standard HuggingFace tools. The combination is done
|
||||
through a simple metadata parsing and tensor data copying process.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import math
|
||||
import struct
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import fsspec
|
||||
import torch
|
||||
|
||||
from fsspec.core import url_to_fs
|
||||
from safetensors import deserialize
|
||||
from torch.distributed.checkpoint._hf_storage import (
|
||||
_gen_file_name,
|
||||
_get_dcp_custom_metadata,
|
||||
_get_dtype,
|
||||
_get_safetensors_file_metadata,
|
||||
DATA_KEY,
|
||||
DATA_OFFSETS_KEY,
|
||||
DEFAULT_EXTRA_METADATA_KEY,
|
||||
DTYPE_KEY,
|
||||
SAVED_OFFSETS_KEY,
|
||||
SHAPE_KEY,
|
||||
SUFFIX,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _FqnData:
|
||||
"""
|
||||
Dataclass to store information about a tensor (identified by its fully qualified name).
|
||||
|
||||
Attributes:
|
||||
offset_in_file: Byte offset where this tensor's data begins in the output file
|
||||
shape_in_file: Shape of the tensor in the output file
|
||||
dtype_size: Size of the tensor's data type in bytes
|
||||
dtype_str: String representation of the tensor's data type
|
||||
"""
|
||||
|
||||
offset_in_file: int = 0
|
||||
shape_in_file: list[int] = field(default_factory=list)
|
||||
dtype_size: int = 0
|
||||
dtype_str: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OutputFileData:
|
||||
"""
|
||||
Dataclass to store information about an output safetensors file.
|
||||
|
||||
Attributes:
|
||||
metadata_size: Size of the metadata section in bytes
|
||||
fqn_data: Dictionary mapping tensor names to their metadata
|
||||
"""
|
||||
|
||||
metadata_size: int = 0
|
||||
fqn_data: dict[str, _FqnData] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _parse_input_metadata(
|
||||
safetensors_metadatas: List[Any], output_files_data: dict[str, _OutputFileData]
|
||||
) -> None:
|
||||
"""
|
||||
Parse metadata from input safetensors files to determine the full tensor shapes and types.
|
||||
|
||||
This function analyzes the metadata from all input files to determine the complete shape
|
||||
of each tensor after consolidation. It updates the output_files_data with this information.
|
||||
|
||||
Args:
|
||||
safetensors_metadatas: List of metadata from input safetensors files
|
||||
output_files_data: Dictionary mapping output file paths to their metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If no DCP custom metadata is found in a safetensors file
|
||||
"""
|
||||
# Dictionary to track the full size of each tensor across all shards
|
||||
fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {}
|
||||
|
||||
for safetensors_metadata in safetensors_metadatas:
|
||||
dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata)
|
||||
if not dcp_sharding_info:
|
||||
raise ValueError(
|
||||
"No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated."
|
||||
)
|
||||
|
||||
for key, val in safetensors_metadata.items():
|
||||
if key == DEFAULT_EXTRA_METADATA_KEY:
|
||||
continue
|
||||
|
||||
# Get the shape of this tensor shard and its offset in the full tensor
|
||||
sizes = val[SHAPE_KEY]
|
||||
offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY]
|
||||
|
||||
if key not in fqn_to_size_mapping:
|
||||
# First time seeing this tensor - calculate its full size by adding offsets to dimensions
|
||||
cur_size = [size + offset for size, offset in zip(sizes, offsets)]
|
||||
fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY])
|
||||
else:
|
||||
# We've seen this tensor before - update its size if this shard extends beyond current known dimensions
|
||||
cur_size = fqn_to_size_mapping[key][0]
|
||||
for i in range(len(sizes)):
|
||||
cur_size[i] = max(cur_size[i], sizes[i] + offsets[i])
|
||||
|
||||
# Now that we know the full size of each tensor, populate the output file data
|
||||
for fqn, tensor_info in fqn_to_size_mapping.items():
|
||||
tensor_size = tensor_info[0]
|
||||
dtype_str = tensor_info[1]
|
||||
for _, output_data in output_files_data.items():
|
||||
# Add this tensor to the output file if it's already assigned there or if we're using a single output file
|
||||
if fqn in output_data.fqn_data or len(output_files_data) == 1:
|
||||
output_data.fqn_data[fqn] = _FqnData(
|
||||
shape_in_file=tensor_size,
|
||||
dtype_size=torch.finfo(_get_dtype(dtype_str)).bits
|
||||
// 8, # Convert bits to bytes
|
||||
dtype_str=dtype_str,
|
||||
)
|
||||
|
||||
|
||||
def _write_metadata(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
output_files_data: dict[str, _OutputFileData],
|
||||
) -> None:
|
||||
"""
|
||||
Write metadata to the beginning of each output safetensors file.
|
||||
|
||||
This function writes the metadata section to each output file, including information
|
||||
about tensor shapes, data types, and offsets. It also updates the offset_in_file
|
||||
field for each tensor in the output_files_data.
|
||||
|
||||
Args:
|
||||
fs: Filesystem interface for file operations
|
||||
output_files_data: Dictionary mapping output file paths to their metadata
|
||||
"""
|
||||
# Process each output file
|
||||
for file_path, output_data in output_files_data.items():
|
||||
with fs.open(file_path, "wb") as f:
|
||||
metadata = {}
|
||||
curr_offset = 0
|
||||
|
||||
# Calculate offsets for each tensor in the file
|
||||
for fqn, fqn_data in output_data.fqn_data.items():
|
||||
# Calculate the end offset by multiplying all dimensions and the data type size
|
||||
end_offset = (
|
||||
curr_offset
|
||||
+ math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size
|
||||
)
|
||||
|
||||
# Store metadata for this tensor
|
||||
metadata[fqn] = {
|
||||
SHAPE_KEY: fqn_data.shape_in_file,
|
||||
DTYPE_KEY: fqn_data.dtype_str,
|
||||
DATA_OFFSETS_KEY: [
|
||||
curr_offset,
|
||||
end_offset,
|
||||
], # Start and end byte offsets
|
||||
}
|
||||
# Store the offset for later use when writing the actual tensor data
|
||||
fqn_data.offset_in_file = curr_offset
|
||||
|
||||
# Update current offset for the next tensor
|
||||
curr_offset = end_offset
|
||||
|
||||
# Convert metadata to JSON and encode as bytes
|
||||
json_metadata = json.dumps(metadata)
|
||||
json_bytes = json_metadata.encode("utf-8")
|
||||
|
||||
# Write the metadata size as an 8-byte unsigned integer (little-endian)
|
||||
size_in_bytes = len(json_bytes)
|
||||
header_len = struct.pack("<Q", size_in_bytes)
|
||||
|
||||
# Write the header length and metadata to the file
|
||||
f.write(header_len)
|
||||
f.write(json_bytes)
|
||||
|
||||
# Store the total metadata size (header + JSON) for later use
|
||||
output_data.metadata_size = f.tell()
|
||||
|
||||
|
||||
def _write_data(
|
||||
input_fs: fsspec.AbstractFileSystem,
|
||||
output_fs: fsspec.AbstractFileSystem,
|
||||
input_safetensors_files: List[str],
|
||||
input_metadatas: dict[str, Any],
|
||||
output_files_data: dict[str, _OutputFileData],
|
||||
) -> None:
|
||||
"""
|
||||
Write tensor data from input files to the output files.
|
||||
|
||||
This function reads tensor data from each input file and writes it to the appropriate
|
||||
position in the output files based on the tensor's offsets.
|
||||
|
||||
Args:
|
||||
fs: Filesystem interface for file operations
|
||||
input_safetensors_files: List of input safetensors file paths
|
||||
input_metadatas: Dictionary mapping input file paths to their metadata
|
||||
output_files_data: Dictionary mapping output file paths to their metadata
|
||||
"""
|
||||
# Process each input safetensors file
|
||||
for safetensors_file in input_safetensors_files:
|
||||
with input_fs.open(safetensors_file, "rb") as f:
|
||||
# Deserialize the safetensors file to get tensor data
|
||||
deserialized = deserialize(f.read())
|
||||
|
||||
# Process each tensor in the file
|
||||
for fqn, val in deserialized:
|
||||
# Get the tensor data as bytes
|
||||
data_to_write = val[DATA_KEY]
|
||||
|
||||
# Get the offsets of this tensor shard within the full tensor
|
||||
offsets_of_tensor_being_read = _get_dcp_custom_metadata(
|
||||
input_metadatas[safetensors_file]
|
||||
)[fqn][SAVED_OFFSETS_KEY]
|
||||
|
||||
# Find which output file(s) this tensor belongs to
|
||||
for output_file, output_data in output_files_data.items():
|
||||
# Skip if this tensor doesn't belong in this output file
|
||||
if fqn not in output_data.fqn_data:
|
||||
continue
|
||||
|
||||
# Get metadata for this tensor in the output file
|
||||
fqn_data = output_data.fqn_data[fqn]
|
||||
|
||||
# Write this tensor shard to the appropriate position in the output file
|
||||
_write_sub_tensor_to_file(
|
||||
output_fs,
|
||||
data_to_write,
|
||||
fqn_data.dtype_size, # Size of each element in bytes
|
||||
fqn_data.shape_in_file, # Full tensor shape
|
||||
offsets_of_tensor_being_read, # Where this shard belongs in the full tensor
|
||||
val[SHAPE_KEY], # Shape of this shard
|
||||
output_file,
|
||||
# Calculate the exact byte position where this tensor data should start
|
||||
output_data.metadata_size + fqn_data.offset_in_file,
|
||||
)
|
||||
|
||||
|
||||
def _write_row_wise_tensor(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
sub_tensor_bytes: bytearray,
|
||||
element_size: int,
|
||||
full_tensor_strides: list[int],
|
||||
sub_tensor_strides: list[int],
|
||||
sub_tensor_offsets: list[int],
|
||||
sub_tensor_shape: list[int],
|
||||
output_file_path: str,
|
||||
output_start_byte: int,
|
||||
):
|
||||
"""
|
||||
Writes a row-wise sharded tensor to the output file.
|
||||
|
||||
This is an optimized path for tensors that are sharded along the first dimension,
|
||||
with all other dimensions being complete. This allows writing entire rows at once.
|
||||
|
||||
Args:
|
||||
fs: Filesystem interface for file operations
|
||||
sub_tensor_bytes: Byte array containing the sub-tensor data
|
||||
element_size: The size of each element in bytes
|
||||
full_tensor_strides: Strides of the full tensor
|
||||
sub_tensor_strides: Strides of the sub-tensor
|
||||
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor
|
||||
sub_tensor_shape: The shape of the sub-tensor
|
||||
output_file_path: The path to the file where the full tensor is stored
|
||||
output_start_byte: The starting byte of the full tensor in the file
|
||||
"""
|
||||
# Open the output file in read+binary mode to allow seeking and writing
|
||||
with fs.open(output_file_path, "r+b") as out_f:
|
||||
# Calculate the number of elements in each row
|
||||
elements_per_row = full_tensor_strides[
|
||||
0
|
||||
] # This is the stride of the first dimension
|
||||
|
||||
# For each row in the sub-tensor
|
||||
for row_idx in range(sub_tensor_shape[0]):
|
||||
# Calculate the row index in the full tensor
|
||||
full_row_idx = sub_tensor_offsets[0] + row_idx
|
||||
|
||||
# Calculate the position in the full tensor
|
||||
full_pos = full_row_idx * full_tensor_strides[0]
|
||||
full_byte_offset = output_start_byte + full_pos * element_size
|
||||
|
||||
# Calculate the position in the sub-tensor
|
||||
sub_pos = row_idx * sub_tensor_strides[0]
|
||||
sub_byte_offset = sub_pos * element_size
|
||||
|
||||
# Extract the row data from the sub-tensor
|
||||
row_size = elements_per_row * element_size
|
||||
row_data = sub_tensor_bytes[sub_byte_offset : sub_byte_offset + row_size]
|
||||
|
||||
# Seek to the correct position in the output file and write the data
|
||||
out_f.seek(full_byte_offset)
|
||||
out_f.write(row_data)
|
||||
|
||||
|
||||
def _write_column_wise_tensor(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
sub_tensor_bytes: bytearray,
|
||||
element_size: int,
|
||||
tensor_shape: list[int],
|
||||
sub_tensor_offsets: list[int],
|
||||
sub_tensor_shape: list[int],
|
||||
output_file_path: str,
|
||||
output_start_byte: int,
|
||||
):
|
||||
"""
|
||||
Writes a column-wise sharded 2D tensor to the output file.
|
||||
|
||||
This is an optimized path for 2D tensors that are sharded along the second dimension,
|
||||
with the first dimension being complete. This requires writing column by column.
|
||||
|
||||
Args:
|
||||
fs: Filesystem interface for file operations
|
||||
sub_tensor_bytes: Byte array containing the sub-tensor data
|
||||
element_size: The size of each element in bytes
|
||||
tensor_shape: The shape of the overall tensor
|
||||
sub_tensor_strides: Strides of the sub-tensor
|
||||
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor
|
||||
sub_tensor_shape: The shape of the sub-tensor
|
||||
output_file_path: The path to the file where the full tensor is stored
|
||||
output_start_byte: The starting byte of the full tensor in the file
|
||||
"""
|
||||
# Open the output file in read+binary mode to allow seeking and writing
|
||||
with fs.open(output_file_path, "r+b") as out_f:
|
||||
# For each column in the sub-tensor
|
||||
for col_idx in range(sub_tensor_shape[1]):
|
||||
# Calculate the column index in the full tensor
|
||||
full_col_idx = sub_tensor_offsets[1] + col_idx
|
||||
|
||||
# For each row in the column
|
||||
for row_idx in range(sub_tensor_shape[0]):
|
||||
# Calculate the position in the full tensor
|
||||
full_pos = row_idx * tensor_shape[1] + full_col_idx
|
||||
full_byte_offset = output_start_byte + full_pos * element_size
|
||||
|
||||
# Calculate the position in the sub-tensor
|
||||
sub_pos = row_idx * sub_tensor_shape[1] + col_idx
|
||||
sub_byte_offset = sub_pos * element_size
|
||||
|
||||
# Extract the element data from the sub-tensor
|
||||
element_data = sub_tensor_bytes[
|
||||
sub_byte_offset : sub_byte_offset + element_size
|
||||
]
|
||||
|
||||
# Seek to the correct position in the output file and write the data
|
||||
out_f.seek(full_byte_offset)
|
||||
out_f.write(element_data)
|
||||
|
||||
|
||||
def _write_element_by_element(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
sub_tensor_bytes: bytearray,
|
||||
element_size: int,
|
||||
tensor_shape: list[int],
|
||||
full_tensor_strides: list[int],
|
||||
sub_tensor_strides: list[int],
|
||||
sub_tensor_offsets: list[int],
|
||||
sub_tensor_shape: list[int],
|
||||
output_file_path: str,
|
||||
output_start_byte: int,
|
||||
):
|
||||
"""
|
||||
Writes a sub-tensor to the output file using a general element-by-element approach.
|
||||
|
||||
This is a general approach that works for any sharding pattern, but is less efficient
|
||||
than the specialized approaches for row-wise or column-wise sharding.
|
||||
|
||||
Args:
|
||||
fs: Filesystem interface for file operations
|
||||
sub_tensor_bytes: Byte array containing the sub-tensor data
|
||||
element_size: The size of each element in bytes
|
||||
tensor_shape: The shape of the overall tensor
|
||||
full_tensor_strides: Strides of the full tensor
|
||||
sub_tensor_strides: Strides of the sub-tensor
|
||||
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor
|
||||
sub_tensor_shape: The shape of the sub-tensor
|
||||
output_file_path: The path to the file where the full tensor is stored
|
||||
output_start_byte: The starting byte of the full tensor in the file
|
||||
"""
|
||||
# Open the output file in read+binary mode to allow seeking and writing
|
||||
with fs.open(output_file_path, "r+b") as out_f:
|
||||
# Create a list to hold the current indices for each dimension
|
||||
indices = [0] * len(tensor_shape)
|
||||
|
||||
# Calculate the total number of elements in the sub-tensor
|
||||
total_elements = 1
|
||||
for dim_size in sub_tensor_shape:
|
||||
total_elements *= dim_size
|
||||
|
||||
# Process each element in the sub-tensor
|
||||
for element_idx in range(total_elements):
|
||||
# Calculate the indices for this element in the sub-tensor
|
||||
sub_idx = element_idx
|
||||
for dim in range(len(sub_tensor_shape) - 1, -1, -1):
|
||||
indices[dim] = sub_idx % sub_tensor_shape[dim]
|
||||
sub_idx //= sub_tensor_shape[dim]
|
||||
|
||||
# Calculate the position of this element in the sub-tensor's byte array
|
||||
sub_pos = 0
|
||||
for dim in range(len(sub_tensor_shape)):
|
||||
sub_pos += indices[dim] * sub_tensor_strides[dim]
|
||||
sub_byte_offset = sub_pos * element_size
|
||||
|
||||
# Calculate the position of this element in the full tensor
|
||||
full_pos = 0
|
||||
for dim in range(len(tensor_shape)):
|
||||
# The global index is the local index plus the offset for this dimension
|
||||
global_idx = indices[dim] + sub_tensor_offsets[dim]
|
||||
full_pos += global_idx * full_tensor_strides[dim]
|
||||
full_byte_offset = output_start_byte + full_pos * element_size
|
||||
|
||||
# Extract the element data from the sub-tensor
|
||||
element_data = sub_tensor_bytes[
|
||||
sub_byte_offset : sub_byte_offset + element_size
|
||||
]
|
||||
|
||||
# Seek to the correct position in the output file and write the data
|
||||
out_f.seek(full_byte_offset)
|
||||
out_f.write(element_data)
|
||||
|
||||
|
||||
def _write_sub_tensor_to_file(
|
||||
fs: fsspec.AbstractFileSystem,
|
||||
sub_tensor_bytes: bytearray,
|
||||
element_size: int,
|
||||
tensor_shape: list[int],
|
||||
sub_tensor_offsets: list[int],
|
||||
sub_tensor_shape: list[int],
|
||||
output_file_path: str,
|
||||
output_start_byte: int,
|
||||
):
|
||||
"""
|
||||
Writes a sub-tensor from a byte array into a file representing the full tensor at specified offsets.
|
||||
|
||||
This function handles the complex task of placing a tensor shard (sub-tensor) at the correct
|
||||
position within the consolidated tensor file. It works by calculating the exact byte offsets
|
||||
for each slice of data and writing them to the appropriate positions. This implementation
|
||||
supports tensors of any dimensionality with optimized paths for common sharding patterns:
|
||||
- Row-wise sharding (optimized path)
|
||||
- Column-wise sharding for 2D tensors (optimized path)
|
||||
- Any other arbitrary sharding pattern (general element-by-element approach)
|
||||
|
||||
Args:
|
||||
sub_tensor_bytes: Byte array containing the sub-tensor data
|
||||
element_size: The size of each element in bytes
|
||||
tensor_shape: The shape of the overall tensor (list)
|
||||
sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor (list)
|
||||
sub_tensor_shape: The shape of the sub-tensor (list)
|
||||
output_file_path: The path to the file where the full tensor is stored
|
||||
output_start_byte: The starting byte of the full tensor in the file
|
||||
"""
|
||||
# Handle the case of empty tensors
|
||||
if not tensor_shape or not sub_tensor_shape:
|
||||
return
|
||||
|
||||
# Calculate strides for the full tensor (row-major order, C-style)
|
||||
# Stride is the number of elements to skip to move to the next element in that dimension
|
||||
full_tensor_strides = [1] * len(tensor_shape)
|
||||
for i in range(len(tensor_shape) - 2, -1, -1):
|
||||
full_tensor_strides[i] = full_tensor_strides[i + 1] * tensor_shape[i + 1]
|
||||
|
||||
# Calculate strides for the sub-tensor (row-major order, C-style)
|
||||
sub_tensor_strides = [1] * len(sub_tensor_shape)
|
||||
for i in range(len(sub_tensor_shape) - 2, -1, -1):
|
||||
sub_tensor_strides[i] = sub_tensor_strides[i + 1] * sub_tensor_shape[i + 1]
|
||||
|
||||
# Check if this is a row-wise sharded tensor
|
||||
# Row-wise sharding is detected when the last dimension is complete
|
||||
# and only the first dimension is partial
|
||||
is_row_wise = False
|
||||
if len(tensor_shape) >= 2:
|
||||
# Check if all dimensions except the first are complete
|
||||
all_other_dims_complete = True
|
||||
for i in range(1, len(tensor_shape)):
|
||||
if sub_tensor_shape[i] != tensor_shape[i]:
|
||||
all_other_dims_complete = False
|
||||
break
|
||||
|
||||
# Row-wise sharding: first dimension is partial, all others are complete
|
||||
is_row_wise = all_other_dims_complete and sub_tensor_shape[0] < tensor_shape[0]
|
||||
|
||||
# Check if this is a column-wise sharded 2D tensor
|
||||
# Column-wise sharding is detected when the first dimension is complete
|
||||
# and the second dimension is partial (only for 2D tensors)
|
||||
is_column_wise = False
|
||||
if len(tensor_shape) == 2:
|
||||
is_column_wise = (
|
||||
sub_tensor_shape[0] == tensor_shape[0]
|
||||
and sub_tensor_shape[1] < tensor_shape[1]
|
||||
)
|
||||
|
||||
# Call the appropriate function based on the sharding pattern
|
||||
if is_row_wise:
|
||||
_write_row_wise_tensor(
|
||||
fs,
|
||||
sub_tensor_bytes,
|
||||
element_size,
|
||||
full_tensor_strides,
|
||||
sub_tensor_strides,
|
||||
sub_tensor_offsets,
|
||||
sub_tensor_shape,
|
||||
output_file_path,
|
||||
output_start_byte,
|
||||
)
|
||||
elif is_column_wise:
|
||||
_write_column_wise_tensor(
|
||||
fs,
|
||||
sub_tensor_bytes,
|
||||
element_size,
|
||||
tensor_shape,
|
||||
sub_tensor_offsets,
|
||||
sub_tensor_shape,
|
||||
output_file_path,
|
||||
output_start_byte,
|
||||
)
|
||||
else:
|
||||
_write_element_by_element(
|
||||
fs,
|
||||
sub_tensor_bytes,
|
||||
element_size,
|
||||
tensor_shape,
|
||||
full_tensor_strides,
|
||||
sub_tensor_strides,
|
||||
sub_tensor_offsets,
|
||||
sub_tensor_shape,
|
||||
output_file_path,
|
||||
output_start_byte,
|
||||
)
|
||||
|
||||
|
||||
def consolidate_safetensors_files(
|
||||
input_dir: str,
|
||||
output_dir: str,
|
||||
fqn_to_index_mapping: Optional[dict[str, int]] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Main function to consolidate sharded safetensors files into one or more output files.
|
||||
|
||||
This function orchestrates the entire consolidation process:
|
||||
1. Sets up the output file structure based on the fqn_to_index_mapping
|
||||
2. Finds all safetensors files in the input directory
|
||||
3. Parses metadata from all input files
|
||||
4. Writes metadata to the output files
|
||||
5. Writes tensor data from input files to output files
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing sharded safetensors files
|
||||
output_dir: Directory where consolidated files will be written
|
||||
fqn_to_index_mapping: Optional mapping of tensor names to output file indices.
|
||||
If None, all tensors will be consolidated into a single file.
|
||||
"""
|
||||
# Create filesystem using fsspec for file operations
|
||||
input_fs, _ = url_to_fs(input_dir)
|
||||
output_fs, _ = url_to_fs(output_dir)
|
||||
|
||||
# Initialize the output file structure
|
||||
output_files_data: dict[str, _OutputFileData] = {}
|
||||
if fqn_to_index_mapping is None:
|
||||
# If no mapping is provided, create a single output file
|
||||
file_name = _gen_file_name(1, 1) # Generate name like "model.safetensors"
|
||||
output_path = f"{output_dir}/{file_name}"
|
||||
output_files_data[output_path] = _OutputFileData()
|
||||
else:
|
||||
# Create multiple output files based on the provided mapping
|
||||
for fqn, index in fqn_to_index_mapping.items():
|
||||
# Generate names like "model-00001-of-00005.safetensors"
|
||||
file_name = _gen_file_name(index, max(fqn_to_index_mapping.values()))
|
||||
output_path = f"{output_dir}/{file_name}"
|
||||
|
||||
# Create output file data structure if it doesn't exist yet
|
||||
if output_path not in output_files_data:
|
||||
output_files_data[output_path] = _OutputFileData(
|
||||
fqn_data={fqn: _FqnData()}
|
||||
)
|
||||
else:
|
||||
output_files_data[output_path].fqn_data[fqn] = _FqnData()
|
||||
|
||||
# Find all safetensors files in the input directory
|
||||
safetensors_files = []
|
||||
for file in input_fs.ls(input_dir, detail=False):
|
||||
if file.endswith(SUFFIX):
|
||||
safetensors_files.append(file)
|
||||
|
||||
# Read metadata from all input files
|
||||
input_safetensors_metadatas = {}
|
||||
for safetensor_file in safetensors_files:
|
||||
with input_fs.open(safetensor_file, "rb") as f:
|
||||
input_safetensors_metadatas[safetensor_file] = (
|
||||
_get_safetensors_file_metadata(f)
|
||||
)
|
||||
|
||||
# Step 1: Parse metadata to determine tensor shapes and types
|
||||
_parse_input_metadata(input_safetensors_metadatas.values(), output_files_data)
|
||||
|
||||
# Step 2: Write metadata headers to output files
|
||||
_write_metadata(output_fs, output_files_data)
|
||||
|
||||
# Step 3: Write actual tensor data from input files to output files
|
||||
_write_data(
|
||||
input_fs,
|
||||
output_fs,
|
||||
safetensors_files,
|
||||
input_safetensors_metadatas,
|
||||
output_files_data,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""
|
||||
Command-line entry point for the consolidation script.
|
||||
|
||||
Parses command-line arguments and calls consolidate_safetensors_files with the provided parameters.
|
||||
"""
|
||||
# Set up command-line argument parser
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Consolidate DCP sharded HuggingFace safetensors files"
|
||||
)
|
||||
|
||||
# Define required and optional arguments
|
||||
parser.add_argument(
|
||||
"input_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to directory containing sharded safetensors files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to write consolidated safetensors files. Must be different from input path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"fqn_to_index_mapping",
|
||||
type=dict[str, int],
|
||||
required=False,
|
||||
help="Mapping of which tensor names should belong to which consolidated file. If not provided, all tensors will be consolidated into one file. Expects numbers from 1 to N, where N is the number of files.",
|
||||
)
|
||||
|
||||
# Parse arguments and call the main function
|
||||
args = parser.parse_args()
|
||||
consolidate_safetensors_files(
|
||||
args.input_path, args.output_path, args.fqn_to_index_mapping
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
195
torch/distributed/checkpoint/scripts/benchmarking.py
Normal file
195
torch/distributed/checkpoint/scripts/benchmarking.py
Normal file
@ -0,0 +1,195 @@
|
||||
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.
|
||||
|
||||
# pyre-strict
|
||||
|
||||
"""
|
||||
This script demonstrates how to:
|
||||
1. Load tensors
|
||||
2. Split them into d-tensors (distributed tensors)
|
||||
3. Save them using DCP (Distributed Checkpoint) with HuggingFace integration
|
||||
4. Run the consolidation function to consolidate the sharded files back to full tensors
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import time
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed.checkpoint import consolidate_safetensors_files
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
|
||||
|
||||
|
||||
def init_distributed():
|
||||
"""Initialize the distributed environment."""
|
||||
if not dist.is_available() or not dist.is_initialized():
|
||||
if torch.cuda.is_available():
|
||||
dist.init_process_group(backend="nccl")
|
||||
else:
|
||||
dist.init_process_group(backend="gloo")
|
||||
|
||||
# Set device to current rank
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(dist.get_rank())
|
||||
|
||||
return dist.get_world_size(), dist.get_rank()
|
||||
|
||||
|
||||
def load_and_split_tensors(model_path: str, device_mesh):
|
||||
"""
|
||||
Load tensors from a model and split them into d-tensors.
|
||||
|
||||
Args:
|
||||
model_path: Path to the model to load
|
||||
device_mesh: Device mesh to distribute tensors on
|
||||
|
||||
Returns:
|
||||
state_dict with distributed tensors
|
||||
"""
|
||||
# Load the model
|
||||
if model_path.endswith(".safetensors"):
|
||||
from safetensors.torch import load_file
|
||||
|
||||
state_dict = load_file(model_path)
|
||||
|
||||
# Create distributed tensors
|
||||
distributed_state_dict = {}
|
||||
for key, tensor in state_dict.items():
|
||||
# Choose sharding strategy based on tensor size
|
||||
if tensor.dim() >= 2 and tensor.size(0) > 10:
|
||||
# Shard along the first dimension for large tensors
|
||||
placements = [Shard(0)]
|
||||
else:
|
||||
# Replicate small tensors
|
||||
placements = [Replicate()]
|
||||
|
||||
# Distribute the tensor
|
||||
dtensor = distribute_tensor(tensor, device_mesh, placements=placements)
|
||||
distributed_state_dict[key] = dtensor
|
||||
|
||||
return distributed_state_dict, state_dict
|
||||
|
||||
|
||||
def save_with_dcp_huggingface(state_dict, output_dir: str):
|
||||
"""
|
||||
Save the distributed state dict using DCP with HuggingFace integration.
|
||||
|
||||
Args:
|
||||
state_dict: State dict with distributed tensors
|
||||
output_dir: Directory to save the checkpoint
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save the distributed checkpoint
|
||||
dist_cp.save(
|
||||
state_dict=state_dict,
|
||||
storage_writer=dist_cp._HuggingFaceStorageWriter(
|
||||
path=output_dir,
|
||||
save_sharded=True, # Save as sharded files
|
||||
),
|
||||
)
|
||||
|
||||
print(f"Saved distributed checkpoint to {output_dir}")
|
||||
|
||||
|
||||
def save_safetensors(state_dict, output_dir: str):
|
||||
"""
|
||||
Save the state dict as safetensors.
|
||||
|
||||
Args:
|
||||
state_dict: State dict with tensors
|
||||
output_dir: Directory to save the checkpoint
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Save the checkpoint
|
||||
from safetensors.torch import save_file
|
||||
|
||||
save_file(state_dict, os.path.join(output_dir, "checkpoint.safetensors"))
|
||||
|
||||
|
||||
def consolidate_checkpoint(input_dir: str, output_dir: str, fqn_to_index_mapping=None):
|
||||
"""
|
||||
Consolidate the sharded checkpoint files into full tensors.
|
||||
|
||||
Args:
|
||||
input_dir: Directory containing sharded checkpoint files
|
||||
output_dir: Directory to save the consolidated checkpoint
|
||||
fqn_to_index_mapping: Optional mapping of tensor names to output file indices
|
||||
"""
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Consolidate the checkpoint
|
||||
consolidate_safetensors_files(
|
||||
input_dir=input_dir,
|
||||
output_dir=output_dir,
|
||||
fqn_to_index_mapping=fqn_to_index_mapping,
|
||||
)
|
||||
|
||||
print(f"Consolidated checkpoint saved to {output_dir}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Split tensors into d-tensors, save with DCP HuggingFace, and consolidate"
|
||||
)
|
||||
"""
|
||||
parser.add_argument(
|
||||
"--model-path", type=str, default="", help="Path to the model to load"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharded-output-dir",
|
||||
type=str,
|
||||
default="./sharded_checkpoint",
|
||||
help="Directory to save the sharded checkpoint",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--consolidated-output-dir",
|
||||
type=str,
|
||||
default="./consolidated_checkpoint",
|
||||
help="Directory to save the consolidated checkpoint",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
"""
|
||||
|
||||
# Initialize distributed environment
|
||||
world_size, rank = init_distributed()
|
||||
print(f"Running with world_size={world_size}, rank={rank}")
|
||||
|
||||
model_path = "/home/ankitageorge/.cache/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/model-00001-of-00030.safetensors"
|
||||
base_dir = "/data/users/ankitageorge/testing/"
|
||||
sharded_output_dir = os.path.join(base_dir, "sharded_checkpoint")
|
||||
dcp_consolidated_output_dir = os.path.join(base_dir, "dcp_consolidated_checkpoint")
|
||||
consolidated_output_dir = os.path.join(base_dir, "consolidated_checkpoint")
|
||||
|
||||
# Initialize device mesh
|
||||
device_type = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device_mesh = init_device_mesh(device_type, (world_size,))
|
||||
|
||||
# Load and split tensors
|
||||
distributed_state_dict, state_dict = load_and_split_tensors(model_path, device_mesh)
|
||||
start_time = time.time()
|
||||
if rank == 0:
|
||||
save_safetensors(state_dict, consolidated_output_dir)
|
||||
print("time to save as safetensors ", time.time() - start_time)
|
||||
|
||||
# Save with DCP HuggingFace
|
||||
start_time = time.time()
|
||||
save_with_dcp_huggingface(distributed_state_dict, sharded_output_dir)
|
||||
|
||||
# Make sure all processes have finished saving
|
||||
dist.barrier()
|
||||
print("time to save with DCP HuggingFace ", time.time() - start_time)
|
||||
|
||||
# Only rank 0 needs to consolidate the checkpoint
|
||||
start_time = time.time()
|
||||
if rank == 0:
|
||||
consolidate_checkpoint(sharded_output_dir, dcp_consolidated_output_dir)
|
||||
|
||||
print("Time to consolidate checkpoint: ", time.time() - start_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Reference in New Issue
Block a user