Files
pytorch/test/distributed/checkpoint/test_format_utils.py
2024-02-22 16:55:58 +00:00

31 lines
1002 B
Python

# Owner(s): ["oncall: distributed"]
import torch
import torch.distributed.checkpoint as dcp
from torch.distributed.checkpoint.format_utils import dcp_to_torch_save
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
ModelArgs,
Transformer,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestFormatUtils(DTensorTestBase):
@with_temp_dir
def test_dcp_to_torch_save(self) -> None:
# Using a transformer model to simulate a 'complicated enough' state dict w/ nested modules
model = Transformer(ModelArgs())
dcp.save({"model": model}, checkpoint_id=self.temp_dir)
torch_path = self.temp_dir + "/model.pt"
dcp_to_torch_save(self.temp_dir, torch_path)
loaded_sd = torch.load(torch_path)
self.assertEqual(loaded_sd, {"model": model.state_dict()})
if __name__ == "__main__":
run_tests()