mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 20:34:54 +08:00 
			
		
		
		
	Compare commits
	
		
			6 Commits
		
	
	
		
			ciflow/tru
			...
			benchmarki
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 0f49e915a9 | |||
| 2f1217f944 | |||
| e0bf01e87b | |||
| 3b5ae0e9fc | |||
| 5f5f654a3e | |||
| 21931cbbc6 | 
							
								
								
									
										135
									
								
								test/distributed/checkpoint/test_consolidate_hf_safetensors.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										135
									
								
								test/distributed/checkpoint/test_consolidate_hf_safetensors.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,135 @@ | ||||
| # Owner(s): ["oncall: distributed_checkpointing"] | ||||
|  | ||||
| import os | ||||
| import sys | ||||
|  | ||||
| import torch | ||||
| import torch.distributed.checkpoint as dist_cp | ||||
| from torch import distributed as dist | ||||
| from torch.distributed.checkpoint.scripts._consolidate_hf_safetensors import ( | ||||
|     consolidate_safetensors_files, | ||||
| ) | ||||
| from torch.distributed.device_mesh import init_device_mesh | ||||
| from torch.distributed.tensor import DTensor, Shard | ||||
| from torch.testing._internal.common_utils import run_tests | ||||
| from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     DTensorTestBase, | ||||
|     skip_if_lt_x_gpu, | ||||
|     with_comms, | ||||
| ) | ||||
| from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir | ||||
|  | ||||
|  | ||||
| class TestConsolidateHFSafeTensors(DTensorTestBase): | ||||
|     def _create_d_tensors(self) -> None: | ||||
|         global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|         mesh_shape = (self.world_size,) | ||||
|         mesh_1d = init_device_mesh(self.device_type, mesh_shape) | ||||
|  | ||||
|         # Create local tensor with row-wise sharding | ||||
|         rows_per_rank = global_tensor.shape[0] // self.world_size | ||||
|         start_row = self.rank * rows_per_rank | ||||
|         end_row = start_row + rows_per_rank | ||||
|         local_tensor = global_tensor[start_row:end_row].clone() | ||||
|  | ||||
|         # Create DTensor with row-wise sharding | ||||
|         dtensor = DTensor.from_local( | ||||
|             local_tensor, | ||||
|             device_mesh=mesh_1d, | ||||
|             placements=[Shard(0)], | ||||
|             shape=global_tensor.shape, | ||||
|             stride=(4, 1), | ||||
|         ) | ||||
|  | ||||
|         # Create local tensor with column-wise sharding | ||||
|         cols_per_rank = global_tensor.shape[1] // self.world_size | ||||
|         start_col = self.rank * cols_per_rank | ||||
|         end_col = start_col + cols_per_rank | ||||
|         local_tensor_col = global_tensor[:, start_col:end_col].clone() | ||||
|  | ||||
|         # Create DTensor with column-wise sharding | ||||
|         dtensor_col = DTensor.from_local( | ||||
|             local_tensor_col, | ||||
|             device_mesh=mesh_1d, | ||||
|             placements=[Shard(1)],  # Column-wise sharding | ||||
|             shape=global_tensor.shape, | ||||
|             stride=(4, 1), | ||||
|         ) | ||||
|  | ||||
|         state_dict_to_save = {"dtensor": dtensor, "dtensor_col": dtensor_col} | ||||
|         dist_cp.save( | ||||
|             state_dict=state_dict_to_save, | ||||
|             storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|                 path=self.temp_dir, save_sharded=True | ||||
|             ), | ||||
|         ) | ||||
|         dist.barrier() | ||||
|         os.sync() | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_consolidate_to_one_file(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         checkpoint_dir = self.temp_dir | ||||
|         output_dir = os.path.join(checkpoint_dir, "consolidated") | ||||
|         os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|         self._create_d_tensors() | ||||
|  | ||||
|         global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|  | ||||
|         if self.rank == 0: | ||||
|             consolidate_safetensors_files(checkpoint_dir, output_dir) | ||||
|  | ||||
|             file_path = os.path.join(output_dir, "model-00001-of-00001.safetensors") | ||||
|             loaded_dict = safetensors.torch.load_file(file_path) | ||||
|             self.assertEqual(loaded_dict.keys(), {"dtensor", "dtensor_col"}) | ||||
|             self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) | ||||
|             self.assertTrue(torch.equal(loaded_dict["dtensor_col"], global_tensor)) | ||||
|         dist.barrier() | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_consolidate_to_two_files(self): | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         checkpoint_dir = self.temp_dir | ||||
|         output_dir = os.path.join(checkpoint_dir, "consolidated") | ||||
|         os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|         self._create_d_tensors() | ||||
|  | ||||
|         global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|  | ||||
|         if self.rank == 0: | ||||
|             fqn_to_index_mapping = {"dtensor": 1, "dtensor_col": 2} | ||||
|             consolidate_safetensors_files( | ||||
|                 checkpoint_dir, output_dir, fqn_to_index_mapping | ||||
|             ) | ||||
|  | ||||
|             file1_path = os.path.join(output_dir, "model-00001-of-00002.safetensors") | ||||
|             file2_path = os.path.join(output_dir, "model-00002-of-00002.safetensors") | ||||
|  | ||||
|             loaded_dict = safetensors.torch.load_file(file1_path) | ||||
|             self.assertEqual(loaded_dict.keys(), {"dtensor"}) | ||||
|             self.assertTrue(torch.equal(loaded_dict["dtensor"], global_tensor)) | ||||
|  | ||||
|             loaded_dict_col = safetensors.torch.load_file(file2_path) | ||||
|             self.assertEqual(loaded_dict_col.keys(), {"dtensor_col"}) | ||||
|             self.assertTrue(torch.equal(loaded_dict_col["dtensor_col"], global_tensor)) | ||||
|         dist.barrier() | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
							
								
								
									
										420
									
								
								test/distributed/checkpoint/test_hf_safetensor_e2e.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										420
									
								
								test/distributed/checkpoint/test_hf_safetensor_e2e.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,420 @@ | ||||
