mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Revert "list_stored_sd_metadata API. (#160610)"
This reverts commit da903b6a8be422529d47649e89c0d50bb95c37ca. Reverted https://github.com/pytorch/pytorch/pull/160610 on behalf of https://github.com/jeffdaily due to broke ROCm CI, but flaky also on CUDA CI https://hud.pytorch.org/failure?name=periodic%20%2F%20linux-jammy-rocm-py3.10%20%2F%20test%20(distributed%2C%202%2C%203%2C%20linux.rocm.gpu.mi250.4%2C%20module%3Arocm%2C%20oncall%3Adistributed)&jobName=undefined&failureCaptures=distributed%2Fcheckpoint%2Ftest_list_stored_state_dict.py%3A%3ATestListStateDict%3A%3Atest_list_stored_sd_metadata ([comment](https://github.com/pytorch/pytorch/pull/160610#issuecomment-3382023022))
This commit is contained in:
@ -64,16 +64,6 @@ The entrypoints to load and save a checkpoint are the following:
|
||||
.. autofunction:: load_state_dict
|
||||
```
|
||||
|
||||
Following APIs can be used for inspecting the metadata of a checkpoint/stored state-dict.
|
||||
|
||||
```{eval-rst}
|
||||
.. automodule:: torch.distributed.checkpoint.metadata_utils
|
||||
```
|
||||
|
||||
```{eval-rst}
|
||||
.. autofunction:: list_stored_state_dict
|
||||
```
|
||||
|
||||
The following module is also useful for additional customization of the staging mechanisms used for asynchronous checkpointing (`torch.distributed.checkpoint.async_save`):
|
||||
|
||||
```{eval-rst}
|
||||
|
@ -1,61 +0,0 @@
|
||||
# Owner(s): ["oncall: distributed checkpointing"]
|
||||
|
||||
import io
|
||||
|
||||
import torch
|
||||
import torch.distributed.checkpoint as dist_cp
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
skip_if_lt_x_gpu,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
|
||||
|
||||
|
||||
class TestListStateDict(DTensorTestBase):
|
||||
@with_comms
|
||||
@skip_if_lt_x_gpu(4)
|
||||
@with_temp_dir
|
||||
def test_list_stored_sd_metadata(self) -> None:
|
||||
CHECKPOINT_DIR = self.temp_dir
|
||||
global_tensor = torch.arange(16, dtype=torch.float).view(4, 4)
|
||||
mesh_shape = (2, self.world_size // 2)
|
||||
mesh_2d = init_device_mesh(self.device_type, mesh_shape)
|
||||
# Save 1) DTensor 2) Blob
|
||||
dtensor = distribute_tensor(
|
||||
global_tensor,
|
||||
mesh_2d,
|
||||
placements=[Shard(0), Replicate()],
|
||||
)
|
||||
state_dict_to_save = {
|
||||
"distributed_weight": dtensor,
|
||||
"bytes_data": io.BytesIO(b"TrainingEpoch:4"),
|
||||
}
|
||||
|
||||
dist_cp.save(
|
||||
state_dict=state_dict_to_save,
|
||||
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
|
||||
planner=dist_cp.DefaultSavePlanner(),
|
||||
)
|
||||
md = dist_cp.list_stored_state_dict(checkpoint_id=CHECKPOINT_DIR)
|
||||
|
||||
# Verify DTensor
|
||||
self.assertTrue("distributed_weight" in md)
|
||||
self.assertTrue(isinstance(md["distributed_weight"], TensorStorageMetadata))
|
||||
self.assertEqual(md["distributed_weight"].size, torch.Size([4, 4]))
|
||||
self.assertEqual(md["distributed_weight"].properties.dtype, torch.float32)
|
||||
|
||||
# Verify Blob
|
||||
self.assertTrue("bytes_data" in md)
|
||||
self.assertTrue(isinstance(md["bytes_data"], BytesStorageMetadata))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
@ -9,7 +9,6 @@ from .metadata import (
|
||||
Metadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
from .metadata_utils import list_stored_state_dict
|
||||
from .optimizer import load_sharded_optimizer_state_dict
|
||||
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
|
||||
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
|
||||
|
@ -1,33 +0,0 @@
|
||||
import os
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
from torch.distributed.checkpoint.metadata import (
|
||||
BytesStorageMetadata,
|
||||
TensorStorageMetadata,
|
||||
)
|
||||
|
||||
from ._storage_utils import _storage_setup
|
||||
from .storage import StorageReader
|
||||
|
||||
|
||||
__all__ = ["list_stored_state_dict"]
|
||||
|
||||
|
||||
def list_stored_state_dict(
|
||||
checkpoint_id: Union[str, os.PathLike, None] = None,
|
||||
storage_reader: Optional[StorageReader] = None,
|
||||
) -> dict[str, Union[TensorStorageMetadata, BytesStorageMetadata]]:
|
||||
"""
|
||||
List the stored checkpoint metadata.
|
||||
NB: The returned state-dict keys are flattened.
|
||||
"""
|
||||
storage_reader = cast(
|
||||
StorageReader, _storage_setup(storage_reader, checkpoint_id, reader=True)
|
||||
)
|
||||
md = storage_reader.read_metadata()
|
||||
sd = md.state_dict_metadata # flattened dict.
|
||||
for v in sd.values():
|
||||
if isinstance(v, TensorStorageMetadata):
|
||||
v.chunks = [] # Limit exposing storage/impl details.
|
||||
|
||||
return sd
|
Reference in New Issue
Block a user