mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
list_stored_sd_metadata API. (#160610)
Summary: 1\ Certain checkpoint load use cases are not aware of the properties of the data/tensors they want to load. 2\ These usecases include data loader checkpoints, reading data for post processing (when the original model definition is not available). 3\ There, we have to use saved checkpoint (metadata) as our source of truth. 4\ This RFC proposal exposes the checkpoint metadata using a public API. In this proposal we expose the stored state-dict metadata (minus associated storage/chunk metadata). Chunk/storage details should not be exposed to the users and is a impl detail of the storage writer/reader. Test Plan: UT. Rollback Plan: Differential Revision: D80231457 Pull Request resolved: https://github.com/pytorch/pytorch/pull/160610 Approved by: https://github.com/saumishr
This commit is contained in:
committed by
PyTorch MergeBot
parent
f76fdcaaf8
commit
da903b6a8b
@ -64,6 +64,16 @@ The entrypoints to load and save a checkpoint are the following:
|
||||
.. autofunction:: load_state_dict
|
||||
```
|
||||
|
||||
Following APIs can be used for inspecting the metadata of a checkpoint/stored state-dict.
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.distributed.checkpoint.metadata_utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: list_stored_state_dict
|
||||
```
|
||||
|
||||
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`):
|
||||
|
||||
```{eval-rst}
|
||||
|
61
test/distributed/checkpoint/test_list_stored_state_dict.py
Normal file
61
test/distributed/checkpoint/test_list_stored_state_dict.py
Normal file
@ -0,0 +1,61 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import io
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
class TestListStateDict(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_list_stored_sd_metadata(self) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
# Save 1) DTensor 2) Blob
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor,
|
||||
mesh_2d,
|
||||
placements=[Shard(0), Replicate()],
|
||||
)
|
||||
state_dict_to_save = {
|
||||
"distributed_weight": dtensor,
|
||||
"bytes_data": io.BytesIO(b"TrainingEpoch:4"),
|
||||
}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
md = dist_cp.list_stored_state_dict(checkpoint_id=CHECKPOINT_DIR)
|
||||
|
||||
# Verify DTensor
|
||||
self.assertTrue("distributed_weight" in md)
|
||||
self.assertTrue(isinstance(md["distributed_weight"], TensorStorageMetadata))
|
||||
self.assertEqual(md["distributed_weight"].size, torch.Size([4, 4]))
|
||||
self.assertEqual(md["distributed_weight"].properties.dtype, torch.float32)
|
||||
|
||||
# Verify Blob
|
||||
self.assertTrue("bytes_data" in md)
|
||||
self.assertTrue(isinstance(md["bytes_data"], BytesStorageMetadata))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -9,6 +9,7 @@ from .metadata import (
|
||||
Metadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from .metadata_utils import list_stored_state_dict
|
||||
from .optimizer import load_sharded_optimizer_state_dict
|
||||
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
|
||||
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
|
||||
|
33
torch/distributed/checkpoint/metadata_utils.py
Normal file
33
torch/distributed/checkpoint/metadata_utils.py
Normal file
@ -0,0 +1,33 @@
|
||||
import os
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
|
||||
from ._storage_utils import _storage_setup
|
||||
from .storage import StorageReader
|
||||
|
||||
|
||||
__all__ = ["list_stored_state_dict"]
|
||||
|
||||
|
||||
def list_stored_state_dict(
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_reader: Optional[StorageReader] = None,
|
||||
) -> dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]]:
|
||||
"""
|
||||
List the stored checkpoint metadata.
|
||||
NB: The returned state-dict keys are flattened.
|
||||
"""
|
||||
storage_reader = cast(
|
||||
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
|
||||
)
|
||||
md = storage_reader.read_metadata()
|
||||
sd = md.state_dict_metadata # flattened dict.
|
||||
for v in sd.values():
|
||||
if isinstance(v, TensorStorageMetadata):
|
||||
v.chunks = [] # Limit exposing storage/impl details.
|
||||
|
||||
return sd
|
Reference in New Issue
Block a user