mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes partially #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165256 Approved by: https://github.com/albanD
488 lines
18 KiB
Python
488 lines
18 KiB
Python
# mypy: allow-untyped-decorators
|
|
# mypy: allow-untyped-defs
|
|
import inspect
|
|
import os
|
|
import warnings
|
|
from concurrent.futures import Future
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import cast, Optional, Union
|
|
from typing_extensions import deprecated
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch.distributed._state_dict_utils import STATE_DICT_TYPE
|
|
from torch.distributed.checkpoint._async_executor import ( # noqa: TC001
|
|
_AsyncCheckpointExecutor,
|
|
)
|
|
from torch.distributed.checkpoint._async_process_executor import (
|
|
_ProcessBasedAsyncCheckpointExecutor,
|
|
)
|
|
from torch.distributed.checkpoint._async_thread_executor import (
|
|
_ThreadBasedAsyncCheckpointExecutor,
|
|
)
|
|
from torch.distributed.checkpoint._storage_utils import _storage_setup
|
|
from torch.distributed.checkpoint.default_planner import DefaultSavePlanner
|
|
from torch.distributed.checkpoint.logger import _dcp_method_logger
|
|
from torch.distributed.checkpoint.metadata import Metadata
|
|
from torch.distributed.checkpoint.planner import SavePlan, SavePlanner
|
|
from torch.distributed.checkpoint.staging import (
|
|
AsyncStager,
|
|
DefaultStager,
|
|
StagingOptions,
|
|
)
|
|
from torch.distributed.checkpoint.stateful import Stateful
|
|
from torch.distributed.checkpoint.storage import StorageWriter, WriteResult
|
|
from torch.distributed.distributed_c10d import _get_default_group
|
|
|
|
from .utils import _api_bc_check, _DistWrapper, _profile
|
|
|
|
|
|
__all__ = [
|
|
"save_state_dict",
|
|
"save",
|
|
"async_save",
|
|
"AsyncCheckpointerType",
|
|
"AsyncSaveResponse",
|
|
]
|
|
|
|
|
|
class AsyncCheckpointerType(Enum):
|
|
"""Enum for async checkpointer type."""
|
|
|
|
THREAD = "thread"
|
|
PROCESS = "process"
|
|
|
|
|
|
@deprecated(
|
|
"`save_state_dict` is deprecated and will be removed in future versions."
|
|
"Please use `save` instead.",
|
|
category=FutureWarning,
|
|
)
|
|
def save_state_dict(
|
|
state_dict: STATE_DICT_TYPE,
|
|
storage_writer: StorageWriter,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
coordinator_rank: int = 0,
|
|
no_dist: bool = False,
|
|
planner: Optional[SavePlanner] = None,
|
|
) -> Metadata:
|
|
"""This method is deprecated. Please switch to 'save'."""
|
|
storage_writer.reset()
|
|
|
|
# TODO: test returning `save` here instead.
|
|
with _profile():
|
|
return _save_state_dict(
|
|
state_dict,
|
|
storage_writer,
|
|
process_group,
|
|
coordinator_rank,
|
|
no_dist,
|
|
planner,
|
|
)
|
|
|
|
|
|
@_dcp_method_logger(log_exceptions=True) # type: ignore[arg-type]
|
|
@_api_bc_check
|
|
def save(
|
|
state_dict: STATE_DICT_TYPE,
|
|
*,
|
|
checkpoint_id: Union[str, os.PathLike, None] = None,
|
|
storage_writer: Optional[StorageWriter] = None,
|
|
planner: Optional[SavePlanner] = None,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
no_dist: bool = False,
|
|
use_collectives: bool = True,
|
|
) -> Metadata:
|
|
"""
|
|
Save a distributed model in SPMD style.
|
|
|
|
This function is different from ``torch.save()`` as it handles
|
|
``ShardedTensor`` , and ``DTensor`` by having each rank only save their local shards.
|
|
|
|
For each ``Stateful`` object (having both a ``state_dict`` and a ``load_state_dict``),
|
|
save will call ``state_dict`` before serialization.
|
|
|
|
.. warning::
|
|
There is no guarantees of Backwards Compatibility across PyTorch versions
|
|
for saved state_dicts.
|
|
|
|
.. warning::
|
|
If using the `process_group` argument, make sure that only its ranks
|
|
call `save_state_dict` and that all data in state_dict belong to it.
|
|
|
|
.. note::
|
|
When saving checkpoint for FSDP's `ShardingStrategy.HYBRID_SHARD`, only one of
|
|
the shard_group should be calling `save_state_dict` and the corresponding process
|
|
group needs to be passed in.
|
|
|
|
.. note::
|
|
If no process group is available, this function assumes the intention is to save the
|
|
state_dict in the local process.
|
|
|
|
.. note:
|
|
Rank 0 is assumed to be the coordinator rank.
|
|
|
|
|
|
Args:
|
|
state_dict (Dict[str, Any]): The state_dict to save.
|
|
checkpoint_id (Union[str, os.PathLike, None]):
|
|
The ID of this checkpoint instance. The meaning of the checkpoint_id
|
|
depends on the storage. It can be a path to a folder or to a file.
|
|
It can also be a key if the storage is a key-value store.
|
|
(Default: ``None``)
|
|
storage_writer (Optional[StorageWriter]):
|
|
Instance of StorageWriter used to perform writes. If this is not
|
|
specified, DCP will automatically infer the writer based on the
|
|
checkpoint_id. If checkpoint_id is also None, an exception will
|
|
be raised. (Default: ``None``)
|
|
planner (Optional[SavePlanner]):
|
|
Instance of SavePlanner. If this is not specified, the default
|
|
planner will be used. (Default: ``None``)
|
|
process_group (Optional[ProcessGroup]):
|
|
ProcessGroup to be used for cross-rank synchronization.
|
|
(Default: ``None``)
|
|
no_dist (bool):
|
|
If ``True``, this function will assume the intent is to load
|
|
a checkpoint on a single rank/process.
|
|
(Default: ``False``)
|
|
use_collectives (bool): If ``False``, this function will assume the intent is to save
|
|
a checkpoint without using cross-rank synchronization.
|
|
(Default: ``True``)
|
|
This configuration is experimental and should be used with caution.
|
|
It will change the format of the saved checkpoint and may not be backward compatible.
|
|
|
|
Returns:
|
|
Metadata: Metadata object for the saved checkpoint.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> my_model = MyModule()
|
|
|
|
>>> state_dict = {"model": my_model}
|
|
|
|
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
|
|
... "/checkpoint/1"
|
|
... )
|
|
>>> torch.distributed.checkpoint.save(
|
|
>>> state_dict=state_dict,
|
|
>>> storage_writer=fs_storage_writer,
|
|
>>> )
|
|
|
|
.. note::
|
|
save_state_dict uses collectives to coordinate writes across ranks.
|
|
For NCCL-based process groups, internal tensor representations of
|
|
objects must be moved to the GPU device before communication takes place.
|
|
In this case, the device used is given by ``torch.cuda.current_device()``
|
|
and it is the user's responsibility to ensure that this is set so that
|
|
each rank has an individual GPU, via ``torch.cuda.set_device()``.
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.checkpoint.save")
|
|
|
|
no_dist = no_dist or (not dist.is_available()) or (not dist.is_initialized())
|
|
if no_dist:
|
|
warnings.warn(
|
|
"torch.distributed is disabled, unavailable or uninitialized, assuming the intent is to save in a single process."
|
|
)
|
|
|
|
with _profile():
|
|
storage_writer = cast(
|
|
StorageWriter, _storage_setup(storage_writer, checkpoint_id, reader=False)
|
|
)
|
|
|
|
return _save_state_dict(
|
|
state_dict=_stateful_to_state_dict(state_dict),
|
|
storage_writer=storage_writer,
|
|
process_group=process_group,
|
|
no_dist=no_dist,
|
|
planner=planner,
|
|
use_collectives=use_collectives,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class AsyncSaveResponse:
|
|
"""This class contains futures for staging and upload completion.
|
|
It is returned by async_save().
|
|
staging_completion is a future that indicates when local copy
|
|
of state_dict is complete.
|
|
upload_completion is a future that indicates when a checkpoint
|
|
completed saving.
|
|
"""
|
|
|
|
staging_completion: Future[None]
|
|
upload_completion: Future[None]
|
|
|
|
|
|
@_dcp_method_logger(log_exceptions=True)
|
|
def async_save(
|
|
state_dict: STATE_DICT_TYPE,
|
|
*,
|
|
checkpoint_id: Union[str, os.PathLike, None] = None,
|
|
storage_writer: Optional[StorageWriter] = None,
|
|
planner: Optional[SavePlanner] = None,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
async_checkpointer_type: AsyncCheckpointerType = AsyncCheckpointerType.THREAD,
|
|
async_stager: Optional[AsyncStager] = None,
|
|
no_dist: bool = False,
|
|
use_collectives: bool = True,
|
|
) -> Union[Future, AsyncSaveResponse]:
|
|
"""Asynchronous version of ``save``. This code first de-stages the state_dict on to the
|
|
staging storage (defaults to CPU memory), and then calls the `save` in a separate thread.
|
|
|
|
.. warning::
|
|
This feature is experimental and subject to change.
|
|
MUST CALL CLOSE AFTER LAST CHECKPOINT IS SAVED
|
|
|
|
Args:
|
|
state_dict (Dict[str, Any]): The state_dict to save.
|
|
checkpoint_id (Union[str, os.PathLike, None]):
|
|
The ID of this checkpoint instance. The meaning of the checkpoint_id
|
|
depends on the storage. It can be a path to a folder or to a file.
|
|
It can also be a key if the storage is a key-value store.
|
|
(Default: ``None``)
|
|
storage_writer (Optional[StorageWriter]):
|
|
Instance of StorageWriter used to perform 'stage' and 'save'. If
|
|
this is not specified, DCP will automatically infer the writer based on the
|
|
checkpoint_id. If checkpoint_id is also None, an exception will
|
|
be raised. (Default: ``None``)
|
|
planner (Optional[SavePlanner]):
|
|
Instance of SavePlanner. If this is not specified, the default
|
|
planner will be used. (Default: ``None``)
|
|
process_group (Optional[ProcessGroup]):
|
|
ProcessGroup to be used for cross-rank synchronization.
|
|
(Default: ``None``)
|
|
async_checkpointer_type (AsyncCheckpointerType):
|
|
whether to do checkpoint in separate thread or process
|
|
(Default: ``AsyncCheckpointerType.THREAD``)
|
|
async_stager (AsyncStager):
|
|
provides staging implementation. If storage_writer implements AsyncStager
|
|
and async_stager is provided, async_stager will be used for staging
|
|
no_dist (bool):
|
|
If ``True``, this function will assume the intent is to save
|
|
a checkpoint on a single rank/process.
|
|
(Default: ``False``)
|
|
use_collectives: If False, Save the checkpoint without rank coordination. (Default: ``True``)
|
|
This configuration is experimental and should be used with caution.
|
|
It will change the format of the saved checkpoint and may not be backward compatible.
|
|
|
|
Returns:
|
|
Future: A future holding the resultant Metadata object from `save`.
|
|
|
|
Example:
|
|
>>> # xdoctest: +SKIP
|
|
>>> my_model = MyModule()
|
|
|
|
>>> state_dict = {"model": my_model}
|
|
|
|
>>> fs_storage_writer = torch.distributed.checkpoint.FileSystemWriter(
|
|
... "/checkpoint/1"
|
|
... )
|
|
>>> checkpoint_future = torch.distributed.checkpoint.async_save(
|
|
>>> state_dict=state_dict,
|
|
>>> storage_writer=fs_storage_writer,
|
|
>>> )
|
|
>>>
|
|
>>> # ... do some work ...
|
|
>>>
|
|
>>> checkpoint_future.result()
|
|
|
|
"""
|
|
torch._C._log_api_usage_once("torch.distributed.checkpoint.async_save")
|
|
|
|
if dist.is_available() and dist.is_initialized():
|
|
pg = process_group or _get_default_group()
|
|
if torch.device("cpu") not in pg._device_types:
|
|
raise AssertionError(
|
|
"A CPU backend must be enabled for async save; try initializing process group with 'cpu:gloo,cuda:nccl'"
|
|
)
|
|
|
|
if async_stager is None:
|
|
if storage_writer is not None and isinstance(storage_writer, AsyncStager):
|
|
# bwc with old storage_writers
|
|
async_stager = storage_writer
|
|
else:
|
|
async_stager = DefaultStager(
|
|
StagingOptions(
|
|
False,
|
|
False,
|
|
False,
|
|
False,
|
|
)
|
|
)
|
|
|
|
state_dict = _stateful_to_state_dict(state_dict)
|
|
|
|
@_dcp_method_logger(log_exceptions=True)
|
|
def stage_state_dict() -> Union[Future[STATE_DICT_TYPE], STATE_DICT_TYPE]:
|
|
return async_stager.stage(state_dict)
|
|
|
|
staging_future_or_state_dict = stage_state_dict()
|
|
|
|
upload_executor: _AsyncCheckpointExecutor = (
|
|
_ProcessBasedAsyncCheckpointExecutor()
|
|
if async_checkpointer_type == AsyncCheckpointerType.PROCESS
|
|
else _ThreadBasedAsyncCheckpointExecutor()
|
|
)
|
|
|
|
upload_future: Future = upload_executor.execute_save(
|
|
staging_future_or_state_dict,
|
|
checkpoint_id=checkpoint_id,
|
|
# pyrefly: ignore # bad-argument-type
|
|
storage_writer=storage_writer,
|
|
planner=planner,
|
|
process_group=process_group,
|
|
no_dist=no_dist,
|
|
use_collectives=use_collectives,
|
|
)
|
|
|
|
if isinstance(staging_future_or_state_dict, Future):
|
|
staging_future = staging_future_or_state_dict
|
|
return_staging_future: Future[None] = Future()
|
|
|
|
def callback(
|
|
original_staging_future: Future[STATE_DICT_TYPE],
|
|
return_staging_future: Future[None] = return_staging_future,
|
|
):
|
|
try:
|
|
original_staging_future.result()
|
|
return_staging_future.set_result(None)
|
|
except Exception as e:
|
|
return_staging_future.set_exception(e)
|
|
|
|
if not staging_future.done():
|
|
staging_future.add_done_callback(callback)
|
|
else:
|
|
return_staging_future.set_result(None)
|
|
|
|
# return new AsyncSaveResponse for users using new ZOC implementation
|
|
return AsyncSaveResponse(
|
|
staging_completion=return_staging_future, upload_completion=upload_future
|
|
)
|
|
else:
|
|
|
|
@_dcp_method_logger(log_exceptions=True)
|
|
def maybe_synchronize_staging():
|
|
if async_stager.should_synchronize_after_execute:
|
|
async_stager.synchronize_staging()
|
|
|
|
maybe_synchronize_staging()
|
|
return upload_future
|
|
|
|
|
|
@_dcp_method_logger(log_exceptions=True)
|
|
def _stateful_to_state_dict(state_dict: STATE_DICT_TYPE) -> STATE_DICT_TYPE:
|
|
"""Creates a shallow copy of `state_dict` where `state_dict` is called for each Stateful object."""
|
|
stateful_state_dict = {}
|
|
for key, elem in state_dict.items():
|
|
stateful_state_dict[key] = (
|
|
elem.state_dict() if isinstance(elem, Stateful) else elem
|
|
)
|
|
return stateful_state_dict
|
|
|
|
|
|
def _save_state_dict(
|
|
state_dict: STATE_DICT_TYPE,
|
|
storage_writer: StorageWriter,
|
|
process_group: Optional[dist.ProcessGroup] = None,
|
|
coordinator_rank: int = 0,
|
|
no_dist: bool = False,
|
|
planner: Optional[SavePlanner] = None,
|
|
use_collectives: bool = True,
|
|
) -> Metadata:
|
|
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
|
|
|
|
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
|
|
if planner is None:
|
|
planner = DefaultSavePlanner()
|
|
if planner is None:
|
|
raise AssertionError("planner is None")
|
|
|
|
global_metadata = None
|
|
|
|
ckpt_kwargs = {}
|
|
if (ckpt_id := getattr(storage_writer, "checkpoint_id", None)) is not None:
|
|
ckpt_kwargs["checkpoint_id"] = ckpt_id
|
|
ckpt_kwargs["process_group"] = distW.group
|
|
|
|
@_dcp_method_logger(**ckpt_kwargs)
|
|
def local_step():
|
|
if planner is None:
|
|
raise AssertionError("planner is None")
|
|
storage_meta = storage_writer.storage_meta()
|
|
if "storage_meta" not in inspect.signature(planner.set_up_planner).parameters:
|
|
warnings.warn(
|
|
"The function definition for SavePlanner.set_up_planner has been updated"
|
|
" to include the storage_meta argument. Please update your implementation"
|
|
" to include this parameter."
|
|
)
|
|
planner.set_up_planner(state_dict, distW.is_coordinator) # type: ignore[call-arg, arg-type]
|
|
else:
|
|
planner.set_up_planner(
|
|
state_dict=state_dict,
|
|
storage_meta=storage_meta,
|
|
is_coordinator=distW.is_coordinator,
|
|
)
|
|
|
|
if (
|
|
"kwargs"
|
|
in inspect.signature(storage_writer.set_up_storage_writer).parameters
|
|
):
|
|
storage_writer.set_up_storage_writer(
|
|
distW.is_coordinator,
|
|
rank=distW.rank,
|
|
use_collectives=use_collectives,
|
|
)
|
|
else:
|
|
storage_writer.set_up_storage_writer(distW.is_coordinator)
|
|
|
|
local_plan = planner.create_local_plan()
|
|
local_plan = storage_writer.prepare_local_plan(local_plan)
|
|
return local_plan
|
|
|
|
@_dcp_method_logger(**ckpt_kwargs)
|
|
def global_step(all_local_plans):
|
|
nonlocal global_metadata
|
|
|
|
if planner is None:
|
|
raise AssertionError("planner is None")
|
|
all_local_plans, global_metadata = planner.create_global_plan(all_local_plans)
|
|
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
|
|
return all_local_plans
|
|
|
|
central_plan: Optional[SavePlan] = None
|
|
if use_collectives:
|
|
central_plan = distW.reduce_scatter("plan", local_step, global_step)
|
|
else:
|
|
local_plan: SavePlan = local_step()
|
|
global_plan: list[SavePlan] = global_step([local_plan])
|
|
central_plan = global_plan[0]
|
|
|
|
@_dcp_method_logger(**ckpt_kwargs)
|
|
def write_data():
|
|
if planner is None:
|
|
raise AssertionError("planner is None")
|
|
if central_plan is None:
|
|
raise AssertionError("central_plan is None")
|
|
final_local_plan = planner.finish_plan(central_plan)
|
|
all_writes = storage_writer.write_data(final_local_plan, planner)
|
|
|
|
all_writes.wait()
|
|
return all_writes.value()
|
|
|
|
@_dcp_method_logger(**ckpt_kwargs)
|
|
def finish_checkpoint(all_results):
|
|
if global_metadata is None:
|
|
raise AssertionError("global_metadata is None")
|
|
storage_writer.finish(metadata=global_metadata, results=all_results)
|
|
return global_metadata
|
|
|
|
if use_collectives:
|
|
metadata = distW.all_reduce("write", write_data, finish_checkpoint)
|
|
else:
|
|
write_results: list[WriteResult] = write_data()
|
|
metadata = finish_checkpoint([write_results])
|
|
distW.barrier()
|
|
|
|
return metadata
|