mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-29 03:04:55 +08:00
as title Differential Revision: [D53718042](https://our.internmc.facebook.com/intern/diff/D53718042/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/119814 Approved by: https://github.com/fegin ghstack dependencies: #119813
31 lines
1002 B
Python
31 lines
1002 B
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import torch
|
|
import torch.distributed.checkpoint as dcp
|
|
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
ModelArgs,
|
|
Transformer,
|
|
)
|
|
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
|
|
|
|
|
class TestFormatUtils(DTensorTestBase):
|
|
@with_temp_dir
|
|
def test_dcp_to_torch_save(self) -> None:
|
|
# Using a transformer model to simulate a 'complicated enough' state dict w/ nested modules
|
|
model = Transformer(ModelArgs())
|
|
dcp.save({"model": model}, checkpoint_id=self.temp_dir)
|
|
|
|
torch_path = self.temp_dir + "/model.pt"
|
|
dcp_to_torch_save(self.temp_dir, torch_path)
|
|
|
|
loaded_sd = torch.load(torch_path)
|
|
self.assertEqual(loaded_sd, {"model": model.state_dict()})
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|