mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[DCP] Adds utility for converting dcp to torch save format (#119814)
as title Differential Revision: [D53718042](https://our.internmc.facebook.com/intern/diff/D53718042/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/119814 Approved by: https://github.com/fegin ghstack dependencies: #119813
This commit is contained in:
committed by
PyTorch MergeBot
parent
e0a7b024b0
commit
1ab441a7dd
@ -95,3 +95,11 @@ an experimental feature and is subject to change.
|
||||
|
||||
.. autoclass:: torch.distributed.checkpoint.state_dict.StateDictOptions
|
||||
:members:
|
||||
|
||||
For users which are used to using and sharing models in the `torch.save` format, the following utilities are pvoided:
|
||||
|
||||
.. automodule:: torch.distributed.checkpoint.format_utils
|
||||
|
||||
.. currentmodule:: torch.distributed.checkpoint.format_utils
|
||||
|
||||
.. autofunction:: dcp_to_torch_save
|
||||
|
30
test/distributed/checkpoint/test_format_utils.py
Normal file
30
test/distributed/checkpoint/test_format_utils.py
Normal file
@ -0,0 +1,30 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dcp
|
||||
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
ModelArgs,
|
||||
Transformer,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
class TestFormatUtils(DTensorTestBase):
|
||||
@with_temp_dir
|
||||
def test_dcp_to_torch_save(self) -> None:
|
||||
# Using a transformer model to simulate a 'complicated enough' state dict w/ nested modules
|
||||
model = Transformer(ModelArgs())
|
||||
dcp.save({"model": model}, checkpoint_id=self.temp_dir)
|
||||
|
||||
torch_path = self.temp_dir + "/model.pt"
|
||||
dcp_to_torch_save(self.temp_dir, torch_path)
|
||||
|
||||
loaded_sd = torch.load(torch_path)
|
||||
self.assertEqual(loaded_sd, {"model": model.state_dict()})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
80
torch/distributed/checkpoint/format_utils.py
Normal file
80
torch/distributed/checkpoint/format_utils.py
Normal file
@ -0,0 +1,80 @@
|
||||
import os
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch.distributed.checkpoint import FileSystemReader
|
||||
from torch.distributed.checkpoint._traverse import set_element
|
||||
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
Metadata,
|
||||
STATE_DICT_TYPE,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
||||
|
||||
__all__ = ["dcp_to_torch_save"]
|
||||
|
||||
|
||||
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
||||
"""
|
||||
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
|
||||
Useful for loading in state_dict without first initializing a model, such as
|
||||
when converting a DCP checkpoint into a Torch save file.
|
||||
|
||||
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
|
||||
|
||||
.. warning::
|
||||
Because the entire state dict is initialized, It's recommended to only utilize
|
||||
this LoadPlanner on a single rank or process to avoid OOM.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def set_up_planner(
|
||||
self,
|
||||
state_dict: STATE_DICT_TYPE,
|
||||
metadata: Metadata,
|
||||
is_coordinator: bool,
|
||||
) -> None:
|
||||
assert not state_dict
|
||||
|
||||
# rebuild the state dict from the metadata
|
||||
for k, v in metadata.state_dict_metadata.items():
|
||||
if isinstance(v, TensorStorageMetadata):
|
||||
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
|
||||
if k in metadata.planner_data:
|
||||
set_element(state_dict, metadata.planner_data[k], v)
|
||||
else:
|
||||
state_dict[k] = v
|
||||
|
||||
super().set_up_planner(state_dict, metadata, is_coordinator)
|
||||
|
||||
|
||||
def dcp_to_torch_save(
|
||||
dcp_checkpoint_dir: Union[str, os.PathLike],
|
||||
torch_save_path: Union[str, os.PathLike],
|
||||
):
|
||||
"""
|
||||
Given a directory containing a DCP checkpoint, this function will convert it into a
|
||||
Torch save file.
|
||||
|
||||
Args:
|
||||
dcp_checkpoint_dir: Directory containing the DCP checkpoint.
|
||||
torch_save_path: Filename to store the converted Torch save file.
|
||||
|
||||
.. warning::
|
||||
To avoid OOM, it's recommended to only run this function on a single rank.
|
||||
"""
|
||||
|
||||
sd: STATE_DICT_TYPE = {}
|
||||
storage_reader = FileSystemReader(dcp_checkpoint_dir)
|
||||
|
||||
_load_state_dict(
|
||||
sd,
|
||||
storage_reader=storage_reader,
|
||||
planner=_EmptyStateDictLoadPlanner(),
|
||||
no_dist=True,
|
||||
)
|
||||
torch.save(sd, torch_save_path)
|
Reference in New Issue
Block a user