PyTorch MergeBot
2025-10-08 15:10:38 +00:00
parent b5e93ffdcf
commit fd4bde430a
4 changed files with 0 additions and 105 deletions

View File

@ -64,16 +64,6 @@ The entrypoints to load and save a checkpoint are the following:
.. autofunction:: load_state_dict
```
Following APIs can be used for inspecting the metadata of a checkpoint/stored state-dict.
```{eval-rst}
.. automodule:: torch.distributed.checkpoint.metadata_utils
```
```{eval-rst}
.. autofunction:: list_stored_state_dict
```
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`):
```{eval-rst}

View File

@ -1,61 +0,0 @@
# Owner(s): ["oncall: distributed checkpointing"]
import io
import torch
import torch.distributed.checkpoint as dist_cp
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
TensorStorageMetadata,
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
class TestListStateDict(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(4)
@with_temp_dir
def test_list_stored_sd_metadata(self) -> None:
CHECKPOINT_DIR = self.temp_dir
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
mesh_shape = (2, self.world_size // 2)
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
# Save 1) DTensor 2) Blob
dtensor = distribute_tensor(
global_tensor,
mesh_2d,
placements=[Shard(0), Replicate()],
)
state_dict_to_save = {
"distributed_weight": dtensor,
"bytes_data": io.BytesIO(b"TrainingEpoch:4"),
}
dist_cp.save(
state_dict=state_dict_to_save,
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
md = dist_cp.list_stored_state_dict(checkpoint_id=CHECKPOINT_DIR)
# Verify DTensor
self.assertTrue("distributed_weight" in md)
self.assertTrue(isinstance(md["distributed_weight"], TensorStorageMetadata))
self.assertEqual(md["distributed_weight"].size, torch.Size([4, 4]))
self.assertEqual(md["distributed_weight"].properties.dtype, torch.float32)
# Verify Blob
self.assertTrue("bytes_data" in md)
self.assertTrue(isinstance(md["bytes_data"], BytesStorageMetadata))
if __name__ == "__main__":
run_tests()

View File

@ -9,7 +9,6 @@ from .metadata import (
Metadata,
TensorStorageMetadata,
)
from .metadata_utils import list_stored_state_dict
from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader

View File

@ -1,33 +0,0 @@
import os
from typing import cast, Optional, Union
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
TensorStorageMetadata,
)
from ._storage_utils import _storage_setup
from .storage import StorageReader
__all__ = ["list_stored_state_dict"]
def list_stored_state_dict(
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_reader: Optional[StorageReader] = None,
) -> dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]]:
"""
List the stored checkpoint metadata.
NB: The returned state-dict keys are flattened.
"""
storage_reader = cast(
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
)
md = storage_reader.read_metadata()
sd = md.state_dict_metadata # flattened dict.
for v in sd.values():
if isinstance(v, TensorStorageMetadata):
v.chunks = [] # Limit exposing storage/impl details.
return sd