mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adds main in format utils. Usage: `python -m torch.distributed.checkpoint.format_utils dcp_to_torch dcp_dir torch_file.pt` or `python -m torch.distributed.checkpoint.format_utils torch_to_dcp torch_file.pt dcp_dir` Differential Revision: [D53791355](https://our.internmc.facebook.com/intern/diff/D53791355/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/120128 Approved by: https://github.com/fegin, https://github.com/wz337
308 lines
11 KiB
Python
308 lines
11 KiB
Python
import argparse
|
|
import os
|
|
from enum import Enum
|
|
from typing import cast, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed.checkpoint as dcp
|
|
from torch.distributed._shard._utils import narrow_tensor_by_index
|
|
from torch.distributed.checkpoint import FileSystemReader
|
|
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
|
|
from torch.distributed.checkpoint._traverse import set_element
|
|
from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
|
|
from torch.distributed.checkpoint.metadata import (
|
|
Metadata,
|
|
STATE_DICT_TYPE,
|
|
STORAGE_TYPES,
|
|
TensorProperties,
|
|
TensorStorageMetadata,
|
|
)
|
|
from torch.distributed.checkpoint.planner import LoadItemType, LoadPlan, LoadPlanner
|
|
from torch.distributed.checkpoint.planner_helpers import _create_chunk_list
|
|
from torch.distributed.checkpoint.state_dict_loader import _load_state_dict
|
|
from torch.distributed.checkpoint.storage import StorageReader
|
|
from torch.futures import Future
|
|
|
|
|
|
__all__ = [
|
|
"dcp_to_torch_save",
|
|
"torch_save_to_dcp",
|
|
"BroadcastingTorchSaveReader",
|
|
"DynamicMetaLoadPlanner",
|
|
]
|
|
|
|
|
|
class _EmptyStateDictLoadPlanner(DefaultLoadPlanner):
|
|
"""
|
|
Extension of DefaultLoadPlanner, which rebuilds state_dict from the saved metadata.
|
|
Useful for loading in state_dict without first initializing a model, such as
|
|
when converting a DCP checkpoint into a Torch save file.
|
|
|
|
. N.B. `state_dict` must be an empty dictionary when used with this LoadPlanner
|
|
|
|
.. warning::
|
|
Because the entire state dict is initialized, It's recommended to only utilize
|
|
this LoadPlanner on a single rank or process to avoid OOM.
|
|
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def set_up_planner(
|
|
self,
|
|
state_dict: STATE_DICT_TYPE,
|
|
metadata: Metadata,
|
|
is_coordinator: bool,
|
|
) -> None:
|
|
assert not state_dict
|
|
|
|
# rebuild the state dict from the metadata
|
|
for k, v in metadata.state_dict_metadata.items():
|
|
if isinstance(v, TensorStorageMetadata):
|
|
v = torch.empty(v.size, dtype=v.properties.dtype) # type: ignore[assignment]
|
|
if k in metadata.planner_data:
|
|
set_element(state_dict, metadata.planner_data[k], v)
|
|
else:
|
|
state_dict[k] = v
|
|
|
|
super().set_up_planner(state_dict, metadata, is_coordinator)
|
|
|
|
|
|
class BroadcastingTorchSaveReader(StorageReader):
|
|
"""
|
|
StorageReader for reading a Torch Save file. This reader will read the entire checkpoint
|
|
on the coordinator rank, and then broadcast and shard each tensor to all ranks.
|
|
|
|
. N.B. Intended to be used with DynamicMetaLoadPlanner
|
|
|
|
.. warning::
|
|
Current implementation only supports loading Tensors.
|
|
|
|
>>> # xdoctest: +SKIP("undefined vars")
|
|
>>> sd = {"mode": model}
|
|
>>> dcp.load(
|
|
>>> sd,
|
|
>>> storage_reader=BroadcastingTorchSaveReader(),
|
|
>>> planner=DynamicMetaLoadPlanner(),
|
|
>>> checkpoint_id="path_to_model.pt"
|
|
>>> )
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
checkpoint_id: Optional[Union[str, os.PathLike]] = None,
|
|
coordinator_rank: int = 0,
|
|
) -> None:
|
|
self.checkpoint_id = checkpoint_id
|
|
self.coordinator_rank = coordinator_rank
|
|
|
|
def read_metadata(self) -> Metadata:
|
|
"""Extends the default StorageReader to support building the metadata file"""
|
|
# Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
|
|
# the disk
|
|
return Metadata(state_dict_metadata={})
|
|
|
|
def read_data(self, plan: LoadPlan, planner: LoadPlanner) -> Future[None]:
|
|
"""
|
|
Reads torch save data on the coordinator rank, and broadcast afterwards
|
|
this incurrs a communication cost, but avoids having to load
|
|
the entire checkpoint on each rank, hopefully preventing OOM issues
|
|
"""
|
|
planner = cast(DefaultLoadPlanner, planner)
|
|
|
|
# data is read in on the coordinator rank, and broadcast afterwards
|
|
# this incurrs a communication cost, but it avoids having to load
|
|
# the entire checkpoint on each rank, hopefully preventing OOM issues
|
|
# TODO: read on each host, instead of only the coordinator
|
|
if self.is_coordinator:
|
|
assert self.checkpoint_id is not None
|
|
torch_state_dict = torch.load(self.checkpoint_id, map_location="cpu")
|
|
if planner.flatten_state_dict:
|
|
torch_state_dict, _ = flatten_state_dict(torch_state_dict)
|
|
else:
|
|
torch_state_dict = None
|
|
|
|
for req in plan.items:
|
|
if req.type == LoadItemType.BYTE_IO:
|
|
raise RuntimeError(
|
|
f"Non-tensor value identified at {req.storage_index.fqn}. "
|
|
f"At this time {type(self).__name__} only supports loading Tensors."
|
|
)
|
|
|
|
# Broadcast the tensor from the coordinator rank
|
|
if self.is_coordinator:
|
|
tensor = torch_state_dict[req.storage_index.fqn].cuda()
|
|
else:
|
|
tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
|
|
|
|
dist.broadcast(tensor, src=self.coordinator_rank, async_op=False)
|
|
|
|
tensor = narrow_tensor_by_index(tensor, req.storage_offsets, req.lengths)
|
|
target_tensor = planner.resolve_tensor(req).detach()
|
|
assert target_tensor.size() == tensor.size(), (
|
|
f"req {req.storage_index} mismatch sizes, "
|
|
f"{target_tensor.size()} vs {tensor.size()}"
|
|
)
|
|
target_tensor.copy_(tensor)
|
|
planner.commit_tensor(req, target_tensor)
|
|
|
|
fut: Future = Future()
|
|
fut.set_result(None)
|
|
return fut
|
|
|
|
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
|
|
"""Implementation of the StorageReader method"""
|
|
self.is_coordinator = is_coordinator
|
|
if self.is_coordinator:
|
|
assert dist.get_rank() == self.coordinator_rank
|
|
|
|
assert self.checkpoint_id is not None
|
|
|
|
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
|
|
"""Implementation of the StorageReader method"""
|
|
return plan
|
|
|
|
def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
|
|
"""Implementation of the StorageReader method"""
|
|
return global_plan
|
|
|
|
def reset(self, checkpoint_id: Union[str, os.PathLike, None] = None) -> None:
|
|
"""Implementation of the StorageReader method"""
|
|
self.checkpoint_id = checkpoint_id
|
|
|
|
@classmethod
|
|
def validate_checkpoint_id(cls, checkpoint_id: Union[str, os.PathLike]) -> bool:
|
|
"""Implementation of the StorageReader method"""
|
|
return os.path.isfile(checkpoint_id)
|
|
|
|
|
|
class DynamicMetaLoadPlanner(DefaultLoadPlanner):
|
|
"""
|
|
Extension of DefaultLoadPlanner, which creates a new Metadata object based on the passed in state dict,
|
|
avoiding the need to read metadata from disk. This is useful when reading formats which don't have a
|
|
metadata file, like Torch Save files.
|
|
|
|
. N.B. Intended to be used with BroadcastingTorchSaveReader
|
|
|
|
.. warning::
|
|
Current implementation only supports loading Tensors.
|
|
|
|
>>> # xdoctest: +SKIP("undefined vars")
|
|
>>> sd = {"mode": model}
|
|
>>> dcp.load(
|
|
>>> sd,
|
|
>>> storage_reader=BroadcastingTorchSaveReader(),
|
|
>>> planner=DynamicMetaLoadPlanner(),
|
|
>>> checkpoint_id="path_to_model.pt"
|
|
>>> )
|
|
"""
|
|
|
|
def set_up_planner(
|
|
self,
|
|
state_dict: STATE_DICT_TYPE,
|
|
metadata: Metadata,
|
|
is_coordinator: bool,
|
|
) -> None:
|
|
"""Setups of the planner, extnding default behavior by creating the Metadata object from the state dict"""
|
|
super().set_up_planner(state_dict, metadata, is_coordinator)
|
|
|
|
state_dict_metadata: Dict[str, STORAGE_TYPES] = {}
|
|
for key, tensor in self.state_dict.items():
|
|
if not torch.is_tensor(tensor):
|
|
raise RuntimeError(
|
|
f"Non-tensor value identified at {key}. "
|
|
f"At this time {type(self).__name__} only supports loading Tensors."
|
|
)
|
|
|
|
state_dict_metadata[key] = TensorStorageMetadata(
|
|
TensorProperties(dtype=tensor.dtype),
|
|
tensor.size(),
|
|
_create_chunk_list(tensor),
|
|
)
|
|
self.metadata = Metadata(state_dict_metadata=state_dict_metadata)
|
|
|
|
|
|
def dcp_to_torch_save(
|
|
dcp_checkpoint_dir: Union[str, os.PathLike],
|
|
torch_save_path: Union[str, os.PathLike],
|
|
):
|
|
"""
|
|
Given a directory containing a DCP checkpoint, this function will convert it into a
|
|
Torch save file.
|
|
|
|
Args:
|
|
dcp_checkpoint_dir: Directory containing the DCP checkpoint.
|
|
torch_save_path: Filename to store the converted Torch save file.
|
|
|
|
.. warning::
|
|
To avoid OOM, it's recommended to only run this function on a single rank.
|
|
"""
|
|
sd: STATE_DICT_TYPE = {}
|
|
storage_reader = FileSystemReader(dcp_checkpoint_dir)
|
|
|
|
_load_state_dict(
|
|
sd,
|
|
storage_reader=storage_reader,
|
|
planner=_EmptyStateDictLoadPlanner(),
|
|
no_dist=True,
|
|
)
|
|
torch.save(sd, torch_save_path)
|
|
|
|
|
|
def torch_save_to_dcp(
|
|
torch_save_path: Union[str, os.PathLike],
|
|
dcp_checkpoint_dir: Union[str, os.PathLike],
|
|
):
|
|
"""
|
|
Given the location of a torch save file, converts it into a DCP checkpoint.
|
|
|
|
Args:
|
|
torch_save_path: Filename to store the converted Torch save file.
|
|
dcp_checkpoint_dir: Directory containing the DCP checkpoint.
|
|
|
|
.. warning::
|
|
To avoid OOM, it's recommended to only run this function on a single rank.
|
|
"""
|
|
|
|
state_dict = torch.load(torch_save_path)
|
|
dcp.save(state_dict, checkpoint_id=dcp_checkpoint_dir, no_dist=True)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
class FormatMode(Enum):
|
|
TORCH_TO_DCP = "torch_to_dcp"
|
|
DCP_TO_TORCH = "dcp_to_torch"
|
|
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"mode",
|
|
type=str,
|
|
help="Conversion mode",
|
|
choices=[m.value for m in FormatMode],
|
|
default=FormatMode.TORCH_TO_DCP,
|
|
)
|
|
parser.add_argument("src", type=str, help="Path to the source model")
|
|
parser.add_argument("dst", type=str, help="Path to the destination model")
|
|
args = parser.parse_args()
|
|
|
|
print(
|
|
f"Converting checkpoint from {args.src} to {args.dst} using method: '{args.mode}'"
|
|
)
|
|
checkpoint_missing_warning = (
|
|
f"No checkpoint found at {args.src}. Skipping conversion."
|
|
)
|
|
if args.mode == FormatMode.TORCH_TO_DCP:
|
|
if os.path.isfile(args.src):
|
|
torch_save_to_dcp(args.src, args.dst)
|
|
else:
|
|
print(checkpoint_missing_warning)
|
|
elif args.mode == FormatMode.DCP_TO_TORCH:
|
|
if os.path.isdir(args.src):
|
|
dcp_to_torch_save(args.src, args.dst)
|
|
else:
|
|
print(checkpoint_missing_warning)
|