Files
pytorch/torch/distributed/checkpoint/_async_executor.py

35 lines
1.2 KiB
Python

# pyre-strict
# mypy: allow-untyped-defs
import abc
import os
from concurrent.futures import Future
from typing import Optional, Union
import torch.distributed as dist
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.checkpoint.planner import SavePlanner
from torch.distributed.checkpoint.storage import StorageWriter
class _AsyncCheckpointExecutor(abc.ABC):
@abc.abstractmethod
def execute_save(
self,
staging_future_or_state_dict: Union[STATE_DICT_TYPE, Future[STATE_DICT_TYPE]],
*,
checkpoint_id: Union[str, os.PathLike, None] = None,
storage_writer: Optional[StorageWriter] = None,
planner: Optional[SavePlanner] = None,
process_group: Optional[dist.ProcessGroup] = None,
no_dist: bool = False,
use_collectives: bool = True,
) -> Future:
"""
Execute the checkpoint save request asynchronously.
This method is intended to be used as an abstraction for
implementing async checkpointing. The actual checkpoint save
operation is executed in a separate thread or process depending
on the implementation of this interface.
"""