mirror of
				https://github.com/pytorch/pytorch.git
				synced 2025-10-31 04:04:57 +08:00 
			
		
		
		
	[DCP] Adds utility for converting dcp to torch save format (#119814)
as title Differential Revision: [D53718042](https://our.internmc.facebook.com/intern/diff/D53718042/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/119814 Approved by: https://github.com/fegin ghstack dependencies: #119813
This commit is contained in:
		
				
					committed by
					
						 PyTorch MergeBot
						PyTorch MergeBot
					
				
			
			
				
	
			
			
			
						parent
						
							e0a7b024b0
						
					
				
				
					commit
					1ab441a7dd
				
			
							
								
								
									
										80
									
								
								torch/distributed/checkpoint/format_utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										80
									
								
								torch/distributed/checkpoint/format_utils.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,80 @@ | ||||
| import os | ||||
| from typing import Union | ||||
|  | ||||
| import torch | ||||
| from torch.distributed.checkpoint import FileSystemReader | ||||
| from torch.distributed.checkpoint._traverse import set_element | ||||
| from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner | ||||
| from torch.distributed.checkpoint.metadata import ( | ||||
|     Metadata, | ||||
|     STATE_DICT_TYPE, | ||||
|     TensorStorageMetadata, | ||||
| ) | ||||
| from torch.distributed.checkpoint.state_dict_loader import _load_state_dict | ||||
|  | ||||
| __all__ = ["dcp_to_torch_save"] | ||||
|  | ||||
|  | ||||
| class _EmptyStateDictLoadPlanner(DefaultLoadPlanner): | ||||
|     """ | ||||
|     Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata. | ||||
|     Useful for loading in state_dict without first initializing a model, such as | ||||
|     when converting a DCP checkpoint into a Torch save file. | ||||
|  | ||||
|     . N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner | ||||
|  | ||||
|     .. warning:: | ||||
|         Because the entire state dict is initialized, It's recommended to only utilize | ||||
|         this LoadPlanner on a single rank or process to avoid OOM. | ||||
|  | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|  | ||||
|     def set_up_planner( | ||||
|         self, | ||||
|         state_dict: STATE_DICT_TYPE, | ||||
|         metadata: Metadata, | ||||
|         is_coordinator: bool, | ||||
|     ) -> None: | ||||
|         assert not state_dict | ||||
|  | ||||
|         # rebuild the state dict from the metadata | ||||
|         for k, v in metadata.state_dict_metadata.items(): | ||||
|             if isinstance(v, TensorStorageMetadata): | ||||
|                 v = torch.empty(v.size, dtype=v.properties.dtype)  # type: ignore[assignment] | ||||
|             if k in metadata.planner_data: | ||||
|                 set_element(state_dict, metadata.planner_data[k], v) | ||||
|             else: | ||||
|                 state_dict[k] = v | ||||
|  | ||||
|         super().set_up_planner(state_dict, metadata, is_coordinator) | ||||
|  | ||||
|  | ||||
| def dcp_to_torch_save( | ||||
|     dcp_checkpoint_dir: Union[str, os.PathLike], | ||||
|     torch_save_path: Union[str, os.PathLike], | ||||
| ): | ||||
|     """ | ||||
|     Given a directory containing a DCP checkpoint, this function will convert it into a | ||||
|     Torch save file. | ||||
|  | ||||
|     Args: | ||||
|         dcp_checkpoint_dir: Directory containing the DCP checkpoint. | ||||
|         torch_save_path: Filename to store the converted Torch save file. | ||||
|  | ||||
|     .. warning:: | ||||
|         To avoid OOM, it's recommended to only run this function on a single rank. | ||||
|     """ | ||||
|  | ||||
|     sd: STATE_DICT_TYPE = {} | ||||
|     storage_reader = FileSystemReader(dcp_checkpoint_dir) | ||||
|  | ||||
|     _load_state_dict( | ||||
|         sd, | ||||
|         storage_reader=storage_reader, | ||||
|         planner=_EmptyStateDictLoadPlanner(), | ||||
|         no_dist=True, | ||||
|     ) | ||||
|     torch.save(sd, torch_save_path) | ||||
		Reference in New Issue
	
	Block a user