list_stored_sd_metadata API. (#160610)

Summary:
1\ Certain checkpoint load use cases are not aware of the properties of the data/tensors they want to load.
2\ These usecases include data loader checkpoints, reading data for post processing (when the original model definition is not available).
3\ There, we have to use saved checkpoint  (metadata) as our source of truth.
4\ This RFC proposal exposes the checkpoint metadata using a public API.

In this proposal we expose the stored state-dict metadata  (minus associated storage/chunk metadata).

Chunk/storage details should not be exposed to the users and is a impl detail of the storage writer/reader.

Test Plan:
UT.

Rollback Plan:

Differential Revision: D80231457

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160610
Approved by: https://github.com/saumishr
This commit is contained in:
Pradeep Fernando
2025-10-08 04:33:51 +00:00
committed by PyTorch MergeBot
parent f76fdcaaf8
commit da903b6a8b
4 changed files with 105 additions and 0 deletions

View File

@ -64,6 +64,16 @@ The entrypoints to load and save a checkpoint are the following:
.. autofunction:: load_state_dict
```
Following APIs can be used for inspecting the metadata of a checkpoint/stored state-dict.
```{eval-rst}
.. automodule:: torch.distributed.checkpoint.metadata_utils
```
```{eval-rst}
.. autofunction:: list_stored_state_dict
```
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`):
```{eval-rst}

View File

@ -0,0 +1,61 @@
# Owner(s): ["oncall: distributed checkpointing"]
import io
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
TensorStorageMetadata,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestListStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_list_stored_sd_metadata(self) -> None:
CHECKPOINT_DIR = self.temp_dir
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
# Save 1) DTensor 2) Blob
dtensor = distribute_tensor(
global_tensor,
mesh_2d,
placements=[Shard(0), Replicate()],
)
state_dict_to_save = {
"distributed_weight": dtensor,
"bytes_data": io.BytesIO(b"TrainingEpoch:4"),
}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
md = dist_cp.list_stored_state_dict(checkpoint_id=CHECKPOINT_DIR)
# Verify DTensor
self.assertTrue("distributed_weight" in md)
self.assertTrue(isinstance(md["distributed_weight"], TensorStorageMetadata))
self.assertEqual(md["distributed_weight"].size, torch.Size([4, 4]))
self.assertEqual(md["distributed_weight"].properties.dtype, torch.float32)
# Verify Blob
self.assertTrue("bytes_data" in md)
self.assertTrue(isinstance(md["bytes_data"], BytesStorageMetadata))
if __name__ == "__main__":
run_tests()

View File

@ -9,6 +9,7 @@ from .metadata import (
Metadata,
TensorStorageMetadata,
)
from .metadata_utils import list_stored_state_dict
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader

View File

@ -0,0 +1,33 @@
import os
from typing import cast, Optional, Union
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
TensorStorageMetadata,
)
from ._storage_utils import _storage_setup
from .storage import StorageReader
__all__ = ["list_stored_state_dict"]
def list_stored_state_dict(
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
) -> dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]]:
"""
List the stored checkpoint metadata.
NB: The returned state-dict keys are flattened.
"""
storage_reader = cast(
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
)
md = storage_reader.read_metadata()
sd = md.state_dict_metadata # flattened dict.
for v in sd.values():
if isinstance(v, TensorStorageMetadata):
v.chunks = [] # Limit exposing storage/impl details.
return sd