mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon. Differential Revision: D70112642 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147758 Approved by: https://github.com/meetv18
35 lines
1.2 KiB
Python
35 lines
1.2 KiB
Python
# pyre-strict
|
|
# mypy: allow-untyped-defs
|
|
import abc
|
|
import os
|
|
from concurrent.futures import Future
|
|
from typing import Optional, Union
|
|
|
|
import torch.distributed as dist
|
|
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
|
|
from torch.distributed.checkpoint.planner import SavePlanner
|
|
from torch.distributed.checkpoint.storage import StorageWriter
|
|
|
|
|
|
class _AsyncCheckpointExecutor(abc.ABC):
|
|
@abc.abstractmethod
|
|
def execute_save(
|
|
self,
|
|
staging_future_or_state_dict: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]],
|
|
*,
|
|
checkpoint_id: Union[str, os.PathLike, None] = None,
|
|
storage_writer: Optional[StorageWriter] = None,
|
|
planner: Optional[SavePlanner] = None,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
no_dist: bool = False,
|
|
use_collectives: bool = True,
|
|
) -> Future:
|
|
"""
|
|
Execute the checkpoint save request asynchronously.
|
|
|
|
This method is intended to be used as an abstraction for
|
|
implementing async checkpointing. The actual checkpoint save
|
|
operation is executed in a separate thread or process depending
|
|
on the implementation of this interface.
|
|
"""
|