Files
pytorch/torch/distributed/checkpoint/_async_executor.py
Saurabh Mishra 6ee175195a [DCP][OSS] Rank local checkpointing in DCP without collectives (#147758)
Summary:
DCP metadata collectives become prohibitively expensive as the job scale grows. This PR introduces rank-local checkpointing which basically saves and loads the checkpoint without any collective. The trade off for now is the dedupe and re-sharding. Support for these would be introduced soon.

Differential Revision: D70112642

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147758
Approved by: https://github.com/meetv18
2025-08-13 16:20:28 +00:00

35 lines
1.2 KiB
Python

# pyre-strict
# mypy: allow-untyped-defs
import abc
import os
from concurrent.futures import Future
from typing import Optional, Union
import torch.distributed as dist
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
class _AsyncCheckpointExecutor(abc.ABC):
@abc.abstractmethod
def execute_save(
self,
staging_future_or_state_dict: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]],
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
no_dist: bool = False,
use_collectives: bool = True,
) -> Future:
"""
Execute the checkpoint save request asynchronously.
This method is intended to be used as an abstraction for
implementing async checkpointing. The actual checkpoint save
operation is executed in a separate thread or process depending
on the implementation of this interface.
"""