[DCP] Adds utility for converting dcp to torch save format (#119814)

as title

Differential Revision: [D53718042](https://our.internmc.facebook.com/intern/diff/D53718042/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119814
Approved by: https://github.com/fegin
ghstack dependencies: #119813
This commit is contained in:
Lucas Pasqualin
2024-02-20 08:56:57 -08:00
committed by PyTorch MergeBot
parent e0a7b024b0
commit 1ab441a7dd
3 changed files with 118 additions and 0 deletions

View File

@ -95,3 +95,11 @@ an experimental feature and is subject to change.
.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions
:members:
For users which are used to using and sharing models in the `torch.save` format, the following utilities are pvoided:
.. automodule:: torch.distributed.checkpoint.format_utils
.. currentmodule:: torch.distributed.checkpoint.format_utils
.. autofunction:: dcp_to_torch_save

View File

@ -0,0 +1,30 @@
# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
ModelArgs,
Transformer,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestFormatUtils(DTensorTestBase):
@with_temp_dir
def test_dcp_to_torch_save(self) -> None:
# Using a transformer model to simulate a 'complicated enough' state dict w/ nested modules
model = Transformer(ModelArgs())
dcp.save({"model": model}, checkpoint_id=self.temp_dir)
torch_path = self.temp_dir + "/model.pt"
dcp_to_torch_save(self.temp_dir, torch_path)
loaded_sd = torch.load(torch_path)
self.assertEqual(loaded_sd, {"model": model.state_dict()})
if __name__ == "__main__":
run_tests()

View File

@ -0,0 +1,80 @@
import os
from typing import Union
import torch
from torch.distributed.checkpoint import FileSystemReader
from torch.distributed.checkpoint._traverse import set_element
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import (
Metadata,
STATE_DICT_TYPE,
TensorStorageMetadata,
)
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
__all__ = ["dcp_to_torch_save"]
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
"""
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
Useful for loading in state_dict without first initializing a model, such as
when converting a DCP checkpoint into a Torch save file.
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
.. warning::
Because the entire state dict is initialized, It's recommended to only utilize
this LoadPlanner on a single rank or process to avoid OOM.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def set_up_planner(
self,
state_dict: STATE_DICT_TYPE,
metadata: Metadata,
is_coordinator: bool,
) -> None:
assert not state_dict
# rebuild the state dict from the metadata
for k, v in metadata.state_dict_metadata.items():
if isinstance(v, TensorStorageMetadata):
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
if k in metadata.planner_data:
set_element(state_dict, metadata.planner_data[k], v)
else:
state_dict[k] = v
super().set_up_planner(state_dict, metadata, is_coordinator)
def dcp_to_torch_save(
dcp_checkpoint_dir: Union[str, os.PathLike],
torch_save_path: Union[str, os.PathLike],
):
"""
Given a directory containing a DCP checkpoint, this function will convert it into a
Torch save file.
Args:
dcp_checkpoint_dir: Directory containing the DCP checkpoint.
torch_save_path: Filename to store the converted Torch save file.
.. warning::
To avoid OOM, it's recommended to only run this function on a single rank.
"""
sd: STATE_DICT_TYPE = {}
storage_reader = FileSystemReader(dcp_checkpoint_dir)
_load_state_dict(
sd,
storage_reader=storage_reader,
planner=_EmptyStateDictLoadPlanner(),
no_dist=True,
)
torch.save(sd, torch_save_path)