| # Owner(s): ["oncall: distributed_checkpointing"] | ||||
|  | ||||
| import sys | ||||
|  | ||||
| import torch | ||||
| import torch.distributed.checkpoint as dist_cp | ||||
| from torch.distributed.checkpoint import _HuggingFaceLoadPlanner | ||||
| from torch.distributed.checkpoint.default_planner import _EmptyStateDictLoadPlanner | ||||
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict_from_keys | ||||
| from torch.distributed.device_mesh import init_device_mesh | ||||
| from torch.distributed.tensor import distribute_tensor, Replicate, Shard, zeros | ||||
| from torch.testing._internal.common_utils import ( | ||||
|     instantiate_parametrized_tests, | ||||
|     run_tests, | ||||
|     TestCase, | ||||
| ) | ||||
| from torch.testing._internal.distributed._tensor.common_dtensor import ( | ||||
|     DTensorTestBase, | ||||
|     skip_if_lt_x_gpu, | ||||
|     with_comms, | ||||
| ) | ||||
| from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir | ||||
|  | ||||
|  | ||||
| CHECKPOINT_DIR = "checkpoint" | ||||
|  | ||||
|  | ||||
| class MyTestModule(torch.nn.Module): | ||||
|     def __init__(self) -> None: | ||||
|         super().__init__() | ||||
|         self.linear_1 = torch.nn.Linear(5, 5) | ||||
|         self.linear_2 = torch.nn.Linear(5, 1) | ||||
|         self.emb = torch.nn.EmbeddingBag(5, 10) | ||||
|  | ||||
| class TestSingleRankSaveLoad(TestCase): | ||||
|     @with_temp_dir | ||||
|     def test_save(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import load_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         dist_cp.save( | ||||
|             state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         state_dict_loaded = load_file(CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key])) | ||||
|  | ||||
|     @with_temp_dir | ||||
|     def test_load(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import save_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         state_dict_to_load = MyTestModule().state_dict() | ||||
|         save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|  | ||||
|         dist_cp.load( | ||||
|             state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key])) | ||||
|  | ||||
|     @with_temp_dir | ||||
|     def test_load_into_empty_dict(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import save_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|  | ||||
|         state_dict_loaded = _load_state_dict_from_keys( | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_loaded.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_loaded[key])) | ||||
|  | ||||
|     @with_temp_dir | ||||
|     def test_load_allowing_resize(self) -> None: | ||||
|         try: | ||||
|             from safetensors.torch import save_file | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         state_dict_to_save = MyTestModule().state_dict() | ||||
|         save_file(state_dict_to_save, CHECKPOINT_DIR + "/model-00001-of-00001.safetensors") | ||||
|  | ||||
|         state_dict_to_load= {} | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             state_dict_to_load[key] = torch.zeros(1) | ||||
|  | ||||
|         dist_cp.load( | ||||
|             state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     path=CHECKPOINT_DIR | ||||
|                 ), | ||||
|                 planner=_HuggingFaceLoadPlanner(allow_tensor_resize=True), | ||||
|             ) | ||||
|  | ||||
|         self.assertEqual(sorted(state_dict_to_save.keys()), sorted(state_dict_to_load.keys())) | ||||
|         for key in state_dict_to_save.keys(): | ||||
|             self.assertTrue(torch.equal(state_dict_to_save[key], state_dict_to_load[key])) | ||||
|  | ||||
| ONE_D_PLACEMENTS = [ | ||||
|     [Shard(0)], | ||||
|     [Replicate()], | ||||
| ] | ||||
| ONE_D_TO_ONE_D_PLACEMENTS = [ | ||||
|     ([Replicate()], [Shard(0)]), | ||||
|     ([Shard(0)], [Replicate()]), | ||||
| ] | ||||
|  | ||||
| TWO_D_PLACEMENTS = [ | ||||
|     [Replicate(), Replicate()], | ||||
|     [Replicate(), Shard(0)], | ||||
|     [Shard(0), Replicate()], | ||||
|     [Shard(0), Shard(0)], | ||||
| ] | ||||
| TWO_D_TO_TWO_D_PLACEMENTS = [] | ||||
| for p1 in TWO_D_PLACEMENTS: | ||||
|     for p2 in TWO_D_PLACEMENTS: | ||||
|         if p1 != p2: | ||||
|             TWO_D_TO_TWO_D_PLACEMENTS.append((p1, p2)) | ||||
|  | ||||
|  | ||||
| @instantiate_parametrized_tests | ||||
| class TestDTensorReshardPlacementChange(DTensorTestBase): | ||||
|     """ | ||||
|     Test DCP reshard for DTensor with placements changes and without world_size change and mesh_tensor change. | ||||
|     """ | ||||
|  | ||||
|     @with_comms | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     @with_temp_dir | ||||
|     def test_1d_to_1d_reshard_placement_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|  | ||||
|         for one_d_to_one_d_placements in ONE_D_TO_ONE_D_PLACEMENTS: | ||||
|             original_placement, new_placement = one_d_to_one_d_placements | ||||
|  | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (self.world_size,) | ||||
|             device_mesh = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, device_mesh, placements=original_placement | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|                     path=CHECKPOINT_DIR, | ||||
|                     save_sharded=True, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|             zero_dtensor = zeros( | ||||
|                 [4, 4], device_mesh=device_mesh, placements=new_placement | ||||
|             ) | ||||
|             state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|             dist_cp.load( | ||||
|                 state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader( | ||||
|                     CHECKPOINT_DIR, | ||||
|                 ), | ||||
|             ) | ||||
|  | ||||
|             # materialize the whole tensor to compare with the original global_tensor | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 device_mesh, | ||||
|                 placements=[Replicate()], | ||||
|             ) | ||||
|             self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) | ||||
|  | ||||
|             # redistribute the tensor back to its original placement for comparison. | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 device_mesh, | ||||
|                 placements=original_placement, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 state_dict_to_save["dtensor"].to_local(), | ||||
|                 state_dict_to_load["dtensor"].to_local(), | ||||
|             ) | ||||
|  | ||||
|     @with_comms | ||||
|     @skip_if_lt_x_gpu(4) | ||||
|     @with_temp_dir | ||||
|     def test_2d_to_2d_reshard_placement_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|         for two_d_to_two_d_placements in TWO_D_TO_TWO_D_PLACEMENTS: | ||||
|             original_placement, new_placement = two_d_to_two_d_placements | ||||
|  | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (2, self.world_size // 2) | ||||
|             mesh_2d = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, | ||||
|                 mesh_2d, | ||||
|                 placements=original_placement, | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True), | ||||
|                 planner=dist_cp.DefaultSavePlanner(), | ||||
|             ) | ||||
|  | ||||
|             zero_dtensor = zeros([4, 4], device_mesh=mesh_2d, placements=new_placement) | ||||
|             state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|             dist_cp.load( | ||||
|                 state_dict=state_dict_to_load, | ||||
|                 storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR), | ||||
|             ) | ||||
|  | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 mesh_2d, | ||||
|                 placements=[Replicate(), Replicate()], | ||||
|             ) | ||||
|             self.assertEqual(global_tensor, state_dict_to_load["dtensor"].to_local()) | ||||
|  | ||||
|             state_dict_to_load["dtensor"] = state_dict_to_load["dtensor"].redistribute( | ||||
|                 mesh_2d, | ||||
|                 placements=original_placement, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 state_dict_to_save["dtensor"].to_local(), | ||||
|                 state_dict_to_load["dtensor"].to_local(), | ||||
|             ) | ||||
|  | ||||
|  | ||||
| class TestDTensorReshardMeshChange(DTensorTestBase): | ||||
|     """ | ||||
|     Test DCP reshard for DTensor with placements changes and mesh_tensor change. | ||||
|     """ | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_1d_to_2d_reshard_mesh_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|         for placements_1d in ONE_D_PLACEMENTS: | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (self.world_size,) | ||||
|             mesh_1d = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, mesh_1d, placements=placements_1d | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True), | ||||
|             ) | ||||
|  | ||||
|             for placements_2d in TWO_D_PLACEMENTS: | ||||
|                 mesh_shape = (2, self.world_size // 2) | ||||
|                 mesh_2d = init_device_mesh(self.device_type, mesh_shape) | ||||
|  | ||||
|                 zero_dtensor = zeros( | ||||
|                     [4, 4], device_mesh=mesh_2d, placements=placements_2d | ||||
|                 ) | ||||
|                 state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|                 dist_cp.load( | ||||
|                     state_dict=state_dict_to_load, | ||||
|                     storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR), | ||||
|                     planner=dist_cp.DefaultLoadPlanner(), | ||||
|                 ) | ||||
|  | ||||
|                 # materialzie the whole tensor to compare with the original global_tensor | ||||
|                 state_dict_to_load["dtensor"] = state_dict_to_load[ | ||||
|                     "dtensor" | ||||
|                 ].redistribute( | ||||
|                     mesh_2d, | ||||
|                     placements=[Replicate(), Replicate()], | ||||
|                 ) | ||||
|                 self.assertEqual( | ||||
|                     global_tensor, state_dict_to_load["dtensor"].to_local() | ||||
|                 ) | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(4) | ||||
|     def test_2d_to_1d_reshard_mesh_change(self) -> None: | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         CHECKPOINT_DIR = self.temp_dir | ||||
|         for placements_2d in TWO_D_PLACEMENTS: | ||||
|             global_tensor = torch.arange(16, dtype=torch.float).view(4, 4) | ||||
|             mesh_shape = (2, self.world_size // 2) | ||||
|             mesh_2d = init_device_mesh(self.device_type, mesh_shape) | ||||
|             dtensor = distribute_tensor( | ||||
|                 global_tensor, mesh_2d, placements=placements_2d | ||||
|             ) | ||||
|             state_dict_to_save = {"dtensor": dtensor} | ||||
|  | ||||
|             dist_cp.save( | ||||
|                 state_dict=state_dict_to_save, | ||||
|                 storage_writer=dist_cp._HuggingFaceStorageWriter(path=CHECKPOINT_DIR, save_sharded=True), | ||||
|                 planner=dist_cp.DefaultSavePlanner(), | ||||
|             ) | ||||
|  | ||||
|             for placements_1d in ONE_D_PLACEMENTS: | ||||
|                 mesh_shape = (self.world_size,) | ||||
|                 mesh_1d = init_device_mesh(self.device_type, mesh_shape) | ||||
|  | ||||
|                 zero_dtensor = zeros( | ||||
|                     [4, 4], device_mesh=mesh_1d, placements=placements_1d | ||||
|                 ) | ||||
|                 state_dict_to_load = {"dtensor": zero_dtensor} | ||||
|  | ||||
|                 dist_cp.load( | ||||
|                     state_dict=state_dict_to_load, | ||||
|                     storage_reader=dist_cp._HuggingFaceStorageReader(CHECKPOINT_DIR), | ||||
|                     planner=dist_cp.DefaultLoadPlanner(), | ||||
|                 ) | ||||
|  | ||||
|                 # materialzie the whole tensor to compare with the original global_tensor | ||||
|                 state_dict_to_load["dtensor"] = state_dict_to_load[ | ||||
|                     "dtensor" | ||||
|                 ].redistribute( | ||||
|                     mesh_1d, | ||||
|                     placements=[Replicate()], | ||||
|                 ) | ||||
|                 self.assertEqual( | ||||
|                     global_tensor, state_dict_to_load["dtensor"].to_local() | ||||
|                 ) | ||||
|  | ||||
|     @with_comms | ||||
|     @with_temp_dir | ||||
|     @skip_if_lt_x_gpu(2) | ||||
|     def test_dtensor_checkpoint_resharding_with_empty_shard(self): | ||||
|         """ | ||||
|         Test dtensor checkpoint resharding with dtensor containing empty shards. | ||||
|         """ | ||||
|         try: | ||||
|             import safetensors | ||||
|         except ImportError: | ||||
|             print("safetensors not installed") | ||||
|             sys.exit(0) | ||||
|  | ||||
|         tensor = torch.rand(1).cuda() | ||||
|         mesh = init_device_mesh(self.device_type, (self.world_size,)) | ||||
|         dtensor = distribute_tensor(tensor, mesh, [Shard(0)]) | ||||
|         ref_state_dict = {"dtensor": dtensor} | ||||
|  | ||||
|         dist_cp.save( | ||||
|             state_dict=ref_state_dict, | ||||
|             storage_writer=dist_cp._HuggingFaceStorageWriter(path=self.temp_dir, save_sharded=True), | ||||
|         ) | ||||
|  | ||||
|         tensor = torch.rand(1).cuda() | ||||
|         mesh_2 = init_device_mesh(self.device_type, (2, self.world_size // 2)) | ||||
|         dtensor = distribute_tensor(tensor, mesh_2, [Shard(0), Shard(0)]) | ||||
|         state_dict = {"dtensor": dtensor} | ||||
|         dist_cp.load( | ||||
|             state_dict=state_dict, | ||||
|             storage_reader=dist_cp._HuggingFaceStorageReader(self.temp_dir), | ||||
|         ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     run_tests() | ||||
| @ -8,10 +8,7 @@ import tempfile | ||||
| from unittest.mock import MagicMock | ||||
|  | ||||
| import torch | ||||
| from torch.distributed.checkpoint._hf_planner import ( | ||||
|     _FqnToFileMapping, | ||||
|     _HuggingFaceLoadPlanner, | ||||
| ) | ||||
| from torch.distributed.checkpoint import DefaultLoadPlanner | ||||
| from torch.distributed.checkpoint._hf_storage import ( | ||||
|     _HuggingFaceStorageReader, | ||||
|     _HuggingFaceStorageWriter, | ||||
| @ -21,24 +18,25 @@ from torch.distributed.checkpoint.default_planner import DefaultSavePlanner | ||||
| from torch.distributed.checkpoint.filesystem import _StorageInfo, FileSystem | ||||
| from torch.distributed.checkpoint.metadata import ( | ||||
|     BytesStorageMetadata, | ||||
|     ChunkStorageMetadata, | ||||
|     Metadata, | ||||
|     MetadataIndex, | ||||
|     TensorProperties, | ||||
|     TensorStorageMetadata, | ||||
| ) | ||||
| from torch.distributed.checkpoint.planner import LoadPlan, SavePlan | ||||
| from torch.distributed.checkpoint.planner_helpers import ( | ||||
|     _create_read_items, | ||||
|     _create_write_item_for_tensor, | ||||
| from torch.distributed.checkpoint.planner import ( | ||||
|     LoadItemType, | ||||
|     LoadPlan, | ||||
|     ReadItem, | ||||
|     SavePlan, | ||||
| ) | ||||
| from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor | ||||
| from torch.distributed.checkpoint.storage import WriteResult | ||||
| from torch.testing._internal.common_utils import run_tests, TestCase | ||||
|  | ||||
|  | ||||
| class TestHfStorage(TestCase): | ||||
|     def test_write_data_hf(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["safetensors"] = mock_module | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|  | ||||
|         mock_module = MagicMock() | ||||
|         mock_module.save.return_value = b"" | ||||
|         sys.modules["safetensors.torch"] = mock_module | ||||
| @ -46,7 +44,7 @@ class TestHfStorage(TestCase): | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             writer = _HuggingFaceStorageWriter( | ||||
|                 path=path, | ||||
|                 fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 1}, | ||||
|                 fqn_to_index_mapping={"tensor_0": 1, "tensor_1": 2}, | ||||
|             ) | ||||
|             writer.fs = FileSystem() | ||||
|  | ||||
| @ -59,7 +57,7 @@ class TestHfStorage(TestCase): | ||||
|  | ||||
|             save_plan = SavePlan( | ||||
|                 [write_item_1, write_item_2], | ||||
|                 storage_data=_FqnToFileMapping({"tensor_0": 1, "tensor_1": 1}), | ||||
|                 storage_data={"fqn_to_file_mapping": {"tensor_0": 1, "tensor_1": 2}}, | ||||
|             ) | ||||
|             save_planner = DefaultSavePlanner() | ||||
|             save_planner.set_up_planner(state_dict=state_dict) | ||||
| @ -76,7 +74,7 @@ class TestHfStorage(TestCase): | ||||
|                     ), | ||||
|                     size_in_bytes=tensor0.numel() * tensor0.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="model-00001-of-00001.safetensors", | ||||
|                         relative_path="model-00001-of-00002.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor0.numel() * tensor0.element_size(), | ||||
|                     ), | ||||
| @ -87,7 +85,68 @@ class TestHfStorage(TestCase): | ||||
|                     ), | ||||
|                     size_in_bytes=tensor1.numel() * tensor1.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="model-00001-of-00001.safetensors", | ||||
|                         relative_path="model-00002-of-00002.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor1.numel() * tensor1.element_size(), | ||||
|                     ), | ||||
|                 ), | ||||
|             ] | ||||
|  | ||||
|             self.assertEqual( | ||||
|                 actual_write_results, | ||||
|                 expected_write_results, | ||||
|             ) | ||||
|  | ||||
|     def test_write_data_with_sharding(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         mock_module.save.return_value = b"" | ||||
|         sys.modules["safetensors.torch"] = mock_module | ||||
|  | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             writer = _HuggingFaceStorageWriter( | ||||
|                 path=path, | ||||
|                 save_sharded=True, | ||||
|             ) | ||||
|             writer.fs = FileSystem() | ||||
|  | ||||
|             tensor0 = torch.rand(4) | ||||
|             tensor1 = torch.rand(10) | ||||
|             write_item_1 = _create_write_item_for_tensor("tensor_0", tensor0) | ||||
|             write_item_2 = _create_write_item_for_tensor("tensor_1", tensor1) | ||||
|  | ||||
|             state_dict = {"tensor_0": tensor0, "tensor_1": tensor1} | ||||
|  | ||||
|             save_plan = SavePlan( | ||||
|                 [write_item_1, write_item_2], | ||||
|                 storage_data={"shard_index": 1}, | ||||
|             ) | ||||
|             save_planner = DefaultSavePlanner() | ||||
|             save_planner.set_up_planner(state_dict=state_dict) | ||||
|  | ||||
|             write_results = writer.write_data(save_plan, save_planner) | ||||
|  | ||||
|             write_results.wait() | ||||
|             actual_write_results = write_results.value() | ||||
|  | ||||
|             expected_write_results = [ | ||||
|                 WriteResult( | ||||
|                     index=MetadataIndex( | ||||
|                         fqn="tensor_0", offset=torch.Size([0]), index=None | ||||
|                     ), | ||||
|                     size_in_bytes=tensor0.numel() * tensor0.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="shard-00001-model-00001-of-00001.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor0.numel() * tensor0.element_size(), | ||||
|                     ), | ||||
|                 ), | ||||
|                 WriteResult( | ||||
|                     index=MetadataIndex( | ||||
|                         fqn="tensor_1", offset=torch.Size([0]), index=None | ||||
|                     ), | ||||
|                     size_in_bytes=tensor1.numel() * tensor1.element_size(), | ||||
|                     storage_data=_StorageInfo( | ||||
|                         relative_path="shard-00001-model-00001-of-00001.safetensors", | ||||
|                         offset=0, | ||||
|                         length=tensor1.numel() * tensor1.element_size(), | ||||
|                     ), | ||||
| @ -100,43 +159,84 @@ class TestHfStorage(TestCase): | ||||
|             ) | ||||
|  | ||||
|     def test_read_data_hf(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["safetensors"] = mock_module | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|         mock_safetensors = MagicMock() | ||||
|         sys.modules["safetensors"] = mock_safetensors | ||||
|  | ||||
|         name = "tensor_0" | ||||
|         tensor_0 = torch.rand(4) | ||||
|         mock_module = MagicMock() | ||||
|         mock_module.load.return_value = {name: tensor_0} | ||||
|         sys.modules["safetensors.torch"] = mock_module | ||||
|         # Create test tensors | ||||
|         tensor_0 = torch.tensor([1.0, 2.0, 3.0, 4.0]) | ||||
|  | ||||
|         # Mock the deserialize function to return our test tensors | ||||
|         # The format matches what's expected in the read_data method | ||||
|         mock_safetensors.deserialize.return_value = [ | ||||
|             ("tensor_0", { | ||||
|                 "data": tensor_0.numpy().tobytes(), | ||||
|                 "dtype": "F32", | ||||
|                 "shape": [4] | ||||
|             }), | ||||
|         ] | ||||
|  | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             # Create the reader | ||||
|             reader = _HuggingFaceStorageReader(path=path) | ||||
|             reader.fs = FileSystem() | ||||
|             file_name = "model-00001-of-00001" | ||||
|  | ||||
|             pathlib.Path(os.path.join(path, file_name)).touch() | ||||
|             # Create test file | ||||
|             file_name = "model-00001-of-00001.safetensors" | ||||
|             file_path = os.path.join(path, file_name) | ||||
|             pathlib.Path(file_path).touch() | ||||
|  | ||||
|             reader.set_up_storage_reader( | ||||
|                 Metadata( | ||||
|                     state_dict_metadata={name: BytesStorageMetadata()}, | ||||
|                     storage_data={name: file_name}, | ||||
|                 ), | ||||
|                 is_coordinator=True, | ||||
|             ) | ||||
|             # Set up storage data with _StorageInfo objects | ||||
|             storage_data = { | ||||
|                 "tensor_0": _StorageInfo(file_path, 0, tensor_0.numel() * tensor_0.element_size()), | ||||
|             } | ||||
|  | ||||
|             read_items = _create_read_items(name, BytesStorageMetadata(), file_name) | ||||
|  | ||||
|             reader.storage_data = storage_data | ||||
|  | ||||
|             # Create target tensors that will be updated by read_data | ||||
|             target_tensor_0 = torch.zeros(4) | ||||
|             state_dict = { | ||||
|                 "tensor_0": target_tensor_0, | ||||
|             } | ||||
|  | ||||
|             # Create read items for the load plan | ||||
|             read_items = [] | ||||
|             for name, tensor in state_dict.items(): | ||||
|                 storage_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None) | ||||
|                 dest_index = MetadataIndex(fqn=name, offset=torch.Size([0]), index=None) | ||||
|                 read_items.append( | ||||
|                     ReadItem( | ||||
|                         type=LoadItemType.TENSOR, | ||||
|                         storage_index=storage_index, | ||||
|                         dest_index=dest_index, | ||||
|                         storage_offsets=[0, 0], | ||||
|                         dest_offsets=[0, 0], | ||||
|                         lengths=tensor.size(), | ||||
|                     ) | ||||
|                 ) | ||||
|  | ||||
|             # Create load plan and planner | ||||
|             load_plan = LoadPlan(read_items) | ||||
|             load_planner = _HuggingFaceLoadPlanner() | ||||
|             load_planner.set_up_planner(state_dict={name: torch.rand(4)}) | ||||
|             load_planner = DefaultLoadPlanner() | ||||
|             load_planner.set_up_planner( | ||||
|                 state_dict=state_dict, | ||||
|                  metadata=Metadata( | ||||
|                             state_dict_metadata={ | ||||
|                                 "tensor_0": TensorStorageMetadata( | ||||
|                                             properties=TensorProperties(dtype=torch.float32), | ||||
|                                             size=torch.Size([4]), | ||||
|                                             chunks=[ChunkStorageMetadata(offsets=[0], sizes=torch.Size([4]))])}, | ||||
|                                  storage_data=storage_data) | ||||
|                             ) | ||||
|  | ||||
|             read_data = reader.read_data(load_plan, load_planner) | ||||
|             read_data.wait() | ||||
|             # Call read_data | ||||
|             future = reader.read_data(load_plan, load_planner) | ||||
|             future.wait() | ||||
|  | ||||
|             loaded_tensor = load_planner.original_state_dict[name] | ||||
|             self.assertEqual(loaded_tensor, tensor_0) | ||||
|             # Verify results - the target tensors should now contain the values from our test tensor | ||||
|             self.assertTrue(torch.equal(state_dict["tensor_0"], tensor_0)) | ||||
|  | ||||
|     def test_metadata_hf(self) -> None: | ||||
|     def test_write_metadata_hf(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
| @ -160,7 +260,6 @@ class TestHfStorage(TestCase): | ||||
|  | ||||
|             writer = _HuggingFaceStorageWriter( | ||||
|                 path=path, | ||||
|                 fqn_to_index_mapping=_FqnToFileMapping({}), | ||||
|             ) | ||||
|             writer.fs = FileSystem() | ||||
|             writer.finish( | ||||
| @ -185,26 +284,16 @@ class TestHfStorage(TestCase): | ||||
|                 metadata = json.load(f) | ||||
|                 self.assertEqual(metadata, expected_metadata) | ||||
|  | ||||
|             reader = _HuggingFaceStorageReader(path=path) | ||||
|             reader.fs = FileSystem() | ||||
|             metadata = reader.read_metadata() | ||||
|             self.assertEqual(metadata.storage_data, expected_metadata["weight_map"]) | ||||
|  | ||||
|     def test_read_metadata_when_metadata_file_does_not_exist(self) -> None: | ||||
|         mock_module = MagicMock() | ||||
|         sys.modules["huggingface_hub"] = mock_module | ||||
|  | ||||
|     def test_read_metadata_hf(self): | ||||
|         with tempfile.TemporaryDirectory() as path: | ||||
|             reader = _HuggingFaceStorageReader(path=path) | ||||
|             reader.fs = FileSystem() | ||||
|             # there is one safetensor file, but no metadata file, | ||||
|             # so we create metadata from the safetensor file | ||||
|             keys = ["tensor_0", "tensor_1"] | ||||
|  | ||||
|             key = "tensor_0" | ||||
|             file_name = "test.safetensors" | ||||
|             with open(os.path.join(path, file_name), "wb") as f: | ||||
|                 # write metadata the same way it would be in safetensors file | ||||
|                 metadata_contents = json.dumps( | ||||
|                     {"tensor_0": "value_0", "tensor_1": "value_1"} | ||||
|                     {'tensor_0': {'dtype': "F32", "shape": [5, 10], "data_offsets": [0, 200]}} | ||||
|                 ) | ||||
|                 metadata_bytes = metadata_contents.encode("utf-8") | ||||
|  | ||||
| @ -216,13 +305,16 @@ class TestHfStorage(TestCase): | ||||
|             self.assertEqual( | ||||
|                 metadata.state_dict_metadata, | ||||
|                 { | ||||
|                     keys[0]: BytesStorageMetadata(), | ||||
|                     keys[1]: BytesStorageMetadata(), | ||||
|                     key: TensorStorageMetadata( | ||||
|                             properties=TensorProperties(dtype=torch.float32), | ||||
|                             size=torch.Size([5, 10]), | ||||
|                             chunks=[ChunkStorageMetadata(offsets=[0, 0], sizes=torch.Size([5, 10]))], | ||||
|                     ), | ||||
|                 }, | ||||
|             ) | ||||
|             self.assertEqual( | ||||
|                 metadata.storage_data, | ||||
|                 {keys[0]: file_name, keys[1]: file_name}, | ||||
|                 {key: _StorageInfo(os.path.join(path, file_name), 0, 200, transform_descriptors=None)}, | ||||
|             ) | ||||
|  | ||||
|  | ||||
|  | ||||
| @ -12,6 +12,7 @@ from .metadata import ( | ||||
| ) | ||||
| from .optimizer import load_sharded_optimizer_state_dict | ||||
| from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem | ||||
| from .scripts._consolidate_hf_safetensors import consolidate_safetensors_files | ||||
| from .state_dict_loader import load, load_state_dict | ||||
| from .state_dict_saver import async_save, save, save_state_dict | ||||
| from .storage import StorageReader, StorageWriter | ||||
|  | ||||
| @ -62,25 +62,3 @@ def dedup_save_plans( | ||||
|         ) | ||||
|         for plan, item_indexes in zip(all_plans, plan_to_item_indices) | ||||
|     ] | ||||
|  | ||||
|  | ||||
| def dedup_save_plans_with_fqn_to_index_mapping( | ||||
|     all_plans: list[SavePlan], fqn_to_index_mapping: dict[str, int] | ||||
| ) -> list[SavePlan]: | ||||
|     num_plans = len(all_plans) | ||||
|  | ||||
|     to_remove: list[set] = [set() for _ in range(len(all_plans))] | ||||
|     for plan_idx, plan in enumerate(all_plans): | ||||
|         for item_idx, item in enumerate(plan.items): | ||||
|             if (fqn_to_index_mapping[item.index.fqn] - 1) % num_plans != plan_idx: | ||||
|                 to_remove[plan_idx].add(item_idx) | ||||
|  | ||||
|     for plan_idx, remove_set in enumerate(to_remove): | ||||
|         new_items = [ | ||||
|             write_item | ||||
|             for item_idx, write_item in enumerate(all_plans[plan_idx].items) | ||||
|             if item_idx not in remove_set | ||||
|         ] | ||||
|         all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items) | ||||
|  | ||||
|     return all_plans | ||||
|  | ||||
| @ -1,49 +1,34 @@ | ||||
| # mypy: allow-untyped-defs | ||||
| from dataclasses import dataclass | ||||
|  | ||||
| from torch.distributed.checkpoint._dedup_save_plans import ( | ||||
|     dedup_save_plans_with_fqn_to_index_mapping, | ||||
| ) | ||||
| from torch.distributed.checkpoint.default_planner import ( | ||||
|     create_default_local_load_plan, | ||||
|     DefaultLoadPlanner, | ||||
|     DefaultSavePlanner, | ||||
| ) | ||||
| from torch.distributed.checkpoint.planner import ReadItem, SavePlan | ||||
| from torch.distributed.checkpoint.planner import LoadPlan | ||||
|  | ||||
|  | ||||
| __all__ = ["_HuggingFaceSavePlanner", "_HuggingFaceLoadPlanner"] | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class _FqnToFileMapping: | ||||
|     fqn_to_file_index_mapping: dict[str, int] | ||||
|  | ||||
|  | ||||
| class _HuggingFaceSavePlanner(DefaultSavePlanner): | ||||
|     """ | ||||
|     A save planner that dedups the save plans based on the fqn to file index mapping. | ||||
|     A planner to work with HuggingFace's safetensors format. | ||||
|     This is a placeholder, as it is likely that the DefaultSavePlanner is enough. | ||||
|     """ | ||||
|  | ||||
|     def _dedup_save_plans(self, all_plans: list[SavePlan]) -> list[SavePlan]: | ||||
|         assert len(all_plans) > 0, "all_plans should not be empty" | ||||
|         assert all_plans[0].storage_data is not None, "storage_data should not be None" | ||||
|         assert isinstance(all_plans[0].storage_data, _FqnToFileMapping), ( | ||||
|             "storage_data should be of type _FqnToFileMapping" | ||||
|         ) | ||||
|  | ||||
|         fqn_to_index_mapping: dict[str, int] = all_plans[ | ||||
|             0 | ||||
|         ].storage_data.fqn_to_file_index_mapping | ||||
|  | ||||
|         return dedup_save_plans_with_fqn_to_index_mapping( | ||||
|             all_plans, fqn_to_index_mapping | ||||
|         ) | ||||
|  | ||||
|  | ||||
| class _HuggingFaceLoadPlanner(DefaultLoadPlanner): | ||||
|     def __init__(self, allow_tensor_resize: bool = False): | ||||
|         super().__init__() | ||||
|         self.allow_tensor_resize = allow_tensor_resize | ||||
|  | ||||
|     def resolve_tensor(self, read_item: ReadItem): | ||||
|         return self.lookup_tensor(read_item.dest_index) | ||||
|     def create_local_plan(self) -> LoadPlan: | ||||
|         assert self.metadata is not None | ||||
|  | ||||
|         # check_md_size is added to avoid the check if we're allowing tensor resize. | ||||
|         # This will be deprecated in favor of _load_state_dict_from_keys and then we | ||||
|         # can remove this planner all together. | ||||
|         return create_default_local_load_plan( | ||||
|             self.state_dict, self.metadata, not self.allow_partial_load, check_md_size=not self.allow_tensor_resize, | ||||
|         ) | ||||
|  | ||||
| @ -2,24 +2,24 @@ | ||||
| import dataclasses | ||||
| import io | ||||
| import json | ||||
| import os | ||||
| import queue | ||||
| import struct | ||||
| from typing import Optional | ||||
| from typing import Any, Optional | ||||
|  | ||||
| import fsspec  # type: ignore[import-untyped] | ||||
| import torch | ||||
| from torch.distributed._shard._utils import narrow_tensor_by_index | ||||
|  | ||||
| from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter | ||||
| from torch.distributed.checkpoint._hf_planner import ( | ||||
|     _FqnToFileMapping, | ||||
|     _HuggingFaceLoadPlanner, | ||||
| ) | ||||
| from torch.distributed.checkpoint.filesystem import SerializationFormat | ||||
| from torch.distributed.checkpoint._hf_planner import _HuggingFaceLoadPlanner | ||||
| from torch.distributed.checkpoint.filesystem import _StorageInfo, SerializationFormat | ||||
| from torch.distributed.checkpoint.metadata import ( | ||||
|     BytesStorageMetadata, | ||||
|     ChunkStorageMetadata, | ||||
|     Metadata, | ||||
|     MetadataIndex, | ||||
|     STORAGE_TYPES, | ||||
|     StorageMeta, | ||||
|     TensorProperties, | ||||
|     TensorStorageMetadata, | ||||
| ) | ||||
| from torch.distributed.checkpoint.planner import ( | ||||
|     LoadPlan, | ||||
| @ -37,9 +37,30 @@ __all__ = ["_HuggingFaceStorageWriter", "_HuggingFaceStorageReader"] | ||||
|  | ||||
| _metadata_fn: str = "model.safetensors.index.json" | ||||
|  | ||||
| FILE_NAME = "model-{cpt_idx}-of-{num_shards}" | ||||
| FILE_NAME = "model-{cpt_idx}-of-{num_files}" | ||||
| SHARDED_FILE_NAME = "shard-{shard_idx}-model-{cpt_idx}-of-{num_files}" | ||||
| SUFFIX = ".safetensors" | ||||
|  | ||||
| # metadata keys | ||||
| CUSTOM_METADATA_KEY = "DCP_SHARDING_INFO" | ||||
| DEFAULT_EXTRA_METADATA_KEY = "__metadata__" | ||||
| SAVED_OFFSETS_KEY = "saved_offsets" | ||||
| SHAPE_KEY = "shape" | ||||
| DATA_KEY = "data" | ||||
| DTYPE_KEY = "dtype" | ||||
| DATA_OFFSETS_KEY = "data_offsets" | ||||
|  | ||||
| DTYPE_MAP = { | ||||
|     "F16": torch.float16, | ||||
|     "F32": torch.float32, | ||||
|     "F64": torch.float64, | ||||
|     "I8": torch.int8, | ||||
|     "U8": torch.uint8, | ||||
|     "I16": torch.int16, | ||||
|     "I32": torch.int32, | ||||
|     "I64": torch.int64, | ||||
|     "BF16": torch.bfloat16, | ||||
| } | ||||
|  | ||||
| class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|     """ | ||||
| @ -50,23 +71,25 @@ class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|     def __init__( | ||||
|         self, | ||||
|         path: str, | ||||
|         fqn_to_index_mapping: dict[str, int], | ||||
|         fqn_to_index_mapping: Optional[dict[str, int]] = None, | ||||
|         token: Optional[str] = None, | ||||
|         save_sharded: bool = False, | ||||
|     ) -> None: | ||||
|         """ | ||||
|         Initialize the huggingface writer pointing to path. | ||||
|  | ||||
|         Args: | ||||
|             path: hf directory where the checkpoint will be written to. Should begin with hf://. | ||||
|             token: The token to use to authenticate with huggingface hub. | ||||
|             path: hf directory where the checkpoint will be read from.  | ||||
|             Needs to have .safetensors files, but can be from any fsspec supported storage,  | ||||
|             including localFS and hf://. | ||||
|             fqn_to_index_mapping: A mapping from tensor FQN to the index of the file that the tensor should be written to. | ||||
|                               Indices are from 1 to N, where N is the number of files. | ||||
|                               Indices are from 1 to N, where N is the number of files. If not provided, | ||||
|                               the tensors will be written to a single file. | ||||
|             token: The token to use to authenticate with huggingface hub. | ||||
|             save_sharded: If True, save the checkpoint as a sharded checkpoint where every rank saves its own shard. | ||||
|                         Default is False which assumes full tensors are being saved. | ||||
|  | ||||
|         """ | ||||
|         from huggingface_hub import HfFileSystem  # type: ignore[import-not-found] | ||||
|  | ||||
|         if HfFileSystem.protocol not in fsspec.available_protocols(): | ||||
|             fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) | ||||
|  | ||||
|         if token is not None: | ||||
|             super().__init__( | ||||
| @ -79,16 +102,21 @@ class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|                 path=path, | ||||
|                 serialization_format=SerializationFormat.SAFETENSORS, | ||||
|             ) | ||||
|         self._fqn_to_index_mapping: dict[str, int] = fqn_to_index_mapping | ||||
|  | ||||
|     def prepare_local_plan(self, plan: SavePlan) -> SavePlan: | ||||
|         plan = super().prepare_local_plan(plan) | ||||
|         return dataclasses.replace( | ||||
|             plan, storage_data=_FqnToFileMapping(self._fqn_to_index_mapping) | ||||
|         ) | ||||
|         self._fqn_to_index_mapping: Optional[dict[str, int]] = fqn_to_index_mapping | ||||
|         self._save_sharded = save_sharded | ||||
|  | ||||
|     def prepare_global_plan(self, plans: list[SavePlan]) -> list[SavePlan]: | ||||
|         return plans | ||||
|         new_plans = [] | ||||
|         for i, plan in enumerate(plans, start=1): | ||||
|             storage_data = {} | ||||
|             if self._fqn_to_index_mapping is not None: | ||||
|                 storage_data["fqn_to_file_mapping"] = self._fqn_to_index_mapping | ||||
|             if self._save_sharded: | ||||
|                 storage_data["shard_index"] = i | ||||
|  | ||||
|             new_plans.append(dataclasses.replace(plan, storage_data=storage_data)) | ||||
|  | ||||
|         return new_plans | ||||
|  | ||||
|     def write_data( | ||||
|         self, | ||||
| @ -101,14 +129,20 @@ class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|             return fut | ||||
|  | ||||
|         # storage_plan is a map from key to file index | ||||
|         storage_plan: dict[str, int] = plan.storage_data.fqn_to_file_index_mapping | ||||
|         storage_data: dict[str, Any] = plan.storage_data | ||||
|         storage_plan: Optional[dict[str, int]] = None | ||||
|         shard_index: Optional[int] = None | ||||
|         if "fqn_to_file_mapping" in storage_data: | ||||
|             storage_plan = storage_data["fqn_to_file_mapping"] | ||||
|         if "shard_index" in storage_data: | ||||
|             shard_index = storage_data["shard_index"] | ||||
|  | ||||
|         buckets = self._split_by_storage_plan(storage_plan, plan.items) | ||||
|         highest_index = max(storage_plan.values()) | ||||
|         highest_index = max(storage_plan.values()) if storage_plan is not None else 1 | ||||
|  | ||||
|         file_queue: queue.Queue = queue.Queue() | ||||
|         for file_index, write_items in buckets.items(): | ||||
|             file_name = self._gen_file_name(file_index, highest_index) | ||||
|             file_name = _gen_file_name(file_index, highest_index, shard_index) | ||||
|             file_queue.put( | ||||
|                 (self.fs.concat_path(self.path, file_name), file_name, write_items) | ||||
|             ) | ||||
| @ -116,6 +150,9 @@ class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|         return super()._write_data(planner, file_queue) | ||||
|  | ||||
|     def finish(self, metadata: Metadata, results: list[list[WriteResult]]) -> None: | ||||
|         if self._save_sharded: | ||||
|             return | ||||
|  | ||||
|         metadata_to_write = {} | ||||
|         storage_md = {} | ||||
|         total_size = 0 | ||||
| @ -132,9 +169,12 @@ class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|             json.dump(metadata_to_write, metadata_file, indent=2) | ||||
|  | ||||
|     def _split_by_storage_plan( | ||||
|         self, storage_plan: dict[str, int], items: list[WriteItem] | ||||
|         self, storage_plan: Optional[dict[str, int]], items: list[WriteItem] | ||||
|     ) -> dict[int, list[WriteItem]]: | ||||
|         # storage_plan is a map from key to index | ||||
|         if storage_plan is None: | ||||
|             return {1: items} | ||||
|  | ||||
|         buckets = {} | ||||
|         for item in items: | ||||
|             key = item.index.fqn | ||||
| @ -146,14 +186,6 @@ class _HuggingFaceStorageWriter(FsspecWriter): | ||||
|  | ||||
|         return buckets | ||||
|  | ||||
|     def _gen_file_name(self, index: int, largest_index: int) -> str: | ||||
|         return ( | ||||
|             FILE_NAME.format( | ||||
|                 cpt_idx=f"{index}".zfill(5), num_shards=f"{largest_index}".zfill(5) | ||||
|             ) | ||||
|             + SUFFIX | ||||
|         ) | ||||
|  | ||||
|     @property | ||||
|     def metadata_path(self) -> str: | ||||
|         return _metadata_fn | ||||
| @ -170,48 +202,61 @@ class _HuggingFaceStorageReader(FsspecReader): | ||||
|         Initialize the huggingface reader pointing to path. | ||||
|  | ||||
|         Args: | ||||
|             path: hf directory where the checkpoint will be read from. Should begin with hf://. | ||||
|             path: hf directory where the checkpoint will be read from.  | ||||
|             Needs to have .safetensors file, but can be from any fsspec supported storage,  | ||||
|             including localFS and hf://. | ||||
|             token: The token to use to authenticate with huggingface hub. | ||||
|         """ | ||||
|         from huggingface_hub import HfFileSystem  # type: ignore[import-not-found] | ||||
|  | ||||
|         if HfFileSystem.protocol not in fsspec.available_protocols(): | ||||
|             fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) | ||||
|  | ||||
|         if token is not None: | ||||
|             super().__init__(path=path, token=token) | ||||
|         else: | ||||
|             super().__init__(path=path) | ||||
|  | ||||
|         self.storage_data: dict[str, str] = {} | ||||
|  | ||||
|     def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]: | ||||
|         from safetensors.torch import load  # type: ignore[import-not-found] | ||||
|  | ||||
|         from safetensors import deserialize  # type: ignore[import-not-found] | ||||
|         per_file: dict[str, list[ReadItem]] = {} | ||||
|  | ||||
|         for read_item in plan.items: | ||||
|             file_name = self.storage_data[read_item.storage_index.fqn] | ||||
|             item_md: _StorageInfo = self.storage_data[read_item.storage_index] | ||||
|             file_name = item_md.relative_path | ||||
|             per_file.setdefault(file_name, []).append(read_item) | ||||
|  | ||||
|         for file_name, reqs in per_file.items(): | ||||
|             new_path = self.fs.concat_path(self.path, file_name) | ||||
|             with self.fs.create_stream(new_path, "rb") as stream: | ||||
|                 loaded_tensors = load(stream.read()) | ||||
|                 for req in reqs: | ||||
|                     tensor = loaded_tensors[req.dest_index.fqn] | ||||
|             with self.fs.create_stream(file_name, "rb") as stream: | ||||
|                 # TODO: make this more efficient by doing offset reads instead of a | ||||
|                 # full deserialization of the file | ||||
|                 deserialized : list[str, dict[str, dict[str, any]]] = deserialize(stream.read()) | ||||
|                 deserialized_dict : dict[str, dict[str, dict[str, any]]] = {tensor_info[0]: tensor_info[1] for tensor_info in deserialized} | ||||
|  | ||||
|                 for req in reqs: | ||||
|                     tensor_bytes = deserialized_dict[req.dest_index.fqn][DATA_KEY] | ||||
|                     tensor = torch.frombuffer(tensor_bytes, dtype=planner.metadata.state_dict_metadata[req.dest_index.fqn].properties.dtype) | ||||
|                     # TODO: update this to req.lengths once I get rid of allow_tensor_resize, shouldn't need to look at the deserialized | ||||
|                     # dict for metadat as we've already done that in read_metadata file | ||||
|                     tensor = tensor.reshape(deserialized_dict[req.dest_index.fqn][SHAPE_KEY]) | ||||
|  | ||||
|                     target_tensor = planner.resolve_tensor(req) | ||||
|                     if ( | ||||
|                         isinstance(planner, _HuggingFaceLoadPlanner) | ||||
|                         and planner.allow_tensor_resize | ||||
|                     ): | ||||
|                         # this is to support the case when users are calling load on | ||||
|                         # an empty state dict without specifying the correct size of the tensors | ||||
|                         # in the state dict. Resizing is a hacky way to support this use case. | ||||
|                         # But will migrate users to _load_state_dict_from_keys method and deprecate this. | ||||
|                         target_tensor = planner.resolve_tensor(req) | ||||
|                         target_tensor.resize_(tensor.size()) | ||||
|                         target_tensor = target_tensor.detach() | ||||
|                     else: | ||||
|                         assert target_tensor.size() == tensor.size(), ( | ||||
|                             f"Tensor size mismatch for {req.dest_index.fqn}: {target_tensor.size()} != {tensor.size()}" | ||||
|                         tensor = narrow_tensor_by_index( | ||||
|                             tensor, req.storage_offsets, req.lengths | ||||
|                         ) | ||||
|                     target_tensor = target_tensor.detach() | ||||
|                         target_tensor = planner.resolve_tensor(req).detach() | ||||
|  | ||||
|                         assert target_tensor.size() == tensor.size(), ( | ||||
|                             f"req {req.storage_index} mismatch sizes {target_tensor.size()} vs {tensor.size()}" | ||||
|                         ) | ||||
|  | ||||
|                     target_tensor.copy_(tensor) | ||||
|                     planner.commit_tensor(req, target_tensor) | ||||
|  | ||||
| @ -220,37 +265,57 @@ class _HuggingFaceStorageReader(FsspecReader): | ||||
|         return fut | ||||
|  | ||||
|     def read_metadata(self) -> Metadata: | ||||
|         metadata_path = self.fs.concat_path(self.path, _metadata_fn) | ||||
|  | ||||
|         state_dict_metadata: dict[str, STORAGE_TYPES] = {} | ||||
|         storage_data: dict[str, str] = {} | ||||
|         storage_data: dict[MetadataIndex, _StorageInfo] = {} | ||||
|  | ||||
|         if not self.fs.exists(metadata_path): | ||||
|             # if metadata file doesn't exist, create it from the safetensors file | ||||
|             safetensors_files = [] | ||||
|             for file in self.fs.ls(self.path): | ||||
|                 if file.endswith(SUFFIX): | ||||
|                     safetensors_files.append(file) | ||||
|         safetensors_files = [] | ||||
|         for file in self.fs.ls(self.path): | ||||
|             if file.endswith(SUFFIX): | ||||
|                 safetensors_files.append(file) | ||||
|  | ||||
|             if len(safetensors_files) != 1: | ||||
|                 raise ValueError( | ||||
|                     f"Need exactly one safetensors file to load without metadata, found {len(safetensors_files)} files" | ||||
|                 ) | ||||
|             storage_data = {} | ||||
|             with self.fs.create_stream(safetensors_files[0], "rb") as f: | ||||
|                 keys = _get_safetensors_file_keys(f) | ||||
|         for safetensor_file in safetensors_files: | ||||
|             with self.fs.create_stream(safetensor_file, "rb") as f: | ||||
|                 safetensors_metadata = _get_safetensors_file_metadata(f) | ||||
|                 custom_metadata = safetensors_metadata.get(DEFAULT_EXTRA_METADATA_KEY) | ||||
|  | ||||
|             for key in keys: | ||||
|                 state_dict_metadata[key] = BytesStorageMetadata() | ||||
|                 storage_data[key] = os.path.basename(safetensors_files[0]) | ||||
|         else: | ||||
|             with self.fs.create_stream(metadata_path, "r") as metadata_file: | ||||
|                 metadata = json.load(metadata_file) | ||||
|                 dcp_sharding_info = None | ||||
|                 if custom_metadata and custom_metadata.get(CUSTOM_METADATA_KEY): | ||||
|                     dcp_sharding_info = json.loads(custom_metadata.get(CUSTOM_METADATA_KEY)) | ||||
|  | ||||
|             for key in metadata["weight_map"].keys(): | ||||
|                 state_dict_metadata[key] = BytesStorageMetadata() | ||||
|             storage_data = metadata["weight_map"] | ||||
|                 for key, val in safetensors_metadata.items(): | ||||
|                     if key == DEFAULT_EXTRA_METADATA_KEY: | ||||
|                         continue | ||||
|  | ||||
|                     # construct state_dict_metadata | ||||
|                     if dcp_sharding_info is not None: | ||||
|                         offset = dcp_sharding_info[key][SAVED_OFFSETS_KEY] | ||||
|                     else: | ||||
|                         offset = [0] * len(val[SHAPE_KEY]) | ||||
|  | ||||
|                     if key not in state_dict_metadata: | ||||
|                         state_dict_metadata[key] = TensorStorageMetadata( | ||||
|                                 properties=TensorProperties(dtype=_get_dtype(val[DTYPE_KEY])), | ||||
|                                 size=torch.Size([saved + offset for saved, offset in zip(val[SHAPE_KEY], offset)]), | ||||
|                                 chunks=[ChunkStorageMetadata(offsets=torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]))], | ||||
|                             ) | ||||
|                     else: | ||||
|                         state_dict_metadata[key].chunks.append(ChunkStorageMetadata(torch.Size(offset), sizes=torch.Size(val[SHAPE_KEY]))) | ||||
|                         size = list(state_dict_metadata[key].size) | ||||
|                         for i in range(len(size)): | ||||
|                             size[i] = max(size[i], val[SHAPE_KEY][i] + offset[i]) | ||||
|                         state_dict_metadata[key].size = torch.Size(size) | ||||
|  | ||||
|                     # construct storage data | ||||
|                     if dcp_sharding_info is not None: | ||||
|                         metadata_index = MetadataIndex(fqn=key, offset=dcp_sharding_info[key][SAVED_OFFSETS_KEY]) | ||||
|                     else: | ||||
|                         metadata_index = MetadataIndex(fqn=key, offset=[0] * len(val[SHAPE_KEY])) | ||||
|                     storage_data[metadata_index] = _StorageInfo( | ||||
|                         safetensor_file, | ||||
|                         val[DATA_OFFSETS_KEY][0], | ||||
|                         val[DATA_OFFSETS_KEY][1] - val[DATA_OFFSETS_KEY][0], | ||||
|                     ) | ||||
|          | ||||
|         metadata = Metadata( | ||||
|             state_dict_metadata=state_dict_metadata, | ||||
|             storage_data=storage_data, | ||||
| @ -262,8 +327,21 @@ class _HuggingFaceStorageReader(FsspecReader): | ||||
|  | ||||
|         return metadata | ||||
|  | ||||
| def _gen_file_name(index: int, largest_index: int, shard_index: Optional[int] = None) -> str: | ||||
|         if shard_index is not None: | ||||
|             return SHARDED_FILE_NAME.format( | ||||
|                 shard_idx=f"{shard_index}".zfill(5), cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5) | ||||
|             ) + SUFFIX | ||||
|         else: | ||||
|             return ( | ||||
|                 FILE_NAME.format( | ||||
|                 cpt_idx=f"{index}".zfill(5), num_files=f"{largest_index}".zfill(5) | ||||
|                 ) | ||||
|                 + SUFFIX | ||||
|             ) | ||||
|  | ||||
| def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]: | ||||
|  | ||||
| def _get_safetensors_file_metadata(file_bytes: io.IOBase) -> Any: | ||||
|     # this uses the same logic that's done in HF code base | ||||
|     # https://github.com/2404589803/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py#L5308 | ||||
|     # and follows their documentation on how their files are serialized | ||||
| @ -273,4 +351,19 @@ def _get_safetensors_file_keys(file_bytes: io.IOBase) -> list[str]: | ||||
|     header_len = struct.unpack("<Q", header_len_bytes)[0] | ||||
|     header_json = file_bytes.read(header_len) | ||||
|     metadata = json.loads(header_json) | ||||
|     return list(metadata.keys()) | ||||
|     return metadata | ||||
|  | ||||
| def _get_dtype(dtype: str) -> torch.dtype: | ||||
|     try: | ||||
|         dtype = DTYPE_MAP[dtype] | ||||
|     except KeyError: | ||||
|         dtype = torch.get_default_dtype | ||||
|  | ||||
|     return dtype | ||||
|  | ||||
| def _get_dcp_custom_metadata(metadata: Any) -> Optional[Any]: | ||||
|     if DEFAULT_EXTRA_METADATA_KEY in metadata: | ||||
|         custom_metadata = metadata[DEFAULT_EXTRA_METADATA_KEY] | ||||
|         if CUSTOM_METADATA_KEY in custom_metadata: | ||||
|             return json.loads(custom_metadata[CUSTOM_METADATA_KEY]) | ||||
|     return None | ||||
|  | ||||
| @ -442,7 +442,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): | ||||
|  | ||||
|             if isinstance(v, TensorStorageMetadata): | ||||
|                 v = torch.empty(v.size, dtype=v.properties.dtype)  # type: ignore[assignment] | ||||
|             if k in metadata.planner_data: | ||||
|             if metadata.planner_data is not None and k in metadata.planner_data: | ||||
|                 set_element(state_dict, metadata.planner_data[k], v) | ||||
|             else: | ||||
|                 state_dict[k] = v | ||||
| @ -451,7 +451,7 @@ class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): | ||||
|  | ||||
|  | ||||
| def create_default_local_load_plan( | ||||
|     state_dict: dict[str, Any], metadata: Metadata, strict: bool = True | ||||
|     state_dict: dict[str, Any], metadata: Metadata, strict: bool = True, check_md_size: bool = True | ||||
| ) -> LoadPlan: | ||||
|     requests = [] | ||||
|     """ | ||||
| @ -477,6 +477,7 @@ def create_default_local_load_plan( | ||||
|             isinstance(md, TensorStorageMetadata) | ||||
|             and getattr(obj, "size", None) is not None | ||||
|             and md.size != obj.size() | ||||
|             and check_md_size | ||||
|         ): | ||||
|             raise ValueError( | ||||
|                 f"Size mismatch between saved {md.size} and current: {obj.size()} for {fqn}", | ||||
|  | ||||
| @ -2,6 +2,7 @@ | ||||
| import collections | ||||
| import dataclasses | ||||
| import io | ||||
| import json | ||||
| import operator | ||||
| import os | ||||
| import pickle | ||||
| @ -416,6 +417,7 @@ def _write_files_from_queue( | ||||
|                     ) | ||||
|  | ||||
|                 tensor_dict = {} | ||||
|                 metadata_dict = {} | ||||
|                 for tensor, write_item in loader.values(): | ||||
|                     assert tensor.is_cpu | ||||
|                     write_results.append( | ||||
| @ -429,11 +431,12 @@ def _write_files_from_queue( | ||||
|                         ) | ||||
|                     ) | ||||
|                     tensor_dict[write_item.index.fqn] = tensor | ||||
|                     metadata_dict[write_item.index.fqn] = {"saved_offsets": write_item.tensor_data.chunk.offsets} | ||||
|  | ||||
|                 if serialization_format == SerializationFormat.SAFETENSORS: | ||||
|                     from safetensors.torch import save  # type: ignore[import-not-found] | ||||
|  | ||||
|                     stream.write(save(tensor_dict)) | ||||
|                     stream.write(save(tensor_dict, metadata={"DCP_SHARDING_INFO": json.dumps(metadata_dict), "DCP_VERSION": "1.0"})) | ||||
|  | ||||
|                 if use_fsync: | ||||
|                     try: | ||||
|  | ||||
| @ -0,0 +1,658 @@ | ||||
| # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||||
|  | ||||
| # pyre-strict | ||||
|  | ||||
| """ | ||||
| This script consolidates distributed checkpoint (DCP) HuggingFace safetensors files. | ||||
|  | ||||
| It takes sharded safetensors files created by DCP and combines them into one or more | ||||
| consolidated files. This is useful for converting distributed model checkpoints into | ||||
| a format that can be loaded by standard HuggingFace tools. The combination is done | ||||
| through a simple metadata parsing and tensor data copying process. | ||||
| """ | ||||
|  | ||||
| import argparse | ||||
| import json | ||||
| import math | ||||
| import struct | ||||
| from dataclasses import dataclass, field | ||||
| from typing import Any, List, Optional | ||||
|  | ||||
| import fsspec | ||||
| import torch | ||||
|  | ||||
| from fsspec.core import url_to_fs | ||||
| from safetensors import deserialize | ||||
| from torch.distributed.checkpoint._hf_storage import ( | ||||
|     _gen_file_name, | ||||
|     _get_dcp_custom_metadata, | ||||
|     _get_dtype, | ||||
|     _get_safetensors_file_metadata, | ||||
|     DATA_KEY, | ||||
|     DATA_OFFSETS_KEY, | ||||
|     DEFAULT_EXTRA_METADATA_KEY, | ||||
|     DTYPE_KEY, | ||||
|     SAVED_OFFSETS_KEY, | ||||
|     SHAPE_KEY, | ||||
|     SUFFIX, | ||||
| ) | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class _FqnData: | ||||
|     """ | ||||
|     Dataclass to store information about a tensor (identified by its fully qualified name). | ||||
|  | ||||
|     Attributes: | ||||
|         offset_in_file: Byte offset where this tensor's data begins in the output file | ||||
|         shape_in_file: Shape of the tensor in the output file | ||||
|         dtype_size: Size of the tensor's data type in bytes | ||||
|         dtype_str: String representation of the tensor's data type | ||||
|     """ | ||||
|  | ||||
|     offset_in_file: int = 0 | ||||
|     shape_in_file: list[int] = field(default_factory=list) | ||||
|     dtype_size: int = 0 | ||||
|     dtype_str: str = "" | ||||
|  | ||||
|  | ||||
| @dataclass | ||||
| class _OutputFileData: | ||||
|     """ | ||||
|     Dataclass to store information about an output safetensors file. | ||||
|  | ||||
|     Attributes: | ||||
|         metadata_size: Size of the metadata section in bytes | ||||
|         fqn_data: Dictionary mapping tensor names to their metadata | ||||
|     """ | ||||
|  | ||||
|     metadata_size: int = 0 | ||||
|     fqn_data: dict[str, _FqnData] = field(default_factory=dict) | ||||
|  | ||||
|  | ||||
| def _parse_input_metadata( | ||||
|     safetensors_metadatas: List[Any], output_files_data: dict[str, _OutputFileData] | ||||
| ) -> None: | ||||
|     """ | ||||
|     Parse metadata from input safetensors files to determine the full tensor shapes and types. | ||||
|  | ||||
|     This function analyzes the metadata from all input files to determine the complete shape | ||||
|     of each tensor after consolidation. It updates the output_files_data with this information. | ||||
|  | ||||
|     Args: | ||||
|         safetensors_metadatas: List of metadata from input safetensors files | ||||
|         output_files_data: Dictionary mapping output file paths to their metadata | ||||
|  | ||||
|     Raises: | ||||
|         ValueError: If no DCP custom metadata is found in a safetensors file | ||||
|     """ | ||||
|     # Dictionary to track the full size of each tensor across all shards | ||||
|     fqn_to_size_mapping: dict[str, tuple[list[int], str]] = {} | ||||
|  | ||||
|     for safetensors_metadata in safetensors_metadatas: | ||||
|         dcp_sharding_info = _get_dcp_custom_metadata(safetensors_metadata) | ||||
|         if not dcp_sharding_info: | ||||
|             raise ValueError( | ||||
|                 "No DCP custom metadata found in safetensors file. The file must be saved with DCP to be consolidated." | ||||
|             ) | ||||
|  | ||||
|         for key, val in safetensors_metadata.items(): | ||||
|             if key == DEFAULT_EXTRA_METADATA_KEY: | ||||
|                 continue | ||||
|  | ||||
|             # Get the shape of this tensor shard and its offset in the full tensor | ||||
|             sizes = val[SHAPE_KEY] | ||||
|             offsets = dcp_sharding_info[key][SAVED_OFFSETS_KEY] | ||||
|  | ||||
|             if key not in fqn_to_size_mapping: | ||||
|                 # First time seeing this tensor - calculate its full size by adding offsets to dimensions | ||||
|                 cur_size = [size + offset for size, offset in zip(sizes, offsets)] | ||||
|                 fqn_to_size_mapping[key] = (cur_size, val[DTYPE_KEY]) | ||||
|             else: | ||||
|                 # We've seen this tensor before - update its size if this shard extends beyond current known dimensions | ||||
|                 cur_size = fqn_to_size_mapping[key][0] | ||||
|                 for i in range(len(sizes)): | ||||
|                     cur_size[i] = max(cur_size[i], sizes[i] + offsets[i]) | ||||
|  | ||||
|     # Now that we know the full size of each tensor, populate the output file data | ||||
|     for fqn, tensor_info in fqn_to_size_mapping.items(): | ||||
|         tensor_size = tensor_info[0] | ||||
|         dtype_str = tensor_info[1] | ||||
|         for _, output_data in output_files_data.items(): | ||||
|             # Add this tensor to the output file if it's already assigned there or if we're using a single output file | ||||
|             if fqn in output_data.fqn_data or len(output_files_data) == 1: | ||||
|                 output_data.fqn_data[fqn] = _FqnData( | ||||
|                     shape_in_file=tensor_size, | ||||
|                     dtype_size=torch.finfo(_get_dtype(dtype_str)).bits | ||||
|                     // 8,  # Convert bits to bytes | ||||
|                     dtype_str=dtype_str, | ||||
|                 ) | ||||
|  | ||||
|  | ||||
| def _write_metadata( | ||||
|     fs: fsspec.AbstractFileSystem, | ||||
|     output_files_data: dict[str, _OutputFileData], | ||||
| ) -> None: | ||||
|     """ | ||||
|     Write metadata to the beginning of each output safetensors file. | ||||
|  | ||||
|     This function writes the metadata section to each output file, including information | ||||
|     about tensor shapes, data types, and offsets. It also updates the offset_in_file | ||||
|     field for each tensor in the output_files_data. | ||||
|  | ||||
|     Args: | ||||
|         fs: Filesystem interface for file operations | ||||
|         output_files_data: Dictionary mapping output file paths to their metadata | ||||
|     """ | ||||
|     # Process each output file | ||||
|     for file_path, output_data in output_files_data.items(): | ||||
|         with fs.open(file_path, "wb") as f: | ||||
|             metadata = {} | ||||
|             curr_offset = 0 | ||||
|  | ||||
|             # Calculate offsets for each tensor in the file | ||||
|             for fqn, fqn_data in output_data.fqn_data.items(): | ||||
|                 # Calculate the end offset by multiplying all dimensions and the data type size | ||||
|                 end_offset = ( | ||||
|                     curr_offset | ||||
|                     + math.prod(fqn_data.shape_in_file) * fqn_data.dtype_size | ||||
|                 ) | ||||
|  | ||||
|                 # Store metadata for this tensor | ||||
|                 metadata[fqn] = { | ||||
|                     SHAPE_KEY: fqn_data.shape_in_file, | ||||
|                     DTYPE_KEY: fqn_data.dtype_str, | ||||
|                     DATA_OFFSETS_KEY: [ | ||||
|                         curr_offset, | ||||
|                         end_offset, | ||||
|                     ],  # Start and end byte offsets | ||||
|                 } | ||||
|                 # Store the offset for later use when writing the actual tensor data | ||||
|                 fqn_data.offset_in_file = curr_offset | ||||
|  | ||||
|                 # Update current offset for the next tensor | ||||
|                 curr_offset = end_offset | ||||
|  | ||||
|             # Convert metadata to JSON and encode as bytes | ||||
|             json_metadata = json.dumps(metadata) | ||||
|             json_bytes = json_metadata.encode("utf-8") | ||||
|  | ||||
|             # Write the metadata size as an 8-byte unsigned integer (little-endian) | ||||
|             size_in_bytes = len(json_bytes) | ||||
|             header_len = struct.pack("<Q", size_in_bytes) | ||||
|  | ||||
|             # Write the header length and metadata to the file | ||||
|             f.write(header_len) | ||||
|             f.write(json_bytes) | ||||
|  | ||||
|             # Store the total metadata size (header + JSON) for later use | ||||
|             output_data.metadata_size = f.tell() | ||||
|  | ||||
|  | ||||
| def _write_data( | ||||
|     input_fs: fsspec.AbstractFileSystem, | ||||
|     output_fs: fsspec.AbstractFileSystem, | ||||
|     input_safetensors_files: List[str], | ||||
|     input_metadatas: dict[str, Any], | ||||
|     output_files_data: dict[str, _OutputFileData], | ||||
| ) -> None: | ||||
|     """ | ||||
|     Write tensor data from input files to the output files. | ||||
|  | ||||
|     This function reads tensor data from each input file and writes it to the appropriate | ||||
|     position in the output files based on the tensor's offsets. | ||||
|  | ||||
|     Args: | ||||
|         fs: Filesystem interface for file operations | ||||
|         input_safetensors_files: List of input safetensors file paths | ||||
|         input_metadatas: Dictionary mapping input file paths to their metadata | ||||
|         output_files_data: Dictionary mapping output file paths to their metadata | ||||
|     """ | ||||
|     # Process each input safetensors file | ||||
|     for safetensors_file in input_safetensors_files: | ||||
|         with input_fs.open(safetensors_file, "rb") as f: | ||||
|             # Deserialize the safetensors file to get tensor data | ||||
|             deserialized = deserialize(f.read()) | ||||
|  | ||||
|             # Process each tensor in the file | ||||
|             for fqn, val in deserialized: | ||||
|                 # Get the tensor data as bytes | ||||
|                 data_to_write = val[DATA_KEY] | ||||
|  | ||||
|                 # Get the offsets of this tensor shard within the full tensor | ||||
|                 offsets_of_tensor_being_read = _get_dcp_custom_metadata( | ||||
|                     input_metadatas[safetensors_file] | ||||
|                 )[fqn][SAVED_OFFSETS_KEY] | ||||
|  | ||||
|                 # Find which output file(s) this tensor belongs to | ||||
|                 for output_file, output_data in output_files_data.items(): | ||||
|                     # Skip if this tensor doesn't belong in this output file | ||||
|                     if fqn not in output_data.fqn_data: | ||||
|                         continue | ||||
|  | ||||
|                     # Get metadata for this tensor in the output file | ||||
|                     fqn_data = output_data.fqn_data[fqn] | ||||
|  | ||||
|                     # Write this tensor shard to the appropriate position in the output file | ||||
|                     _write_sub_tensor_to_file( | ||||
|                         output_fs, | ||||
|                         data_to_write, | ||||
|                         fqn_data.dtype_size,  # Size of each element in bytes | ||||
|                         fqn_data.shape_in_file,  # Full tensor shape | ||||
|                         offsets_of_tensor_being_read,  # Where this shard belongs in the full tensor | ||||
|                         val[SHAPE_KEY],  # Shape of this shard | ||||
|                         output_file, | ||||
|                         # Calculate the exact byte position where this tensor data should start | ||||
|                         output_data.metadata_size + fqn_data.offset_in_file, | ||||
|                     ) | ||||
|  | ||||
|  | ||||
| def _write_row_wise_tensor( | ||||
|     fs: fsspec.AbstractFileSystem, | ||||
|     sub_tensor_bytes: bytearray, | ||||
|     element_size: int, | ||||
|     full_tensor_strides: list[int], | ||||
|     sub_tensor_strides: list[int], | ||||
|     sub_tensor_offsets: list[int], | ||||
|     sub_tensor_shape: list[int], | ||||
|     output_file_path: str, | ||||
|     output_start_byte: int, | ||||
| ): | ||||
|     """ | ||||
|     Writes a row-wise sharded tensor to the output file. | ||||
|  | ||||
|     This is an optimized path for tensors that are sharded along the first dimension, | ||||
|     with all other dimensions being complete. This allows writing entire rows at once. | ||||
|  | ||||
|     Args: | ||||
|         fs: Filesystem interface for file operations | ||||
|         sub_tensor_bytes: Byte array containing the sub-tensor data | ||||
|         element_size: The size of each element in bytes | ||||
|         full_tensor_strides: Strides of the full tensor | ||||
|         sub_tensor_strides: Strides of the sub-tensor | ||||
|         sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor | ||||
|         sub_tensor_shape: The shape of the sub-tensor | ||||
|         output_file_path: The path to the file where the full tensor is stored | ||||
|         output_start_byte: The starting byte of the full tensor in the file | ||||
|     """ | ||||
|     # Open the output file in read+binary mode to allow seeking and writing | ||||
|     with fs.open(output_file_path, "r+b") as out_f: | ||||
|         # Calculate the number of elements in each row | ||||
|         elements_per_row = full_tensor_strides[ | ||||
|             0 | ||||
|         ]  # This is the stride of the first dimension | ||||
|  | ||||
|         # For each row in the sub-tensor | ||||
|         for row_idx in range(sub_tensor_shape[0]): | ||||
|             # Calculate the row index in the full tensor | ||||
|             full_row_idx = sub_tensor_offsets[0] + row_idx | ||||
|  | ||||
|             # Calculate the position in the full tensor | ||||
|             full_pos = full_row_idx * full_tensor_strides[0] | ||||
|             full_byte_offset = output_start_byte + full_pos * element_size | ||||
|  | ||||
|             # Calculate the position in the sub-tensor | ||||
|             sub_pos = row_idx * sub_tensor_strides[0] | ||||
|             sub_byte_offset = sub_pos * element_size | ||||
|  | ||||
|             # Extract the row data from the sub-tensor | ||||
|             row_size = elements_per_row * element_size | ||||
|             row_data = sub_tensor_bytes[sub_byte_offset : sub_byte_offset + row_size] | ||||
|  | ||||
|             # Seek to the correct position in the output file and write the data | ||||
|             out_f.seek(full_byte_offset) | ||||
|             out_f.write(row_data) | ||||
|  | ||||
|  | ||||
| def _write_column_wise_tensor( | ||||
|     fs: fsspec.AbstractFileSystem, | ||||
|     sub_tensor_bytes: bytearray, | ||||
|     element_size: int, | ||||
|     tensor_shape: list[int], | ||||
|     sub_tensor_offsets: list[int], | ||||
|     sub_tensor_shape: list[int], | ||||
|     output_file_path: str, | ||||
|     output_start_byte: int, | ||||
| ): | ||||
|     """ | ||||
|     Writes a column-wise sharded 2D tensor to the output file. | ||||
|  | ||||
|     This is an optimized path for 2D tensors that are sharded along the second dimension, | ||||
|     with the first dimension being complete. This requires writing column by column. | ||||
|  | ||||
|     Args: | ||||
|         fs: Filesystem interface for file operations | ||||
|         sub_tensor_bytes: Byte array containing the sub-tensor data | ||||
|         element_size: The size of each element in bytes | ||||
|         tensor_shape: The shape of the overall tensor | ||||
|         sub_tensor_strides: Strides of the sub-tensor | ||||
|         sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor | ||||
|         sub_tensor_shape: The shape of the sub-tensor | ||||
|         output_file_path: The path to the file where the full tensor is stored | ||||
|         output_start_byte: The starting byte of the full tensor in the file | ||||
|     """ | ||||
|     # Open the output file in read+binary mode to allow seeking and writing | ||||
|     with fs.open(output_file_path, "r+b") as out_f: | ||||
|         # For each column in the sub-tensor | ||||
|         for col_idx in range(sub_tensor_shape[1]): | ||||
|             # Calculate the column index in the full tensor | ||||
|             full_col_idx = sub_tensor_offsets[1] + col_idx | ||||
|  | ||||
|             # For each row in the column | ||||
|             for row_idx in range(sub_tensor_shape[0]): | ||||
|                 # Calculate the position in the full tensor | ||||
|                 full_pos = row_idx * tensor_shape[1] + full_col_idx | ||||
|                 full_byte_offset = output_start_byte + full_pos * element_size | ||||
|  | ||||
|                 # Calculate the position in the sub-tensor | ||||
|                 sub_pos = row_idx * sub_tensor_shape[1] + col_idx | ||||
|                 sub_byte_offset = sub_pos * element_size | ||||
|  | ||||
|                 # Extract the element data from the sub-tensor | ||||
|                 element_data = sub_tensor_bytes[ | ||||
|                     sub_byte_offset : sub_byte_offset + element_size | ||||
|                 ] | ||||
|  | ||||
|                 # Seek to the correct position in the output file and write the data | ||||
|                 out_f.seek(full_byte_offset) | ||||
|                 out_f.write(element_data) | ||||
|  | ||||
|  | ||||
| def _write_element_by_element( | ||||
|     fs: fsspec.AbstractFileSystem, | ||||
|     sub_tensor_bytes: bytearray, | ||||
|     element_size: int, | ||||
|     tensor_shape: list[int], | ||||
|     full_tensor_strides: list[int], | ||||
|     sub_tensor_strides: list[int], | ||||
|     sub_tensor_offsets: list[int], | ||||
|     sub_tensor_shape: list[int], | ||||
|     output_file_path: str, | ||||
|     output_start_byte: int, | ||||
| ): | ||||
|     """ | ||||
|     Writes a sub-tensor to the output file using a general element-by-element approach. | ||||
|  | ||||
|     This is a general approach that works for any sharding pattern, but is less efficient | ||||
|     than the specialized approaches for row-wise or column-wise sharding. | ||||
|  | ||||
|     Args: | ||||
|         fs: Filesystem interface for file operations | ||||
|         sub_tensor_bytes: Byte array containing the sub-tensor data | ||||
|         element_size: The size of each element in bytes | ||||
|         tensor_shape: The shape of the overall tensor | ||||
|         full_tensor_strides: Strides of the full tensor | ||||
|         sub_tensor_strides: Strides of the sub-tensor | ||||
|         sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor | ||||
|         sub_tensor_shape: The shape of the sub-tensor | ||||
|         output_file_path: The path to the file where the full tensor is stored | ||||
|         output_start_byte: The starting byte of the full tensor in the file | ||||
|     """ | ||||
|     # Open the output file in read+binary mode to allow seeking and writing | ||||
|     with fs.open(output_file_path, "r+b") as out_f: | ||||
|         # Create a list to hold the current indices for each dimension | ||||
|         indices = [0] * len(tensor_shape) | ||||
|  | ||||
|         # Calculate the total number of elements in the sub-tensor | ||||
|         total_elements = 1 | ||||
|         for dim_size in sub_tensor_shape: | ||||
|             total_elements *= dim_size | ||||
|  | ||||
|         # Process each element in the sub-tensor | ||||
|         for element_idx in range(total_elements): | ||||
|             # Calculate the indices for this element in the sub-tensor | ||||
|             sub_idx = element_idx | ||||
|             for dim in range(len(sub_tensor_shape) - 1, -1, -1): | ||||
|                 indices[dim] = sub_idx % sub_tensor_shape[dim] | ||||
|                 sub_idx //= sub_tensor_shape[dim] | ||||
|  | ||||
|             # Calculate the position of this element in the sub-tensor's byte array | ||||
|             sub_pos = 0 | ||||
|             for dim in range(len(sub_tensor_shape)): | ||||
|                 sub_pos += indices[dim] * sub_tensor_strides[dim] | ||||
|             sub_byte_offset = sub_pos * element_size | ||||
|  | ||||
|             # Calculate the position of this element in the full tensor | ||||
|             full_pos = 0 | ||||
|             for dim in range(len(tensor_shape)): | ||||
|                 # The global index is the local index plus the offset for this dimension | ||||
|                 global_idx = indices[dim] + sub_tensor_offsets[dim] | ||||
|                 full_pos += global_idx * full_tensor_strides[dim] | ||||
|             full_byte_offset = output_start_byte + full_pos * element_size | ||||
|  | ||||
|             # Extract the element data from the sub-tensor | ||||
|             element_data = sub_tensor_bytes[ | ||||
|                 sub_byte_offset : sub_byte_offset + element_size | ||||
|             ] | ||||
|  | ||||
|             # Seek to the correct position in the output file and write the data | ||||
|             out_f.seek(full_byte_offset) | ||||
|             out_f.write(element_data) | ||||
|  | ||||
|  | ||||
| def _write_sub_tensor_to_file( | ||||
|     fs: fsspec.AbstractFileSystem, | ||||
|     sub_tensor_bytes: bytearray, | ||||
|     element_size: int, | ||||
|     tensor_shape: list[int], | ||||
|     sub_tensor_offsets: list[int], | ||||
|     sub_tensor_shape: list[int], | ||||
|     output_file_path: str, | ||||
|     output_start_byte: int, | ||||
| ): | ||||
|     """ | ||||
|     Writes a sub-tensor from a byte array into a file representing the full tensor at specified offsets. | ||||
|  | ||||
|     This function handles the complex task of placing a tensor shard (sub-tensor) at the correct | ||||
|     position within the consolidated tensor file. It works by calculating the exact byte offsets | ||||
|     for each slice of data and writing them to the appropriate positions. This implementation | ||||
|     supports tensors of any dimensionality with optimized paths for common sharding patterns: | ||||
|     - Row-wise sharding (optimized path) | ||||
|     - Column-wise sharding for 2D tensors (optimized path) | ||||
|     - Any other arbitrary sharding pattern (general element-by-element approach) | ||||
|  | ||||
|     Args: | ||||
|         sub_tensor_bytes: Byte array containing the sub-tensor data | ||||
|         element_size: The size of each element in bytes | ||||
|         tensor_shape: The shape of the overall tensor (list) | ||||
|         sub_tensor_offsets: The starting offsets of the sub-tensor within the full tensor (list) | ||||
|         sub_tensor_shape: The shape of the sub-tensor (list) | ||||
|         output_file_path: The path to the file where the full tensor is stored | ||||
|         output_start_byte: The starting byte of the full tensor in the file | ||||
|     """ | ||||
|     # Handle the case of empty tensors | ||||
|     if not tensor_shape or not sub_tensor_shape: | ||||
|         return | ||||
|  | ||||
|     # Calculate strides for the full tensor (row-major order, C-style) | ||||
|     # Stride is the number of elements to skip to move to the next element in that dimension | ||||
|     full_tensor_strides = [1] * len(tensor_shape) | ||||
|     for i in range(len(tensor_shape) - 2, -1, -1): | ||||
|         full_tensor_strides[i] = full_tensor_strides[i + 1] * tensor_shape[i + 1] | ||||
|  | ||||
|     # Calculate strides for the sub-tensor (row-major order, C-style) | ||||
|     sub_tensor_strides = [1] * len(sub_tensor_shape) | ||||
|     for i in range(len(sub_tensor_shape) - 2, -1, -1): | ||||
|         sub_tensor_strides[i] = sub_tensor_strides[i + 1] * sub_tensor_shape[i + 1] | ||||
|  | ||||
|     # Check if this is a row-wise sharded tensor | ||||
|     # Row-wise sharding is detected when the last dimension is complete | ||||
|     # and only the first dimension is partial | ||||
|     is_row_wise = False | ||||
|     if len(tensor_shape) >= 2: | ||||
|         # Check if all dimensions except the first are complete | ||||
|         all_other_dims_complete = True | ||||
|         for i in range(1, len(tensor_shape)): | ||||
|             if sub_tensor_shape[i] != tensor_shape[i]: | ||||
|                 all_other_dims_complete = False | ||||
|                 break | ||||
|  | ||||
|         # Row-wise sharding: first dimension is partial, all others are complete | ||||
|         is_row_wise = all_other_dims_complete and sub_tensor_shape[0] < tensor_shape[0] | ||||
|  | ||||
|     # Check if this is a column-wise sharded 2D tensor | ||||
|     # Column-wise sharding is detected when the first dimension is complete | ||||
|     # and the second dimension is partial (only for 2D tensors) | ||||
|     is_column_wise = False | ||||
|     if len(tensor_shape) == 2: | ||||
|         is_column_wise = ( | ||||
|             sub_tensor_shape[0] == tensor_shape[0] | ||||
|             and sub_tensor_shape[1] < tensor_shape[1] | ||||
|         ) | ||||
|  | ||||
|     # Call the appropriate function based on the sharding pattern | ||||
|     if is_row_wise: | ||||
|         _write_row_wise_tensor( | ||||
|             fs, | ||||
|             sub_tensor_bytes, | ||||
|             element_size, | ||||
|             full_tensor_strides, | ||||
|             sub_tensor_strides, | ||||
|             sub_tensor_offsets, | ||||
|             sub_tensor_shape, | ||||
|             output_file_path, | ||||
|             output_start_byte, | ||||
|         ) | ||||
|     elif is_column_wise: | ||||
|         _write_column_wise_tensor( | ||||
|             fs, | ||||
|             sub_tensor_bytes, | ||||
|             element_size, | ||||
|             tensor_shape, | ||||
|             sub_tensor_offsets, | ||||
|             sub_tensor_shape, | ||||
|             output_file_path, | ||||
|             output_start_byte, | ||||
|         ) | ||||
|     else: | ||||
|         _write_element_by_element( | ||||
|             fs, | ||||
|             sub_tensor_bytes, | ||||
|             element_size, | ||||
|             tensor_shape, | ||||
|             full_tensor_strides, | ||||
|             sub_tensor_strides, | ||||
|             sub_tensor_offsets, | ||||
|             sub_tensor_shape, | ||||
|             output_file_path, | ||||
|             output_start_byte, | ||||
|         ) | ||||
|  | ||||
|  | ||||
| def consolidate_safetensors_files( | ||||
|     input_dir: str, | ||||
|     output_dir: str, | ||||
|     fqn_to_index_mapping: Optional[dict[str, int]] = None, | ||||
| ) -> None: | ||||
|     """ | ||||
|     Main function to consolidate sharded safetensors files into one or more output files. | ||||
|  | ||||
|     This function orchestrates the entire consolidation process: | ||||
|     1. Sets up the output file structure based on the fqn_to_index_mapping | ||||
|     2. Finds all safetensors files in the input directory | ||||
|     3. Parses metadata from all input files | ||||
|     4. Writes metadata to the output files | ||||
|     5. Writes tensor data from input files to output files | ||||
|  | ||||
|     Args: | ||||
|         input_dir: Directory containing sharded safetensors files | ||||
|         output_dir: Directory where consolidated files will be written | ||||
|         fqn_to_index_mapping: Optional mapping of tensor names to output file indices. | ||||
|                              If None, all tensors will be consolidated into a single file. | ||||
|     """ | ||||
|     # Create filesystem using fsspec for file operations | ||||
|     input_fs, _ = url_to_fs(input_dir) | ||||
|     output_fs, _ = url_to_fs(output_dir) | ||||
|  | ||||
|     # Initialize the output file structure | ||||
|     output_files_data: dict[str, _OutputFileData] = {} | ||||
|     if fqn_to_index_mapping is None: | ||||
|         # If no mapping is provided, create a single output file | ||||
|         file_name = _gen_file_name(1, 1)  # Generate name like "model.safetensors" | ||||
|         output_path = f"{output_dir}/{file_name}" | ||||
|         output_files_data[output_path] = _OutputFileData() | ||||
|     else: | ||||
|         # Create multiple output files based on the provided mapping | ||||
|         for fqn, index in fqn_to_index_mapping.items(): | ||||
|             # Generate names like "model-00001-of-00005.safetensors" | ||||
|             file_name = _gen_file_name(index, max(fqn_to_index_mapping.values())) | ||||
|             output_path = f"{output_dir}/{file_name}" | ||||
|  | ||||
|             # Create output file data structure if it doesn't exist yet | ||||
|             if output_path not in output_files_data: | ||||
|                 output_files_data[output_path] = _OutputFileData( | ||||
|                     fqn_data={fqn: _FqnData()} | ||||
|                 ) | ||||
|             else: | ||||
|                 output_files_data[output_path].fqn_data[fqn] = _FqnData() | ||||
|  | ||||
|     # Find all safetensors files in the input directory | ||||
|     safetensors_files = [] | ||||
|     for file in input_fs.ls(input_dir, detail=False): | ||||
|         if file.endswith(SUFFIX): | ||||
|             safetensors_files.append(file) | ||||
|  | ||||
|     # Read metadata from all input files | ||||
|     input_safetensors_metadatas = {} | ||||
|     for safetensor_file in safetensors_files: | ||||
|         with input_fs.open(safetensor_file, "rb") as f: | ||||
|             input_safetensors_metadatas[safetensor_file] = ( | ||||
|                 _get_safetensors_file_metadata(f) | ||||
|             ) | ||||
|  | ||||
|     # Step 1: Parse metadata to determine tensor shapes and types | ||||
|     _parse_input_metadata(input_safetensors_metadatas.values(), output_files_data) | ||||
|  | ||||
|     # Step 2: Write metadata headers to output files | ||||
|     _write_metadata(output_fs, output_files_data) | ||||
|  | ||||
|     # Step 3: Write actual tensor data from input files to output files | ||||
|     _write_data( | ||||
|         input_fs, | ||||
|         output_fs, | ||||
|         safetensors_files, | ||||
|         input_safetensors_metadatas, | ||||
|         output_files_data, | ||||
|     ) | ||||
|  | ||||
|  | ||||
| def main() -> None: | ||||
|     """ | ||||
|     Command-line entry point for the consolidation script. | ||||
|  | ||||
|     Parses command-line arguments and calls consolidate_safetensors_files with the provided parameters. | ||||
|     """ | ||||
|     # Set up command-line argument parser | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Consolidate DCP sharded HuggingFace safetensors files" | ||||
|     ) | ||||
|  | ||||
|     # Define required and optional arguments | ||||
|     parser.add_argument( | ||||
|         "input_path", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="Path to directory containing sharded safetensors files", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "output_path", | ||||
|         type=str, | ||||
|         required=True, | ||||
|         help="Path to write consolidated safetensors files. Must be different from input path", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "fqn_to_index_mapping", | ||||
|         type=dict[str, int], | ||||
|         required=False, | ||||
|         help="Mapping of which tensor names should belong to which consolidated file. If not provided, all tensors will be consolidated into one file. Expects numbers from 1 to N, where N is the number of files.", | ||||
|     ) | ||||
|  | ||||
|     # Parse arguments and call the main function | ||||
|     args = parser.parse_args() | ||||
|     consolidate_safetensors_files( | ||||
|         args.input_path, args.output_path, args.fqn_to_index_mapping | ||||
|     ) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
							
								
								
									
										195
									
								
								torch/distributed/checkpoint/scripts/benchmarking.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										195
									
								
								torch/distributed/checkpoint/scripts/benchmarking.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,195 @@ | ||||
| # (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary. | ||||
|  | ||||
| # pyre-strict | ||||
|  | ||||
| """ | ||||
| This script demonstrates how to: | ||||
| 1. Load tensors | ||||
| 2. Split them into d-tensors (distributed tensors) | ||||
| 3. Save them using DCP (Distributed Checkpoint) with HuggingFace integration | ||||
| 4. Run the consolidation function to consolidate the sharded files back to full tensors | ||||
| """ | ||||
|  | ||||
| import argparse | ||||
| import os | ||||
| import time | ||||
|  | ||||
| import torch | ||||
| import torch.distributed as dist | ||||
| import torch.distributed.checkpoint as dist_cp | ||||
| from torch.distributed.checkpoint import consolidate_safetensors_files | ||||
| from torch.distributed.device_mesh import init_device_mesh | ||||
| from torch.distributed.tensor import distribute_tensor, Replicate, Shard | ||||
|  | ||||
|  | ||||
| def init_distributed(): | ||||
|     """Initialize the distributed environment.""" | ||||
|     if not dist.is_available() or not dist.is_initialized(): | ||||
|         if torch.cuda.is_available(): | ||||
|             dist.init_process_group(backend="nccl") | ||||
|         else: | ||||
|             dist.init_process_group(backend="gloo") | ||||
|  | ||||
|         # Set device to current rank | ||||
|         if torch.cuda.is_available(): | ||||
|             torch.cuda.set_device(dist.get_rank()) | ||||
|  | ||||
|     return dist.get_world_size(), dist.get_rank() | ||||
|  | ||||
|  | ||||
| def load_and_split_tensors(model_path: str, device_mesh): | ||||
|     """ | ||||
|     Load tensors from a model and split them into d-tensors. | ||||
|  | ||||
|     Args: | ||||
|         model_path: Path to the model to load | ||||
|         device_mesh: Device mesh to distribute tensors on | ||||
|  | ||||
|     Returns: | ||||
|         state_dict with distributed tensors | ||||
|     """ | ||||
|     # Load the model | ||||
|     if model_path.endswith(".safetensors"): | ||||
|         from safetensors.torch import load_file | ||||
|  | ||||
|         state_dict = load_file(model_path) | ||||
|  | ||||
|     # Create distributed tensors | ||||
|     distributed_state_dict = {} | ||||
|     for key, tensor in state_dict.items(): | ||||
|         # Choose sharding strategy based on tensor size | ||||
|         if tensor.dim() >= 2 and tensor.size(0) > 10: | ||||
|             # Shard along the first dimension for large tensors | ||||
|             placements = [Shard(0)] | ||||
|         else: | ||||
|             # Replicate small tensors | ||||
|             placements = [Replicate()] | ||||
|  | ||||
|         # Distribute the tensor | ||||
|         dtensor = distribute_tensor(tensor, device_mesh, placements=placements) | ||||
|         distributed_state_dict[key] = dtensor | ||||
|  | ||||
|     return distributed_state_dict, state_dict | ||||
|  | ||||
|  | ||||
| def save_with_dcp_huggingface(state_dict, output_dir: str): | ||||
|     """ | ||||
|     Save the distributed state dict using DCP with HuggingFace integration. | ||||
|  | ||||
|     Args: | ||||
|         state_dict: State dict with distributed tensors | ||||
|         output_dir: Directory to save the checkpoint | ||||
|     """ | ||||
|     os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|     # Save the distributed checkpoint | ||||
|     dist_cp.save( | ||||
|         state_dict=state_dict, | ||||
|         storage_writer=dist_cp._HuggingFaceStorageWriter( | ||||
|             path=output_dir, | ||||
|             save_sharded=True,  # Save as sharded files | ||||
|         ), | ||||
|     ) | ||||
|  | ||||
|     print(f"Saved distributed checkpoint to {output_dir}") | ||||
|  | ||||
|  | ||||
| def save_safetensors(state_dict, output_dir: str): | ||||
|     """ | ||||
|     Save the state dict as safetensors. | ||||
|  | ||||
|     Args: | ||||
|         state_dict: State dict with tensors | ||||
|         output_dir: Directory to save the checkpoint | ||||
|     """ | ||||
|     os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|     # Save the checkpoint | ||||
|     from safetensors.torch import save_file | ||||
|  | ||||
|     save_file(state_dict, os.path.join(output_dir, "checkpoint.safetensors")) | ||||
|  | ||||
|  | ||||
| def consolidate_checkpoint(input_dir: str, output_dir: str, fqn_to_index_mapping=None): | ||||
|     """ | ||||
|     Consolidate the sharded checkpoint files into full tensors. | ||||
|  | ||||
|     Args: | ||||
|         input_dir: Directory containing sharded checkpoint files | ||||
|         output_dir: Directory to save the consolidated checkpoint | ||||
|         fqn_to_index_mapping: Optional mapping of tensor names to output file indices | ||||
|     """ | ||||
|     os.makedirs(output_dir, exist_ok=True) | ||||
|  | ||||
|     # Consolidate the checkpoint | ||||
|     consolidate_safetensors_files( | ||||
|         input_dir=input_dir, | ||||
|         output_dir=output_dir, | ||||
|         fqn_to_index_mapping=fqn_to_index_mapping, | ||||
|     ) | ||||
|  | ||||
|     print(f"Consolidated checkpoint saved to {output_dir}") | ||||
|  | ||||
|  | ||||
| def main(): | ||||
|     parser = argparse.ArgumentParser( | ||||
|         description="Split tensors into d-tensors, save with DCP HuggingFace, and consolidate" | ||||
|     ) | ||||
|     """ | ||||
|     parser.add_argument( | ||||
|         "--model-path", type=str, default="", help="Path to the model to load" | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--sharded-output-dir", | ||||
|         type=str, | ||||
|         default="./sharded_checkpoint", | ||||
|         help="Directory to save the sharded checkpoint", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--consolidated-output-dir", | ||||
|         type=str, | ||||
|         default="./consolidated_checkpoint", | ||||
|         help="Directory to save the consolidated checkpoint", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|     """ | ||||
|  | ||||
|     # Initialize distributed environment | ||||
|     world_size, rank = init_distributed() | ||||
|     print(f"Running with world_size={world_size}, rank={rank}") | ||||
|  | ||||
|     model_path = "/home/ankitageorge/.cache/huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b/model-00001-of-00030.safetensors" | ||||
|     base_dir = "/data/users/ankitageorge/testing/" | ||||
|     sharded_output_dir = os.path.join(base_dir, "sharded_checkpoint") | ||||
|     dcp_consolidated_output_dir = os.path.join(base_dir, "dcp_consolidated_checkpoint") | ||||
|     consolidated_output_dir = os.path.join(base_dir, "consolidated_checkpoint") | ||||
|  | ||||
|     # Initialize device mesh | ||||
|     device_type = "cuda" if torch.cuda.is_available() else "cpu" | ||||
|     device_mesh = init_device_mesh(device_type, (world_size,)) | ||||
|  | ||||
|     # Load and split tensors | ||||
|     distributed_state_dict, state_dict = load_and_split_tensors(model_path, device_mesh) | ||||
|     start_time = time.time() | ||||
|     if rank == 0: | ||||
|         save_safetensors(state_dict, consolidated_output_dir) | ||||
|         print("time to save as safetensors ", time.time() - start_time) | ||||
|  | ||||
|     # Save with DCP HuggingFace | ||||
|     start_time = time.time() | ||||
|     save_with_dcp_huggingface(distributed_state_dict, sharded_output_dir) | ||||
|  | ||||
|     # Make sure all processes have finished saving | ||||
|     dist.barrier() | ||||
|     print("time to save with DCP HuggingFace ", time.time() - start_time) | ||||
|  | ||||
|     # Only rank 0 needs to consolidate the checkpoint | ||||
|     start_time = time.time() | ||||
|     if rank == 0: | ||||
|         consolidate_checkpoint(sharded_output_dir, dcp_consolidated_output_dir) | ||||
|  | ||||
|         print("Time to consolidate checkpoint: ", time.time() - start_time) | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     main() | ||||
		Reference in New Issue
	
	Block a user
	